#!/usr/bin/python

# Copyright 2019 Gareth S. Davies
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.

from __future__ import division

import sys, h5py
import argparse, logging

from matplotlib import use
use('Agg')
from matplotlib import pyplot as plt
import numpy as np

from pycbc import events, bin_utils, results
from pycbc.events import triggers as trigs
from pycbc.events import trigger_fits as trstats
from pycbc.events.stat import sngl_statistic_dict
import pycbc.version

def get_stat(statchoice, trigs):
    # Initialize statclass with an empty file list. In general could feed it
    # files here for statistics which need that.
    stat_instance = sngl_statistic_dict[statchoice]([])
    return stat_instance.single(trigs)

parser = argparse.ArgumentParser(usage="",
    description="Plot histograms of triggers split over various parameters")
parser.add_argument("--trigger-file", required=True,
                    help="Input hdf5 file containing single triggers. "
                         "Required")
parser.add_argument("--bank-file", default=None, required=True,
                    help="hdf file containing template parameters. Required")
parser.add_argument('--output-file', required=True,
                    help="Output image file. Required")
parser.add_argument('--verbose', action='store_true')
parser.add_argument('--bin-param', default='template_duration',
                    help="Parameter for binning within plots, default "
                         "'template_duration'")
parser.add_argument('--bin-spacing', choices=['linear', 'log'], default='log',
                    help="How to space bin-param bin edges. "
                         "Choices=[linear, log], default log")
parser.add_argument('--num-bins', type=int, default=6,
                    help="Number of bins over which to split bin-param, default 6")
parser.add_argument('--max-bin-param', type=float, default=None,
                    help="Maximum allowed value of bin-param")
parser.add_argument('--split-param-one', default='eta',
                    help="Parameter for splitting plot grid in y-direction, "
                         "default 'eta'")
parser.add_argument('--split-param-two', default='chi_eff',
                    help="Parameter for splitting plot grid in x-direction, "
                         "default 'chi_eff'")
parser.add_argument('--split-one-nbins', default=3, type=int,
                    help="Number of split plots over split-param-one, default 3")
parser.add_argument('--split-two-nbins', default=4, type=int,
                    help="Number of split plots over split-param-two, default 4")
parser.add_argument('--split-one-spacing', choices=['linear', 'log'], default='linear',
                    help="How to space split-param-one bin edges. "
                         "Choices=[linear, log], default=linear")
parser.add_argument('--split-two-spacing', choices=['linear', 'log'], default='linear',
                    help="How to space split-param-two bin edges. "
                         "Choices=[linear, log], default=linear")
parser.add_argument("--sngl-stat", default="new_snr",
                    choices=sngl_statistic_dict.keys(),
                    help="Function of SNR and chisq to perform fits with")
parser.add_argument('--ifo', help="Detector. Required")
parser.add_argument('--plot-max-x', type=float, default=None,
                    help="Maximum stat value to plot, if not given "
                         "1.05 * largest stat value will be used")
parser.add_argument("--veto-file",
                    help="File(s) in .xml format with veto segments to apply "
                         "to triggers before fitting")
parser.add_argument("--veto-segment-name",
                    help="Name(s) of veto segments to apply. Optional, if not "
                         "given all triggers for the given ifo will be used")
parser.add_argument("--stat-fit-threshold", type=float, required=True,
                    help="Only fit triggers with statistic value above this "
                         "threshold. Required")
parser.add_argument("--fit-function",
                    choices=["exponential", "rayleigh", "power"],
                    help="Functional form for the maximum likelihood fit")
parser.add_argument("--prune-number", type=int, default=0,
                    help="Number of loudest events to remove from each split "
                         "histogram, default 0")
parser.add_argument("--prune-window", type=float, default=0.1,
                    help="Time (s) to remove all triggers around a trigger "
                         "which is loudest in each split, default 0.1s")
args = parser.parse_args()

pycbc.init_logging(args.verbose)

logging.info('Opening trigger file: %s' % args.trigger_file)
trigf = h5py.File(args.trigger_file, 'r')

logging.info('Opening template file: %s' % args.bank_file)
bank = h5py.File(args.bank_file, 'r')
logging.info('setting up template bank parameters')
params = {
        "mass1" : bank['mass1'][:],
        "mass2" : bank['mass2'][:],
        "spin1z": bank['spin1z'][:],
        "spin2z": bank['spin2z'][:]
}
bank.close()

# Only calculate params if needed
usedparams = [args.split_param_two, args.split_param_one, args.bin_param]
extparams = [par for par in ['mchirp', 'eta', 'chi_eff', 'template_duration']
             if par in usedparams]

