#!/usr/bin/python

# Copyright 2020 Gareth S. Davies
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.

import numpy as np
import pycbc
from pycbc import bin_utils
from pycbc.events import trigger_fits as trstats
import matplotlib
matplotlib.use('agg')
from matplotlib import pyplot as plt
import argparse, logging, h5py

parser = argparse.ArgumentParser(usage="",
    description="Plot histograms of triggers split over various parameters")
parser.add_argument("--verbose", action="store_true",
                    help="Print extra debugging information", default=False)
parser.add_argument("--trigger-fits-file", required=True,
                    help="Trigger fits file to plot")
default_plot_format = "{ifo}-TRIGGER-FITS.png"
parser.add_argument("--output-plot-name-format",
                    default=default_plot_format,
                    help="Format to save plots, must contain '{ifo}' to "
                         "indicate ifo in filename. Default: " +
                         default_plot_format)
parser.add_argument("--colormap", default="rainbow_r", choices=plt.colormaps(),
                    help="Colormap to use for choosing the colours of the "
                         "duration bin lines. Default=rainbow_r")
parser.add_argument("--log-colormap", action='store_true',
                    help="Use log spacing for choosing colormap values "
                         "based on duration bins.")
parser.add_argument("--x-lim-upper", type=float,
                    help="Add an upper limit to the x-axis of the plot")
parser.add_argument("--y-lim-upper", type=float,
                    help="Add an upper limit to the y-axis of the plot")
#parser.add_argument("--", default="", help="")

#Add some input sanitisation
args = parser.parse_args()

if '{ifo}' not in args.output_plot_name_format:
    parser.error("--output-plot-name-format must contain '{ifo}' "
                 "to indicate ifo in filename.")

pycbc.init_logging(args.verbose)

logging.info("Getting trigger fits file information")
with h5py.File(args.trigger_fits_file, 'r') as trfit_f:
    # Get the ifos to plot from the file
    # Check that all the ifos we want to plot are in the file
    ifos = [k for k in trfit_f.keys() if not k.startswith('bins')]

    # Grab some info from the attributes
    sngl_ranking = trfit_f.attrs['sngl_ranking']
    fit_threshold = trfit_f.attrs['fit_threshold']
    fit_function = trfit_f.attrs['fit_function']
    analysis_date = trfit_f.attrs['analysis_date']

    # Get the triggers for each detector
    # (This is ones which passed the cuts in the fitting code)
    stats = {ifo: {} for ifo in ifos}
    durations = {ifo: {} for ifo in ifos}
    for ifo in ifos:
        stats[ifo] = trfit_f[ifo]['triggers'][sngl_ranking][:]
        durations[ifo] = trfit_f[ifo]['triggers']['template_duration'][:]
    live_time = {ifo: trfit_f[ifo].attrs['live_time'] for ifo in ifos}
    alphas = {ifo: trfit_f[ifo]['fit_coeff'][:] for ifo in ifos}
    counts = {ifo: trfit_f[ifo]['counts'][:] for ifo in ifos}
    bu = trfit_f['bins_upper'][:]
    bl = trfit_f['bins_lower'][:]

bin_max = max(bu)
bin_min = min(bl)

def bin_proportion(upper, lower, log_spacing=False):
    if log_spacing:
        ll = np.log(lower)
        ul = np.log(lower)
        centl = (ll + ul) / 2.
        minl = np.log(bin_min)
        maxl = np.log(bin_max)
        return (centl - minl) / (maxl - minl)

    else:
        return ((lower + upper) / 2. - bin_min) / (bin_max - bin_min)


duration_bin_edges = list(bl) + [bu[-1]]

tbins = bin_utils.IrregularBins(duration_bin_edges)

logger = logging.getLogger()
init_level = logger.level

logging.info("Plotting fits")

for ifo in ifos:
    # Skip if no triggers in this IFO
    if not len(stats[ifo]): continue

    # Keep track of some maxima for use in setting the plot limits
    maxstat = stats[ifo].max()
    max_rate = 0

    fig, ax = plt.subplots(1)

    plotbins = np.linspace(fit_threshold, 1.05 * maxstat)

    logging.info("Putting events into bins")
    event_bin = np.array([tbins[d] for d in durations[ifo]])

    for bin_num, lower_upper in enumerate(zip(duration_bin_edges[:-1],
                                              duration_bin_edges[1:])):
        lower, upper = lower_upper
        binlabel = f"{lower:.3g} - {upper:.3g}"

        inbin = event_bin == bin_num
        # Skip if there are no triggers in this bin in this IFO
        if not any(inbin): continue
        binned_sngl_stats = stats[ifo][event_bin == bin_num]

        # Histogram the triggers
        histcounts, edges = np.histogram(binned_sngl_stats,
                                         bins=plotbins)
        cum_rate = histcounts[::-1].cumsum()[::-1] / live_time[ifo]

        max_rate = max(max_rate, cum_rate[0])

        cum_fit = counts[ifo][bin_num] / live_time[ifo] * \
                      trstats.cum_fit(fit_function, plotbins,
                                      alphas[ifo][bin_num], fit_threshold)

        # Get the colour from the centre of the bin vs the full range
        # of all bins
        bin_prop = bin_proportion(upper, lower,
                                  log_spacing=args.log_colormap)
        bin_colour = plt.get_cmap(args.colormap)(bin_prop)
        ax.plot(edges[:-1], cum_rate, linewidth=2,
                color=bin_colour, label=binlabel, alpha=0.6)
        ax.plot(plotbins, cum_fit, "--", color=bin_colour,
                label=r"$\alpha = $%.2f" % alphas[ifo][bin_num])
    oput_plot = args.output_plot_name_format.format(ifo=ifo)
    ax.semilogy()
    ax.grid()
    x_upper = args.x_lim_upper or 1.05 * maxstat
    y_upper = args.y_lim_upper or 1.5 * max_rate
    ax.set_xlim(fit_threshold, x_upper)
    ax.set_ylim(0.5 / live_time[ifo], y_upper)
    ax.set_xlabel(sngl_ranking)
    ax.set_ylabel("Number of louder triggers per live time")
    title = f"{ifo} {analysis_date} trigger fits"
    ax.set_title(title)
    ax.legend(loc='upper right')
    logging.info(f"Saving {oput_plot}")
    # Save initial logging level
    logger.setLevel(logging.WARNING)
    fig.savefig(oput_plot)
    logger.setLevel(init_level)
    
