#!/usr/bin/env python
"""
The program combines output files generated
by pycbc_sngls_findtrigs to generate a mapping between SNR and FAP/FAR, along
with producing the combined foreground and background triggers.
"""

import argparse, h5py, itertools
import lal, logging, numpy
from pycbc.events import veto, coinc
from pycbc.events import triggers, trigger_fits as trstats
from pycbc.events import significance
import pycbc.version, pycbc.pnutils, pycbc.io
import sys
import pycbc.conversions as conv

class fw(object):
    def __init__(self, name):
        self.f = h5py.File(name, 'w')
        self.attrs = self.f.attrs

    def __setitem__(self, name, data):
        # Make a new item if isn't in the hdf file
        if not name in self.f:
            self.f.create_dataset(name, data=data, compression="gzip",
                                  compression_opts=9, shuffle=True,
                                  maxshape=data.shape)
        # Else reassign values
        else:
            self.f[name][:] = data

    def __getitem__(self, *args):
        return self.f.__getitem__(*args)

parser = argparse.ArgumentParser()
# General required options
parser.add_argument('--version', action='version',
                    version=pycbc.version.git_verbose_msg)
parser.add_argument('--sngls-files', nargs='+',
                    help='List of files containing trigger and statistic '
                         'information.')
parser.add_argument('--ifos', nargs=1,
                    help='List of ifos used in these coincidence files')
parser.add_argument('--verbose', action='count')
parser.add_argument('--cluster-window', type=float, default=10,
                    help='Length of time window in seconds to cluster coinc '
                         'events [default=10s]')
parser.add_argument('--veto-window', type=float, default=.1,
                    help='Time around each zerolag trigger to window out '
                         '[default=.1s]')
significance.insert_significance_option_group(parser)
parser.add_argument('--hierarchical-removal-window', type=float, default=1.0,
                    help='Time around each trigger to window out for a very '
                         'louder trigger in the hierarchical removal '
                         'procedure [default=1.0s]')
parser.add_argument('--max-hierarchical-removal', type=int, default=0,
                    help='Maximum number of hierarchical removals to carry '
                         'out. Choose -1 for unlimited hierarchical removal '
                         'until no foreground triggers are louder than the '
                         '(inclusive/exclusive) background. Choose 0 to do '
                         'no hierarchical removals, choose 1 to do at most '
                         '1 hierarchical removal and so on. If given, must '
                         'also provide --hierarchical-removal-against to '
                         'indicate which background to remove triggers '
                         'against. [default=0]')
parser.add_argument('--hierarchical-removal-against', type=str,
                    default='none', choices=['none', 'inclusive', 'exclusive'],
                    help='If doing hierarchical removal, remove foreground '
                         'triggers that are louder than either the "inclusive"'
                         ' (little-dogs-in) background, or the "exclusive" '
                         '(little-dogs-out) background. [default="none"]')
parser.add_argument('--hierarchical-removal-ifar-thresh', type=float,
                    default=100,
                    help="Minimum IFAR for a foreground event to be "
                         "hierarchically removed from background of quieter "
                         "events (years) [default=100yr]")
parser.add_argument('--output-file')
args = parser.parse_args()


significance.check_significance_options(args, parser)

# Check that the user chose inclusive or exclusive background to perform
# hierarchical removals of foreground triggers against.
if args.max_hierarchical_removal == 0:
    if args.hierarchical_removal_against != 'none':
        parser.error("User Error: 0 maximum hierarchical removals chosen but "
                     "option for --hierarchical-removal-against was given. "
                     "These are conflicting options. Use with --help for more "
                     "information.")
else :
    is_bkg_inc = (args.hierarchical_removal_against == 'inclusive')
    is_bkg_exc = (args.hierarchical_removal_against == 'exclusive')
    if not(is_bkg_inc or is_bkg_exc):
        parser.error("--max-hierarchical-removal requires a choice of which "
                    "background to remove foreground triggers against, "
                     "inclusive or exclusive. Use with --help for more "
                     "information.")

pycbc.init_logging(args.verbose)

logging.info("Loading triggers")
ifo = args.ifos[0]
logging.info("IFO input: %s" % ifo)
all_trigs = pycbc.io.MultiifoStatmapData(files=args.sngls_files,
                                         ifos=[ifo])
