#!/usr/bin/env python
""" Plot found and missed injections.
"""
import h5py, numpy, logging, os.path, argparse, sys
import matplotlib; matplotlib.use('Agg')
import matplotlib.pyplot as plot
import pycbc.results.followup, pycbc.pnutils, pycbc.results, pycbc.version
import pycbc.pnutils
from pycbc import init_logging

parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('--injection-file',
                    help="The hdf injection file to plot", required=True)
parser.add_argument('--axis-type', default='mchirp', choices=['mchirp',
                    'effective_spin', 'time', 'mtotal', 'mass_ratio'])
parser.add_argument('--log-x', action='store_true', default=False)
parser.add_argument('--distance-type', default='decisive_optimal_snr',
                    choices=['decisive_distance', 'dec_chirp_distance',
                             'chirp_distance', 'comb_optimal_snr',
                             'decisive_optimal_snr', 'redshift'],
                    help="Variable related to injected distance. Decisive "
                         "distance and dec chirp distance only available for "
                         "2-ifo search")
parser.add_argument('--plot-all-distance', action='store_true', default=False,
                    help="Plot all values of distance or SNR. If not given, "
                         "the plot will be truncated below 1")
parser.add_argument('--colormap',default='cividis_r',
                   help="Type of colormap to be used for the plots.")
parser.add_argument('--verbose', action='count')
parser.add_argument('--log-distance', action='store_true', default=False)
parser.add_argument('--dynamic', action='store_true', default=False)
parser.add_argument('--gradient-far', action='store_true',
                    help="Show far of found injections as a gradient")
parser.add_argument('--output-file', required=True)
parser.add_argument('--far-type', choices=('inclusive', 'exclusive'),
                    default='inclusive',
                    help="Type of far to plot for the color. Choices are "
                         "'inclusive' or 'exclusive'. Default = 'inclusive'")
parser.add_argument('--missed-on-top', action='store_true',
                    help="Plot missed injections on top of found ones and "
                         "high FAR on top of low FAR")
parser.add_argument('--version', action='version',
                    version=pycbc.version.git_verbose_msg)
args = parser.parse_args()

init_logging(args.verbose)

logging.info('Read in the data')
f = h5py.File(args.injection_file, 'r')
time = f['injections/end_time'][:]
found = f['found_after_vetoes/injection_index'][:]
missed = f['missed/after_vetoes'][:]

if args.far_type == 'inclusive':
    ifar_found = f['found_after_vetoes/ifar'][:]
    far_title = 'Inclusive'
elif args.far_type == 'exclusive':
    ifar_found = f['found_after_vetoes/ifar_exc'][:]
    far_title = 'Exclusive'

s1z = f['injections/spin1z'][:]
s2z = f['injections/spin2z'][:]
dist = f['injections/distance'][:]
m1, m2 = f['injections/mass1'][:], f['injections/mass2'][:]

vals = {}
vals['mchirp'], eta = pycbc.pnutils.mass1_mass2_to_mchirp_eta(m1, m2)
# GRRR should use pnutils formula
vals['effective_spin'] = (m1 * s1z + m2 * s2z) / (m1 + m2)
vals['time'] = time
vals['mtotal'] = m1 + m2
# use convention that q>1
vals['mass_ratio'] = numpy.maximum(m1/m2, m2/m1)

dvals = {}

if 'ifos' in f.attrs:
    ifos = f.attrs['ifos'].split(' ')
else:
    logging.warning("Ifos not found in input file, assuming H-L")
    ifos = ['H1', 'L1']

# NOTE: Effective distance is hardcoded to these values. It's also normally
#       meaningless for precessing injections.
eff_dist_map = {ifo: 'eff_dist_%s' % ifo[0].lower() for ifo in ifos}

try:
    eff_dists = [f['injections/{}'.format(eff_dist_map[ifo])][:]\
                 for ifo in ifos]
    eff_dists = numpy.array(eff_dists).T
    # "Decisive" distance is the second smallest effective distance
    dvals['decisive_distance'] = numpy.sort(eff_dists)[:,1]
    dvals['dec_chirp_distance'] = \
        pycbc.pnutils.chirp_distance(dvals['decisive_distance'],
                                     vals['mchirp'])
except KeyError:
    # If the ifo isn't in the effective distance columns you can't get this.
    # But you can still use other values.
    pass

dvals['chirp_distance'] = pycbc.pnutils.chirp_distance(dist, vals['mchirp'])
if args.distance_type == 'redshift':
    dvals['redshift'] = f['injections/redshift'][:]

if 'snr' in args.distance_type:  # only evaluate SNRs if needed
    if 'optimal_snr_1' in f['injections']:  # old 2-ifo search behaviour
        opt_snr_1 = f['injections/optimal_snr_1'][:]
        opt_snr_2 = f['injections/optimal_snr_2'][:]
        dvals['comb_optimal_snr'] = \
            numpy.sqrt(opt_snr_1 ** 2. + opt_snr_2 ** 2.)
        dvals['decisive_optimal_snr'] = \
            numpy.where(opt_snr_1 < opt_snr_2, opt_snr_1, opt_snr_2)
    else: # New workflow code syntax
        opt_snrsq_arr = \
            [f['injections/optimal_snr_%s' % ifo][:]**2. for ifo in ifos]
        dvals['comb_optimal_snr'] = \
            numpy.array([numpy.sqrt(sum(opt_snrsq))
                         for opt_snrsq in zip(*opt_snrsq_arr)])
        # Decisive optimal SNR is the 2nd largest optimal SNR
        dvals['decisive_optimal_snr'] = \
            numpy.array([numpy.sqrt(sorted(opt_snrsq)[-2])
                         for opt_snrsq in zip(*opt_snrsq_arr)])

