#!/usr/bin/python

# Copyright 2019 Gareth S. Davies
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.


import sys, h5py
import argparse, logging

from matplotlib import use
use('Agg')
from matplotlib import pyplot as plt
import numpy as np

from pycbc import events, bin_utils, results
from pycbc.events import triggers as trigs
from pycbc.events import trigger_fits as trstats
from pycbc.events import stat as pystat
from pycbc.types.optparse import MultiDetOptionAction
import pycbc.version

parser = argparse.ArgumentParser(usage="",
    description="Plot histograms of triggers split over various parameters")
parser.add_argument("--trigger-file", required=True,
                    help="Input hdf5 file containing single triggers. "
                         "Required")
parser.add_argument("--bank-file", default=None, required=True,
                    help="hdf file containing template parameters. Required")
parser.add_argument('--output-file', required=True,
                    help="Output image file. Required")
parser.add_argument('--verbose', action='store_true')
parser.add_argument('--bin-param', default='template_duration',
                    help="Parameter for binning within plots, default "
                         "'template_duration'")
parser.add_argument('--bin-spacing', choices=['linear', 'log'], default='log',
                    help="How to space bin-param bin edges. "
                         "Choices=[linear, log], default log")
parser.add_argument('--num-bins', type=int, default=6,
                    help="Number of bins over which to split bin-param, default 6")
parser.add_argument('--max-bin-param', type=float, default=None,
                    help="Maximum allowed value of bin-param")
parser.add_argument('--split-param-one', default='eta',
                    help="Parameter for splitting plot grid in y-direction, "
                         "default 'eta'")
parser.add_argument('--split-param-two', default='chi_eff',
                    help="Parameter for splitting plot grid in x-direction, "
                         "default 'chi_eff'")
parser.add_argument('--split-one-nbins', default=3, type=int,
                    help="Number of split plots over split-param-one, default 3")
parser.add_argument('--split-two-nbins', default=4, type=int,
                    help="Number of split plots over split-param-two, default 4")
parser.add_argument('--split-one-spacing', choices=['linear', 'log'], default='linear',
                    help="How to space split-param-one bin edges. "
                         "Choices=[linear, log], default=linear")
parser.add_argument('--split-two-spacing', choices=['linear', 'log'], default='linear',
                    help="How to space split-param-two bin edges. "
                         "Choices=[linear, log], default=linear")
parser.add_argument('--ifo', help="Detector. Required")
parser.add_argument('--plot-max-x', type=float, default=None,
                    help="Maximum stat value to plot, if not given "
                         "1.05 * largest stat value will be used")
parser.add_argument("--veto-file",
                    help="File(s) in .xml format with veto segments to apply "
                         "to triggers before fitting")
parser.add_argument("--veto-segment-name",
                    help="Name(s) of veto segments to apply. Optional, if not "
                         "given all triggers for the given ifo will be used")
parser.add_argument("--gating-veto-windows", nargs='+',
                    action=MultiDetOptionAction,
                    help="Seconds to be vetoed before and after the central time "
                         "of each gate. Given as detector-values pairs, e.g. "
                         "H1:-1,2.5 L1:-1,2.5 V1:0,0")
parser.add_argument("--stat-fit-threshold", type=float, required=True,
                    help="Only fit triggers with statistic value above this "
                         "threshold. Required")
parser.add_argument("--fit-function",
                    choices=["exponential", "rayleigh", "power"],
                    help="Functional form for the maximum likelihood fit")
parser.add_argument("--prune-number", type=int, default=0,
                    help="Number of loudest events to remove from each split "
                         "histogram, default 0")
parser.add_argument("--prune-window", type=float, default=0.1,
                    help="Time (s) to remove all triggers around a trigger "
                         "which is loudest in each split, default 0.1s")

pystat.insert_statistic_option_group(parser,
    default_ranking_statistic='single_ranking_only')
args = parser.parse_args()

pycbc.init_logging(args.verbose)

logging.info('Opening trigger file: %s' % args.trigger_file)
trigf = h5py.File(args.trigger_file, 'r')