assert ifo + '/time' in all_trigs.data

logging.info("We have %s triggers" % len(all_trigs.stat))
logging.info("Clustering triggers")
all_trigs = all_trigs.cluster(args.cluster_window)

# For now, all triggers are both in the foreground and background
fore_locs = numpy.flatnonzero(all_trigs.timeslide_id == 0)
back_locs = numpy.flatnonzero(all_trigs.timeslide_id == 0)

fg_time = float(all_trigs.attrs['foreground_time'])
exc_veto_time = (len(back_locs) - len(back_locs)) * args.veto_window
fg_time_exc = fg_time - exc_veto_time


logging.info("Dumping foreground triggers")
f = fw(args.output_file)
f.attrs['num_of_ifos'] = 1
f.attrs['ifos'] = ifo

f.attrs['timeslide_interval'] = all_trigs.attrs['timeslide_interval']

# Copy over the segment info
for key in all_trigs.seg.keys():
    f['segments/%s/start' % key] = all_trigs.seg[key]['start'][:]
    f['segments/%s/end' % key] = all_trigs.seg[key]['end'][:]

f['segments/foreground_veto/start'] = numpy.array([0])
f['segments/foreground_veto/end'] = numpy.array([0])
for k in all_trigs.data:
    f['foreground/' + k] = all_trigs.data[k][fore_locs]
    f['background/' + k] = all_trigs.data[k][back_locs]
    f['background_exc/' + k] = all_trigs.data[k][back_locs]


logging.info("Estimating FAN from background statistic values")
# Ranking statistic of foreground and background
fore_stat = all_trigs.stat[fore_locs]
back_stat = all_trigs.stat[back_locs]
back_stat_exc = all_trigs.stat[back_locs]

bkg_dec_facs = all_trigs.decimation_factor[back_locs]
bkg_exc_dec_facs = all_trigs.decimation_factor[back_locs]

significance_dict = significance.digest_significance_options([ifo], args)

# Cumulative array of inclusive background triggers and the number of
# inclusive background triggers louder than each foreground trigger
back_cnum, fnlouder = significance.get_n_louder(
    back_stat,
    fore_stat,
    bkg_dec_facs,
    **significance_dict[ifo])

# Cumulative array of exclusive background triggers and the number
# of exclusive background triggers louder than each foreground trigger
back_cnum_exc, fnlouder_exc = significance.get_n_louder(
    back_stat_exc,
    fore_stat,
    bkg_exc_dec_facs,
    **significance_dict[ifo])

bg_ifar = fg_time / (back_cnum + 1)
fg_ifar = fg_time_exc / (fnlouder + 1)
bg_ifar_exc = fg_time / (back_cnum_exc + 1)
fg_ifar_exc = fg_time_exc / (fnlouder_exc + 1)

f['background/ifar'] = conv.sec_to_year(bg_ifar)
f['background_exc/ifar'] = conv.sec_to_year(bg_ifar_exc)
f.attrs['background_time'] = fg_time
f.attrs['foreground_time'] = fg_time
f.attrs['background_time_exc'] = fg_time_exc
f.attrs['foreground_time_exc'] = fg_time_exc

logging.info("calculating foreground ifar/fap values")

fap = 1 - numpy.exp(- fg_time / fg_ifar)
f['foreground/ifar'] = conv.sec_to_year(fg_ifar)
f['foreground/fap'] = fap
fap_exc = 1 - numpy.exp(- fg_time_exc / fg_ifar_exc)
f['foreground/ifar_exc'] = conv.sec_to_year(fg_ifar_exc)
f['foreground/fap_exc'] = fap_exc

if 'name' in all_trigs.attrs:
    f.attrs['name'] = all_trigs.attrs['name']

# Incorporate hierarchical removal for any other loud triggers
logging.info("Beginning hierarchical removal of foreground triggers")

# Step 1: Create a copy of foreground trigger ranking statistic for reference
#         in the hierarchical removal while loop when updating ifar and fap o
#         hierarchically removed foreground triggers.

# Set an index to keep track of how many hierarchical removals we want to do.
h_iterations = 0

orig_fore_stat = fore_stat

# Assign a new variable to keep track of whether we care about inclusive or
# exclusive ifar depending on whether we want to remove hierarchically
# against inclusive or exclusive background.

