#!/usr/bin/python

# Copyright 2020 Gareth S. Davies
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.

import h5py, numpy as np, argparse
import matplotlib
matplotlib.use('agg')
from matplotlib import pyplot as plt
import logging
from lal import gpstime
import pycbc



parser = argparse.ArgumentParser(usage="",
    description="Combine fitting parameters from several different files")
parser.add_argument("--verbose", action="store_true",
                    help="Print extra debugging information", default=False)
parser.add_argument("--combined-fits-file", required=True,
                    help="File containing information on the combined trigger "
                         "fits.")
default_plot_format = "{ifo}-TRIGGER-FITS-{type}.png"
parser.add_argument("--output-plot-name-format",
                    default=default_plot_format,
                    help="Format to save plots, must contain '{ifo}' and "
                         "'{type}' to indicate ifo and 'alphas' or 'counts' "
                         " in filename. Default: " +
                         default_plot_format)
parser.add_argument("--ifos", nargs="+",
                    help="List of ifos to plot. If not given, use all "
                         "in the combined file.")
parser.add_argument("--colormap", default="rainbow_r", choices=plt.colormaps(),
                    help="Colormap to use for choosing the colours of the "
                         "duration bin lines. Default=rainbow_r")
parser.add_argument("--log-colormap", action='store_true',
                    help="Use log spacing for choosing colormap values "
                         "based on duration bins.")
args=parser.parse_args()

if '{ifo}' not in args.output_plot_name_format or \
    '{type}' not in args.output_plot_name_format:
    parser.error("--output-plot-name-format must contain '{ifo}' and "
                 "'{type}' to indicate ifo and 'alphas' or 'counts' "
                 " in filename.")

pycbc.init_logging(args.verbose)

mean_count = {}
mean_alpha = {}
cons_count = {}
cons_alpha = {}
live_total = {}

separate_alphas = {}
separate_counts = {}
separate_dates = {}
separate_times = {}

logging.info("Loading Data")
with h5py.File(args.combined_fits_file, 'r') as cff:
    ifos = args.ifos or cff.attrs['ifos']
    bins_edges = cff['bins_edges'][:]
    conservative_percentile = cff.attrs['conservative_percentile']
    n_bins = len(bins_edges) - 1
    fits_dates = cff['fits_dates'][:]
    for ifo in ifos:
        logging.info(ifo)
        live_total[ifo] = cff[ifo].attrs['live_time']
        mean_count[ifo] = cff[ifo]['mean']['counts'][:]
        mean_alpha[ifo] = cff[ifo]['mean']['fit_coeff'][:]
        cons_count[ifo] = cff[ifo]['conservative']['counts'][:]
        cons_alpha[ifo] = cff[ifo]['conservative']['fit_coeff'][:]

        separate_dates[ifo] = cff[ifo]['separate_fits']['date'][:]
        separate_times[ifo] = cff[ifo]['separate_fits']['live_times'][:]

        separate_data = cff[ifo]['separate_fits']
        separate_alphas[ifo] = np.array([separate_data[f'bin_{i}']['fit_coeff'][:]
                                         for i in range(n_bins)])
        separate_counts[ifo] = np.array([separate_data[f'bin_{i}']['counts'][:]
                                         for i in range(n_bins)])

bin_starts = bins_edges[:-1]
bin_ends = bins_edges[1:]

bin_max = max(bin_ends)
bin_min = min(bin_starts)

def bin_proportion(upper, lower, log_spacing=False):
    if log_spacing:
        ll = np.log(lower)
        ul = np.log(lower)
        centl = (ll + ul) / 2.
        minl = np.log(bin_min)
        maxl = np.log(bin_max)
        return (centl - minl) / (maxl - minl)

    else:
        return ((lower + upper) / 2. - bin_min) / (bin_max - bin_min)


logging.info("Plotting fits information")
for ifo in ifos:
    logging.info(ifo)

    # Some things for the plots
    fig_alpha, ax_alpha = plt.subplots(1, figsize=(12, 7.5))
    fig_count, ax_count = plt.subplots(1, figsize=(12, 7.5))
    alpha_lines = []
    count_lines = []

    start_date_dt = gpstime.gps_to_utc(separate_dates[ifo][0])
    start_date = start_date_dt.strftime("%Y-%m-%d %H:%M:%S")

    for i, bl_bu in enumerate(zip(bin_starts, bin_ends)):
        bl, bu = bl_bu

        alphas = separate_alphas[ifo][i]
        counts = separate_counts[ifo][i]

        valid = np.logical_and(alphas > 0, np.isfinite(alphas))


        if not any(valid):
            continue
        a = alphas[valid]
        l_times = separate_times[ifo][valid]
        rate = counts[valid] / l_times
        ad = (separate_dates[ifo][valid] - separate_dates[ifo][0]) / 86400.

        ma = mean_alpha[ifo][i]
        ca = cons_alpha[ifo][i]
        mr = mean_count[ifo][i] / live_total[ifo]
        cr = cons_count[ifo][i] / live_total[ifo]

        bin_prop = bin_proportion(bu, bl,
                                  log_spacing=args.log_colormap)
        bin_colour = plt.get_cmap(args.colormap)(bin_prop)

        alpha_lines += ax_alpha.plot(ad, a, c=bin_colour,
                                     label="duration %.2f-%.2f" % (bl, bu))
        alpha_lines.append(ax_alpha.axhline(ma,
                                            label="total fit = %.2f" % ma,
                                             c=bin_colour, linestyle='--',))
        alpha_lab = f"{conservative_percentile:d}th %ile = {ca:.2f}"
        alpha_lines.append(ax_alpha.axhline(ca,
                                            c=bin_colour, linestyle=':',
                                            label=alpha_lab))

        count_lines += ax_count.plot(ad, rate,
                                     c=bin_colour,
                                     label="duration %.2f-%.2f" % (bl, bu))
        count_lines.append(ax_count.axhline(mr,
                                            c=bin_colour, linestyle='--',
                                            label=f"mean = {mr:.3f}"))
        count_lab = f"{conservative_percentile:d}th %ile = {cr:.3f}"
        count_lines.append(ax_count.axhline(cr,
                                            c=bin_colour, linestyle=':',
                                            label=count_lab))

    ax_alpha.set_xlabel('Days since ' + start_date)
    ax_alpha.set_ylabel('Fit coefficient')
    ax_alpha.set_xlim([ad[0], ad[-1]])
    alpha_labels = [l.get_label() for l in alpha_lines]
    ax_alpha.grid(zorder=-30)
    ax_alpha.legend(alpha_lines, alpha_labels, loc='lower center',
                    ncol=3, bbox_to_anchor=(0.5, 1.01))
    fig_alpha.tight_layout()
    fig_alpha.savefig(args.output_plot_name_format.format(ifo=ifo, type='fit_coeffs'))
    ax_count.set_xlabel('Days since ' + start_date)
    ax_count.set_ylabel('Counts per live time')
    ax_count.set_xlim([ad[0], ad[-1]])
    count_labels = [l.get_label() for l in count_lines]
    ax_count.legend(count_lines, count_labels, loc='lower center',
                    ncol=3, bbox_to_anchor=(0.5, 1.01))
    ax_count.grid(zorder=-30)
    fig_count.tight_layout()
    fig_count.savefig(args.output_plot_name_format.format(ifo=ifo, type='counts'))
