#!/usr/bin/python

# Copyright 2016 Thomas Dent
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.

from __future__ import division

import sys, h5py
import argparse, logging

from matplotlib import use
use('Agg')
from matplotlib import pyplot as plt

import copy, numpy as np
from scipy import stats as scistats

from pycbc import io, events, bin_utils, results
from pycbc.events import triggers
from pycbc.events.stat import sngl_statistic_dict
import pycbc.version

#### DEFINITIONS AND FUNCTIONS ####

def get_stat(statchoice, trigs):
    # Initialize statclass with an empty file list. In general could feed it
    # files here for statistics which need that.
    stat_instance = sngl_statistic_dict[statchoice]([])
    return stat_instance.single(trigs)

#### MAIN ####

parser = argparse.ArgumentParser(usage="", description="Plot trigger rates")

parser.add_argument("--version", action=pycbc.version.Version)
parser.add_argument("-V", "--verbose", action="store_true",
                    help="Print extra debugging information", default=False)
parser.add_argument("--trigger-file",
                    help="Input hdf5 file containing single triggers. "
                    "Required")
parser.add_argument("--bank-file", default=None,
                    help="hdf file containing template parameters. Required")
parser.add_argument("--veto-file", nargs='*', default=[], action='append',
                    help="File(s) in .xml format with veto segments to apply "
                    "to triggers before fitting")
parser.add_argument("--veto-segment-name", nargs='*', default=[], action='append',
                    help="Name(s) of veto segments to apply. Optional, if not "
                    "given all segments for a given ifo will be used")
parser.add_argument("--ifo", required=True,
                    help="Ifo producing triggers to be fitted. Required")
parser.add_argument("--gps-start-time", type=float,
                    help="Time from which to load and plot triggers")
parser.add_argument("--gps-end-time", type=float,
                    help="End time up to which to load/plot triggers")
parser.add_argument("--sngl-stat", default="new_snr",
                    choices=sngl_statistic_dict.keys(),
                    help="Function of SNR and chisq to threshold on")
parser.add_argument("--stat-factor", type=float,
                    help="Adjustable magic number used in some sngl "
                    "statistics. Values commonly used: 6 for new_snr, 250 "
                    "or 50 for effective_snr")
parser.add_argument("--stat-threshold", type=float,
                    help="Only plot triggers with statistic value above this "
                    "threshold")
parser.add_argument("--thorne-limit", action="store_true",
                    help="Remove triggers from templates with one or both "
                    "spins above 0.998")
parser.add_argument("--f-lower", type=float, default=0.,
                    help="Starting frequency for calculating template "
                    "duration; if not given, duration will be read from "
                    "single trigger files")
parser.add_argument("--bin-param", required=True,
                    help="Parameter over which to bin. Required. "
                    "Choose from mchirp, mtotal, template_duration or a named "
                    "frequency cutoff in pnutils or a frequency function in "
                    "LALSimulation")
parser.add_argument("--bin-spacing", choices=["linear", "log", "irreg"],
                    help="How to space parameter bin edges")
binspec = parser.add_mutually_exclusive_group()
binspec.add_argument("--num-bins", type=int,
                     help="Number of regularly spaced bins to use over the "
                     " parameter")
binspec.add_argument("--irregular-bins", type=float, nargs="*",
                     help="Boundaries of irregular bins")
parser.add_argument("--bin-param-units",
                    help="String to display units of the binning parameter")
parser.add_argument("--approximant", default="SEOBNRv4",
                    help="Approximant for template duration. Default SEOBNRv4")
parser.add_argument("--min-duration", default=0.,
                    help="Fudge factor for templates with tiny or negative "
                    "values of template_duration: add to duration values "
                    "before fitting. Units seconds")
parser.add_argument("--kde-bandwidth", type=float, default=1.,
                    help="Width of the smoothing kernel as compared to total "
                    "plot duration. 0.02 usually works OK")
