#!/usr/bin/env python
"""
The program combines coincident output files generated
by pycbc_multiifo_coinc_findtrigs to generated a mapping between SNR and FAP, along
with producing the combined foreground and background triggers. It also has
the capability of doing hierarchical removal of foreground triggers that are
louder than all of the background triggers. We use this to properly assess
the FANs of any other gravitational waves in the dataset.
"""
import argparse, h5py, itertools
import lal, logging, numpy
from pycbc.events import veto, coinc
import pycbc.version, pycbc.pnutils, pycbc.io
import sys
import pycbc.conversions as conv

class fw(object):
    def __init__(self, name):
        self.f = h5py.File(name, 'w')
        self.attrs = self.f.attrs

    def __setitem__(self, name, data):
        # Make a new item if isn't in the hdf file
        if not name in self.f:
            self.f.create_dataset(name, data=data, compression="gzip",
                                  compression_opts=9, shuffle=True,
                                  maxshape=data.shape)
        # Else reassign values
        else:
            self.f[name][:] = data

    def __getitem__(self, *args):
        return self.f.__getitem__(*args)

parser = argparse.ArgumentParser()
# General required options
parser.add_argument('--version', action='version',
                    version=pycbc.version.git_verbose_msg)
parser.add_argument('--coinc-files', nargs='+',
                    help='List of coincidence files used to calculate the '
                         'FAP, FAR, etc.')
parser.add_argument('--ifos', nargs='+',
                    help='List of ifos used in these coincidence files')
parser.add_argument('--verbose', action='count')
parser.add_argument('--cluster-window', type=float, default=10,
                    help='Length of time window in seconds to cluster coinc '
                         'events [default=10s]')
parser.add_argument('--veto-window', type=float, default=.1,
                    help='Time around each zerolag trigger to window out '
                         '[default=.1s]')
parser.add_argument('--hierarchical-removal-window', type=float, default=1.0,
                    help='Time around each trigger to window out for a very '
                         'louder trigger in the hierarchical removal '
                         'procedure [default=1.0s]')
parser.add_argument('--max-hierarchical-removal', type=int, default=0,
                    help='Maximum number of hierarchical removals to carry '
                         'out. Choose -1 for unlimited hierarchical removal '
                         'until no foreground triggers are louder than the '
                         '(inclusive/exclusive) background. Choose 0 to do '
                         'no hierarchical removals, choose 1 to do at most '
                         '1 hierarchical removal and so on. If given, must '
                         'also provide --hierarchical-removal-against to '
                         'indicate which background to remove triggers '
                         'against. [default=0]')
parser.add_argument('--hierarchical-removal-against', type=str,
                    default='none', choices=['none', 'inclusive', 'exclusive'],
                    help='If doing hierarchical removal, remove foreground '
                         'triggers that are louder than either the "inclusive"'
                         ' (little-dogs-in) background, or the "exclusive" '
                         '(little-dogs-out) background. [default="none"]')
parser.add_argument('--output-file')
args = parser.parse_args()

# Check that the user chose inclusive or exclusive background to perform
# hierarchical removals of foreground triggers against.
if args.max_hierarchical_removal == 0:
    if args.hierarchical_removal_against is not 'none':
        parser.error("User Error: 0 maximum hierarchical removals chosen but "
                     "option for --hierarchical-removal-against was given. "
                     "These are conflicting options. Use with --help for more "
                     "information.")
else :
    is_bkg_inc = (args.hierarchical_removal_against == 'inclusive')
    is_bkg_exc = (args.hierarchical_removal_against == 'exclusive')
    if not(is_bkg_inc or is_bkg_exc):
        parser.error("--max-hierarchical-removal requires a choice of which "
                     "background to remove foreground triggers against, "
                     "inclusive or exclusive. Use with --help for more "
                     "information.")

pycbc.init_logging(args.verbose)

logging.info("Loading coinc triggers")
logging.info("IFO input: %s" % args.ifos)
all_trigs = pycbc.io.MultiifoStatmapData(files=args.coinc_files, ifos=args.ifos)
if 'ifos' in all_trigs.attrs:
    ifos = all_trigs.attrs['ifos'].split(' ')
    logging.info('using ifos from file {}'.format(args.coinc_files[0]))
