#!/usr/bin/env python
import argparse, logging, h5py, numpy as np
from ligo.segments import infinity
from numpy.random import seed, shuffle
from pycbc.events import veto, coinc, stat
import pycbc.conversions as conv
import pycbc.version
from pycbc import io
from pycbc.events import cuts, trigger_fits as trfits
from pycbc.events.veto import indices_outside_times
from pycbc.types.optparse import MultiDetOptionAction
from pycbc import init_logging

parser = argparse.ArgumentParser()
parser.add_argument("--verbose", action='count')
parser.add_argument("--version", action='version',
                    version=pycbc.version.git_verbose_msg)
# Basic file input options
parser.add_argument("--trigger-files", type=str, nargs=1,
                    help="File containing single-detector triggers")
parser.add_argument("--template-bank", required=True,
                    help="Template bank file in HDF format")
parser.add_argument("--template-fraction-range", default="0/1",
                    help="Optional, analyze only part of template bank. Format"
                         " PART/NUM_PARTS")
parser.add_argument("--randomize-template-order", action="store_true",
                    help="Random shuffle templates with fixed seed "
                         "before selecting range to analyze")
# Options to define the vetoes
parser.add_argument("--veto-files", nargs='*', action='append', default=[],
                    help="Optional veto file. Triggers within veto segments "
                         "contained in the file are ignored")
parser.add_argument("--segment-name", nargs='*', action='append', default=[],
                    help="Optional, name of veto segment in veto file")
parser.add_argument("--gating-veto-windows", nargs='+',
                    action=MultiDetOptionAction,
                    help="Seconds to be vetoed before and after the central time "
                         "of each gate. Given as detector-values pairs, e.g. "
                         "H1:-1,2.5 L1:-1,2.5 V1:0,0")
# additional veto options
# produces a list of lists to allow multiple invocations and multiple args
parser.add_argument('--cluster-window', type=float,
                    help='Window (seconds) during which to keep the trigger '
                         'with the loudest statistic value. '
                         'Default=do not cluster')
parser.add_argument("--output-file",
                    help="File to store the candidate triggers")
stat.insert_statistic_option_group(parser)
cuts.insert_cuts_option_group(parser)
args = parser.parse_args()

trigger_file = args.trigger_files[0]

if (args.veto_files and not args.segment_name) or \
    (args.segment_name and not args.veto_files):
    raise RuntimeError('--veto-files and --segment-name are mutually required')

if not len(args.veto_files) == len(args.segment_name):
    raise RuntimeError('--segment-name optionss are required for each --veto-files')

args.segment_name = sum(args.segment_name, [])
args.veto_files = sum(args.veto_files, [])

init_logging(args.verbose)

trigger_cut_dict, template_cut_dict = cuts.ingest_cuts_option_group(args)

logging.info('Opening trigger file: %s', trigger_file)
trigf = io.HFile(trigger_file, 'r')
ifo = list(trigf.keys())[0]

# Set up to only load triggers from the templates of interest

def parse_template_range(num_templates, rangestr):
    part = int(rangestr.split('/')[0])
    pieces = int(rangestr.split('/')[1])
    tmin = int(num_templates / float(pieces) * part)
    tmax = int(num_templates / float(pieces) * (part+1))
    return tmin, tmax

num_templates = io.HFile(args.template_bank, "r")['template_hash'].size
tmin, tmax = parse_template_range(num_templates, args.template_fraction_range)
logging.info('Analyzing template %s - %s' % (tmin, tmax-1))

if args.randomize_template_order:
    seed(0)
    template_ids = np.arange(0, num_templates)
    shuffle(template_ids)
    template_ids = template_ids[tmin:tmax]
else:
    template_ids = np.array(range(tmin, tmax))

from pycbc.io.hdf import ReadByTemplate
trigs = ReadByTemplate(trigger_file,
                       args.template_bank,
                       args.segment_name,
                       args.veto_files,
                       args.gating_veto_windows)
logging.info("%d triggers in file", trigf[ifo + '/snr'].size)

stat_all = []
trigger_ids_all = []
template_ids_all = []
trigger_times_all = []

rank_method = stat.get_statistic_from_opts(args, [ifo])

# Apply cuts to templates
template_ids = cuts.apply_template_cuts(
    trigs.bank,
    template_cut_dict,
    statistic=rank_method,
    ifos=[ifo],
    template_ids=template_ids)

for tnum in template_ids:
    tids_uncut = trigs.set_template(tnum)

    trigger_keep_ids = cuts.apply_trigger_cuts(trigs, trigger_cut_dict)
    tids_full = tids_uncut[trigger_keep_ids]
    logging.info('%s:%s', tnum, len(tids_uncut))
    if len(tids_full) < len(tids_uncut):
        logging.info("%s triggers cut",
                     len(tids_uncut) - len(tids_full))

    n_tot_trigs = tids_full.size
    if not n_tot_trigs: continue

    # Stat class instance to calculate the ranking statistic
    extra_kwargs = {}
    for inputstr in args.statistic_keywords:
        try:
            key, value = inputstr.split(':')
            extra_kwargs[key] = value
        except ValueError:
            err_txt = "--statistic-keywords must take input in the " \
                      "form KWARG1:VALUE1 KWARG2:VALUE2 KWARG3:VALUE3 ... " \
                      "Received {}".format(args.statistic_keywords)
            raise ValueError(err_txt)

    sds = rank_method.single(trigs)
    stat_t = rank_method.rank_stat_single((ifo, sds),
                                          **extra_kwargs)
    trigger_times = sds['end_time']
    if args.cluster_window:
        logging.info("Clustering")
        cid = coinc.cluster_over_time(stat_t, trigger_times,
                                      args.cluster_window)
        stat_t = stat_t[cid]
        tids_full = tids_full[cid]
        trigger_times = trigger_times[cid]

    trigger_ids_all += list(tids_full)
    template_ids_all += list(tnum * np.ones_like(tids_full))
    trigger_times_all += list(trigger_times)
    stat_all += list(stat_t)

data = {"stat": stat_all,
        "decimation_factor": np.ones_like(stat_all),
        "timeslide_id": np.zeros_like(stat_all),
        "template_id": template_ids_all,
        "%s/time" % ifo : trigger_times_all,
        "%s/trigger_id" % ifo: trigger_ids_all}

logging.info("saving triggers")
f = io.HFile(args.output_file, 'w')
for key in data:
    f.create_dataset(key, data=data[key],
                     compression="gzip",
                     compression_opts=9,
                     shuffle=True)
# Store segments
f['segments/%s/start' % ifo], f['segments/%s/end' % ifo] = trigs.valid
fg_segs = veto.start_end_to_segments(*trigs.valid)
fg_time = abs(fg_segs)
f.attrs['foreground_time'] = fg_time
f.attrs['background_time'] = fg_time
f.attrs['num_of_ifos'] = 1
f.attrs['pivot'] = ifo
f.attrs['fixed'] = ifo
f.attrs['ifos'] = ifo
f.attrs['timeslide_interval'] = 0

# Do hierarchical removal
# h_iterations = 0
# if args.max_hierarchical_removal != 0:

logging.info("Done")
