#!/usr/bin/python

# Copyright 2016 Thomas Dent
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.

from __future__ import division

import sys, h5py
import argparse, logging

from matplotlib import use
use('Agg')
from matplotlib import pyplot as plt

import copy, numpy as np

from pycbc import events, bin_utils, results
from pycbc.events import triggers
from pycbc.events import trigger_fits as trstats
from pycbc.events.stat import sngl_statistic_dict
import pycbc.version

#### DEFINITIONS AND FUNCTIONS ####

def get_stat(statchoice, trigs):
    # Initialize statclass with an empty file list. In general could feed it
    # files here for statistics which need that.
    stat_instance = sngl_statistic_dict[statchoice]([])
    return stat_instance.single(trigs)

#### MAIN ####

parser = argparse.ArgumentParser(usage="",
    description="Perform maximum-likelihood fits of single inspiral trigger"
                " distributions to various functions")

parser.add_argument("--version", action=pycbc.version.Version)
parser.add_argument("-V", "--verbose", action="store_true",
                    help="Print extra debugging information", default=False)
parser.add_argument("--trigger-file",
                    help="Input hdf5 file containing single triggers. "
                    "Required")
parser.add_argument("--bank-file", default=None,
                    help="hdf file containing template parameters. Required")
parser.add_argument("--veto-file", nargs='*', default=[], action='append',
                    help="File(s) in .xml format with veto segments to apply "
                    "to triggers before fitting")
parser.add_argument("--veto-segment-name", nargs='*', default=[], action='append',
                    help="Name(s) of veto segments to apply. Optional, if not "
                    "given all segments for a given ifo will be used")
parser.add_argument("--ifo", required=True,
                    help="Ifo producing triggers to be fitted. Required")
parser.add_argument("--fit-function",
                    choices=["exponential", "rayleigh", "power"],
                    help="Functional form for the maximum likelihood fit")
parser.add_argument("--sngl-stat", default="new_snr",
                    choices=sngl_statistic_dict.keys(),
                    help="Function of SNR and chisq to perform fits with")
parser.add_argument("--stat-factor", type=float,
                    help="Adjustable magic number used in some sngl "
                    "statistics. Values commonly used: 6 for new_snr, 250 "
                    "or 50 for effective_snr")
parser.add_argument("--stat-threshold", nargs="+", type=float,
                    help="Only fit triggers with statistic value above this "
                    "threshold : can be a space-separated list, then a fit "
                    "will be done for each threshold.  Required.  Typical "
                    "values 6.25 6.5 6.75")
parser.add_argument("--prune-param",
                    help="Parameter to define bins for 'pruning' loud triggers"
                    " to make the fit insensitive to signals and outliers. "
                    "Choose from mchirp, mtotal, template_duration or a named "
                    "frequency cutoff in pnutils or a frequency function in "
                    "LALSimulation")
parser.add_argument("--prune-bins", type=int,
                    help="Number of bins to divide bank into when pruning")
parser.add_argument("--prune-number", type=int,
                    help="Number of loudest events to prune in each bin")
parser.add_argument("--log-prune-param", action='store_true',
                    help="Bin in the log of prune-param")
parser.add_argument("--f-lower", type=float, default=0.,
                    help="Starting frequency for calculating template "
                    "duration; if not given, duration will be read from "
                    "single trigger files")
# FIXME : allow choice of SEOBNRv2/v4 or PhenD duration formula ?
parser.add_argument("--bin-param", required=True,
                    help="Parameter over which to bin when fitting. Required. "
                    "Choose from mchirp, mtotal, template_duration or a named "
                    "frequency cutoff in pnutils or a frequency function in "
                    "LALSimulation")
parser.add_argument("--bin-spacing", choices=["linear", "log", "irregular"],
                    help="How to space parameter bin edges")
binopt = parser.add_mutually_exclusive_group(required=True)
binopt.add_argument("--num-bins", type=int,
                    help="Number of regularly spaced bins to use over the "
                    " parameter")