logging.info('Opening template file: %s' % args.bank_file)
bank = h5py.File(args.bank_file, 'r')
logging.info('setting up template bank parameters')
params = {
        "mass1" : bank['mass1'][:],
        "mass2" : bank['mass2'][:],
        "spin1z": bank['spin1z'][:],
        "spin2z": bank['spin2z'][:]
}
bank.close()

# Only calculate params if needed
usedparams = [args.split_param_two, args.split_param_one, args.bin_param]
extparams = [par for par in ['mchirp', 'eta', 'chi_eff', 'template_duration']
             if par in usedparams]

for ex_p in extparams:
    if ex_p == 'template_duration':
        logging.info('Reading duration from trigger file')
        # List comprehension loops over templates; if a template has no triggers, accessing
        # the 0th entry of its region reference will return zero due to a quirk of h5py.
        params[ex_p] = np.array([trigf[args.ifo + '/template_duration'][ref][0]
                                 for ref in trigf[args.ifo + '/template_duration_template'][:]])
    else:
        logging.info("Calculating " + ex_p + " from template parameters")
        params[ex_p] = trigs.get_param(ex_p, args, params['mass1'],
                                       params['mass2'], params['spin1z'],
                                       params['spin2z'])

# string formats for labels, logging etc.
formats = {
        "mchirp": '{:.2f}',
        "eta": '{:.3f}',
        "chi_eff": '{:.3f}',
        'template_duration': '{:.0f}'
}

logging.info('setting up {}'.format(usedparams))
sp_one_bin_input = (params[args.split_param_one].min(),
                    params[args.split_param_one].max(), args.split_one_nbins)
sp_two_bin_input = (params[args.split_param_two].min(),
                    params[args.split_param_two].max(), args.split_two_nbins)
if args.max_bin_param:
    logging.info(('setting maximum {} value: '
                  + formats[args.bin_param]).format(args.bin_param,
                                                    args.max_bin_param))
    pbin_upper_lim = float(args.max_bin_param)
else:
    pbin_upper_lim = params[args.bin_param].max()

# For templates with no triggers, the duration will be read as zero
if args.bin_param == 'template_duration' and params[args.bin_param].min() == 0:
    logging.warn('WARNING: Some templates do not contain triggers')
    # Use the lowest nonzero template duration as lower limit for bins
    pbin_lower_lim = params[args.bin_param][params[args.bin_param] > 0].min()
else:
    pbin_lower_lim = params[args.bin_param].min()

bb_input = (pbin_lower_lim, pbin_upper_lim, args.num_bins)

logging.info('splitting {} into bins'.format(args.bin_param))
if args.bin_spacing == 'log':
    assert pbin_lower_lim > 0
    pbins = bin_utils.LogarithmicBins(*bb_input)
else:
    pbins = bin_utils.LinearBins(*bb_input)
# Use sentinel value -1 for templates outside range
pind = np.array([pbins[par] if pbin_lower_lim < par < pbin_upper_lim
                 else -1 for par in params[args.bin_param]])

if args.split_one_spacing == 'log':
    assert params[args.split_param_one].min() > 0
    sp_one_bounds = bin_utils.LogarithmicBins(*sp_one_bin_input)
else:
    sp_one_bounds = bin_utils.LinearBins(*sp_one_bin_input)

if args.split_two_spacing == 'log':
    assert params[args.split_param_two].min() > 0
    sp_two_bounds = bin_utils.LogarithmicBins(*sp_two_bin_input)
else:
    sp_two_bounds = bin_utils.LinearBins(*sp_two_bin_input)

logging.info('assigning template ids to different splits')
id_in_bin1 = [[]] * args.split_one_nbins
id_in_bin2 = [[]] * args.split_two_nbins
for i, lower_1, upper_1 in zip(range(args.split_one_nbins),
                               sp_one_bounds.lower(), sp_one_bounds.upper()):
    id_in_bin1[i] = np.intersect1d(np.argwhere(params[args.split_param_one] > lower_1),
                                   np.argwhere(params[args.split_param_one] <= upper_1))