for ex_p in extparams:
    if ex_p == 'template_duration':
        logging.info('Reading duration from trigger file')
        # List comprehension loops over templates; if a template has no triggers, accessing
        # the 0th entry of its region reference will return zero due to a quirk of h5py.
        params[ex_p] = np.array([trigf[args.ifo + '/template_duration'][ref][0]
                                 for ref in trigf[args.ifo + '/template_duration_template'][:]])
    else:
        logging.info("Calculating " + ex_p + " from template parameters")
        params[ex_p] = trigs.get_param(ex_p, args, params['mass1'],
                                       params['mass2'], params['spin1z'],
                                       params['spin2z'])

# string formats for labels, logging etc.
formats = {
        "mchirp": '{:.2f}',
        "eta": '{:.3f}',
        "chi_eff": '{:.3f}',
        'template_duration': '{:.0f}'
}

logging.info('setting up {}'.format(usedparams))
sp_one_bin_input = (params[args.split_param_one].min(),
                    params[args.split_param_one].max(), args.split_one_nbins)
sp_two_bin_input = (params[args.split_param_two].min(),
                    params[args.split_param_two].max(), args.split_two_nbins)
if args.max_bin_param:
    logging.info(('setting maximum {} value: '
                  + formats[args.bin_param]).format(args.bin_param,
                                                    args.max_bin_param))
    pbin_upper_lim = float(args.max_bin_param)
else:
    pbin_upper_lim = params[args.bin_param].max()

# For templates with no triggers, the duration will be read as zero
if args.bin_param == 'template_duration' and params[args.bin_param].min() == 0:
    logging.warn('WARNING: Some templates do not contain triggers')
    # Use the lowest nonzero template duration as lower limit for bins
    pbin_lower_lim = params[args.bin_param][params[args.bin_param] > 0].min()
else:
    pbin_lower_lim = params[args.bin_param].min()

bb_input = (pbin_lower_lim, pbin_upper_lim, args.num_bins)

logging.info('splitting {} into bins'.format(args.bin_param))
if args.bin_spacing == 'log':
    assert pbin_lower_lim > 0
    pbins = bin_utils.LogarithmicBins(*bb_input)
else:
    pbins = bin_utils.LinearBins(*bb_input)
# Use sentinel value -1 for templates outside range
pind = np.array([pbins[par] if pbin_lower_lim < par < pbin_upper_lim
                 else -1 for par in params[args.bin_param]])

if args.split_one_spacing == 'log':
    assert params[args.split_param_one].min() > 0
    sp_one_bounds = bin_utils.LogarithmicBins(*sp_one_bin_input)
else:
    sp_one_bounds = bin_utils.LinearBins(*sp_one_bin_input)

if args.split_two_spacing == 'log':
    assert params[args.split_param_two].min() > 0
    sp_two_bounds = bin_utils.LogarithmicBins(*sp_two_bin_input)
else:
    sp_two_bounds = bin_utils.LinearBins(*sp_two_bin_input)

logging.info('assigning template ids to different splits')
id_in_bin1 = [[]] * args.split_one_nbins
id_in_bin2 = [[]] * args.split_two_nbins
for i, lower_1, upper_1 in zip(range(args.split_one_nbins),
                               sp_one_bounds.lower(), sp_one_bounds.upper()):
    id_in_bin1[i] = np.intersect1d(np.argwhere(params[args.split_param_one] > lower_1),
                                   np.argwhere(params[args.split_param_one] <= upper_1))

for i, lower_2, upper_2 in zip(range(args.split_two_nbins),
                               sp_two_bounds.lower(), sp_two_bounds.upper()):
    id_in_bin2[i] = np.intersect1d(np.argwhere(params[args.split_param_two] > lower_2),
                                   np.argwhere(params[args.split_param_two] <= upper_2))

logging.info('getting template boundaries from trigger file')
boundaries = trigf[args.ifo + '/template_boundaries'][:]
max_boundary_id = np.argmax(boundaries)
sorted_boundary_list = np.sort(boundaries)

logging.info('Processing template boundaries')
where_idx_end = np.zeros_like(boundaries)
for idx, idx_start in enumerate(boundaries):
    if idx == max_boundary_id:
        where_idx_end[idx] = trigf[args.ifo + '/end_time'].size
    else:
        where_idx_end[idx] = sorted_boundary_list[
            np.argmax(sorted_boundary_list == idx_start) + 1]

logging.info('calculating single stat values from trigger file')
stat = get_stat(args.sngl_stat, trigf[args.ifo])

