#!/usr/bin/python
# Copyright (C) 2015 Christopher M. Biwer
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

import argparse
import h5py
import numpy
import sys
import lal
import matplotlib as mpl; mpl.use('Agg')
import pylab
import pycbc.results
import pycbc.version
import copy
from pycbc.events import veto
from ligo import segments

def calculate_time_slide_duration(pifo_segments, fifo_segments, offset=0):
   ''' Returns the amount of coincident time between two segmentlists.
   '''

   # dertermine how much coincident time
   overlap = pifo_segments.coalesce() & fifo_segments.coalesce().shift(offset)
   duration = abs(overlap.coalesce())

   # unshift the shifted segmentlist
   fifo_segments.shift(-offset)

   return duration


# parser command line
parser = argparse.ArgumentParser(usage='pycbc_page_ifar [--options]',
                    description='Plots a cumulative histogram of IFAR for'
                          'coincident foreground triggers and a subset of' 
                          'the coincident time slide triggers.')
parser.add_argument('--version', action='version',
                    version=pycbc.version.git_verbose_msg)
parser.add_argument('--trigger-file', type=str, required=True,
                    help='Path to coincident trigger HDF file.')
parser.add_argument('--output-file', type=str, required=True,
                    help='Path to output plot.')
parser.add_argument('--decimation-factor', type=int, required=True,
                    help='Decimation factor used in background estimation.' 
                          'Decimation factor means that every nth time' 
                          'slide kept all of its coincident triggers in' 
                          'the HDF file.')
parser.add_argument('--all-decimated-background-min-ifar', type=float,
                    default=0.01, metavar='VAL',
                    help='Threshold the IFARs of decimated background events '
                          'before combining them and rescaling to zerolag '
                          '(green line). Plotting all decimated background '
                          'can be memory intensive and is not necessary for '
                          'checking the background estimate at interesting '
                          'IFAR levels. Slide coincs with IFARs below VAL '
                          'will be cut, the green line will then be valid for '
                          'zerolag down to VAL*(zerolag coinc time/total '
                          'decimated slide time). Recommended value equal '
                          'to default value 0.01yr.')
parser.add_argument('--use-hierarchical-level', type=int, default=None,
                    help='Indicate which inclusive background and FARs of '
                         'foreground triggers to plot if there were any '
                         'hierarchical removals done. Choosing None plots '
                         'the inclusive backgrounds prior to any '
                         'hierarchical removals with the updated FARs for '
                         'foreground triggers after hierarchical removal(s). '
                         'Choosing 0 means plotting inclusive background '
                         'from prior to any hierarchical removals with FARs '
                         'for foreground triggers prior to hierarchical '
                         'removal. Choosing 1 means plotting the inclusive '
                         'background after doing 1 hierarchical removal, and '
                         'includes updated FARs from after 1 hierarchical '
                         'removal. [default=None]')
parser.add_argument('--open-box', action='store_true', default=False,
                    help='Show the foreground triggers on the output plot. ')
opts = parser.parse_args()

# read file
fp = h5py.File(opts.trigger_file, 'r')

if 'ifos' in fp.attrs:
    legacy_smap_file = False
else:
    legacy_smap_file = True

# Parse which inclusive background to use for the plotting
h_inc_back_num = opts.use_hierarchical_level

try:
    h_iterations = fp.attrs['hierarchical_removal_iterations']
except KeyError:
    h_iterations = 0

if h_inc_back_num is None:
    h_inc_back_num = h_iterations

if h_inc_back_num > h_iterations:
    # Produce a null plot saying no hierarchical removals can be plotted
    import sys
    fig = pylab.figure()
    ax = fig.add_subplot(111)

    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)

    output_message = "No more foreground events louder than all background\n" \
                     "at this removal level.\nAttempted to show %d " \
                     "removal(s),\nbut only %d removal(s) done." % \
                     (h_inc_back_num, h_iterations)

    ax.text(0.5, 0.5, output_message, horizontalalignment='center',
            verticalalignment='center')

    pycbc.results.save_fig_with_metadata(fig, opts.output_file,
                                         title='Cumulative Number vs. IFAR',
                                         caption=output_message,
                                         cmd=' '.join(sys.argv))

    # Exit the code successfully and bypass the rest of the plotting code.
    sys.exit(0)

