#!/usr/bin/env python
import h5py, argparse, logging, numpy, numpy.random
from ligo.segments import infinity
from pycbc.events import veto, coinc, stat
import pycbc.version
from numpy.random import seed, shuffle
from pycbc.io.hdf import ReadByTemplate

parser = argparse.ArgumentParser()
parser.add_argument("--verbose", action="count")
parser.add_argument("--version", action="version", version=pycbc.version.git_verbose_msg)
parser.add_argument("--veto-files", nargs='*', action='append', default=[],
                    help="Optional veto file. Triggers within veto segments "
                         "contained in the file are ignored")
parser.add_argument("--segment-name", nargs='*', action='append', default=[],
                    help="Optional, name of veto segment in veto file")
parser.add_argument("--trigger-files", nargs='*', action='append', default=[],
                    help="Files containing single-detector triggers")
parser.add_argument("--template-bank", required=True,
                    help="Template bank file in HDF format")
parser.add_argument("--pivot-ifo", required=True,
                    help="Add the ifo to use as the pivot for multi "
                         "detector coincidence")
parser.add_argument("--fixed-ifo", required=True,
                    help="Add the ifo to use as the fixed ifo for "
                         "multi detector coincidence")
# produces a list of lists to allow multiple invocations and multiple args
parser.add_argument("--statistic-files", nargs='*', action='append', default=[],
                    help="Files containing ranking statistic info")
parser.add_argument("--ranking-statistic", choices=stat.statistic_dict.keys(),
                    default='newsnr',
                    help="The coinc ranking statistic to calculate")
parser.add_argument("--statistic-keywords", nargs='*',
                    default=[],
                    help="Provide additional key-word arguments to be sent to "
                         "the statistic class when it is initialized. Should "
                         "be given in format --statistic-keywords "
                         "KWARG1:VALUE1 KWARG2:VALUE2 KWARG3:VALUE3 ...")
parser.add_argument("--use-maxalpha", action="store_true")
parser.add_argument("--coinc-threshold", type=float, default=0.0,
                    help="Seconds to add to time-of-flight coincidence window")
parser.add_argument("--timeslide-interval", type=float,
                    help="Interval between timeslides in seconds. Timeslides are"
                         " disabled if the option is omitted.")
parser.add_argument("--loudest-keep-values", type=str, nargs='*',
                    default=['6:1'],
                    help="Apply successive multiplicative levels of decimation"
                         " to coincs with stat value below the given thresholds"
                         " Ex. 9:10 8.5:30 8:30 7.5:30. Default: no decimation")
parser.add_argument("--template-fraction-range", default="0/1",
                    help="Optional, analyze only part of template bank. Format"
                         " PART/NUM_PARTS")
parser.add_argument("--randomize-template-order", action="store_true",
                    help="Random shuffle templates with fixed seed "
                         "before selecting range to analyze")
parser.add_argument("--cluster-window", type=float,
                    help="Optional, window size in seconds to cluster "
                         "coincidences over the bank")
parser.add_argument("--output-file",
                    help="File to store the coincident triggers")
parser.add_argument("--batch-singles", default=5000, type=int,
                    help="Number of single triggers to process at once")
parser.add_argument("--legacy-output", action="store_true")
args = parser.parse_args()

if args.legacy_output:
    assert (args.pivot_ifo == 'L1')
    assert (args.fixed_ifo == 'H1')

# flatten the list of lists of filenames to a single list (may be empty)
args.statistic_files = sum(args.statistic_files, [])
args.segment_name = sum(args.segment_name, [])
args.veto_files = sum(args.veto_files, [])
args.trigger_files = sum(args.trigger_files, [])

if args.verbose:
    logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG)


def parse_template_range(num_templates, rangestr):
    part = int(rangestr.split('/')[0])
    pieces = int(rangestr.split('/')[1])
    tmin = int(num_templates / float(pieces) * part)
    tmax = int(num_templates / float(pieces) * (part+1))
    return tmin, tmax

logging.info('Starting...')

num_templates = len(h5py.File(args.template_bank, "r")['template_hash'])
tmin, tmax = parse_template_range(num_templates, args.template_fraction_range)
logging.info('Analyzing template %s - %s' % (tmin, tmax-1))

class MultiifoTrigs(object):
    """store trigger info in parallel with ifo name and shift vector"""
    def __init__(self):
        self.ifos = []
        self.to_shift = []
        self.singles = []

trigs = MultiifoTrigs()
for i in range(len(args.trigger_files)):
    logging.info('Opening trigger file %s: %s' % (i,args.trigger_files[i]))
    reader = ReadByTemplate(args.trigger_files[i],
                            args.template_bank,
                            args.segment_name,
                            args.veto_files)
    ifo = reader.ifo
    trigs.ifos.append(ifo)
    # time shift is subtracted from pivot ifo time
    trigs.to_shift.append(-1 if ifo == args.pivot_ifo else 0)
    logging.info('Applying time shift multiple %i to ifo %s' %
                 (trigs.to_shift[-1], trigs.ifos[-1]))
    trigs.singles.append(reader)