else:
    ifos = args.ifos
    logging.info('using ifos from command line input')

logging.info("We have %s triggers" % len(all_trigs.stat))
fore_locs = all_trigs.timeslide_id == 0

# Foreground trigger times for ifos
fore_time = {}
for ifo in ifos:
    fore_time[ifo] = all_trigs.data['%s/time' % ifo][fore_locs]
# Average times of triggers (note that coincs where not all ifos have triggers
# will contain -1 sentinel values)
fore_time_zip = zip(*fore_time.values())
ave_fore_time = []
for ts in fore_time_zip:
    ave_fore_time.append(coinc.mean_if_greater_than_zero(ts)[0])
ave_fore_time = numpy.array(ave_fore_time)

# Window out start and end time around every average foreground trigger time
remove_start_time = ave_fore_time - args.veto_window
remove_end_time = ave_fore_time + args.veto_window

# Total amount of time removed around foreground triggers
veto_time = abs(veto.start_end_to_segments(remove_start_time,
                                           remove_end_time).coalesce())

# Veto indices from list of triggers in the windowed times around fg triggers
# This gives exclusive background triggers
exc_zero_trigs = all_trigs.remove([])  # Start by copying existing triggers
for ifo in ifos:
    fg_veto_ids = veto.indices_within_times(
                       exc_zero_trigs.data['%s/time' % ifo],
                       remove_start_time, remove_end_time)
    exc_zero_trigs = exc_zero_trigs.remove(fg_veto_ids)

logging.info("Clustering coinc triggers (inclusive of zerolag)")
all_trigs = all_trigs.cluster(args.cluster_window)

# Return an array of true or false if the trigger has not been time-slid
fore_locs = all_trigs.timeslide_id == 0
logging.info("%s clustered foreground triggers" % fore_locs.sum())

logging.info("Clustering coinc triggers (exclusive of zerolag)")
exc_zero_trigs = exc_zero_trigs.cluster(args.cluster_window)

logging.info("Dumping foreground triggers")
f = fw(args.output_file)
f.attrs['num_of_ifos'] = all_trigs.attrs['num_of_ifos']
f.attrs['pivot'] = all_trigs.attrs['pivot']
f.attrs['fixed'] = all_trigs.attrs['fixed']
if 'ifos' in all_trigs.attrs:
    f.attrs['ifos'] = all_trigs.attrs['ifos']
else:
    f.attrs['ifos'] = ' '.join(sorted(args.ifos))

f.attrs['timeslide_interval'] = all_trigs.attrs['timeslide_interval']

# Copy over the segment for coincs and singles
for key in all_trigs.seg.keys():
    f['segments/%s/start' % key] = all_trigs.seg[key]['start'][:]
    f['segments/%s/end' % key] = all_trigs.seg[key]['end'][:]

if fore_locs.sum() > 0:
    f['segments/foreground_veto/start'] = remove_start_time
    f['segments/foreground_veto/end'] = remove_end_time
    for k in all_trigs.data:
        f['foreground/' + k] = all_trigs.data[k][fore_locs]
else:
    # Put SOMETHING in here to avoid failures later
    f['segments/foreground_veto/start'] = numpy.array([0])
    f['segments/foreground_veto/end'] = numpy.array([0])
    for k in all_trigs.data:
        f['foreground/' + k] = numpy.array([], dtype=all_trigs.data[k].dtype)

# If a particular index of all_trigs.timeslide_id isn't 0, evaluate true.
# List of locations that is background.
back_locs = all_trigs.timeslide_id != 0

if (back_locs.sum()) == 0:
    logging.warn("There were no background events, so we could not assign "
                 "any statistic values")
    sys.exit()

logging.info("Dumping background triggers (inclusive of zerolag)")
for k in all_trigs.data:
    f['background/' + k] = all_trigs.data[k][back_locs]

logging.info("Dumping background triggers (exclusive of zerolag)")
for k in exc_zero_trigs.data:
    f['background_exc/' + k] = exc_zero_trigs.data[k]