for i, lower_2, upper_2 in zip(range(args.split_two_nbins),
                               sp_two_bounds.lower(), sp_two_bounds.upper()):
    id_in_bin2[i] = np.intersect1d(np.argwhere(params[args.split_param_two] > lower_2),
                                   np.argwhere(params[args.split_param_two] <= upper_2))

logging.info('Getting template boundaries from trigger file')
boundaries = trigf[args.ifo + '/template_boundaries'][:]
max_boundary_id = np.argmax(boundaries)
sorted_boundary_list = np.sort(boundaries)

logging.info('Processing template boundaries')
where_idx_end = np.zeros_like(boundaries)
for idx, idx_start in enumerate(boundaries):
    if idx == max_boundary_id:
        where_idx_end[idx] = trigf[args.ifo + '/end_time'].size
    else:
        where_idx_end[idx] = sorted_boundary_list[
            np.argmax(sorted_boundary_list == idx_start) + 1]

logging.info('Calculating single stat values from trigger file')
rank_method = pystat.get_statistic_from_opts(args, [args.ifo])
stat = rank_method.get_sngl_ranking(trigf[args.ifo])

if args.veto_file:
    logging.info('Applying DQ vetoes')
    time = trigf[args.ifo + '/end_time'][:]
    remove, junk = events.veto.indices_within_segments(time, [args.veto_file],
                         ifo=args.ifo, segment_name=args.veto_segment_name)
    # Set stat to zero for triggers being vetoed: given that the fit threshold is
    # >0 these will not be fitted or plotted.  Avoids complications from changing
    # the number of triggers, ie changes of template boundary.
    stat[remove] = np.zeros_like(remove)
    time[remove] = np.zeros_like(remove)
    logging.info('{} out of {} trigs removed after vetoing with {} from {}'.format(
                      remove.size, stat.size, args.veto_segment_name, args.veto_file))

if args.gating_veto_windows:
    logging.info('Applying veto to triggers near gates')
    gating_veto = args.gating_veto_windows[args.ifo].split(',')
    gveto_before = float(gating_veto[0])
    gveto_after = float(gating_veto[1])
    if gveto_before > 0 or gveto_after < 0:
        raise ValueError("Gating veto window values must be negative before "
                         "gates and positive after gates.")
    if not (gveto_before == 0 and gveto_after == 0):
        time = trigf[args.ifo + '/end_time'][:]
        autogate_times = np.unique(trigf[args.ifo + '/gating/auto/time'][:])
        if args.ifo + '/gating/file' in trigf:
            detgate_times = trigf[args.ifo + '/gating/file/time'][:]
        else:
            detgate_times = []
        gate_times = np.concatenate((autogate_times, detgate_times))
        gveto_remove = events.veto.indices_within_times(time, gate_times + gveto_before,
                                                        gate_times + gveto_after)
        stat[gveto_remove] = np.zeros_like(gveto_remove)
        time[gveto_remove] = np.zeros_like(gveto_remove)
        logging.info('{} out of {} trigs removed after vetoing triggers near gates'.format(
                          gveto_remove.size, stat.size))