# Coinc_segs contains only segments where all ifos are analyzed
coinc_segs = veto.start_end_to_segments([-infinity()], [infinity()])
for i, sngl in zip(trigs.ifos, trigs.singles):
    coinc_segs = (coinc_segs & sngl.segs)
for sngl in trigs.singles:
    sngl.segs = coinc_segs
    sngl.valid = veto.segments_to_start_end(sngl.segs)

# Stat class instance to calculate the coinc ranking statistic
extra_kwargs = {}
for inputstr in args.statistic_keywords:
    try:
        key, value = inputstr.split(':')
        extra_kwargs[key] = value
    except ValueError:
        err_txt = "--statistic-keywords must take input in the " \
                  "form KWARG1:VALUE1 KWARG2:VALUE2 KWARG3:VALUE3 ... " \
                  "Received {}".format(args.statistic_keywords)
        raise ValueError(err_txt)

rank_method = stat.get_statistic(args.ranking_statistic)(args.statistic_files,
                                                         ifos=trigs.ifos,
                                                         **extra_kwargs)

# Sanity check, time slide interval should be larger than twice the
# Earth crossing time, which is approximately 0.085 seconds.
TWOEARTH = 0.085
if args.timeslide_interval is not None and args.timeslide_interval <= TWOEARTH:
    raise parser.error("The time slide interval should be larger "
                       "than twice the Earth crossing time.")

# slide = 0 means don't do timeslides
if args.timeslide_interval is None:
    args.timeslide_interval = 0

if args.randomize_template_order:
    seed(0)
    template_ids = numpy.arange(0, num_templates)
    shuffle(template_ids)
    template_ids = template_ids[tmin:tmax]
else:
    template_ids = range(tmin, tmax)

# 'data' will store output of coinc finding
# in addition to these lists of coinc info, will also store trigger times and
# ids in each ifo
data = {'stat': [], 'decimation_factor': [], 'timeslide_id': [], 'template_id': []}
for ifo in trigs.ifos:
    data['%s/time' % ifo] = []
    data['%s/trigger_id' % ifo] = []

for tnum in template_ids:
    times_full = {}
    sds_full = {}
    tids_full = {}
    logging.info('Obtaining trigs for template %i ..' % (tnum))
    for i, sngl in zip(trigs.ifos, trigs.singles):
        tids_full[i] = sngl.set_template(tnum)
        times_full[i] = sngl['end_time']
        logging.info('%s:%s' % (i, len(tids_full[i])))

        # get single-detector statistic
        sds_full[i] = rank_method.single(sngl)
         
    mintrigs = min([len(ti) for ti in tids_full.values()])
    if mintrigs == 0:
        logging.info('No triggers in at least one ifo for template %i, '
                     'skipping' % tnum)
        continue

    # Loop over the single triggers and calculate the coincs they can form
    start0 = 0
    while start0 < len(sds_full[args.pivot_ifo]):
        start1 = 0
        while start1 < len(sds_full[args.fixed_ifo]):
            end0 = start0 + args.batch_singles
            end1 = start1 + args.batch_singles
            if end0 > len(sds_full[args.pivot_ifo]):
                end0 = len(sds_full[args.pivot_ifo])
            if end1 > len(sds_full[args.fixed_ifo]):
                end1 = len(sds_full[args.fixed_ifo])

            times = times_full.copy()
            times[args.pivot_ifo] = times_full[args.pivot_ifo][start0:end0]
            times[args.fixed_ifo] = times_full[args.fixed_ifo][start1:end1]

            sds = sds_full.copy()
            sds[args.pivot_ifo] = sds_full[args.pivot_ifo][start0:end0]
            sds[args.fixed_ifo] = sds_full[args.fixed_ifo][start1:end1]

            tids = tids_full.copy()
            tids[args.pivot_ifo] = tids_full[args.pivot_ifo][start0:end0]
            tids[args.fixed_ifo] = tids_full[args.fixed_ifo][start1:end1]

            # find the coincs
            ids, slide = coinc.time_multi_coincidence(times,
                                                      args.timeslide_interval,
                                                      args.coinc_threshold,
                                                      args.pivot_ifo,
                                                      args.fixed_ifo)
            logging.info('Coincident trigs: %s' % (len(ids[args.pivot_ifo])))

            logging.info('Calculating multi-detector combined statistic')
            # list in ifo order of remaining trigger data
            single_info = [(i, sds[i][ids[i]]) for i in trigs.ifos]
            cstat = rank_method.coinc_multiifo(
                single_info, slide, args.timeslide_interval,
                to_shift=trigs.to_shift,
                time_addition=args.coinc_threshold)

            # index values of the zerolag triggers
            fi = numpy.where(slide == 0)[0]
            # index values of the background triggers
            bi = numpy.where(slide != 0)[0]
            logging.info('%s foreground triggers' % len(fi))
            logging.info('%s background triggers' % len(bi))

            # coincs will be decimated by successive (multiplicative) levels
            # tracked by 'total_factor'
            bi_dec = bi.copy()
            dec = numpy.ones(len(bi))
            total_factor = 1
            for decstr in args.loudest_keep_values:
                thresh, factor = decstr.split(':')
                thresh = float(thresh)
                # throws an error if 'factor' is not the string representation
                # of an integer
                total_factor *= int(factor)

                # triggers not being further decimated
                upper = cstat[bi_dec] >= thresh
                idx_keep = bi_dec[upper]

                # decimate the remaining triggers
                idx = bi_dec[cstat[bi_dec] < thresh]
                idx = idx[slide[idx] % total_factor == 0]

                bi_dec = numpy.concatenate([idx_keep, idx])
                dec = numpy.concatenate([dec[upper],
                                         numpy.repeat(total_factor, len(idx))])

            ti = numpy.concatenate([bi_dec, fi]).astype(numpy.uint32)
            dec_fac = numpy.concatenate([dec, numpy.ones(len(fi))])
            logging.info('%s after decimation' % len(ti))

            # temporary storage for decimated trigger ids
            decid = {}
            for ifo in ids:
                decid[ifo] = ids[ifo][ti]
            del ids

            for ifo in decid:
                addtime = times[ifo][decid[ifo]]
                addtriggerid = tids[ifo][decid[ifo]]
                data['%s/time' % ifo] += [addtime]
                data['%s/trigger_id' % ifo] += [addtriggerid]
            data['stat'] += [cstat[ti]]
            data['decimation_factor'] += [dec_fac]
            data['timeslide_id'] += [slide[ti]]
            data['template_id'] += [numpy.repeat(tnum, len(ti))]

            start1 += args.batch_singles
        start0 += args.batch_singles

