#!/usr/bin/python
""" Make SNR vs rate of triggers histogram for foreground coincident events.
    Also has the ability to plot inclusive backgrounds from different stages
    of hierarchical removal.
"""
import argparse, h5py, numpy, logging, sys, lal
import matplotlib
matplotlib.use('Agg')
import pylab, pycbc.results, pycbc.version
from scipy.special import erf, erfinv

def sigma_from_p(p):
    return - erfinv(1 - (1 - p) * 2) * 2**0.5

def p_from_sigma(sig):
    return 1 - (1 - erf(sig / 2**0.5)) / 2


parser = argparse.ArgumentParser()
# General required options
parser.add_argument('--version', action='version', version=pycbc.version.git_verbose_msg)
parser.add_argument('--trigger-file')
parser.add_argument('--verbose', action='count')
parser.add_argument('--output-file')
parser.add_argument('--bin-size', type=float)
parser.add_argument('--x-min', type=float)
parser.add_argument('--trials-factor', type=int, default=1)
parser.add_argument('--use-hierarchical-level', type=int, default=None,
                    help='Indicate which inclusive background to plot '
                         'with the foreground triggers if there were '
                         'any hierarchical removals done. Choosing 0 '
                         'indicates plotting the inclusive background prior '
                         'to any hierarchical removals. Choosing 1 indicates '
                         'the inclusive background after the loudest '
                         'foreground trigger was removed. Choosing another '
                         'integer gives the inclusive background after that '
                         'integer number of hierarchical removals. If not '
                         'provided, default to the background from after all '
                         'hierarchical removals performed. [default=None]')
parser.add_argument('--closed-box', action='store_true',
                    help='Make a closed box version that excludes '
                         'foreground triggers')
args = parser.parse_args()

logging.info('Read in the data')
f = h5py.File(args.trigger_file, 'r')

# Determine which inclusive background to plot.
h_inc_back_num = args.use_hierarchical_level

try:
    h_iterations = f.attrs['hierarchical_removal_iterations']
except KeyError:
    h_iterations = 0

if h_inc_back_num is None:
    h_inc_back_num = h_iterations

if h_inc_back_num > h_iterations:
    # Produce a null plot saying no hierarchical removals can be plotted
    import sys
    fig = pylab.figure()
    ax = fig.add_subplot(111)

    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)

    output_message = "No more foreground events louder than all background\n" \
                     "at this removal level.\n" \
                     "Attempted to show " + str(h_inc_back_num) + " removal(s),\n" \
                     "but only " + str(h_iterations) + " removal(s) done."

    ax.text(0.5, 0.5, output_message, horizontalalignment='center',
            verticalalignment='center')

    pycbc.results.save_fig_with_metadata(fig, args.output_file,
        title="%s bin, Count vs Rank" % f.attrs['name'] if 'name' in f.attrs else "Count vs Rank",
        caption=output_message,
        cmd=' '.join(sys.argv))

    # Exit the code successfully and bypass the rest of the plotting code.
    sys.exit(0)

if args.verbose:
    log_level = logging.INFO
    logging.basicConfig(format='%(asctime)s : %(message)s', level=log_level)

if args.closed_box:
    fstat = None
else:
    try:
        fstat = f['foreground/stat'][:]
        fstat.sort()
    except:
        fstat = None
    if len(fstat) == 0:
        fstat = None

if h_inc_back_num == 0:
    bstat = f['background/stat'][:]
    fap = 1 - numpy.exp(- f.attrs['foreground_time'] / f['background/ifar'][:] / lal.YRJUL_SI)
    dec = f['background/decimation_factor'][:]
else :
    bstat = f['background_h%s/stat' % h_inc_back_num][:]
    fap = 1 - numpy.exp(- f.attrs['foreground_time_h%s' % h_inc_back_num] / f['background_h%s/ifar' % h_inc_back_num][:] / lal.YRJUL_SI)
    dec = f['background_h%s/decimation_factor' % h_inc_back_num][:]

s = bstat.argsort()
dec, bstat, fap = dec[s], bstat[s], fap[s]
logging.info('Found %s background (inclusive zerolag) triggers' % len(bstat))

dec_exc = f['background_exc/decimation_factor'][:]
bstat_exc = f['background_exc/stat'][:]
s = bstat_exc.argsort()
dec_exc, bstat_exc = dec_exc[s], bstat_exc[s]

logging.info('Found %s background (exclusive zerolag) triggers' % len(bstat_exc))

fig = pylab.figure()

if fstat is not None:
    minimum = min(fstat.min(), bstat.min())
    maximum = max(fstat.max(), bstat.max())
elif args.closed_box:
    minimum = bstat_exc.min()
    maximum = bstat_exc.max()
else:
    minimum = bstat.min()
    maximum = bstat.max()

bin_size = args.bin_size if args.bin_size else (maximum - minimum) / 100.

bins = numpy.arange(minimum, maximum + bin_size, bin_size)

# plot background minus foreground
exc_binweights = dec_exc / f.attrs['background_time_exc'] * lal.YRJUL_SI
exc_binvals = pylab.hist(bstat_exc, bins=bins, histtype='step',
                         linewidth=2,
                         color='grey', log=True,
                         label= 'Background Uncorrelated with Foreground',
                         weights=exc_binweights)

histpeak = bins[exc_binvals[0].argmax()]