binopt.add_argument("--irregular-bins",
                    help="Comma-separated list of parameter bin edges. "
                    "Required if --bin-spacing = irregular")
parser.add_argument("--bin-param-units",
                    help="String to display units of the binning parameter")
parser.add_argument("--approximant", default="SEOBNRv4",
                    help="Approximant for template duration. Default SEOBNRv4")
parser.add_argument("--min-duration", default=0.,
                    help="Fudge factor for templates with tiny or negative "
                    "values of template_duration: add to duration values "
                    "before fitting. Units seconds")
outputchoice = parser.add_mutually_exclusive_group()
outputchoice.add_argument("--plot-dir",
                    help="Plot the fits made, the variation of fitting "
                    "coefficients and the Kolmogorov-Smirnov test values "
                    "and save to the specified directory.")
outputchoice.add_argument("--output-file",
                    help="Output a plot of hists and fits made for a single "
                    "threshold value.")
parser.add_argument("--user-tag", default="",
                    help="Put a possibly informative string in the names of "
                    "plot files")

args = parser.parse_args()

args.veto_segment_name = sum(args.veto_segment_name, [])
args.veto_file = sum(args.veto_file, [])

if len(args.veto_segment_name) != len(args.veto_file):
    raise RuntimeError("Number of veto files much match veto file names")

if (args.prune_param or args.prune_bins or args.prune_number) and not \
   (args.prune_param and args.prune_bins and args.prune_number):
    raise RuntimeError("To prune, need to specify param, number of bins and "
                       "nonzero number to prune in each bin!")

if args.output_file is not None and len(args.stat_threshold) > 1:
    raise RuntimeError("Cannot plot more than one threshold in a single "
                       "output file!")

if args.verbose:
    log_level = logging.DEBUG
else:
    log_level = logging.WARN
logging.basicConfig(format='%(asctime)s : %(message)s', level=log_level)

statname = "reweighted SNR" if args.sngl_stat == "new_snr" else \
           args.sngl_stat.replace("_", " ").replace("snr", "SNR")
paramname = args.bin_param.replace("_", " ")
paramtag = args.bin_param.replace("_", "")
if args.plot_dir:
    if not args.plot_dir.endswith('/'):
        args.plot_dir += '/'
    plotbase = args.plot_dir + args.ifo + "-" + args.user_tag

## Check option logic
if args.bin_spacing == "irregular":
    if args.irregular_bins is None:
        raise RuntimeError("Must specify a list of irregular bin edges!")
    else:
        args.bin_edges = [float(b) for b in args.irregular_bins.split(',')]

logging.info('Opening trigger file: %s' % args.trigger_file)
trigf = h5py.File(args.trigger_file, 'r')
logging.info('Opening template file: %s' % args.bank_file)
templatef = h5py.File(args.bank_file, 'r')

# get the stat values
stat = get_stat(args.sngl_stat, trigf[args.ifo])

# get the duration values if needed
if args.bin_param == 'template_duration' and not args.f_lower:
    logging.info('Using template duration from the trigger file')
    trig_dur = True
else:
    trig_dur = False

# stat threshold to reduce trigger numbers
minth = min(args.stat_threshold)
abovethresh = stat >= minth
stat = stat[abovethresh]
tid = trigf[args.ifo+'/template_id'][:][abovethresh]
time = trigf[args.ifo+'/end_time'][:][abovethresh]
if trig_dur:
    tdur = trigf[args.ifo+'/template_duration'][:][abovethresh]
logging.info('%i trigs left after thresholding at %f' % (len(stat), minth))

# now do vetoing
for veto_file, veto_segment_name in zip(args.veto_file, args.veto_segment_name):
    retain, junk = events.veto.indices_outside_segments(time, [veto_file],
                             ifo=args.ifo, segment_name=veto_segment_name)
    stat = stat[retain]
    tid = tid[retain]
    time = time[retain]
    if trig_dur:
        tdur = tdur[retain]
    logging.info('%i trigs left after vetoing with %s' %
                                                   (len(stat), args.veto_file))