for x in range(args.split_one_nbins):
    if not args.prune_number:
        logging.info('Not performing any pruning')
        break
    elif args.prune_number and x == 0:
        logging.info('Applying pruning around loudest triggers in each split')
    id_bin1 = id_in_bin1[x]
    for y in range(args.split_two_nbins):
        id_bin2 = id_in_bin2[y]
        # Finding ids of templates in split
        id_in_both = np.intersect1d(id_bin1, id_bin2)
        if len(id_in_both) == 0: continue
        vals_inbin = []
        time_inbin = []
        # getting triggers that are in these templates
        for idx in id_in_both:
            where_idx_start = boundaries[idx]
            vals_inbin += list(stat[where_idx_start:where_idx_end[idx]])
            time_inbin += list(time[where_idx_start:where_idx_end[idx]])

        vals_inbin = np.array(vals_inbin)
        time_inbin = np.array(time_inbin)

        count_pruned = 0
        logging.info('Pruning in split {}-{} {}-{}'.format(
                     args.split_param_one, x, args.split_param_two, y))
        while count_pruned < args.prune_number:
            # Getting loudest statistic value in split
            max_val_arg = vals_inbin.argmax()
            max_statval = vals_inbin[max_val_arg]

            remove = np.nonzero(abs(time_inbin[max_val_arg] - time)
                                    < args.prune_window)[0]
            # Remove from inbin triggers as well in case there
            # are more pruning iterations
            remove_inbin = np.nonzero(abs(time_inbin[max_val_arg] - time_inbin)
                                      < args.prune_window)[0]
            logging.info('Prune {}: removing {} triggers around time {:.2f},'
                         ' {} in this split'.format(count_pruned, remove.size,
                                                    time[max_val_arg],
                                                    remove_inbin.size))
            # Set pruned triggers' stat values to zero, as above for vetoes
            vals_inbin[remove_inbin] = np.zeros_like(remove_inbin)
            time_inbin[remove_inbin] = np.zeros_like(remove_inbin)
            stat[remove] = np.zeros_like(remove)
            time[remove] = np.zeros_like(remove)
            count_pruned += 1

trigf.close()

logging.info('Setting up plotting and fitting limit values')
minplot = max(stat[np.nonzero(stat)].min(), args.stat_fit_threshold - 1)
min_fit = max(minplot, args.stat_fit_threshold)
max_fit = 1.05 * stat.max()
if args.plot_max_x:
    maxplot = args.plot_max_x
else:
    maxplot = max_fit
fitrange = np.linspace(min_fit, max_fit, 100)

logging.info('Setting up plotting variables')
histcolors = ['r',(1.0,0.6,0),'y','g','c','b','m','k',(0.8,0.25,0),(0.25,0.8,0)]
fig, axes = plt.subplots(args.split_one_nbins, args.split_two_nbins,
                         sharex=True, sharey=True, squeeze=False,
                         figsize=(3 * (args.split_two_nbins + 1),
                                  3 * args.split_one_nbins))

# Setting up overall legend outside the split-up plots
lines = []
labels = []
for i, lower, upper in zip(range(args.num_bins), pbins.lower(), pbins.upper()):
    binlabel = r"%.3g - %.3g" % (lower, upper)
    line, = axes[0,0].plot([0,0], [0,0], linewidth=2,
                     color=histcolors[i], alpha=0.6)
    lines.append(line)
    labels.append(binlabel)

line_fit, = axes[0,0].plot([0,0], [0,0], linestyle='--',
                           color='k', alpha=0.6)
lines.append(line_fit)
labels.append(args.fit_function + ' fit to counts')
fig.legend(lines, labels, labelspacing=0.2,
           loc='upper left', title=args.bin_param)

pidx = []
for i in range(args.num_bins):
    pidx.append([np.argwhere(pind == i)])

