#!/usr/bin/env python
""" Bin triggers by their dq value and calculate trigger rates in each bin
"""
import logging
import argparse
import pycbc
import pycbc.events
from pycbc.events import stat as pystat
from pycbc.types.timeseries import load_timeseries
import numpy as np
import h5py as h5
from pycbc.version import git_verbose_msg as version

parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('--version', action='version', version=version)
parser.add_argument('--verbose', action="store_true")
parser.add_argument("--ifo", type=str,required=True)
parser.add_argument("--trig-file", required=True)
parser.add_argument("--stat-threshold", type=float,
                    help="Only consider triggers with statistic value "
                    "above this threshold")
parser.add_argument("--dq-file", required=True,nargs='+')
parser.add_argument("--dq-channel", required=False,type=str,
                    help='name of channel to read in from '
                         'provided dq file. Required if '
                         'provided file is .hdf5 format')
parser.add_argument('--bank-file', help='hdf format template bank file',
                    required=True)
parser.add_argument('--background-bins', nargs='+',
                    help='list of background bin format strings')
parser.add_argument("--output-file", required=True)
parser.add_argument("--prune-number", type=int, default=0,
                    help="Number of loudest events to remove from each split "
                         "histogram, default 0")
parser.add_argument("--prune-window", type=float, default=0.1,
                    help="Time (s) to remove all triggers around a trigger "
                         "which is loudest in each split, default 0.1s")

pystat.insert_statistic_option_group(parser,
    default_ranking_statistic='single_ranking_only')
args = parser.parse_args()
pycbc.init_logging(args.verbose)

with h5.File(args.bank_file, 'r') as bank:
    if args.background_bins:
        logging.info('Sorting bank into bins...')
        data = {'mass1': bank['mass1'][:], 'mass2': bank['mass2'][:],
                'spin1z': bank['spin1z'][:], 'spin2z': bank['spin2z'][:],
                'f_lower': bank['f_lower']}
        locs_dict = pycbc.events.background_bin_from_string(
                                                   args.background_bins, data)
        del data
        locs_names = [b.split(':')[0] for b in args.background_bins]
    else:
        locs_dict = {'all_bin': np.arange(0, len(bank['mass1'][:]), 1)}
        locs_names = ['all_bin']

logging.info('Reading trigger file...')
ifo = args.ifo
with h5.File(args.trig_file,'r') as trig_file:
    trig_times = trig_file[ifo+'/end_time'][:]
    trig_ids = trig_file[ifo+'/template_id'][:]

    if args.stat_threshold or args.prune_number>0:
        logging.info('Calculating stat and filtering...')
        rank_method = pystat.get_statistic_from_opts(args, [ifo])
        stat = rank_method.get_sngl_ranking(trig_file[ifo])
        if args.stat_threshold:
            abovethresh = stat >= args.stat_threshold
            trig_ids = trig_ids[abovethresh]
            trig_times = trig_times[abovethresh]
            stat = stat[abovethresh]
            del abovethresh
        if args.prune_number<1:
            del stat

trig_times_int = trig_times.astype('int')

dq_times = np.array([])
dq_logl = np.array([])

for filename in args.dq_file:
    logging.info('Reading DQ file %s...', filename)
    dq_data = load_timeseries(filename, group=args.dq_channel)
    dq_logl = np.concatenate((dq_logl,dq_data[:]))
    dq_times = np.concatenate((dq_times,dq_data.sample_times))
    del dq_data

# todo: make this configurable
percent_bin = 0.5
n_bins = int(100./percent_bin)
percentiles = np.linspace(0,100,n_bins+1)
bin_times = np.zeros(n_bins)
dq_percentiles = np.percentile(dq_logl,percentiles)[1:]


# seconds bin tells what bin each second ends up
seconds_bin = np.array([n_bins - \
                        len(dq_percentiles[dq_percentiles >= dq_ll]) \
                        for dq_ll in dq_logl]).astype('int')
del dq_percentiles

# bin times tells how much time ends up in each bin
bin_times = np.array([len(seconds_bin[seconds_bin==(i)]) \
                     for i in range(n_bins)]).astype('float')
full_time = float(len(seconds_bin))
times_nz = (bin_times > 0)
del dq_logl

# create a dict to look up dq percentile at any time
dq_percentiles_time = dict(zip(dq_times, seconds_bin*percent_bin/100))
del dq_times

if args.prune_number>0:
    for bin_name in locs_names:
        logging.info('Processing bin %s...', bin_name)
        bin_locs = locs_dict[bin_name]
        trig_times_bin = trig_times[np.isin(trig_ids, bin_locs)]
        trig_stats_bin = stat[np.isin(trig_ids, bin_locs)]

        for j in range(args.prune_number):
            max_stat_arg = np.argmax(trig_stats_bin)
            remove = np.nonzero(abs(trig_times_bin[max_stat_arg] - trig_times)
                               < args.prune_window)[0]
            remove_inbin = np.nonzero(abs(trig_times_bin[max_stat_arg] \
                               - trig_times_bin) < args.prune_window)[0]
            stat[remove] = 0
            trig_stats_bin[remove_inbin] = 0
    keep = np.nonzero(stat)[0]
    trig_times_int = trig_times_int[keep]
    trig_ids = trig_ids[keep]
    del stat
    del keep

del trig_times

with h5.File(args.output_file,'w') as f:
    for bin_name in locs_names:
        bin_locs = locs_dict[bin_name]
        trig_times_bin = trig_times_int[np.isin(trig_ids, bin_locs)]
        trig_percentile = np.array([dq_percentiles_time[t] \
                                    for t in trig_times_bin])
        logging.info('Processing %d triggers...', len(trig_percentile))
    
        (counts, bins) = np.histogram(trig_percentile, bins = (percentiles)/100)
        counts = counts.astype('float')
        rates = np.zeros(len(bin_times))
        rates[times_nz] = counts[times_nz]/bin_times[times_nz]
        mean_rate = len(trig_percentile) / full_time
        if mean_rate > 0.:
            rates = rates / mean_rate
    
        logging.info('Writing rates to output file %s...', args.output_file)
        grp = f.create_group(bin_name)
        grp['rates']=rates
        grp['locs']=locs_dict[bin_name]
    
    f.attrs['names'] = locs_names

logging.info('Done!')