parser.add_argument("--raw-rate", action="store_true",
                    help="Plot rates without normalizing to template count")
parser.add_argument("--log-y", action="store_true")
parser.add_argument("--output-file",
                    help="Name of file to save plot")

args = parser.parse_args()

args.veto_segment_name = sum(args.veto_segment_name, [])
args.veto_file = sum(args.veto_file, [])

if len(args.veto_segment_name) != len(args.veto_file):
    raise RuntimeError("Number of veto files much match veto file names")

if (args.gps_start_time or args.gps_end_time) and not (args.gps_start_time \
                                                       and args.gps_end_time):
    raise RuntimeError("I need both gps start time and end time!")

if args.verbose:
    log_level = logging.DEBUG
else:
    log_level = logging.WARN
logging.basicConfig(format='%(asctime)s : %(message)s', level=log_level)

statname = "reweighted SNR" if args.sngl_stat == "new_snr" else \
           args.sngl_stat.replace("_", " ").replace("snr", "SNR")
paramname = args.bin_param.replace("_", " ")
paramtag = args.bin_param.replace("_", "")

logging.info('Opening trigger file: %s' % args.trigger_file)
trigf = h5py.File(args.trigger_file, 'r')
logging.info('Opening template file: %s' % args.bank_file)
templatef = h5py.File(args.bank_file, 'r')

# get the stat values
stat = get_stat(args.sngl_stat, trigf[args.ifo])

# get the duration values if needed
if args.bin_param == 'template_duration' and not args.f_lower:
    logging.info('Using template duration from the trigger file')
    trig_dur = True
else:
    trig_dur = False

# stat threshold to reduce trigger numbers
abovethresh = stat >= args.stat_threshold
stat = stat[abovethresh]
tid = trigf[args.ifo+'/template_id'][:][abovethresh]
time = trigf[args.ifo+'/end_time'][:][abovethresh]
if trig_dur:
    tdur = trigf[args.ifo+'/template_duration'][:][abovethresh]
logging.info('%i trigs left after thresholding at %f' % (len(stat), args.stat_threshold))
del stat

if args.gps_start_time and args.gps_end_time:
    inside = np.logical_and(time >= args.gps_start_time, time < args.gps_end_time)
    #stat = stat[inside]
    tid = tid[inside]
    time = time[inside]
    if trig_dur:
        tdur = tdur[inside]
    logging.info('%i trigs left after restricting gps times' % len(time))

# now do vetoing
for vfile, vsegmentname in zip(args.veto_file, args.veto_segment_name):
    retain, junk = events.veto.indices_outside_segments(time, [vfile],
                                       ifo=args.ifo, segment_name=vsegmentname)
    #stat = stat[retain]
    tid = tid[retain]
    time = time[retain]
    if trig_dur:
        tdur = tdur[retain]
    logging.info('%i trigs left after vetoing with %s' %
                                                   (len(time), args.veto_file))

# get a minimum time for plotting purposes
if not args.gps_start_time:
    args.gps_start_time = int(time.min())

plottime = time - args.gps_start_time

if args.thorne_limit:
    m1, m2, s1z, s2z = triggers.get_mass_spin(templatef, tid)
    inside = np.logical_and(abs(s1z) < 0.998, abs(s2z) < 0.998)
    #stat = stat[inside]
    tid = tid[inside]
    time = time[inside]
    if trig_dur:
        tdur = tdur[inside]

def get_pars(args, tag, m1, m2, s1z, s2z):
    # used for binning params
    paramarg = getattr(args, tag+'_param')
    try:
        # will fail if m1 is a float rather than a sequence
        logging.info('Getting %s values for %i triggers' % (paramarg, len(m1)))
    except:
        pass
    return triggers.get_param(paramarg, args, m1, m2, s1z, s2z)

# get binning params
if trig_dur:
    binpars = tdur + args.min_duration