if args.max_hierarchical_removal != 0:
    # If user wants to remove against inclusive background.
    if is_bkg_inc:
        ifar_louder = fg_ifar
    # Otherwise user wants to remove against exclusive background
    else :
        ifar_louder = fg_ifar_exc
else :
    # It doesn't matter if you choose inclusive or exclusive,
    # the while loop below will break if none are louder than ifar_louder,
    # or at the comparison
    # h_iterations == args.max_hierarchical_removal. But this avoids
    # a NameError
    ifar_louder = fg_ifar

# Step 2 : Loop until we don't have to hierarchically remove anymore. This
#          will happen when fnlouder has no elements that equal 0.

while numpy.any(ifar_louder >= fg_time):
    # If the user wants to stop doing hierarchical removals after a set
    # number of iterations then break when that happens.
    if (h_iterations == args.max_hierarchical_removal):
        break

    # Write foreground trigger info before hierarchical removals for
    # downstream codes.
    if h_iterations == 0:
        f['background_h%s/stat' % h_iterations] = back_stat
        f['background_h%s/ifar' % h_iterations] = conv.sec_to_year(bg_ifar)
        f['background_h%s/timeslide_id' % h_iterations] = all_trigs.data['timeslide_id'][back_locs]
        f['foreground_h%s/stat' % h_iterations] = fore_stat
        f['foreground_h%s/ifar' % h_iterations] = conv.sec_to_year(fg_ifar)
        f['foreground_h%s/ifar_exc' % h_iterations] = conv.sec_to_year(fg_ifar_exc)
        f['foreground_h%s/fap' % h_iterations] = fap
        f['foreground_h%s/template_id' % h_iterations] = all_trigs.data['template_id'][fore_locs]
        trig_id = all_trigs.data['%s/trigger_id' % ifo][fore_locs]
        trig_time = all_trigs.data['%s/time' % ifo][fore_locs]
        f['foreground_h%s/%s/time' % (h_iterations,ifo)] = trig_time
        f['foreground_h%s/%s/trigger_id' % (h_iterations,ifo)] = trig_id
    # Add the iteration number of hierarchical removals done.
    h_iterations += 1

    # Among foreground triggers, find the one with the largest ifar
    # and mark it for removal.
    max_stat_idx = ifar_louder.argmax()

    # Step 3: Remove that trigger from the list of zerolag triggers

    # Find the index of the loud foreground trigger to remove next. And find
    # the index in the list of original foreground triggers.
    rm_trig_idx = numpy.where(all_trigs.stat[:] == fore_stat[max_stat_idx])[0][0]
    orig_fore_idx = numpy.where(orig_fore_stat == fore_stat[max_stat_idx])[0][0]

    # Store any foreground trigger's information that we want to
    # hierarchically remove.
    f['foreground/ifar'][orig_fore_idx] = conv.sec_to_year(fg_ifar[max_stat_idx])
    f['foreground/fap'][orig_fore_idx] = fap[max_stat_idx]

    logging.info("Removing foreground trigger that is louder than the inclusive background.")

    # Remove the foreground trigger and all of the background triggers that
    # are associated with it.
    ave_rm_time = all_trigs.data['%s/time' % ifo][rm_trig_idx]

    ind_to_rm = {}
    ind_to_rm[ifo] = veto.indices_within_times(all_trigs.data['%s/time' % ifo],
                              [ave_rm_time - args.hierarchical_removal_window],
                              [ave_rm_time + args.hierarchical_removal_window])
    indices_to_rm = []
    indices_to_rm = numpy.concatenate([indices_to_rm, ind_to_rm[ifo]])

    all_trigs = all_trigs.remove(indices_to_rm.astype(int))
    logging.info("We have %s triggers after hierarchical removal." % len(all_trigs.stat))

    # Step 4: Re-cluster the triggers and calculate the inclusive ifar/fap
    logging.info("Clustering coinc triggers (inclusive of zerolag)")
    all_trigs = all_trigs.cluster(args.cluster_window)
    fore_locs = all_trigs.timeslide_id == 0

    logging.info("%s clustered foreground triggers" % fore_locs.sum())
    logging.info("%s hierarchically removed foreground trigger(s)" % h_iterations)

    back_locs = all_trigs.timeslide_id == 0
    fore_locs = all_trigs.timeslide_id == 0

    logging.info("Dumping foreground triggers")
    logging.info("Dumping background triggers (inclusive of zerolag)")
    for k in all_trigs.data:
         f['background_h%s/' % h_iterations + k] = all_trigs.data[k][back_locs]

    logging.info("Calculating FAN from background statistic values")
    back_stat = all_trigs.stat[back_locs]
    fore_stat = all_trigs.stat[fore_locs]

    back_cnum, fnlouder = significance.get_n_louder(
        back_stat,
        fore_stat,
        numpy.ones_like(back_stat),
        **significance_dict[ifo])

    bg_ifar = fg_time / (back_cnum + 1)
    fg_ifar = fg_time / (fnlouder + 1)

    # Update the ifar_louder criteria depending on whether foreground
    # triggers are being removed via inclusive or exclusive background.
    if is_bkg_inc:
        ifar_louder = fg_ifar

    # Exclusive background doesn't change when removing foreground triggers.
    # So we don't have to take back_cnum_exc, just repopulate fg_ifar_exc
    else:
        _, fnlouder_exc = significance.get_n_louder(
            back_stat_exc,
            fore_stat,
            numpy.ones_like(back_stat_exc),
            **significance_dict[ifo])

        fg_ifar_exc = fg_time / (fnlouder_exc + 1)

        ifar_louder = fg_ifar_exc

    # louder_foreground has been updated and the code can continue.

    logging.info("Calculating ifar/fap values")
    f['background_h%s/ifar' % h_iterations] = conv.sec_to_year(bg_ifar)
    f.attrs['background_time_h%s' % h_iterations] = fg_time
    f.attrs['foreground_time_h%s' % h_iterations] = fg_time

    if fore_locs.sum() > 0:
        # Write ranking statistic to file just for downstream plotting code
        f['foreground_h%s/stat' % h_iterations] = fore_stat

        fap = 1 - numpy.exp(- fg_time / fg_ifar)
        f['foreground_h%s/ifar' % h_iterations] = conv.sec_to_year(fg_ifar)
        f['foreground_h%s/fap' % h_iterations] = fap

        fap_exc = 1 - numpy.exp(- fg_time / fg_ifar_exc)
        f['foreground_h%s/ifar' % h_iterations] = conv.sec_to_year(fg_ifar_exc)
        f['foreground_h%s/fap' % h_iterations] = fap_exc

        # Update ifar and fap for other foreground triggers
        for i in range(len(fg_ifar)):
            orig_fore_idx = numpy.where(orig_fore_stat == fore_stat[i])[0][0]
            f['foreground/ifar'][orig_fore_idx] = conv.sec_to_year(fg_ifar[i])
            f['foreground/fap'][orig_fore_idx] = fap[i]

        # Save trigger ids for foreground triggers for downstream plotting code.
        # These don't change with the iterations but should be written at every
        # level.

        f['foreground_h%s/template_id' % h_iterations] = all_trigs.data['template_id'][fore_locs]
        trig_id = all_trigs.data['%s/trigger_id' % ifo][fore_locs]
        trig_time = all_trigs.data['%s/time' % ifo][fore_locs]
        f['foreground_h%s/%s/time' % (h_iterations,ifo)] = trig_time
        f['foreground_h%s/%s/trigger_id' % (h_iterations,ifo)] = trig_id
    else :
        f['foreground_h%s/stat' % h_iterations] = numpy.array([])
        f['foreground_h%s/ifar' % h_iterations] = numpy.array([])
        f['foreground_h%s/fap' % h_iterations] = numpy.array([])
        f['foreground_h%s/template_id' % h_iterations] = numpy.array([])
        f['foreground_h%s/%s/time' % (h_iterations,ifo)] = numpy.array([])
        f['foreground_h%s/%s/trigger_id' % (h_iterations,ifo)] = numpy.array([])

# Write to file how many hierarchical removals were implemented.
f.attrs['hierarchical_removal_iterations'] = h_iterations

# Write whether hierarchical removals were removed against the
# inclusive background or the exclusive background. Have to use
# numpy.string_ datatype.
if h_iterations != 0:
    hrm_method = args.hierarchical_removal_against
    f.attrs['hierarchical_removal_method'] = numpy.string_(hrm_method)

logging.info("Done")

