#!/usr/bin/python

# Copyright 2020 Gareth S. Davies
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.

import h5py, numpy as np, argparse
import matplotlib
matplotlib.use('agg')
from matplotlib import pyplot as plt
import re, datetime, logging
import pycbc

parser = argparse.ArgumentParser(usage="",
    description="Plot histograms of triggers split over various parameters")
parser.add_argument("--verbose", action="store_true",
                    help="Print extra debugging information", default=False)
parser.add_argument("--trfits-files", nargs="+", required=True,
                    help="Files containing daily trigger fits")
parser.add_argument("--output-file", required=True,
                    help="Output hdf file for final results")
parser.add_argument("--output-plot-file-format",
                    help="Output file format for alphas vs time plots, needs to "
                         "contain '{}' as marker for where ifo string is included")
parser.add_argument("--ifos", required=True, nargs="+",
                    help="list of ifos fo collect info for")

args=parser.parse_args()

pycbc.init_logging(args.verbose)

colours = ['r', (1.0, 0.6, 0), 'y', 'g', 'c', 'b', 'm', 'k',
           (0.8, 0.25, 0), (0.25, 0.8, 0)]

counts_all = {ifo:[] for ifo in args.ifos}
alphas_all = {ifo:[] for ifo in args.ifos}
analysis_dates = []

files = args.trfits_files

fit_f0 = h5py.File(files[0], 'r')

bl = fit_f0['bins_lower'][:]
bu = fit_f0['bins_upper'][:]

live_times = {ifo : [] for ifo in args.ifos}

for f in files:
    fits_f = h5py.File(f, 'r')
    if 'days_since_epoch' in fits_f.attrs:
        analysis_dates += [fits_f.attrs['days_since_epoch']]
    elif 'analysis_date' in fits_f.attrs:
        analysis_year = int(fits_f.attrs['analysis_date'][:4])
        analysis_month = int(fits_f.attrs['analysis_date'][5:7])
        analysis_day = int(fits_f.attrs['analysis_date'][8:10])
        n_since_epoch = (datetime.date(analysis_year,
                                       analysis_month,
                                       analysis_day)
                         - datetime.date(2000, 1, 1)).days
        analysis_dates += [n_since_epoch]
    else:
        # This is the regex string to match a date in format YYYY_MM_DD
        # FutureWarning: this will become obsolete in the year 3000
        re_string = '([12]\d{3}_(0[1-9]|1[0-2])_(0[1-9]|[12]\d|3[01]))'
        m = re.search(re_string, f)
        analysis_year = int(m.group(0)[:4])
        analysis_month = int(m.group(0)[5:7])
        analysis_day = int(m.group(0)[8:10])
        n_since_epoch = (datetime.date(analysis_year, analysis_month,
                                       analysis_day)
                         - datetime.date(2000,1,1)).days
        analysis_dates += [n_since_epoch]
    for ifo in args.ifos:
        if ifo not in fits_f:
            counts_all[ifo] += [-1 * np.ones_like(counts_all[ifo][-1])]
            alphas_all[ifo] += [-1 * np.ones_like(alphas_all[ifo][-1])]
            logging.info(f + " has no " + ifo + " triggers")
        else:
            live_times[ifo] += [fits_f[ifo].attrs['live_time']]
            counts_all[ifo] += [fits_f[ifo + '/counts'][:]]
            alphas_all[ifo] += [fits_f[ifo + '/fit_coeff'][:]]
            if any(np.isnan(fits_f[ifo + '/fit_coeff'][:])):
                logging.info("nan in " + f + ", " + ifo)
                logging.info(fits_f[ifo + '/fit_coeff'][:])
    fits_f.close()

ad_order = np.argsort(np.array(analysis_dates))
ad = np.array(analysis_dates)[ad_order] - np.array(analysis_dates)[ad_order[0]]
start_date_n = analysis_dates[ad_order[0]]
start_date_dt = datetime.date(2000, 1, 1) + datetime.timedelta(days=start_date_n)

start_date = "{:04d}_{:02d}_{:02d}".format(start_date_dt.year,
                                           start_date_dt.month,
                                           start_date_dt.day)

counts_bin = {ifo: [c for c in zip(*counts_all[ifo])] for ifo in args.ifos}
alphas_bin = {ifo: [a for a in zip(*alphas_all[ifo])] for ifo in args.ifos}

alphas_out = {ifo : np.zeros(len(alphas_bin[ifo])) for ifo in args.ifos}
counts_out = {ifo : np.zeros(len(counts_bin[ifo])) for ifo in args.ifos}
q05_alphas_out = {ifo : np.zeros(len(alphas_bin[ifo])) for ifo in args.ifos}
q95_counts_out = {ifo : np.zeros(len(alphas_bin[ifo])) for ifo in args.ifos}

fout = h5py.File(args.output_file, 'w')
fout.attrs['start_date'] = ad[0]
fout.attrs['end_date'] = ad[-1]
fout.attrs['fit_threshold'] = fit_f0.attrs['fit_threshold']
fout['bins_edges'] = list(bl) + [bu[-1]]

for ifo in args.ifos:
    fout.create_group(ifo)
    fout[ifo].attrs['live_time'] = sum(live_times[ifo])