fdvals = dvals[args.distance_type][found]
mdvals = dvals[args.distance_type][missed]

labels={'mchirp': 'Chirp Mass',
        'mtotal': 'Total Mass',
        'mass_ratio': 'Mass Ratio',
        'decisive_distance': 'Injected Decisive Distance (Mpc)',
        'dec_chirp_distance': 'Injected Decisive Chirp Distance (Mpc)',
        'chirp_distance': 'Injected Chirp Distance (Mpc)',
        'comb_optimal_snr': 'Combined Optimal SNR',
        'decisive_optimal_snr': 'Decisive Optimal SNR',
        'redshift': 'Redshift',
        'time': 'Time (s)',
        'effective_spin': 'Weighted Aligned Spin',
       }

if args.missed_on_top:
  fig_title = 'Missed and Found Injections'
else:
  fig_title = 'Found and Missed Injections'

fig = plot.figure()
zmissed = args.missed_on_top
zfound = not args.missed_on_top

mpoints = plot.scatter(vals[args.axis_type][missed], mdvals, s=16,
                       linewidth=0.5, marker='x', color='red',
                       label='missed', zorder=zmissed)

fvals = vals[args.axis_type][found]

ifsort = numpy.argsort(ifar_found)
if args.missed_on_top:
    ifsort = ifsort[::-1]

fvals = fvals[ifsort]
fdval = fdvals[ifsort]
ifsorted = ifar_found[ifsort]

if not args.gradient_far:
    color = numpy.ones(len(found))
    ten = numpy.where(ifsorted > 10)[0]
    hundred = numpy.where(ifsorted > 100)[0]
    thousand = numpy.where(ifsorted > 1000)[0]
    color[hundred] = 0.5
    color[thousand] = 0

    norm = matplotlib.colors.Normalize()
    caption = (fig_title + ": Red x's are missed injections. "
              "Blue circles are found with IFAR < 100 years, gray are < "
              "1000 years, and yellow are found with IFAR >=1000 years. ")
else:
    color = 1.0 / ifsorted
    if len(color) < 2:
        color=None

    norm = matplotlib.colors.LogNorm()
    caption = (fig_title + ": Red x's are missed injections. "
               "Circles are found injections. The color indicates the value "
               "of the false alarm rate." )

points = plot.scatter(fvals, fdval, c=color, linewidth=0, s=16, norm=norm,
                      marker='o', label='found', zorder=zfound,
                      cmap=args.colormap)
if args.gradient_far:
    try:
        c = plot.colorbar()
        c.set_label('False Alarm Rate $(yr^{-1})$, %s' % far_title)

        # Set up tick labels - there will always be 5
        min_tick = numpy.ceil(min(numpy.log10(color)))
        max_tick = numpy.floor(max(numpy.log10(color)))
        tick_step = numpy.floor((max_tick - min_tick) / 5)
        ticks = numpy.arange(min_tick, max_tick, tick_step)
        c.set_ticks(numpy.power(10, ticks))
    except (TypeError, ZeroDivisionError):
        # Can't make colorbar if no quiet found injections
        if len(fvals):
            raise

if args.missed_on_top:
  caption += "Missed injections are shown on top of found injections."
else:
  caption += "Found injections are shown on top of missed injections."

ax = plot.gca()
plot.xlabel(labels[args.axis_type])
plot.ylabel(labels[args.distance_type])
plot.grid()

if args.log_x:
    # log x axis may fail for some choices, eg effective spin
    ax.set_xscale('log')
    tmpxvals = list(vals[args.axis_type][missed])
    tmpxvals += list(vals[args.axis_type][found])
    xmax = 1.4 * max(tmpxvals)
    xmin = 0.7 * min(tmpxvals)
    plot.xlim(xmin, xmax)
if args.log_distance:
    ax.set_yscale('log')

tmpyvals = list(fdvals) + list(mdvals)
ymax = 1.2 * max(tmpyvals)
ymin = 0.9 * min(tmpyvals)
if args.plot_all_distance:
    # note: ymin=0 will clash with args.log_distance
    # in that case it *should* throw an error!
    plot.ylim(ymin, ymax)
elif args.distance_type == 'redshift':
    # default y limit: min redshift 0
    plot.ylim(ymin=0, ymax=ymax)
else:
    # arbitrary limit of 1 distance-unit or snr-unit
    plot.ylim(ymin=1, ymax=ymax)

fig_kwds = {}
if '.png' in args.output_file:
    fig_kwds['dpi'] = 200

if ('.html' in args.output_file):
    plot.subplots_adjust(left=0.1, right=0.8, top=0.9, bottom=0.1)
    import mpld3, mpld3.plugins, mpld3.utils
    mpld3.plugins.connect(fig, mpld3.plugins.MousePosition(fmt='.5g'))
    legend =  mpld3.plugins.InteractiveLegendPlugin([mpoints, points],
                                                    ['missed', 'found'],
                                                    alpha_unsel=0.1)
    mpld3.plugins.connect(fig, legend)

title = '%s: %s vs %s' % (fig_title, args.axis_type, args.distance_type)
cmd = ' '.join(sys.argv)
pycbc.results.save_fig_with_metadata(fig, args.output_file, fig_kwds=fig_kwds,
                                     title=title, cmd=cmd, caption=caption)
