#!/bin/env python
"""
The program combines coincident output files generated
by pycbc_coinc_findtrigs to generated a mapping between SNR and FAP, along
with producing the combined foreground and background triggers
"""
import argparse, h5py, logging, itertools, copy, pycbc.io, numpy, lal
from pycbc.events import veto, coinc
import pycbc.version
import pycbc.conversions as conv

parser = argparse.ArgumentParser()
# General required options
parser.add_argument('--verbose', action='count')
parser.add_argument('--version', action='version',
                    version=pycbc.version.git_verbose_msg)
parser.add_argument('--cluster-window', type=float, default=10,
                    help='Length of time window in seconds to cluster coinc '
                         'events [default=10s]')
parser.add_argument('--zero-lag-coincs', nargs='+',
                    help='Files containing the injection zerolag coincidences')
parser.add_argument('--mixed-coincs-inj-full', nargs='+',
                    help='Files containing the mixed injection/clean data '
                         'time slides')
parser.add_argument('--mixed-coincs-full-inj', nargs='+',
                    help='Files containing the mixed clean/injection data '
                         'time slides')
parser.add_argument('--full-data-background',
                    help='background file from full data for use in analyzing '
                         'injection coincs')
parser.add_argument('--veto-window', type=float, default=.1,
                    help='Time around each zerolag trigger to window out '
                         '[default=.1s]')
parser.add_argument('--ranking-statistic-threshold', type=float,
                    help='Minimum value of the ranking statistic to calculate'
                         ' a unique inclusive background.')
parser.add_argument('--ifos', nargs='+',
                    help='List of ifos used in these coincidence files')
parser.add_argument('--output-file')
args = parser.parse_args()

if args.verbose:
    log_level = logging.INFO
    logging.basicConfig(format='%(asctime)s : %(message)s', level=log_level)

window = args.cluster_window
logging.info("Loading coinc zerolag triggers")
zdata = pycbc.io.MultiifoStatmapData(files=args.zero_lag_coincs, ifos=args.ifos)
zdata = zdata.cluster(window)

logging.info("Loading coinc full inj triggers")
fidata = pycbc.io.MultiifoStatmapData(files=args.mixed_coincs_full_inj,
                                      ifos=args.ifos).cluster(window)

logging.info("Loading coinc inj full triggers")
ifdata = pycbc.io.MultiifoStatmapData(files=args.mixed_coincs_inj_full,
                                      ifos=args.ifos).cluster(window)

f = h5py.File(args.output_file, "w")

f.attrs['num_of_ifos'] = zdata.attrs['num_of_ifos']
f.attrs['pivot'] = zdata.attrs['pivot']
f.attrs['fixed'] = zdata.attrs['fixed']
f.attrs['timeslide_interval'] = zdata.attrs['timeslide_interval']
f.attrs['ifos'] = ' '.join(sorted(args.ifos))

# Copy over the segment for coincs and singles
for key in zdata.seg.keys():
    f['segments/%s/start' % key] = zdata.seg[key]['start'][:]
    f['segments/%s/end' % key] = zdata.seg[key]['end'][:]

logging.info('writing zero lag triggers')
if len(zdata) > 0:
    for key in zdata.data:
        f['foreground/%s' % key] = zdata.data[key]
else:
    for key in zdata.data:
        f['foreground/%s' % key] = numpy.array([], dtype=zdata.data[key].dtype)

logging.info('calculating statistics excluding zerolag')
fb = h5py.File(args.full_data_background, "r")

# we expect the injfull file to contain injection data as pivot
# and fullinj to contain full data as pivot
background_time = float(fb.attrs['background_time'])
coinc_time = float(fb.attrs['foreground_time'])
back_stat = fb['background_exc/stat'][:]
dec_fac = fb['background_exc/decimation_factor'][:]

f.attrs['background_time_exc'] = background_time
f.attrs['foreground_time_exc'] = coinc_time
f.attrs['background_time'] = background_time
f.attrs['foreground_time'] = coinc_time

if len(zdata) > 0:
    fnlouder_exc = coinc.calculate_n_louder(back_stat, zdata.stat, dec_fac,
                                            skip_background=True)
    ifar_exc = background_time / (fnlouder_exc + 1)
    fap_exc = 1 - numpy.exp(- coinc_time / ifar_exc)
    f['foreground/ifar_exc'] = conv.sec_to_year(ifar_exc)
    f['foreground/fap_exc'] = fap_exc

    logging.info('calculating injection backgrounds')
    ifotimes = zip(zdata.data['%s/time' % ifo] for ifo in args.ifos)
    ftimes = numpy.concatenate(ifotimes).mean(axis=0)
    start, end = ftimes - args.veto_window, ftimes + args.veto_window

    fnlouder = numpy.zeros(len(ftimes), dtype=numpy.float32)
    ifar = numpy.zeros(len(ftimes), dtype=numpy.float32)
    fap = numpy.zeros(len(ftimes), dtype=numpy.float32)

    # We are relying on the injection data set to be the first one,
    # this is determined
    # by the argument order to pycbc_coinc_findtrigs
    pivot_ifo = zdata.attrs['pivot']
    ifsort = ifdata.data['%s/time' % pivot_ifo].argsort()
    ifsorted = ifdata.data['%s/time' % pivot_ifo][ifsort]
    if_start, if_end = numpy.searchsorted(ifsorted, start), \
                       numpy.searchsorted(ifsorted, end)

    fisort = fidata.data['%s/time' % pivot_ifo].argsort()
    fisorted = fidata.data['%s/time' % pivot_ifo][fisort]
    fi_start, fi_end = numpy.searchsorted(fisorted, start), \
                       numpy.searchsorted(fisorted, end)

    # most of the triggers are here so presort to speed up later sorting
    bsort = back_stat.argsort()
    dec_fac = dec_fac[bsort]
    back_stat = back_stat[bsort]

    for i, fstat in enumerate(zdata.stat):
        # If the trigger is quiet enough, then don't calculate a separate
        # background type, as it would not be significantly different
        if args.ranking_statistic_threshold and \
                    fstat < args.ranking_statistic_threshold:
            fnlouder[i] = fnlouder_exc[i]
            ifar[i] = ifar_exc[i]
            fap[i] = fap_exc[i]
            continue

        v1 = fisort[fi_start[i]:fi_end[i]]
        v2 = ifsort[if_start[i]:if_end[i]]

        inj_stat = numpy.concatenate(
                     [ifdata.stat[v2], fidata.stat[v1], back_stat])
        inj_dec = numpy.concatenate(
                     [numpy.repeat(1, len(v1) + len(v2)), dec_fac])

        fnlouder[i] = coinc.calculate_n_louder(inj_stat, fstat, inj_dec,
                                               skip_background=True)
        ifar[i] = background_time / (fnlouder[i] + 1)
        fap[i] = 1 - numpy.exp(- coinc_time / ifar[i])
        logging.info('processed %s, %s' % (i, fstat))

    f['foreground/ifar'] = conv.sec_to_year(ifar)
    f['foreground/fap'] = fap
else:
    f['foreground/ifar_exc'] = numpy.array([])
    f['foreground/fap_exc'] = numpy.array([])
    f['foreground/ifar'] = numpy.array([])
    f['foreground/fap'] = numpy.array([])

logging.info("Done")