maxtime = all_trigs.attrs['%s_foreground_time' % f.attrs['pivot']]
for ifo in ifos:
    if all_trigs.attrs['%s_foreground_time' % ifo] > maxtime:
        maxtime = all_trigs.attrs['%s_foreground_time' % ifo]

mintime = all_trigs.attrs['%s_foreground_time' % f.attrs['pivot']]
for ifo in ifos:
    if all_trigs.attrs['%s_foreground_time' % ifo] < mintime:
        mintime = all_trigs.attrs['%s_foreground_time' % ifo]

maxtime_exc = maxtime - veto_time
mintime_exc = mintime - veto_time

background_time = int(maxtime / all_trigs.attrs['timeslide_interval']) * mintime
coinc_time = float(all_trigs.attrs['coinc_time'])

background_time_exc = int(maxtime_exc / all_trigs.attrs['timeslide_interval']) \
                          * mintime_exc
coinc_time_exc = coinc_time - veto_time

logging.info("Calculating FAN from background statistic values")
# Ranking statistic of foreground and background
back_stat = all_trigs.stat[back_locs]
fore_stat = all_trigs.stat[fore_locs]

# Cumulative array of inclusive background triggers and the number of
# inclusive background triggers louder than each foreground trigger
back_cnum, fnlouder = coinc.calculate_n_louder(back_stat, fore_stat,
                               all_trigs.decimation_factor[back_locs])

# Cumulative array of exclusive background triggers and the number
# of exclusive background triggers louder than each foreground trigger
back_cnum_exc, fnlouder_exc = coinc.calculate_n_louder(
        exc_zero_trigs.stat, fore_stat, exc_zero_trigs.decimation_factor)

f['background/ifar'] = conv.sec_to_year(background_time / (back_cnum + 1))
f['background_exc/ifar'] = conv.sec_to_year(background_time_exc /
                                            (back_cnum_exc + 1))
f.attrs['background_time'] = background_time
f.attrs['foreground_time'] = coinc_time
f.attrs['background_time_exc'] = background_time_exc
f.attrs['foreground_time_exc'] = coinc_time_exc

logging.info("calculating ifar/fap values")

if fore_locs.sum() > 0:
    ifar = background_time / (fnlouder + 1)
    fap = 1 - numpy.exp(- coinc_time / ifar)
    f['foreground/ifar'] = conv.sec_to_year(ifar)
    f['foreground/fap'] = fap
    ifar_exc = background_time_exc / (fnlouder_exc + 1)
    fap_exc = 1 - numpy.exp(- coinc_time_exc / ifar_exc)
    f['foreground/ifar_exc'] = conv.sec_to_year(ifar_exc)
    f['foreground/fap_exc'] = fap_exc
else:
    f['foreground/ifar'] = numpy.array([])
    f['foreground/fap'] = numpy.array([])
    f['foreground/ifar_exc'] = numpy.array([])
    f['foreground/fap_exc'] = numpy.array([])

if 'name' in all_trigs.attrs:
    f.attrs['name'] = all_trigs.attrs['name']

# Incorporate hierarchical removal for any other loud triggers
logging.info("Beginning hierarchical removal of foreground triggers")

# Step 1: Create a copy of foreground trigger ranking statistic for reference
#         in the hierarchical removal while loop when updating ifar and fap of
#         hierarchically removed foreground triggers.

# Set an index to keep track of how many hierarchical removals we want to do.
h_iterations = 0

orig_fore_stat = fore_stat

# Assign a new variable to keep track of whether we care about fnlouder or
# fnlouder_exc. Whether we want to remove hierarchically against inclusive
# or exclusive background.
if args.max_hierarchical_removal != 0:
    # If user wants to remove against inclusive background.
    if is_bkg_inc:
        louder_foreground = fnlouder
    # Otherwise user wants to remove against exclusive background
    else :
        louder_foreground = fnlouder_exc
else :
    # It doesn't matter if you choose louder_foreground = fnlouder
    # or vice versa, the while loop below will break if
    # louder_foreground doesn't contain 0, or at the comparison
    # h_iterations == args.max_hierarchical_removal. But this avoids
    # a NameError
    louder_foreground = fnlouder