### Functions for doing the pruning (removal of trigs at loudest times)

def get_pars(args, tag, m1, m2, s1z, s2z):
    # here used for both pruning and binning params
    paramarg = getattr(args, tag+'_param')
    try:
        # will fail if m1 is a float rather than a sequence
        logging.info('Getting %s values for %i triggers' % (paramarg, len(m1)))
    except:
        pass
    return triggers.get_param(paramarg, args, m1, m2, s1z, s2z)

if args.prune_param:
    logging.info('Getting min and max param values')
    prpars = get_pars(args, 'prune',
                      templatef['mass1'][:], templatef['mass2'][:],
                      templatef['spin1z'][:], templatef['spin2z'][:])
    minprpar = min(prpars)
    maxprpar = max(prpars)
    del prpars
    logging.info('prune param range %f %f' % (minprpar, maxprpar))

    # hard-coded time window of 0.1s
    args.prune_window = 0.1
    # initialize bin storage
    prunedtimes = {}
    for i in range(args.prune_bins):
        prunedtimes[i] = []

    # keep a record of the triggers if all successive loudest events were to
    # be pruned
    statpruneall = copy.deepcopy(stat)
    tidpruneall = copy.deepcopy(tid)
    timepruneall = copy.deepcopy(time)

    # many trials may be required to prune in 'quieter' bins
    for j in range(1000):
        # are all the bins full already?
        numpruned = sum([len(prunedtimes[i]) for i in range(args.prune_bins)])
        if numpruned == args.prune_bins * args.prune_number:
            logging.info('Finished pruning!')
            break
        if numpruned > args.prune_bins * args.prune_number:
            logging.error('Uh-oh, we pruned too many things .. %i, to be '
                          'precise' % numpruned)
            raise RuntimeError
        loudest = np.argmax(statpruneall)
        lstat = statpruneall[loudest]
        ltid = tidpruneall[loudest]
        ltime = timepruneall[loudest]
        m1, m2, s1z, s2z = triggers.get_mass_spin(templatef, ltid)
        lbin = trstats.which_bin(get_pars(args, 'prune', m1, m2, s1z, s2z),
                                 minprpar, maxprpar,
                                 args.prune_bins, log=args.log_prune_param)
        # is the bin where the loudest trigger lives full already?
        if len(prunedtimes[lbin]) == args.prune_number:
            logging.info('%i - Bin %i full, not pruning event with stat %f at time '
                         '%.3f' % (j, lbin, lstat, ltime))
            # prune the reference trigger array
            retain = abs(timepruneall - ltime) > args.prune_window
            statpruneall = statpruneall[retain]
            tidpruneall = tidpruneall[retain]
            timepruneall = timepruneall[retain]
            del retain
            continue
        else:
            logging.info('Pruning event with stat %f at time %.3f in bin %i' %
                         (lstat, ltime, lbin))
            # now do the pruning
            retain = abs(time - ltime) > args.prune_window
            logging.info('%i trigs before pruning' % len(stat))
            stat = stat[retain]
            logging.info('%i trigs remain' % len(stat))
            tid = tid[retain]
            time = time[retain]
            if trig_dur:
                tdur = tdur[retain]
            # also for the reference trig arrays
            retain = abs(timepruneall - ltime) > args.prune_window
            statpruneall = statpruneall[retain]
            tidpruneall = tidpruneall[retain]
            timepruneall = timepruneall[retain]
            # record the time
            prunedtimes[lbin].append(ltime)
            del retain
    del statpruneall
    del tidpruneall
    del timepruneall

# get binning params after tuning
if trig_dur:
    binpars = tdur + args.min_duration
else:
    m1, m2, s1z, s2z = triggers.get_mass_spin(templatef, tid)
    binpars = get_pars(args, 'bin', m1, m2, s1z, s2z)
