#!/usr/bin/env python
import h5py, argparse, logging, numpy, numpy.random
from ligo.segments import infinity
from pycbc.events import veto, coinc, stat
import pycbc.conversions as conv
import pycbc.version
from numpy.random import seed, shuffle
from pycbc.io.hdf import ReadByTemplate

parser = argparse.ArgumentParser()
parser.add_argument("--verbose", action="count")
parser.add_argument("--version", action="version", version=pycbc.version.git_verbose_msg)
parser.add_argument("--veto-files", nargs='*', action='append', default=[],
                    help="Optional veto file. Triggers within veto segments "
                         "contained in the file are ignored")
parser.add_argument("--segment-name", nargs='*', action='append', default=[],
                    help="Optional, name of veto segment in veto file")
parser.add_argument("--trigger-file",type=str,
                    help="File containing single-detector triggers")
parser.add_argument("--template-bank", required=True,
                    help="Template bank file in HDF format")
# produces a list of lists to allow multiple invocations and multiple args
parser.add_argument("--statistic-files", nargs='*', action='append', default=[],
                    help="Files containing ranking statistic info")
parser.add_argument("--ranking-statistic", choices=stat.statistic_dict.keys(),
                    default='newsnr',
                    help="The single-detector ranking statistic to calculate")
parser.add_argument("--statistic-keywords", nargs='*',
                    default=[],
                    help="Provide additional key-word arguments to be sent to "
                         "the statistic class when it is initialized. Should "
                         "be given in format --statistic-keywords "
                         "KWARG1:VALUE1 KWARG2:VALUE2 KWARG3:VALUE3 ...")
parser.add_argument("--template-fraction-range", default="0/1",
                    help="Optional, analyze only part of template bank. Format"
                         " PART/NUM_PARTS")
parser.add_argument("--randomize-template-order", action="store_true",
                    help="Random shuffle templates with fixed seed "
                         "before selecting range to analyze")
parser.add_argument("--cluster-window", type=float,
                    help="Optional, window size in seconds to cluster "
                         "over the bank")
parser.add_argument("--output-file",
                    help="File to store the candidate triggers")
parser.add_argument("--batch-singles", default=5000, type=int,
                    help="Number of single triggers to process at once")
args = parser.parse_args()

args.statistic_files = sum(args.statistic_files, [])
args.segment_name = sum(args.segment_name, [])
args.veto_files = sum(args.veto_files, [])

if args.verbose:
    logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG)


def parse_template_range(num_templates, rangestr):
    part = int(rangestr.split('/')[0])
    pieces = int(rangestr.split('/')[1])
    tmin = int(num_templates / float(pieces) * part)
    tmax = int(num_templates / float(pieces) * (part+1))
    return tmin, tmax

logging.info('Starting...')

num_templates = len(h5py.File(args.template_bank, "r")['template_hash'])
tmin, tmax = parse_template_range(num_templates, args.template_fraction_range)
logging.info('Analyzing template %s - %s' % (tmin, tmax-1))

class Trigs(object):
    """store trigger info in parallel with ifo name and shift vector"""
    def __init__(self):
        self.singles = []

trigs = Trigs()
logging.info('Opening trigger file : %s' % args.trigger_file)
reader = ReadByTemplate(args.trigger_file,
                        args.template_bank,
                        args.segment_name,
                        args.veto_files)
ifo = reader.ifo
trigs.singles.append(reader)

fg_segs = reader.segs
valid = veto.segments_to_start_end(reader.segs)

# Stat class instance to calculate the ranking statistic
extra_kwargs = {}
for inputstr in args.statistic_keywords:
    try:
        key, value = inputstr.split(':')
        extra_kwargs[key] = value
    except ValueError:
        err_txt = "--statistic-keywords must take input in the " \
                  "form KWARG1:VALUE1 KWARG2:VALUE2 KWARG3:VALUE3 ... " \
                  "Received {}".format(args.statistic_keywords)
        raise ValueError(err_txt)

rank_method = stat.get_statistic(args.ranking_statistic)(args.statistic_files,
                                                         ifos=[ifo],
                                                         **extra_kwargs)

if args.randomize_template_order:
    seed(0)
    template_ids = numpy.arange(0, num_templates)
    shuffle(template_ids)
    template_ids = template_ids[tmin:tmax]
else:
    template_ids = range(tmin, tmax)

# 'data' will store candidate triggers
# in addition to these lists of info, will also store trigger times and
# ids in ifo-specific datasets
data = {'stat': [], 'decimation_factor': [], 'timeslide_id': [],
        'template_id': [], '%s/time' % ifo : [], '%s/trigger_id' % ifo: []}

tid_counter = 0
for tnum in template_ids:
    if not tid_counter % 500:
        logging.info('Obtaining trigs for template %i out of %i',
                     tid_counter, len(template_ids))
    tid_counter += 1
    sngl = reader
    tids_full = sngl.set_template(tnum)
    times_full = sngl['end_time']
    sds_full = rank_method.single(sngl)

    if len(tids_full) == 0:
        logging.info('No triggers in template %i, skipping', tnum)
        continue

    # list in ifo order of remaining trigger data
    single_info = (ifo, sds_full)
    data['stat'] += [rank_method.single_multiifo(single_info)]
    # All triggers are foreground and therefore not decimated
    data['decimation_factor'] += [numpy.repeat(1, len(sds_full))]
    # No timeslides performed - all zeros
    data['timeslide_id'] += [numpy.repeat(0, len(sds_full))]
    data['template_id'] += [numpy.repeat(tnum, len(sds_full))]
    data['%s/time' % ifo] += [times_full]
    data['%s/trigger_id' % ifo] += [tids_full]

if len(data['stat']) > 0:
    for key in data:
        data[key] = numpy.concatenate(data[key])
else:
    raise RuntimeError("No triggers in any of the templates - expand template "
                       "range or run for longer")
    

if args.cluster_window and len(data['stat']) > 0:
    cid = coinc.cluster_over_time(data['stat'], data['%s/time' % ifo],
                                  args.cluster_window)

logging.info('saving clustered triggers')
f = h5py.File(args.output_file, 'w')
if len(data['stat']) > 0:
    for key in data:
        var = data[key][cid] if args.cluster_window else data[key]
        f.create_dataset('foreground/' + key, data=var,
                         compression='gzip',
                         compression_opts=9,
                         shuffle=True)

# Store segments
f['segments/%s/start' % ifo], f['segments/%s/end' % ifo] = reader.valid

f.attrs['foreground_time'] = abs(fg_segs)
f.attrs['num_of_ifos'] = 1
f.attrs['pivot'] = ifo
f.attrs['fixed'] = ifo
f.attrs['ifos'] = ifo

fore_stat = data['stat'][cid]

logging.info('Assigning IFARs according to X model')
# Currently just using a "Count N louder" model as it's easy to code - 
# which is a bit crap and doesnt give far below 1 per live time
_, fnlouder = coinc.calculate_n_louder(fore_stat, fore_stat,
                                       numpy.ones_like(fore_stat))
fg_time = abs(fg_segs)
ifar = fg_time / (fnlouder + 1)
fap = 1 - numpy.exp(-fg_time / ifar)

f['foreground/ifar'] = conv.sec_to_year(ifar)
f['foreground/fap'] = fap

#Do hierarchical removal
#h_iterations = 0
#if args.max_hierarchical_removal != 0:

logging.info('Done')