# Step 2 : Loop until we don't have to hierarchically remove anymore. This
#          will happen when fnlouder has no elements that equal 0.

while numpy.any(louder_foreground == 0):
    # If the user wants to stop doing hierarchical removals after a set
    # number of iterations then break when that happens.
    if (h_iterations == args.max_hierarchical_removal):
        break

    # Write foreground trigger info before hierarchical removals for
    # downstream codes.
    if h_iterations == 0:
        f['background_h%s/stat' % h_iterations] = back_stat
        f['background_h%s/ifar' % h_iterations] = conv.sec_to_year(background_time / (back_cnum + 1))
        f['background_h%s/timeslide_id' % h_iterations] = all_trigs.data['timeslide_id'][back_locs]
        f['foreground_h%s/stat' % h_iterations] = fore_stat
        f['foreground_h%s/ifar' % h_iterations] = conv.sec_to_year(ifar)
        f['foreground_h%s/fap' % h_iterations] = fap
        f['foreground_h%s/template_id' % h_iterations] = all_trigs.data['template_id'][fore_locs]
        for ifo in args.ifos:
            trig_id = all_trigs.data['%s/trigger_id' % ifo][fore_locs]
            trig_time = all_trigs.data['%s/time' % ifo][fore_locs]
            f['foreground_h%s/%s/time' % (h_iterations,ifo)] = trig_time
            f['foreground_h%s/%s/trigger_id' % (h_iterations,ifo)] = trig_id

    # Add the iteration number of hierarchical removals done.
    h_iterations += 1

    # Among foreground triggers, find the one with the largest ranking
    # statistic and mark it for removal.
    max_stat_idx = fore_stat.argmax()

    # Step 3: Remove that trigger from the list of zerolag triggers

    # Find the index of the loud foreground trigger to remove next. And find
    # the index in the list of original foreground triggers.
    rm_trig_idx = numpy.where(all_trigs.stat[:] == fore_stat[max_stat_idx])[0][0]
    orig_fore_idx = numpy.where(orig_fore_stat == fore_stat[max_stat_idx])[0][0]

    # Store any foreground trigger's information that we want to
    # hierarchically remove.
    f['foreground/ifar'][orig_fore_idx] = conv.sec_to_year(ifar[max_stat_idx])
    f['foreground/fap'][orig_fore_idx] = fap[max_stat_idx]

    logging.info("Removing foreground trigger that is louder than the inclusive background.")

    # Remove the foreground trigger and all of the background triggers that
    # are associated with it.
    ave_rm_time = 0
    for ifo in args.ifos:
        ave_rm_time += all_trigs.data['%s/time' % ifo][rm_trig_idx] / len(args.ifos)

    ind_to_rm = {}
    for ifo in args.ifos:
        ind_to_rm[ifo] = veto.indices_within_times(all_trigs.data['%s/time' % ifo],
                              [ave_rm_time - args.hierarchical_removal_window],
                              [ave_rm_time + args.hierarchical_removal_window])
    indices_to_rm = []
    for ifo in args.ifos:
        indices_to_rm = numpy.concatenate([indices_to_rm, ind_to_rm[ifo]])
    all_trigs = all_trigs.remove(indices_to_rm)
    logging.info("We have %s triggers after hierarchical removal." % len(all_trigs.stat))

    # Step 4: Re-cluster the triggers and calculate the inclusive ifar/fap
    logging.info("Clustering coinc triggers (inclusive of zerolag)")
    all_trigs = all_trigs.cluster(args.cluster_window)
    fore_locs = all_trigs.timeslide_id == 0

    logging.info("%s clustered foreground triggers" % fore_locs.sum())
    logging.info("%s hierarchically removed foreground trigger(s)" % h_iterations)

    back_locs = all_trigs.timeslide_id != 0

    logging.info("Dumping foreground triggers")
    logging.info("Dumping background triggers (inclusive of zerolag)")
    for k in all_trigs.data:
         f['background_h%s/' % h_iterations + k] = all_trigs.data[k][back_locs]

    maxtime = all_trigs.attrs['%s_foreground_time' % f.attrs['pivot']]
    for ifo in ifos:
        if all_trigs.attrs['%s_foreground_time' % ifo] > maxtime:
            maxtime = all_trigs.attrs['%s_foreground_time' % ifo]

    mintime = all_trigs.attrs['%s_foreground_time' % f.attrs['pivot']]
    for ifo in ifos:
        if all_trigs.attrs['%s_foreground_time' % ifo] < mintime:
            mintime = all_trigs.attrs['%s_foreground_time' % ifo]

    background_time = int(maxtime / all_trigs.attrs['timeslide_interval']) * mintime
    coinc_time = float(all_trigs.attrs['coinc_time'])

    logging.info("Calculating FAN from background statistic values")
    back_stat = all_trigs.stat[back_locs]
    fore_stat = all_trigs.stat[fore_locs]
    back_cnum, fnlouder = coinc.calculate_n_louder(back_stat, fore_stat,
                                                   all_trigs.decimation_factor[back_locs])

    # Update the louder_foreground criteria depending on whether foreground
    # triggers are being removed via inclusive or exclusive background.
    if is_bkg_inc:
        louder_foreground = fnlouder

    # Exclusive background doesn't change when removing foreground triggers.
    # So we don't have to take back_cnum_exc, jut repopulate fnlouder_exc
    else :
        _, fnlouder_exc = coinc.calculate_n_louder(exc_zero_trigs.stat,
                                                   fore_stat,
                                                   exc_zero_trigs.decimation_factor)
        louder_foreground = fnlouder_exc
    # louder_foreground has been updated and the code can continue.

    logging.info("Calculating ifar/fap values")
    f['background_h%s/ifar' % h_iterations] = conv.sec_to_year(background_time / (back_cnum + 1))
    f.attrs['background_time_h%s' % h_iterations] = background_time
    f.attrs['foreground_time_h%s' % h_iterations] = coinc_time

    if fore_locs.sum() > 0:
        # Write ranking statistic to file just for downstream plotting code
        f['foreground_h%s/stat' % h_iterations] = fore_stat

        ifar = background_time / (fnlouder + 1)
        fap = 1 - numpy.exp(- coinc_time / ifar)
        f['foreground_h%s/ifar' % h_iterations] = conv.sec_to_year(ifar)
        f['foreground_h%s/fap' % h_iterations] = fap

        # Update ifar and fap for other foreground triggers
        for i in range(len(ifar)):
            orig_fore_idx = numpy.where(orig_fore_stat == fore_stat[i])[0][0]
            f['foreground/ifar'][orig_fore_idx] = conv.sec_to_year(ifar[i])
            f['foreground/fap'][orig_fore_idx] = fap[i]

        # Save trigger ids for foreground triggers for downstream plotting code.
        # These don't change with the iterations but should be written at every
        # level.

        f['foreground_h%s/template_id' % h_iterations] = all_trigs.data['template_id'][fore_locs]
        for ifo in args.ifos:
            trig_id = all_trigs.data['%s/trigger_id' % ifo][fore_locs]
            trig_time = all_trigs.data['%s/time' % ifo][fore_locs]
            f['foreground_h%s/%s/time' % (h_iterations,ifo)] = trig_time
            f['foreground_h%s/%s/trigger_id' % (h_iterations,ifo)] = trig_id
    else :
        f['foreground_h%s/stat' % h_iterations] = numpy.array([])
        f['foreground_h%s/ifar' % h_iterations] = numpy.array([])
        f['foreground_h%s/fap' % h_iterations] = numpy.array([])
        f['foreground_h%s/template_id' % h_iterations] = numpy.array([])
        for ifo in args.ifos:
            f['foreground_h%s/%s/time' % (h_iterations,ifo)] = numpy.array([])
            f['foreground_h%s/%s/trigger_id' % (h_iterations,ifo)] = numpy.array([])

# Write to file how many hierarchical removals were implemented.
f.attrs['hierarchical_removal_iterations'] = h_iterations

# Write whether hierarchical removals were removed against the
# inclusive background or the exclusive background. Have to use
# numpy.string_ datatype.
if h_iterations != 0:
    hrm_method = args.hierarchical_removal_against
    f.attrs['hierarchical_removal_method'] = numpy.string_(hrm_method)

logging.info("Done")