# plot full background
if not args.closed_box:
    if h_inc_back_num == 0:
        pylab.hist(bstat, bins=bins, histtype='step',
                   linewidth=2,
                   color='black', log=True,
                   label='Full Background',
                   weights=dec / f.attrs['background_time'] * lal.YRJUL_SI)
    else :
        pylab.hist(bstat, bins=bins, histtype='step',
                   linewidth=2,
                   color='black', log=True,
                   label='Full Background',
                   weights=dec / f.attrs['background_time_h%s' % h_inc_back_num] * lal.YRJUL_SI)

if fstat is not None and not args.closed_box:
    le, re = bins[:-1], bins[1:]
    # We need to plot a histogram "errorbar" with the hierarchically
    # removed foreground triggers in purple.
    if h_iterations > 0 :
        fstat_h_rm = numpy.array([], dtype=float)

        # Since there is only one background bin we can just remove
        # hierarchically removed triggers from highest ranking statistic
        # to lower ranking statistic.
        for i in range(0, h_inc_back_num):
            rm_idx = fstat.argmax()
            fstat_h_rm = numpy.append(fstat_h_rm, fstat[rm_idx])
            fstat = numpy.delete(fstat, rm_idx)

        # Sort these so the plotting doesn't screw up.
        fstat = numpy.sort(fstat)
        fstat_h_rm = numpy.sort(fstat_h_rm)

        # Write "histogram" information.
        left_h_rm = numpy.searchsorted(fstat_h_rm, le)
        right_h_rm = numpy.searchsorted(fstat_h_rm, re)

        # Just use the top level foreground time if you want background
        # before h-removal.
        if h_inc_back_num == 0:
            count_h_rm = (right_h_rm - left_h_rm) / \
                         f.attrs['foreground_time'] * lal.YRJUL_SI

        # Or use the foreground time after h-removal
        else:
            count_h_rm = (right_h_rm - left_h_rm) / \
                         f.attrs['foreground_time_h%s' % h_inc_back_num] * \
                         lal.YRJUL_SI

        pylab.errorbar(bins[:-1] + bin_size / 2, count_h_rm,
                       xerr=bin_size/2,
                       label='Hierarchically Removed Foreground', mec='none',
                       fmt='s', ms=1, capthick=0, elinewidth=4,
                       color='#b66dff')

    left = numpy.searchsorted(fstat, le)
    right = numpy.searchsorted(fstat, re)
    count = (right - left) / f.attrs['foreground_time'] * lal.YRJUL_SI
    pylab.errorbar(bins[:-1] + bin_size / 2, count, xerr=bin_size/2,
                   label='Foreground', mec='none', fmt='o', ms=1, capthick=0,
                   elinewidth=4,  color='#ff6600')

pylab.xlabel('Coincident ranking statistic (bin size = %.2f)' % bin_size)
pylab.ylabel('Trigger Rate (yr$^{-1})$')
if args.x_min is not None:
    pylab.xlim(xmin=args.x_min)
else:
    pylab.xlim(xmin=numpy.floor(histpeak))
pylab.ylim(ymin=0.5 / f.attrs['background_time_exc'] * lal.YRJUL_SI)
pylab.grid()
leg = pylab.legend(fontsize=9)

end = sigma_from_p(fap.min() * args.trials_factor)

if not args.closed_box:
    sigmas = [1, 2, 3, 4, end]
    for ii, sig in enumerate(sigmas[:-1]):
        next_sig = sigmas[ii + 1]

        # Find the p-value of a sigma curve
        p1 = 1 - p_from_sigma(sig) / args.trials_factor

        # Find the p-value of the next sigma curve
        p2 = 1 - p_from_sigma(next_sig) / args.trials_factor

        # Search the fap data for these p values
        x1 = numpy.searchsorted(fap[::-1], p1)
        x2 = numpy.searchsorted(fap[::-1], p2)

        if x1 == x2:
            continue

        ymin, ymax = pylab.gca().get_ylim()
        try:
            x = [bstat[::-1][x1], bstat[::-1][x2]]
        except IndexError:
            break
        pylab.fill_between(x, ymin, ymax, zorder=-1,
                           color=pylab.cm.Blues(next_sig / 8.0))

        if next_sig == end:
            next_sig = '%.1f' % next_sig

        pylab.text(bstat[::-1][x2] - .1, ymax, r"$%s \sigma$" % next_sig,
                   fontsize=10, horizontalalignment='center',
                   verticalalignment='bottom')

ax1 =  pylab.gca()
ax2 = ax1.twinx()

if h_inc_back_num == 0:
    fac = f.attrs['foreground_time'] / lal.YRJUL_SI
else :
    fac = f.attrs['foreground_time_h%s' % h_inc_back_num] / lal.YRJUL_SI

ymin = ax1.get_ylim()[0] * fac
ymax =  ax1.get_ylim()[1] * fac

ax2.set_ylim(ymin=ymin, ymax=ymax)
ax2.set_yscale('log')
ax2.set_ylabel('Number per Experiment')

if 'name' in f.attrs:
    title = "%s bin, Count vs Rank" % f.attrs['name']
    caption="Histogram of FAR vs the ranking statistic in the search."
elif 'ifos' in f.attrs:
    title = "%s coincidences, Count vs Rank" % f.attrs['ifos']
    caption="Histogram of the FAR vs the ranking statistic " \
            "for %s coincidences only" % f.attrs['ifos']
else:
    title = "Count vs Rank"
    caption="Histogram of the FAR vs the ranking statistic in the search."

pycbc.results.save_fig_with_metadata(fig, args.output_file,
     title=title, caption=caption, cmd=' '.join(sys.argv))