# get foreground IFAR values and cumulative number for each IFAR value
if opts.open_box:

    if h_inc_back_num >= 0 and h_iterations is not None and h_iterations != 0:
        fore_ifar = fp['foreground_h%s/ifar' % h_inc_back_num][:]
        # Get ifar of hierarchically removed foreground triggers
        supp_fore_ifar = fp['foreground/ifar'][:]

        # Use to plot hierarchically removed foreground triggers
        # on the plot and correct the cumulative number at IFAR level
        hrm_sorted = sorted(supp_fore_ifar)
        idx_start = len(supp_fore_ifar) - h_inc_back_num

        h_rm_ifar = hrm_sorted[idx_start:]
        h_rm_cumnum = numpy.arange(len(h_rm_ifar), 0, -1)

    else :
        fore_ifar = fp['foreground/ifar'][:]

    fore_ifar.sort()
    fore_cumnum = numpy.arange(len(fore_ifar), 0, -1)

# get expected foreground IFAR values and cumulative number for each IFAR value
expected_ifar = numpy.logspace(-8, 2, num=100, endpoint=True, base=10.0)
expected_cumnum = fp.attrs['foreground_time'] / expected_ifar / lal.YRJUL_SI 

# get background timeslide IDs and IFAR values

# Use these backgrounds for n-th stage of hierarchical removals
# and test for backwards compatibility and that stage was written. 
if opts.open_box:
    if h_inc_back_num >= 0 and h_iterations is not None and h_iterations != 0:
        back_tsid = fp['background_h%s/timeslide_id' % h_inc_back_num][:]
        back_ifar = fp['background_h%s/ifar' % h_inc_back_num][:]

    # If the user wants to plot foreground events at h_inc_back_num == 0
    # and h_iterations == 0 then no hierarchical removals were done and
    # the line below is safe to use. The other case is the trigger-file
    # is before hierarchical removals so the else below catches that.
    # The third case is requesting to plot hierarchical removal level
    # greater than was done and an empty plot is made earlier in the code.
    else :
        back_tsid = fp['background/timeslide_id'][:]
        back_ifar = fp['background/ifar'][:]

else :
    # If the box is closed use the background from no hierarchical
    # removals. Results should be agnostic of judgments about
    # whether a hierarchical removal was done.
    if h_iterations > 0 and h_iterations is not None:
        back_tsid = fp['background_h0/timeslide_id'][:]
        back_ifar = fp['background_h0/ifar'][:]

    # Otherwise the box is closed and no hierarchical removals
    # were done so use top level background
    else :
        back_tsid = fp['background/timeslide_id'][:]
        back_ifar = fp['background/ifar'][:]

# make figure
fig = pylab.figure(1)

# get a unique list of timeslide_ids and loop over them
interval = fp.attrs['timeslide_interval']
if legacy_smap_file:
    pifo, fifo = fp.attrs['detector_1'], fp.attrs['detector_2']
    p_starts = fp['segments'][pifo]['start'][:]
    p_ends = fp['segments'][pifo]['end'][:]
    pifo_segments = veto.start_end_to_segments(p_starts, p_ends)
    f_starts = fp['segments'][pifo]['start'][:]
    f_ends = fp['segments'][pifo]['end'][:]
    fifo_segments = veto.start_end_to_segments(f_starts, f_ends)
    min_tsid = (p_starts.min() - f_ends.max()) / interval
    max_tsid = (p_ends.max() - f_starts.min()) / interval
else:
    pifo, fifo = fp.attrs['pivot'], fp.attrs['fixed']
    ifo_joined = fp.attrs['ifos'].replace(' ','')
    p_starts = fp['segments'][ifo_joined]['start'][:]
    p_ends = fp['segments'][ifo_joined]['end'][:]
    pifo_segments = veto.start_end_to_segments(p_starts, p_ends)
    fifo_segments = copy.deepcopy(pifo_segments)
    min_tsid = (p_starts.min() - p_ends.max()) / interval
    max_tsid = -min_tsid

min_tsid = numpy.rint(min_tsid) // opts.decimation_factor
max_tsid = numpy.rint(max_tsid) // opts.decimation_factor

tsids = numpy.arange(min_tsid, max_tsid, 1).astype(numpy.int64) * \
                     opts.decimation_factor

allbkg_dur = 0
allbkg_ifars = []
empty_slide_count = 0
plotted_slide_count = 0

