#!/usr/bin/python

# Copyright 2020 Gareth S. Davies
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.

import h5py, numpy as np, argparse
import logging
import pycbc

parser = argparse.ArgumentParser(usage="",
    description="Combine fitting parameters from several different files")
parser.add_argument("--verbose", action="store_true",
                    help="Print extra debugging information", default=False)
parser.add_argument("--trfits-files", nargs="+", required=True,
                    help="Files containing daily trigger fits")
parser.add_argument("--conservative-percentile", type=int, default=95,
                    help="What percentile to use for the conservative "
                         "combined fit. Integer in range 50-99. Default=95")
parser.add_argument("--output", required=True,
                    help="Output file for combined fit parameters")
parser.add_argument("--ifos", required=True, nargs="+",
                    help="list of ifos fo collect info for")

args=parser.parse_args()

pycbc.init_logging(args.verbose)

# Assert some sensible limits on the arguments

if args.conservative_percentile < 50 or \
    args.conservative_percentile > 99:
    parser.error("--conservative-percentile must be between 50 and 99, "
                 "otherwise it is either not a percentile, or not "
                 "conservative.")

counts_all = {ifo:[] for ifo in args.ifos}
alphas_all = {ifo:[] for ifo in args.ifos}
analysis_dates = []

with h5py.File(args.trfits_files[0], 'r') as fit_f0:
    # Store some attributes so we can check that all files are
    # comparable

    # Keep the upper and lower bins
    bl = fit_f0['bins_lower'][:]
    bu = fit_f0['bins_upper'][:]

    sngl_rank = fit_f0.attrs['sngl_ranking']
    fit_thresh = fit_f0.attrs['fit_threshold']
    fit_func = fit_f0.attrs['fit_function']

live_times = {ifo : [] for ifo in args.ifos}

trigger_file_times = []

n_files = len(args.trfits_files)
logging.info("Checking through %d files", n_files)

for f in args.trfits_files:

    fits_f = h5py.File(f, 'r')
    # Check that the file uses the same setup as file 0, to make sure
    # all coefficients are comparable

    assert fits_f.attrs['sngl_ranking'] == sngl_rank
    assert fits_f.attrs['fit_threshold'] == fit_thresh
    assert fits_f.attrs['fit_function'] == fit_func
    assert all(fits_f['bins_lower'][:] == bl)
    assert all(fits_f['bins_upper'][:] == bu)

    # Get the time of the last trigger in the trigger_fits file
    gps_last = 0
    for ifo in args.ifos:
        if ifo not in fits_f:
            continue
        else:
            trig_times = fits_f[ifo]['triggers']['end_time'][:]
            gps_last = max(gps_last,
                           trig_times.max())
    trigger_file_times.append(gps_last)

    for ifo in args.ifos:
        if ifo not in fits_f:
            live_times[ifo].append(0)
            counts_all[ifo].append(-1 * np.ones_like(bl))
            alphas_all[ifo].append(-1 * np.ones_like(bl))
        else:
            live_times[ifo].append(fits_f[ifo].attrs['live_time'])
            counts_all[ifo].append(fits_f[ifo + '/counts'][:])
            alphas_all[ifo].append(fits_f[ifo + '/fit_coeff'][:])
            if any(np.isnan(fits_f[ifo + '/fit_coeff'][:])):
                logging.info("nan in " + f + ", " + ifo)
                logging.info(fits_f[ifo + '/fit_coeff'][:])
    fits_f.close()

# Set up the date array, this is stored as an offset from the start date of
# the combination to the end of the trigger_fits file.
# This allows for missing or empty days

trigger_file_times = np.array(trigger_file_times)
ad_order = np.argsort(np.array(trigger_file_times))
start_date_n = trigger_file_times[ad_order[0]]
ad = trigger_file_times[ad_order] - start_date_n

# Get the counts and alphas sorted by bin rather than by date
counts_bin = {ifo: [c for c in zip(*counts_all[ifo])] for ifo in args.ifos}
alphas_bin = {ifo: [a for a in zip(*alphas_all[ifo])] for ifo in args.ifos}

alphas_out = {ifo : np.zeros(len(alphas_bin[ifo])) for ifo in args.ifos}
counts_out = {ifo : np.inf * np.ones(len(counts_bin[ifo])) for ifo in args.ifos}
cons_alphas_out = {ifo : np.zeros(len(alphas_bin[ifo])) for ifo in args.ifos}
cons_counts_out = {ifo : np.inf * np.ones(len(alphas_bin[ifo])) for ifo in args.ifos}

logging.info("Writing results")
fout = h5py.File(args.output, 'w')
fout.attrs['fit_threshold'] = fit_thresh
fout.attrs['conservative_percentile'] = args.conservative_percentile
fout.attrs['ifos'] = args.ifos
fout['bins_edges'] = list(bl) + [bu[-1]]
fout['fits_dates'] = ad + start_date_n

for ifo in args.ifos:
    fout.create_group(ifo)
    fout[ifo].attrs['live_time'] = sum(live_times[ifo])

save_allmeanalpha = {}
for ifo in args.ifos:
    fout_ifo = fout[ifo]
    logging.info(ifo)
    l_times = np.array(live_times[ifo])
    count_all = np.sum(counts_bin[ifo], axis=0) / l_times
    invalphan = np.array(counts_bin[ifo]) / np.array(alphas_bin[ifo])
    invalphan_all = np.mean(invalphan, axis=0)
    alpha_all = np.mean(counts_bin[ifo], axis=0) / invalphan_all
    meant = l_times.mean()

    fout_ifo[f'separate_fits/live_times'] = l_times
    fout_ifo[f'separate_fits/date'] = ad + start_date_n
    for counter, a_c_u_l in enumerate(zip(alphas_bin[ifo],
                                          counts_bin[ifo], bu, bl)):
        a, c, u, l = a_c_u_l
        a = np.array(a)
        c = np.array(c)
        invalphan = c / a
        mean_alpha = c.mean() / invalphan.mean()
        cons_alpha = np.percentile(a, 100 - args.conservative_percentile)
        cons_alphas_out[ifo][counter] = cons_alpha
        alphas_out[ifo][counter] = mean_alpha
        cons_count = np.percentile(c, args.conservative_percentile)
        cons_counts_out[ifo][counter] = cons_count * len(c)
        counts_out[ifo][counter] = c.sum()

        fout_ifo[f'separate_fits/bin_{counter:d}/fit_coeff'] = a
        fout_ifo[f'separate_fits/bin_{counter:d}/counts'] = c

    # Output the mean average values
    fout_ifo['mean/fit_coeff'] = alphas_out[ifo]
    fout_ifo['mean/counts'] = counts_out[ifo]

    # Output the conservative values
    fout_ifo['conservative/fit_coeff'] = cons_alphas_out[ifo]
    fout_ifo['conservative/counts'] = cons_counts_out[ifo]

    # Take some averages for plotting and summary values
    overall_invalphan = counts_out[ifo] / alphas_out[ifo]
    overall_meanalpha = counts_out[ifo].mean() / overall_invalphan.mean()
    sum_counts_out = counts_out[ifo].sum() / sum(live_times[ifo])
    save_allmeanalpha[ifo] = overall_meanalpha

    # For the fixed version, we just set this to 1
    fout_ifo['fixed/counts'] = [1 for c in counts_out[ifo]]
    fout_ifo['fixed/fit_coeff'] = [0 for a in alphas_out[ifo]]

    # Add some useful info to the output file
    fout_ifo.attrs['mean_alpha'] = save_allmeanalpha[ifo]
    fout_ifo.attrs['total_counts'] = counts_out[ifo].sum()

fout.close()