if len(data['stat']) > 0:
    for key in data:
        data[key] = numpy.concatenate(data[key])

if args.cluster_window and len(data['stat']) > 0:
    timestring0 = '%s/time' % args.pivot_ifo
    timestring1 = '%s/time' % args.fixed_ifo
    cid = coinc.cluster_coincs(data['stat'], data[timestring0], data[timestring1],
                               data['timeslide_id'], args.timeslide_interval,
                               args.cluster_window)

logging.info('saving coincident triggers')
f = h5py.File(args.output_file, 'w')
if len(data['stat']) > 0:
    for key in data:
        var = data[key][cid] if args.cluster_window else data[key]
        f.create_dataset(key, data=var,
                         compression='gzip',
                         compression_opts=9,
                         shuffle=True)

# Store coinc segments keyed by detector combination
key = ''.join(sorted(trigs.ifos))
if args.legacy_output:
    # Associated segments will all be equal here
    sngl = trigs.singles[0]
    f['segments/H1/start'] , f['segments/H1/end'] = sngl.valid
    f['segments/L1/start'] , f['segments/L1/end'] = sngl.valid
    f['segments/coinc/start'] , f['segments/coinc/end'] = sngl.valid
else:
    f['segments/%s/start' % key], f['segments/%s/end' % key] = trigs.singles[0].valid

f.attrs['timeslide_interval'] = args.timeslide_interval
f.attrs['coinc_time'] = abs(coinc_segs)
if args.legacy_output:
    f.attrs['detector_1'] = 'H1'
    f.attrs['detector_2'] = 'L1'
    # These two will be equal, so order is irrelevant
    f.attrs['foreground_time1'] = abs(trigs.singles[0].segs)
    f.attrs['foreground_time2'] = abs(trigs.singles[1].segs)
else:
    f.attrs['num_of_ifos'] = len(args.trigger_files)
    f.attrs['pivot'] = args.pivot_ifo
    f.attrs['fixed'] = args.fixed_ifo
    for i, sngl in zip(trigs.ifos, trigs.singles):
        f.attrs['%s_foreground_time' % i] = abs(sngl.segs)
    f.attrs['ifos'] = ' '.join(sorted(trigs.ifos))

# What does this code actually calculate?
if args.timeslide_interval:
    maxtrigs = max([abs(sngl.segs) for sngl in trigs.singles])
    nslides = int(maxtrigs / args.timeslide_interval)
else:
    nslides = 0
f.attrs['num_slides'] = nslides

if args.legacy_output:
    f['time1'] = f['H1/time'][:]
    f['trigger_id1'] = f['H1/trigger_id'][:]
    f['time2'] = f['L1/time'][:]
    f['trigger_id2'] = f['L1/trigger_id'][:]
    del f['H1']
    del f['L1']

logging.info('Done')