else:
    m1, m2, s1z, s2z = triggers.get_mass_spin(templatef, tid)
    binpars = get_pars(args, 'bin', m1, m2, s1z, s2z)
logging.info("Parameter range of triggers: %f - %f" %
                                                  (min(binpars), max(binpars)))

# get the bins
# we assume that parvals are all positive
assert min(binpars) >= 0
pmin = 0.999 * min(binpars)
pmax = 1.001 * max(binpars)
bincolors = ['r',(1.0,0.65,0),#'y',
             'g','c','b','m','k',(0.8,0.25,0),(0.25,0.8,0)]
if args.bin_spacing == "linear":
    pbins = bin_utils.LinearBins(pmin, pmax, args.num_bins)
elif args.bin_spacing == "log":
    pbins = bin_utils.LogarithmicBins(pmin, pmax, args.num_bins)
elif args.bin_spacing == "irreg":
    # allow bins in reverse order!
    if args.irregular_bins[1] < args.irregular_bins[0]:
        args.irregular_bins = args.irregular_bins[::-1]
        #bincolors = bincolors[::-1]
    pbins = bin_utils.IrregularBins(args.irregular_bins)

# list of bin indices
binind = [pbins[c] for c in pbins.centres()]
logging.info("Assigning trigger param values to bins")
# This is slow!! Either find a better way of using pylal.rate or write faster binning routine
pind = np.array([pbins[par] for par in binpars])

# initialize result storage
bincounts = {}
bintemplates = {}
maxtime = int(plottime.max()) + 1

minrate = 10000
maxrate = 0.0001

fig = plt.figure()
for i, lower, upper in zip(binind, pbins.lower(), pbins.upper()):
    # determine number of templates generating the triggers involved
    # using the template id
    tid_inbin = tid[pind == i]
    bintemplates[i] = len(set(tid_inbin))
    times_inbin = plottime[pind == i]
    bincounts[i] = len(times_inbin)
    print(bincounts[i])
    if len(times_inbin) == 0:
        logging.info("No trigs in bin %f-%f", (lower, upper))
        continue
    # add KDE to plot
    kd = scistats.gaussian_kde(times_inbin, bw_method=args.kde_bandwidth)
    xplot = np.linspace(0, maxtime, int(10 / args.kde_bandwidth))
    # if the KDE is a normalized PDF, need to rescale to turn it into a rate
    if args.raw_rate:
        yplot = kd(xplot) * bincounts[i]
    else:
        yplot = kd(xplot) * bincounts[i] / bintemplates[i]
    minrate = min(minrate, yplot.max()/5e2)
    print(minrate)
    maxrate = max(maxrate, yplot.max())
    binlabel = r"%.3g - %.3g" % (lower, upper)
    plt.plot(xplot, yplot, '-', c=bincolors[i], label=binlabel)

# finish the plot
leg = plt.legend(labelspacing=0.2, loc='lower center')
unitstring = " (%s)" % args.bin_param_units if \
                                       args.bin_param_units is not None else ""
leg.set_title(paramname+unitstring)
plt.setp(leg.get_texts(), fontsize=11)
if args.log_y:
    plt.yscale('log')
plt.ylim(minrate, 1.4*maxrate)
plt.xlim(0, maxtime)
plt.grid()
plt.xlabel("GPS time after %i" % args.gps_start_time, size="large")
if args.rate:
    plt.ylabel(r"Trigger rate (s$^{-1}$)", size="large")
else:
    plt.ylabel(r"Rate per template (s$^{-1}$)", size="large")
logging.info("Saving to %s" % args.output_file)
results.save_fig_with_metadata(
    fig, args.output_file,
    title="%s: rate of triggers above a %s threshold %f" % (args.ifo,
              statname, args.stat_threshold),
    caption=(r"Rate of %s single detector triggers thresholded on %s" \
             % (args.ifo, statname)),
    cmd=" ".join(sys.argv)
    )
plt.close()

logging.info('Done!')