logging.info("Parameter range of triggers: %f - %f" %
                                                  (min(binpars), max(binpars)))

# remove triggers outside irregular bins
if args.bin_spacing == "irregular":
    logging.info("Removing triggers outside bin range %f - %f" %
                                      (min(args.bin_edges), max(args.bin_edges)))
    in_range = np.logical_and(binpars >= min(args.bin_edges),
                                  binpars <= max(args.bin_edges))
    binpars = binpars[in_range]
    stat = stat[in_range]
    tid = tid[in_range]
    logging.info("%i remain" % len(binpars))

# get the bins
# we assume that parvals are all positive
assert min(binpars) >= 0
pmin = 0.999 * min(binpars)
pmax = 1.001 * max(binpars)
if args.bin_spacing == "linear":
    pbins = bin_utils.LinearBins(pmin, pmax, args.num_bins)
elif args.bin_spacing == "log":
    pbins = bin_utils.LogarithmicBins(pmin, pmax, args.num_bins)
elif args.bin_spacing == "irregular":
    pbins = bin_utils.IrregularBins(args.bin_edges)

# list of bin indices
binind = [pbins[c] for c in pbins.centres()]
logging.info("Assigning trigger param values to bins")
# FIXME: This is slow!! either find a better way of using pylal.rate
# or write faster binning routine
pind = np.array([pbins[par] for par in binpars])

logging.info("Getting max counts in bins")
# determine trigger counts first to get plot limits to make them the same
# for all thresholds; use only the smallest threshold requested
bincounts = []
for i in binind:
    vals_inbin = stat[pind == i]
    bincounts.append(sum(vals_inbin >= minth))
maxcount = max(bincounts)
plotrange = np.linspace(0.95 * min(stat), 1.05 * max(stat), 100)

# initialize result storage
parbins = {}
counts = {}
templates = {}
alphas = {}
stdev = {}
ks_prob = {}
nabove = {}

histcolors = ['r',(1.0,0.6,0),'y','g','c','b','m','k',(0.8,0.25,0),(0.25,0.8,0)]

for th in args.stat_threshold:
    logging.info("Fitting above threshold %f" % th)
    counts[th] = {}
    alphas[th] = {}
    stdev[th] = {}
    ks_prob[th] = {}

    if args.output_file:
        fig = plt.figure()
    for i, lower, upper in zip(binind, pbins.lower(), pbins.upper()):
        # determine number of templates generating the triggers involved
        # for hdf5, use the template id; otherwise use masses
        tid_inbin = tid[pind == i]
        numtmpl = len(set(tid_inbin))
        templates[i] = numtmpl
        vals_inbin = stat[pind == i]
        counts[th][i] = sum(vals_inbin >= th)
        if len(vals_inbin) == 0:
            logging.info("No trigs in bin %f-%f", (lower, upper))
            continue
        # do the fit
        alpha, sig_alpha = trstats.fit_above_thresh(
                                             args.fit_function, vals_inbin, th)
        alphas[th][i] = alpha
        stdev[th][i] = sig_alpha
        _, ks_prob[th][i] = trstats.KS_test(
                                      args.fit_function, vals_inbin, alpha, th)
        # add histogram to plot
        histcounts, edges = np.histogram(vals_inbin, bins=50)
        cum_counts = histcounts[::-1].cumsum()[::-1]
        binlabel = r"%.3g - %.3g" % (lower, upper)
        # histogram of fitted values
        plt.semilogy(edges[:-1], cum_counts, linewidth=2,
                     color=histcolors[i], label=binlabel, alpha=0.6)
        # fit central value
        plt.semilogy(plotrange, counts[th][i] * \
                     trstats.cum_fit(args.fit_function, plotrange, alpha, th),
                     "--", color=histcolors[i],
                     label=r"$\alpha = $%.2f $\pm$ %.2f" % (alpha, sig_alpha))
        # 1sigma upper deviation on alpha
        plt.semilogy(plotrange, counts[th][i] * \
                     trstats.cum_fit(args.fit_function, plotrange, alpha + \
                     sig_alpha, th), ":", alpha=0.6, color=histcolors[i])
        # 1sigma lower deviation
        plt.semilogy(plotrange, counts[th][i] * \
                     trstats.cum_fit(args.fit_function, plotrange, alpha - \
                     sig_alpha, th), ":", alpha=0.6, color=histcolors[i])
    # finish the hist plot
    leg = plt.legend(labelspacing=0.2)
    unitstring = " (%s)" % args.bin_param_units if \
                                       args.bin_param_units is not None else ""
    leg.set_title(paramname+unitstring)
    plt.setp(leg.get_texts(), fontsize=11)
    plt.ylim(0.7, 2*maxcount)
    plt.xlim(0.9*minth, 1.1*max(plotrange))
    plt.grid()
    plt.xlabel(statname, size="large")
    plt.ylabel("cumulative number", size="large")
    if args.plot_dir:
        plt.title(args.ifo + " " + statname + " distribution split by " + \
              paramname)
        dest = plotbase + "_" + args.sngl_stat + "_cdf_by_" + paramtag[0:3] + \
                                            "_fit_thresh_" + str(th) + ".png"
        logging.info("Saving cumhist to %s" % dest)
        plt.savefig(dest)
    elif args.output_file:
        logging.info("Saving cumhist to %s" % args.output_file)
        results.save_fig_with_metadata(
            fig, args.output_file,
            title="%s: %s histogram of single detector triggers" % (args.ifo,
                  statname),
            caption=(r"Histogram of single detector %s values binned by %s "
                     "with fitted %s distribution parameterized by &alpha;"\
                     % (statname, paramname, args.fit_function)),
            cmd=" ".join(sys.argv)
        )
    plt.close()