save_allmeanalpha = {}
for ifo in args.ifos:
    logging.info(ifo)
    fig_alpha = plt.figure(figsize=(12, 7.5))
    ax_alpha = fig_alpha.add_subplot(111)
    fig_count = plt.figure(figsize=(12, 7.5))
    ax_count = fig_count.add_subplot(111)
    counter = 0
    alpha_lines = []
    count_lines = []
    count_all = np.sum(counts_bin[ifo], axis=0) / np.array(live_times[ifo])
    invalphan = np.array(counts_bin[ifo]) / np.array(alphas_bin[ifo])
    invalphan_all = np.mean(invalphan, axis=0)
    alpha_all = np.mean(counts_bin[ifo], axis=0) / invalphan_all
    for a, c, u, l, t in zip(alphas_bin[ifo], counts_bin[ifo], bu, bl,
                             live_times[ifo]):
        a = np.array(a)
        c = np.array(c)
        valid_alpha = np.nonzero(a > 0)
        a = a[valid_alpha]
        c = c[valid_alpha]
        invalphan = c / a
        mean_alpha = c.mean() / invalphan.mean()
        q05_alpha = np.quantile(a, 0.05)
        q05_alphas_out[ifo][counter] = q05_alpha
        alphas_out[ifo][counter] = mean_alpha
        q95_count = np.quantile(c, 0.95)
        q95_counts_out[ifo][counter] = q95_count * len(c)
        counts_out[ifo][counter] = c.sum()
        if args.output_plot_file_format:
            alpha_lines += ax_alpha.plot(ad[valid_alpha], a, c=colours[counter],
                                         label="duration %.2f-%.2f" % (l, u))
            alpha_lines += ax_alpha.plot([ad[0], ad[-1]],
                                         [mean_alpha, mean_alpha],
                                         c=colours[counter], linestyle='--',
                                         label="total fit = %.2f" % mean_alpha)
            alpha_lines += ax_alpha.plot([ad[0], ad[-1]], [q05_alpha, q05_alpha],
                                         c=colours[counter], linestyle=':',
                                         label="95th %%ile = %.2f" % q05_alpha)
            count_lines += ax_count.plot(ad[valid_alpha], c / t,
                                         c=colours[counter],
                                         label="duration %.2f-%.2f" % (l, u))
            count_lines += ax_count.plot([ad[0], ad[-1]],
                                         [c.mean() / t, c.mean() / t],
                                         c=colours[counter], linestyle='--',
                                         label="mean = %.3f" % (c.mean() / t))
            count_lines += ax_count.plot([ad[0], ad[-1]],
                                         [q95_count / t, q95_count / t],
                                         c=colours[counter], linestyle=':',
                                         label="95th %%ile = %.3f" % (q95_count / t))
        fout[ifo + '/daily_fits/bin_%d/fit_coeff' % counter] = a
        fout[ifo + '/daily_fits/bin_%d/counts' % counter] = c
        fout[ifo + '/daily_fits/bin_%d/date' % counter] = ad[valid_alpha]
        counter += 1
    overall_invalphan = counts_out[ifo] / alphas_out[ifo]
    overall_meanalpha = counts_out[ifo].mean() / overall_invalphan.mean()
    alpha_lines += ax_alpha.plot(ad, alpha_all,
                                 c='k', linestyle='-', linewidth=2,
                                 label="daily overall alpha")
    alpha_lines += ax_alpha.plot([ad[0], ad[-1]],
                                 [overall_meanalpha, overall_meanalpha],
                                 c='k', linestyle='--', linewidth=2,
                                 label="overall alpha = %.2f" % overall_meanalpha)
    sum_counts_out = counts_out[ifo].sum() / sum(live_times[ifo])
    count_lines += ax_count.plot(ad, count_all,
                                 c='k', linestyle='-', linewidth=2,
                                 label="overall count")
    count_lines += ax_count.plot([ad[0], ad[-1]],
                                 [sum_counts_out, sum_counts_out],
                                 c='k', linestyle='--', linewidth=2,
                                 label="overall count per live time = %.3f" % sum_counts_out)
    save_allmeanalpha[ifo] = overall_meanalpha
    ax_alpha.set_xlabel('Days since ' + start_date)
    ax_alpha.set_ylabel('Fit coefficient')
    alpha_labels = [l.get_label() for l in alpha_lines]
    ax_alpha.legend(alpha_lines, alpha_labels, loc='lower center',
                    ncol=3, bbox_to_anchor=(0.5, 1.01))
    fig_alpha.tight_layout()
    fig_alpha.savefig(args.output_plot_file_format.format(ifo + "-alphas"))
    ax_count.set_xlabel('Days since ' + start_date)
    ax_count.set_ylabel('Counts per live time')
    count_labels = [l.get_label() for l in count_lines]
    ax_count.legend(count_lines, count_labels, loc='lower center',
                    ncol=3, bbox_to_anchor=(0.5, 1.01))
    fig_count.tight_layout()
    fig_count.savefig(args.output_plot_file_format.format(ifo + "-counts"))

    fout[ifo + '/mean/fit_coeff'] = alphas_out[ifo]
    fout[ifo + '/conservative/fit_coeff'] = q05_alphas_out[ifo]
    fout[ifo + '/fixed/fit_coeff'] = [0 for a in alphas_out[ifo]]
    fout[ifo + '/mean/counts'] = counts_out[ifo]
    fout[ifo + '/conservative/counts'] = q95_counts_out[ifo]
    fout[ifo + '/fixed/counts'] = [1 for c in counts_out[ifo]]
    fout[ifo].attrs['mean_alpha'] = save_allmeanalpha[ifo]
    fout[ifo].attrs['total_counts'] = counts_out[ifo].sum()

fout.close()
