#!/bin/env python
"""
The program combines coincident output files generated
by pycbc_coinc_findtrigs to generated a mapping between SNR and FAP, along
with producing the combined foreground and background triggers
"""
import argparse, h5py, logging, itertools, copy, pycbc.io, numpy, lal
from pycbc.events import veto, coinc
import pycbc.version
import pycbc.conversions as conv

parser = argparse.ArgumentParser()
# General required options
parser.add_argument('--verbose', action='count')
parser.add_argument('--version', action='version', version=pycbc.version.git_verbose_msg)
parser.add_argument('--cluster-window', type=float, default=10,
         help='Length of time window in seconds to cluster coinc events, [default=10s]')
parser.add_argument('--zero-lag-coincs', nargs='+',
                    help="Files containing the injection zerolag coincidences")
parser.add_argument('--mixed-coincs-inj-full', nargs='+',
                    help="Files containing the mixed injection/clean data "
                         "time slides")
parser.add_argument('--mixed-coincs-full-inj', nargs='+',
                    help="Files containing the mixed clean/injection data "
                         "time slides")
parser.add_argument('--full-data-background',
                    help='background file from full data for use in analyzing injection coincs')
parser.add_argument('--veto-window', type=float, default=.1,
         help='Time around each zerolag trigger to window out, [default=.1s]')
parser.add_argument("--ranking-statistic-threshold", type=float,
                    help="Minimum value of the ranking statistic to calculate"
                         " a unique inclusive background.")
parser.add_argument('--output-file')
args = parser.parse_args()

if args.verbose:
    log_level = logging.INFO
    logging.basicConfig(format='%(asctime)s : %(message)s', level=log_level)

logging.info("Loading coinc zerolag triggers")
window = args.cluster_window
zdata = pycbc.io.StatmapData(files=args.zero_lag_coincs)
zdata = zdata.cluster(window)

logging.info("Loading coinc full inj triggers")
fidata = pycbc.io.StatmapData(files=args.mixed_coincs_full_inj).cluster(window)

logging.info("Loading coinc inj full triggers")
ifdata = pycbc.io.StatmapData(files=args.mixed_coincs_inj_full).cluster(window)

f = h5py.File(args.output_file, "w")

f.attrs['detector_1'] = zdata.attrs['detector_1']
f.attrs['detector_2'] = zdata.attrs['detector_2']
f.attrs['timeslide_interval'] = zdata.attrs['timeslide_interval']

# Copy over the segment for coincs and singles
for key in zdata.seg.keys():
    f['segments/%s/start' % key] = zdata.seg[key]['start'][:]
    f['segments/%s/end' % key] = zdata.seg[key]['end'][:]

logging.info('writing zero lag triggers')
if len(zdata) > 0:
    for key in zdata.data:
        f['foreground/%s' % key] = zdata.data[key]
else:
    for key in zdata.data:
        f['foreground/%s' % key] = numpy.array([], dtype=zdata.data[key].dtype)

logging.info('calculating statistics excluding zerolag')
fb = h5py.File(args.full_data_background, "r")

# Allow all time to contribute to injection calculated IFARs, as for zerolag
# This prevents loud injections from being assigned to a bin where they have
# bad chisq if that bin has slightly larger exclusive bg livetime due to
# different censoring
background_time = float(fb.attrs['background_time'])
coinc_time = float(fb.attrs['foreground_time'])
back_stat = fb['background_exc/stat'][:]
dec_fac = fb['background_exc/decimation_factor'][:]

f.attrs['background_time_exc'] = background_time
f.attrs['foreground_time_exc'] = coinc_time
f.attrs['background_time'] = background_time
f.attrs['foreground_time'] = coinc_time

if len(zdata) > 0:
    fnlouder_exc = coinc.calculate_n_louder(back_stat, zdata.stat, dec_fac, skip_background=True)
    ifar_exc = background_time / (fnlouder_exc + 1)
    fap_exc = 1 - numpy.exp(- coinc_time / ifar_exc)
    f['foreground/ifar_exc'] = conv.sec_to_year(ifar_exc)
    f['foreground/fap_exc'] = fap_exc

    logging.info('calculating injection backgrounds')
    ftimes = (zdata.time1 + zdata.time2) / 2.0
    start, end = ftimes - args.veto_window, ftimes + args.veto_window

    fnlouder = numpy.zeros(len(ftimes), dtype=numpy.float32)
    ifar = numpy.zeros(len(ftimes), dtype=numpy.float32)
    fap = numpy.zeros(len(ftimes), dtype=numpy.float32)

    # We are relying on the injection data set to be the first one,
    # this is determined
    # by the argument order to pycbc_coinc_findtrigs
    ifsort = ifdata.time1.argsort()
    ifsorted = ifdata.time1[ifsort]
    if_start, if_end = numpy.searchsorted(ifsorted, start), numpy.searchsorted(ifsorted, end)

    fisort = fidata.time1.argsort()
    fisorted = fidata.time1[fisort]
    fi_start, fi_end = numpy.searchsorted(fisorted, start), numpy.searchsorted(fisorted, end)

    # most of the triggers are here so presort to speed up later sorting
    bsort = back_stat.argsort()
    dec_fac = dec_fac[bsort]
    back_stat = back_stat[bsort]

    for i, fstat in enumerate(zdata.stat):
        # If the trigger is quiet enough, then don't calculate a separate
        # background type, as it would not be significantly different
        if args.ranking_statistic_threshold and fstat < args.ranking_statistic_threshold:
            fnlouder[i] = fnlouder_exc[i]
            ifar[i] = ifar_exc[i]
            fap[i] = fap_exc[i]
            continue

        v1 = fisort[fi_start[i]:fi_end[i]]
        v2 = ifsort[if_start[i]:if_end[i]]

        inj_stat = numpy.concatenate([ifdata.stat[v2], fidata.stat[v1], back_stat])
        inj_dec = numpy.concatenate([numpy.repeat(1, len(v1) + len(v2)), dec_fac])

        fnlouder[i] = coinc.calculate_n_louder(inj_stat, fstat, inj_dec, skip_background=True)
        ifar[i] = background_time / (fnlouder[i] + 1)
        fap[i] = 1 - numpy.exp(- coinc_time / ifar[i])
        logging.info('processed %s, %s' % (i, fstat))

    f['foreground/ifar'] = conv.sec_to_year(ifar)
    f['foreground/fap'] = fap
else:
    f['foreground/ifar_exc'] = numpy.array([])
    f['foreground/fap_exc'] = numpy.array([])
    f['foreground/ifar'] = numpy.array([])
    f['foreground/fap'] = numpy.array([])

logging.info("Done")