if args.veto_file:
    logging.info('applying DQ vetoes')
    time = trigf[args.ifo + '/end_time'][:]
    remove, junk = events.veto.indices_within_segments(time, [args.veto_file],
                         ifo=args.ifo, segment_name=args.veto_segment_name)
    # Set stat to zero for triggers being vetoed: given that the fit threshold is
    # >0 these will not be fitted or plotted.  Avoids complications from changing
    # the number of triggers, ie changes of template boundary.
    stat[remove] = np.zeros_like(remove)
    time[remove] = np.zeros_like(remove)
    logging.info('{} out of {} trigs removed after vetoing with {} from {}'.format(
                      remove.size, stat.size, args.veto_segment_name, args.veto_file))

for x in range(args.split_one_nbins):
    if not args.prune_number:
        logging.info('Not performing any pruning')
        break
    elif args.prune_number and x == 0:
        logging.info('Applying pruning around loudest triggers in each split')
    id_bin1 = id_in_bin1[x]
    for y in range(args.split_two_nbins):
        id_bin2 = id_in_bin2[y]
        # Finding ids of templates in split
        id_in_both = np.intersect1d(id_bin1, id_bin2)
        if len(id_in_both) == 0: continue
        vals_inbin = []
        time_inbin = []
        # getting triggers that are in these templates
        for idx in id_in_both:
            where_idx_start = boundaries[idx]
            vals_inbin += list(stat[where_idx_start:where_idx_end[idx]])
            time_inbin += list(time[where_idx_start:where_idx_end[idx]])

        vals_inbin = np.array(vals_inbin)
        time_inbin = np.array(time_inbin)

        count_pruned = 0
        logging.info('Pruning in split {}-{} {}-{}'.format(
                     args.split_param_one, x, args.split_param_two, y))
        while count_pruned < args.prune_number:
            # Getting loudest statistic value in split
            max_val_arg = vals_inbin.argmax()
            max_statval = vals_inbin[max_val_arg]

            remove = np.nonzero(abs(time_inbin[max_val_arg] - time)
                                    < args.prune_window)[0]
            # Remove from inbin triggers as well in case there
            # are more pruning iterations
            remove_inbin = np.nonzero(abs(time_inbin[max_val_arg] - time_inbin)
                                      < args.prune_window)[0]
            logging.info('Prune {}: removing {} triggers around time {:.2f},'
                         ' {} in this split'.format(count_pruned, remove.size,
                                                    time[max_val_arg],
                                                    remove_inbin.size))
            # Set pruned triggers' stat values to zero, as above for vetoes
            vals_inbin[remove_inbin] = np.zeros_like(remove_inbin)
            time_inbin[remove_inbin] = np.zeros_like(remove_inbin)
            stat[remove] = np.zeros_like(remove)
            time[remove] = np.zeros_like(remove)
            count_pruned += 1

trigf.close()

logging.info('setting up plotting and fitting limit values')
minplot = max(stat[np.nonzero(stat)].min(), args.stat_fit_threshold - 1)
min_fit = max(minplot, args.stat_fit_threshold)
max_fit = 1.05 * stat.max()
if args.plot_max_x:
    maxplot = args.plot_max_x
else:
    maxplot = max_fit
fitrange = np.linspace(min_fit, max_fit, 100)

logging.info('setting up plotting variables')
histcolors = ['r',(1.0,0.6,0),'y','g','c','b','m','k',(0.8,0.25,0),(0.25,0.8,0)]
fig, axes = plt.subplots(args.split_one_nbins, args.split_two_nbins,
                         sharex=True, sharey=True, squeeze=False,
                         figsize=(3 * (args.split_two_nbins + 1),
                                  3 * args.split_one_nbins))

# setting up overall legend outside the split-up plots
lines = []
labels = []
for i, lower, upper in zip(range(args.num_bins), pbins.lower(), pbins.upper()):
    binlabel = r"%.3g - %.3g" % (lower, upper)
    line, = axes[0,0].plot([0,0], [0,0], linewidth=2,
                     color=histcolors[i], alpha=0.6)
    lines.append(line)
    labels.append(binlabel)

line_fit, = axes[0,0].plot([0,0], [0,0], linestyle='--',
                           color='k', alpha=0.6)
lines.append(line_fit)
labels.append(args.fit_function + ' fit to counts')
fig.legend(lines, labels, labelspacing=0.2,
           loc='upper left', title=args.bin_param)

pidx = []
for i in range(args.num_bins):
    pidx.append([np.argwhere(pind == i)])

