#!/usr/bin/env python
""" Merge and rerank clustered DQ data
"""
import logging, argparse, numpy, h5py
import pycbc
from pycbc.types.timeseries import load_timeseries
from os import path
from pycbc.version import git_verbose_msg as version

parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('--version', action='version', version=version)
parser.add_argument('--verbose', action="store_true",default=False)
parser.add_argument('--no-low-dq-cutoff', action="store_true",
                    help="Allow dq values less than 0. "
                         "Otherwise values less than 0 are set to 0")
parser.add_argument("--ifo", type=str,required=True,
                     help="interferometer to analyze")
parser.add_argument("--max-stat", type=float,default=8,
                     help="maximum allowed dq value. "
                           "Only used if --rate-file is not used. "
                            "[default=8]")
parser.add_argument("--input-file", required=True,nargs='+',
                     help="List of files containing the dq timeseries")
parser.add_argument("--input-channel", required=False,type=str,
                    help='name of channel to read in from '
                         'provided input file. Required if '
                         'provided file is .hdf5 format')
rate_group = parser.add_mutually_exclusive_group()
rate_group.add_argument("--rate-file", type = str, default = None,
                        help="File containing the mapping from each "
                             "dq bin to the trigger rate")
rate_group.add_argument("--bank-file", type = str, default = None,
                        help="File containing the template bank")
parser.add_argument("--dq-type", type = str, default = 'dq',
                     help="Name of dq type to be included in output file "
                          "metadata [default='dq']")
parser.add_argument("--output-file", required=True,
                     help="name of output file")

args = parser.parse_args()
pycbc.init_logging(args.verbose)

ifo = args.ifo

log_like = numpy.array([])
times = numpy.array([])
for filename in args.input_file:
    logging.info('Reading file %s...', filename)
    dq_data = load_timeseries(filename, group=args.input_channel)
    log_like = numpy.concatenate((log_like,dq_data[:]))
    times = numpy.concatenate((times,dq_data.sample_times))
    del dq_data

log_like_ranking = numpy.argsort(log_like)
len_log_like = len(log_like)

times = times[log_like_ranking]
log_like_vals = numpy.array(log_like[log_like_ranking])

log_like_dict = {}
percent_dict = {}

if args.rate_file is not None:
    rates = {}
    locs_dict = {}
    with h5py.File(args.rate_file,'r') as f:
        bin_names = f.attrs['names'][:]
        for bin_name in bin_names:
            logging.info('Processing bin %s...', bin_name)
            rates = f['%s/rates'%bin_name][:]
            locs_dict[bin_name] = f['%s/locs'%bin_name][:]
            log_like_dict[bin_name] = numpy.zeros(len(log_like))
    
            if not args.no_low_dq_cutoff:
                rates = numpy.array([max(r,1.) for r in rates])
    
            log_rates = numpy.log(rates)
            percent_dict[bin_name] = log_rates
    
            n_bins = len(log_rates)
            percentiles = numpy.linspace(0,100,n_bins+1)
            dq_percentiles = numpy.percentile(log_like_vals,percentiles)[1:]
    
            rates_bin = numpy.array([n_bins - \
                              len(dq_percentiles[dq_percentiles >= dq_ll]) \
                              for dq_ll in log_like_vals])
            log_like_dict[bin_name] = log_rates[rates_bin.astype('int')]

else:
    # Use analytic model from Godwin et al. arXiv:2010.15282 
    bin_names = ['all_bin']
    with h5py.File(args.bank_file, 'r') as bank:
        locs_dict = {'all_bin': numpy.arange(0, len(bank['mass1'][:]), 1)}

    log_like = numpy.arange(len(log_like))*100./len(log_like)

    p_min = int(len_log_like/2.)
    p_max = int(len_log_like/(1.+numpy.exp(-args.max_stat)))

    log_like[p_min:p_max] = numpy.log(log_like[p_min:p_max] \
                                      / (100.-log_like[p_min:p_max]))
    log_like[p_max:] = args.max_stat
    log_like[:p_min] = 0
    log_like_dict['all_bin'] = log_like


with h5py.File(args.output_file, 'w') as f:
    f.attrs['names'] = bin_names
    for bin_name in bin_names:
        f[ifo + '/dq_vals/' + bin_name] = numpy.array(log_like_dict[bin_name],
                                                      dtype=numpy.float32)
        f[ifo + '/dq_percentiles/' + bin_name] = numpy.array(percent_dict[bin_name],
                                                      dtype=numpy.float32)
        f[ifo + '/locs/' + bin_name] = numpy.array(locs_dict[bin_name],
                                                   dtype=numpy.float32)
    f[ifo + '/times'] = numpy.array(times, dtype=numpy.uint32)
    f[ifo + '/dq_input_vals'] = numpy.array(log_like_vals, dtype=numpy.float32)
    f.attrs.create('stat', data=ifo+'-'+args.dq_type+'-dq_ts_reference')

logging.info('Done!')