logging.info('Starting bin, histogram and plot loop')
maxyval = 0
for x in range(args.split_one_nbins):
    id_bin1 = id_in_bin1[x]
    for y in range(args.split_two_nbins):
        id_bin2 = id_in_bin2[y]
        id_in_both = np.intersect1d(id_bin1, id_bin2)
        logging.info('split {}: {}, {}: {}'.format(args.split_param_one, x,
                                                   args.split_param_two, y))
        ax = axes[x,y]
        for i, lower, upper in zip(range(args.num_bins), pbins.lower(),
                                   pbins.upper()):
            indices_all_conditions = np.intersect1d(pidx[i], id_in_both)
            logging.info('{} split {}-{}'.format(args.bin_param, lower, upper))
            if len(indices_all_conditions) == 0: continue
            vals_inbin = []
            for idx in indices_all_conditions:
                where_idx_start = boundaries[idx]
                vals_inbin += list(stat[where_idx_start:where_idx_end[idx]])

            vals_inbin = np.array(vals_inbin)
            vals_above_thresh = vals_inbin[vals_inbin >= args.stat_fit_threshold]
            if not len(vals_above_thresh):
                logging.info('No triggers above threshold')
                continue
            else:
                logging.info('{} triggers out of {} above threshold'.format(
                             len(vals_above_thresh), len(vals_inbin)))
            alpha, sig_alpha = trstats.fit_above_thresh(args.fit_function,
                                 vals_above_thresh, args.stat_fit_threshold)
            fitted_cum_counts = len(vals_above_thresh) * \
                                trstats.cum_fit(args.fit_function, fitrange,
                                                alpha, args.stat_fit_threshold)
            # upper and lower 1-sigma bounds on fit are not currently plotted
            #fitted_cum_counts_plus = len(vals_above_thresh) * \
            #                             trstats.cum_fit(args.fit_function,
            #                                          fitrange, alpha + sig_alpha,
            #                                          args.stat_fit_threshold)
            #fitted_cum_counts_minus = len(vals_above_thresh) * \
            #                              trstats.cum_fit(args.fit_function, fitrange,
            #                                          alpha - sig_alpha,
            #                                          args.stat_fit_threshold)

            # Histogram of fitted values
            histcounts, edges = np.histogram(vals_inbin, bins=50)
            cum_counts = histcounts[::-1].cumsum()[::-1]
            # Plot the lines!
            ax.semilogy(edges[:-1], cum_counts, linewidth=2,
                     color=histcolors[i], alpha=0.6)
            ax.semilogy(fitrange, fitted_cum_counts, "--", color=histcolors[i],
                     label=r"$\alpha = $%.2f" % alpha)
            lgd_sub = ax.legend(fontsize='small', framealpha=0.5)

            del vals_inbin, vals_above_thresh

        maxyval = max(maxyval, cum_counts.max())
        ax.grid()

for i in range(args.split_one_nbins):
    for j in range(args.split_two_nbins):
        axes[i,j].semilogy([args.stat_fit_threshold, args.stat_fit_threshold],
                           [1, 5 * maxyval], 'k', linestyle=':', alpha=0.2)
axes[0,0].set_ylim(1, 5 * maxyval)
axes[0,0].set_xlim(minplot, maxplot)

for j in range(args.split_two_nbins):
    axes[args.split_one_nbins - 1, j].set_xlabel(args.sngl_ranking, size="large")
    if args.split_two_nbins == 1:
        break
    axes[0, j].set_xlabel(args.split_param_two + ': ' +
                          (formats[args.split_param_two] + ' to ' +
                           formats[args.split_param_two]).format(sp_two_bounds.lower()[j],
                                     sp_two_bounds.upper()[j]), size="large")
    axes[0, j].xaxis.set_label_position("top")

for i in range(args.split_one_nbins):
    if args.split_one_nbins == 1:
        axes[0, 0].set_ylabel('Cumulative number', size='large')
        break
    axes[i, 0].set_ylabel(args.split_param_one + ': ' +
                          (formats[args.split_param_one] + ' to ' +
                           formats[args.split_param_one]).format(sp_one_bounds.lower()[i],
                                      sp_one_bounds.upper()[i]) + '\ncumulative number',
                                      size="large")

fig.tight_layout(rect=(1./(args.split_two_nbins+1), 0, 1, 1))

logging.info('Saving to file ' + args.output_file)
results.save_fig_with_metadata(
    fig, args.output_file,
    title="{}: {} histogram of single detector triggers split by"
          " {} and {}".format(args.ifo, args.sngl_ranking, args.split_param_one,
                              args.split_param_two),
    caption=(r"Histogram of {} single detector {} values binned by {}, split by "
             "{} and {}, with fitted {} distribution parameterized by"
             " &alpha;".format(args.ifo, args.sngl_ranking, args.bin_param,
                               args.split_param_one, args.split_param_two,
                               args.fit_function)),
    cmd=" ".join(sys.argv)
)
logging.info('Done!')