for tsid in tsids:
    if tsid == 0:
        continue

    # find all triggers in this time slide
    ts_indx = numpy.where(back_tsid == tsid)

    # calculate the amount of coincident time in this time slide
    offset = tsid*fp.attrs['timeslide_interval']
    back_dur = calculate_time_slide_duration(pifo_segments, fifo_segments,
                                             offset=offset)
    if back_dur == 0:
        continue

    plotted_slide_count += 1
    allbkg_dur = allbkg_dur + back_dur
    # Limit how far back the decimated background gets plotted
    red_trigs = back_ifar[ts_indx][back_ifar[ts_indx] > \
                                        opts.all_decimated_background_min_ifar]
    allbkg_ifars.extend(red_trigs)

    # apply the correction factor for this time slide to its IFAR
    # you need a correction factor because the analyzed time of the time slide
    # is not the same as the analyzed time of the foreground
    ts_ifar = back_ifar[ts_indx] * (fp.attrs['foreground_time'] / back_dur)
    ts_ifar.sort()

    # get the cumulative number of triggers in this time slide
    ts_cumnum = numpy.arange(len(ts_ifar), 0, -1)

    # plot the time slide triggers
    if len(ts_ifar) > 1:
        pylab.loglog(ts_ifar, ts_cumnum, color='gray', alpha=0.4)
    elif len(ts_ifar) == 1:
        pylab.plot(ts_ifar, ts_cumnum, color='gray', marker='.', alpha=0.4)
    else:
        empty_slide_count += 1

allbkg_ifars = numpy.array(allbkg_ifars)
allbkg_ifars = allbkg_ifars * (fp.attrs['foreground_time'] / allbkg_dur )
allbkg_ifars.sort()
allbkg_cumnum = numpy.arange(len(allbkg_ifars), 0, -1)
pylab.loglog(allbkg_ifars, allbkg_cumnum, color='green', linewidth=1.5,
             label="All decimated background")

# plot the expected background
pylab.loglog(expected_ifar, expected_cumnum, linestyle='--', linewidth=2,
             color='black', label='Expected Background')

# plot the counting error
error_plus = expected_cumnum + numpy.sqrt(expected_cumnum)
error_minus = expected_cumnum - numpy.sqrt(expected_cumnum)
error_minus = numpy.where(error_minus<=0, 1e-5, error_minus)
pylab.fill_between(expected_ifar, error_minus, error_plus, facecolor='y',
                   alpha=0.4, label='$N^{1/2}$ Errors')

# plot the counting error
error_plus = expected_cumnum + 2 * numpy.sqrt(expected_cumnum)
error_minus = expected_cumnum - 2*numpy.sqrt(expected_cumnum)
error_minus = numpy.where(error_minus<=0, 1e-5, error_minus)
pylab.fill_between(expected_ifar, error_minus, error_plus, facecolor='y',
                   alpha=0.2, label='$2N^{1/2}$ Errors')

# plot the foreground triggers
if opts.open_box:
    pylab.loglog(fore_ifar, fore_cumnum, linestyle='None', color='blue',
                 marker='^', label='Foreground')

    if h_inc_back_num > 0:
        pylab.loglog(h_rm_ifar, h_rm_cumnum, linestyle='None', color='#b66dff',
                     marker='v', label='Hierarchically Removed Foreground')

# format plot
if opts.open_box and len(fore_cumnum) > 100:
    # If we have > 100 foreground triggers, scale the plot around the
    # foreground
    pylab.ylim(0.8, 1.1 * len(fore_cumnum))
    pylab.xlim(0.9 * min(fore_ifar))
elif len(allbkg_cumnum) > 0:
    # If we have < 100 foreground triggers (or this is closed box), scale the
    # plot around the cumulative background (the green line).
    pylab.ylim(0.8, 1.1 * len(allbkg_cumnum))
    pylab.xlim(0.9 * min(allbkg_ifars))
pylab.grid()
pylab.legend(loc='upper right', fontsize=9)
pylab.ylabel('Cumulative Number')
pylab.xlabel('Inverse False Alarm Rate (yr)')

# save
caption = 'This is a cumulative histogram of triggers. The blue triangles ' \
          'represent coincident foreground triggers. The dashed line ' \
          'represents the expected background given the analysis time. The ' \
          'shaded regions represent counting errors. The gray lines are ' \
          'time slides treated as zero lag, here there are %d time slides ' \
          'plotted. Gray dots are time slides with only one event. %d of ' \
          'the plotted slides have zero events. The green line represents ' \
          'all decimated time slides rescaled to the analysis time.' \
          %  (plotted_slide_count, empty_slide_count)

pycbc.results.save_fig_with_metadata(fig, opts.output_file,
     title='Cumulative Number vs. IFAR',
     caption=caption,
     cmd=' '.join(sys.argv))