logging.info('starting bin, histogram and plot loop')
maxyval = 0
for x in range(args.split_one_nbins):
    id_bin1 = id_in_bin1[x]
    for y in range(args.split_two_nbins):
        id_bin2 = id_in_bin2[y]
        id_in_both = np.intersect1d(id_bin1, id_bin2)
        logging.info('split {}: {}, {}: {}'.format(args.split_param_one, x,
                                                   args.split_param_two, y))
        ax = axes[x,y]
        for i, lower, upper in zip(range(args.num_bins), pbins.lower(),
                                   pbins.upper()):
            indices_all_conditions = np.intersect1d(pidx[i], id_in_both)
            logging.info('{} split {}-{}'.format(args.bin_param, lower, upper))
            if len(indices_all_conditions) == 0: continue
            vals_inbin = []
            for idx in indices_all_conditions:
                where_idx_start = boundaries[idx]
                vals_inbin += list(stat[where_idx_start:where_idx_end[idx]])

            vals_inbin = np.array(vals_inbin)
            vals_above_thresh = vals_inbin[vals_inbin >= args.stat_fit_threshold]
            if not len(vals_above_thresh):
                logging.info('No triggers above threshold')
                continue
            else:
                logging.info('{} triggers out of {} above threshold'.format(
                             len(vals_above_thresh), len(vals_inbin)))
            alpha, sig_alpha = trstats.fit_above_thresh(args.fit_function,
                                 vals_above_thresh, args.stat_fit_threshold)
            fitted_cum_counts = len(vals_above_thresh) * \
                                trstats.cum_fit(args.fit_function, fitrange,
                                                alpha, args.stat_fit_threshold)
            # upper and lower 1-sigma bounds on fit are not currently plotted
            #fitted_cum_counts_plus = len(vals_above_thresh) * \
            #                             trstats.cum_fit(args.fit_function,
            #                                          fitrange, alpha + sig_alpha,
            #                                          args.stat_fit_threshold)
            #fitted_cum_counts_minus = len(vals_above_thresh) * \
            #                              trstats.cum_fit(args.fit_function, fitrange,
            #                                          alpha - sig_alpha,
            #                                          args.stat_fit_threshold)

            # make histogram of fitted values
            histcounts, edges = np.histogram(vals_inbin, bins=50)
            cum_counts = histcounts[::-1].cumsum()[::-1]
            # plot the lines!
            ax.semilogy(edges[:-1], cum_counts, linewidth=2,
                     color=histcolors[i], alpha=0.6)
            ax.semilogy(fitrange, fitted_cum_counts, "--", color=histcolors[i],
                     label=r"$\alpha = $%.2f" % alpha)
            lgd_sub = ax.legend(fontsize='small', framealpha=0.5)

            del vals_inbin, vals_above_thresh

        maxyval = max(maxyval, cum_counts.max())
        ax.grid()

logging.info('setting up labels')
for i in range(args.split_one_nbins):
    for j in range(args.split_two_nbins):
        axes[i,j].semilogy([args.stat_fit_threshold, args.stat_fit_threshold],
                           [1, 5 * maxyval], 'k', linestyle=':', alpha=0.2)
axes[0,0].set_ylim(1, 5 * maxyval)
axes[0,0].set_xlim(minplot, maxplot)

for j in range(args.split_two_nbins):
    axes[args.split_one_nbins - 1, j].set_xlabel(args.sngl_stat, size="large")
    if args.split_two_nbins == 1:
        break
    axes[0, j].set_xlabel(args.split_param_two + ': ' +
                          (formats[args.split_param_two] + ' to ' +
                           formats[args.split_param_two]).format(sp_two_bounds.lower()[j],
                                     sp_two_bounds.upper()[j]), size="large")
    axes[0, j].xaxis.set_label_position("top")

for i in range(args.split_one_nbins):
    if args.split_one_nbins == 1:
        axes[0, 0].set_ylabel('cumulative number', size='large')
        break
    axes[i, 0].set_ylabel(args.split_param_one + ': ' +
                          (formats[args.split_param_one] + ' to ' +
                           formats[args.split_param_one]).format(sp_one_bounds.lower()[i],
                                      sp_one_bounds.upper()[i]) + '\ncumulative number',
                                      size="large")

fig.tight_layout(rect=(1./(args.split_two_nbins+1), 0, 1, 1))

logging.info('saving to file ' + args.output_file)
results.save_fig_with_metadata(
    fig, args.output_file,
    title="{}: {} histogram of single detector triggers split by"
          " {} and {}".format(args.ifo, args.sngl_stat, args.split_param_one,
                              args.split_param_two),
    caption=(r"Histogram of {} single detector {} values binned by {}, split by "
             "{} and {}, with fitted {} distribution parameterized by"
             " &alpha;".format(args.ifo, args.sngl_stat, args.bin_param,
                               args.split_param_one, args.split_param_two,
                               args.fit_function)),
    cmd=" ".join(sys.argv)
)
logging.info('Done!')