# don't make any more plots if only making the single rainbow hist
if args.output_file:
    exit()

# make plots of alpha, trig count and KS significance
for th in args.stat_threshold:
    plt.errorbar(pbins.centres(), [alphas[th][i] for i in binind],
                 yerr=[stdev[th][i] for i in binind], fmt="+-",
                 label=args.ifo + " fit above %.2f" % th)
if args.bin_spacing == "log": plt.semilogx()
plt.xlim(0.03, 150)
plt.ylim(0, 10)
plt.grid()
plt.legend(loc="best")
plt.xlabel(paramname, size="large")
plt.ylabel(r"fit parameter $\alpha$", size='large')
plt.savefig(plotbase + '_alpha_vs_' + paramtag[0:3] + '.png')
plt.close()

for th in args.stat_threshold:
    plt.errorbar(pbins.centres(),
                 [float(counts[th][i])/templates[i] for i in binind],
                 yerr=[counts[th][i]**0.5/templates[i] for i in binind],
                 fmt="+-", label=args.ifo + " trigs above %.2f" % th)
if args.bin_spacing == "log": plt.semilogx()
plt.xlim(0.03, 150)
plt.grid()
plt.legend(loc="best")
plt.xlabel(paramname, size="large")
plt.ylabel(r"Triggers above threshold per template", size='large')
plt.savefig(plotbase + '_nabove_vs_' + paramtag[0:3] + '.png')
plt.close()

for th in args.stat_threshold:
    plt.plot(pbins.centres(), [ks_prob[th][i] for i in binind],
          '+--', label=args.ifo+' KS prob, thresh %.2f' % th)
if args.bin_spacing == 'log':
    plt.loglog()
else:
    plt.semilogy()
plt.xlim(0.03, 150)
plt.grid()
leg = plt.legend(loc='best', labelspacing=0.2)
plt.setp(leg.get_texts(), fontsize=11)
plt.xlabel(paramname, size='large')
plt.ylabel('KS test p-value')
plt.savefig(plotbase + '_KS_prob_vs_' + paramtag[0:3] + '.png')
plt.close()

logging.info('Done!')
