#!/usr/bin/python

# Copyright 2016 Thomas Dent, Alex Nitz, Gareth Cabourn Davies
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.


import sys, h5py, argparse, logging, pycbc.version, numpy
from scipy.stats import norm
from pycbc.events import triggers


def dist(i1, i2, parvals, smoothing_width):
    """
    Computes the vector of parameter values at index/indices i1 and
    index/indices i2, and gives the Euclidean distance between
    the two with a metric of 1/(smoothing width^2)
    """
    dsq = 0
    for v, s in zip(parvals, smoothing_width):
        dsq += (v[i2] - v[i1]) ** 2.0 / s ** 2.0
    return dsq ** 0.5


def smooth_templates(nabove, invalphan, ntotal, template_idx,
                     weights=None):
    """
    Find the smoothed values according to the specified templates,
    weighted appropriately.
    The max likelihood fit for 1/alpha is linear in the trigger
    statistic values, so we perform a possibly-weighted average
    of (n_above / alpha) over templates and then invert this
    and multiply by (smoothed) nabove to obtain smoothed alpha.

    Parameters
    ----------
    nabove: ndarray
        The array of counts of triggers above threshold for all templates
    invalphan: ndarray
        The array of n_above / alpha values for all templates
    ntotal: ndarray
        The array of count of triggers in the template, both above and
        below threshold
    template_idx: ndarray of ints
        The indices of the templates to be used for the smoothing

    Optional Parameters
    -------------------
    weights: ndarray
        Weighting factor to apply to the templates specified by template_idx

    Returns
    -------
    tuple: 3 floats
        First float: the smoothed count above threshold value
        Second float: the smoothed fit coefficient (alpha) value
        Third float: the smoothed total count in template value

    """
    if weights is None: weights = numpy.ones_like(template_idx)
    nabove_t_smoothed = numpy.average(nabove[template_idx], weights=weights)
    ntotal_t_smoothed = numpy.average(ntotal[template_idx], weights=weights)
    invalphan_mean = numpy.average(invalphan[template_idx], weights=weights)

    return_tuple = (nabove_t_smoothed,
                    nabove_t_smoothed / invalphan_mean,
                    ntotal_t_smoothed)
    return return_tuple


def smooth_tophat(nabove, invalphan, ntotal, dists):
    """
    Smooth templates using a tophat function with templates within unit
    dists
    """
    idx_within_area = numpy.flatnonzero(dists < 1.)
    return smooth_templates(nabove,
                            invalphan,
                            ntotal,
                            idx_within_area)


def smooth_n_closest(nabove, invalphan, ntotal, dists, total_trigs=500):
    """
    Smooth templates according to the closest N templates
    No weighting is applied
    """
    dist_sort = numpy.argsort(dists)
    templates_required = 0
    n_triggers = 0
    # Count number of templates required to gather total_trigs templates,
    # start at closest
    # total_trigs, if supplied on command line will be a str so convert to int
    while n_triggers < int(total_trigs):
        n_triggers += nabove[dist_sort[n_triggers]]
        templates_required += 1
    logging.debug("%d templates required to obtain %d triggers",
                  templates_required, n_triggers)
    idx_to_smooth = dist_sort[:templates_required]
    return smooth_templates(nabove, invalphan, ntotal, idx_to_smooth)


def smooth_distance_weighted(nabove, invalphan, ntotal, dists):
    """
    Smooth templates weighted according to dists in a unit-width normal
    distribution, truncated at three sigma
    """
    idx_within_area = numpy.flatnonzero(dists < 3.)
    weights = norm.pdf(dists[idx_within_area])
    return smooth_templates(nabove, invalphan, ntotal,
                            idx_within_area, weights=weights)

_smooth_dist_func = {
    'smooth_tophat': smooth_tophat,
    'n_closest': smooth_n_closest,
    'distance_weighted': smooth_distance_weighted
}


def smooth(nabove, invalphan, ntotal, dists, smoothing_method, **kwargs):
    """
    Wrapper for smoothing according to a function defined by smoothing_method

    nabove, invalphan, ntotal are as defined in the above smooth_templates
    function docstring

    dists is an array of the distances of the templates from the
    template of interest
    """
    return _smooth_dist_func[smoothing_method](nabove, invalphan,
                                               ntotal, dists, **kwargs)


parser = argparse.ArgumentParser(usage="",
    description="Smooth (regress) the dependence of coefficients describing "
                "single-ifo background trigger distributions on a template "
                "parameter, to suppress random noise in the resulting "
                "background model.")

parser.add_argument("--version", action=pycbc.version.Version)
parser.add_argument("-V", "--verbose", action="store_true",
                    help="Print extra debugging information", default=False)
parser.add_argument("--template-fit-file",
                    help="hdf5 file containing fit coefficients for each"
                         " individual template. Required")
parser.add_argument("--bank-file", default=None,
                    help="hdf file containing template parameters. Required "
                         "unless reading param from template fit file")
parser.add_argument("--output", required=True,
                    help="Location for output file containing smoothed fit "
                         "coefficients.  Required")
parser.add_argument("--use-template-fit-param", action="store_true",
                    help="Use parameter values stored in the template fit "
                         "file as template_param for smoothing.")
parser.add_argument("--fit-param", nargs='+',
                    help="Parameter(s) over which to regress the background "
                         "fit coefficients. Required. Either read from "
                         "template fit file or choose from mchirp, mtotal, "
                         "chi_eff, eta, tau_0, tau_3, template_duration, "
                         "a frequency cutoff in pnutils or a frequency function"
                         "in LALSimulation. To regress the background over "
                         "multiple parameters, provide them as a list.")
parser.add_argument("--approximant", default="SEOBNRv4",
                    help="Approximant for template duration. Default SEOBNRv4")
parser.add_argument("--f-lower", type=float, default=0.,
                    help="Starting frequency for calculating template "
                         "duration, if not reading from the template fit file")
parser.add_argument("--min-duration", type=float, default=0.,
                    help="Fudge factor for templates with tiny or negative "
                         "values of template_duration: add to duration values"
                         " before fitting. Units seconds.")
parser.add_argument("--log-param", nargs='+',
                    help="Take the log of the fit param before smoothing.")
parser.add_argument("--smoothing-width", type=float, nargs='+', required=True,
                    help="Distance in the space of fit param values (or the "
                         "logs of them) to smooth over. Required. "
                         "This must be a list corresponding to the smoothing "
                         "parameters.")
parser.add_argument("--smoothing-method", default="smooth_tophat",
                    choices = _smooth_dist_func.keys(),
                    help="Method used to smooth the fit parameters; "
                         "'smooth_tophat' (default) finds all templates within "
                         "unit distance from the template of interest "
                         "(distance normalised by --smoothing-width). "
                         "'n_closest' adds the closest templates to "
                         "the smoothing until 500 triggers are reached. "
                         "'distance_weighted' weights the closest templates "
                         "with a normal distribution of width smoothing-width "
                         "trucated at three smoothing-widths.")
parser.add_argument("--smoothing-keywords", nargs='*',
                    help="Keywords for the smoothing function, supplied "
                         "as key:value pairs, e.g. total_trigs:500 to define "
                         "the number of templates in the n_closest smoothing "
                         "method")
parser.add_argument("--output-fits-by-template", action='store_true',
                    help="If given, will output the input file fits to "
                         "fit_by_template group")
args = parser.parse_args()

if args.smoothing_keywords:
    smooth_kwargs = args.smoothing_keywords
else:
    smooth_kwargs = []

kwarg_dict = {}
for inputstr in smooth_kwargs:
    try:
        key, value = inputstr.split(':')
        kwarg_dict[key] = value
    except ValueError:
            err_txt = "--smoothing-keywords must take input in the " \
                      "form KWARG1:VALUE1 KWARG2:VALUE2 KWARG3:VALUE3 ... " \
                      "Received {}".format(' '.join(args.smoothing_keywords))
            raise ValueError(err_txt)

assert len(args.log_param) == len(args.fit_param) == len(args.smoothing_width)

pycbc.init_logging(args.verbose)

fits = h5py.File(args.template_fit_file, 'r')

# get the ifo from the template-level fit
ifo = fits.attrs['ifo']

# get template id and template parameter values
tid = fits['template_id'][:]

logging.info('Calculating template parameter values')
bank = h5py.File(args.bank_file, 'r')
m1, m2, s1z, s2z = triggers.get_mass_spin(bank, tid)

parvals = []

for param, slog in zip(args.fit_param, args.log_param):
    data = triggers.get_param(param, args, m1, m2, s1z, s2z)
    if slog in ['false', 'False', 'FALSE']:
        logging.info('Using param: %s', param)
        parvals.append(data)
    elif slog in ['true', 'True', 'TRUE']:
        logging.info('Using log param: %s', param)
        parvals.append(numpy.log(data))
    else:
        raise ValueError("invalid log param argument, use 'true', or 'false'")

nabove = fits['count_above_thresh'][:]
ntotal = fits['count_in_template'][:]
# For an exponential fit 1/alpha is linear in the trigger statistic values
# so calculating weighted sums or averages of 1/alpha is appropriate
invalpha = 1. / fits['fit_coeff'][:]
invalphan = invalpha * nabove

nabove_smoothed = []
ntotal_smoothed = []
alpha_smoothed = []
rang = numpy.arange(0, len(nabove))

logging.info("Smoothing ...")

# Handle the one-dimensional case of tophat smoothing separately
# as it is easier to optimize computational performance.
if len(parvals) == 1 and args.smoothing_method == 'smooth_tophat':
    logging.info("Using efficient 1D tophat smoothing")
    sort = parvals[0].argsort()
    parvals_0 = parvals[0][sort]

    # For each template, find the range of nearby templates which fall within
    # the chosen window.
    left = numpy.searchsorted(parvals_0, parvals[0] - args.smoothing_width[0])
    right = numpy.searchsorted(parvals_0, parvals[0] + args.smoothing_width[0]) - 1

    del parvals_0
    # Precompute the sums so we can quickly look up differences between
    # templates
    ntsum = ntotal.cumsum()
    nasum = nabove.cumsum()
    invsum = invalphan.cumsum()
    num = right - left

    ntotal_smoothed = (ntsum[right] - ntsum[left]) / num
    nabove_smoothed = (nasum[right] - nasum[left]) / num
    invmean = (invsum[right] - invsum[left]) / num
    alpha_smoothed = nabove_smoothed / invmean

else:
    for i in range(len(nabove)):
        d = dist(i, rang, parvals, args.smoothing_width)
        smoothed_tuple = smooth(nabove, invalphan, ntotal, d,
                                args.smoothing_method, **kwarg_dict)
        nabove_smoothed.append(smoothed_tuple[0])
        alpha_smoothed.append(smoothed_tuple[1])
        ntotal_smoothed.append(smoothed_tuple[2])

logging.info("Writing output")
outfile = h5py.File(args.output, 'w')
outfile['template_id'] = tid
outfile['count_above_thresh'] = nabove_smoothed
outfile['fit_coeff'] = alpha_smoothed
outfile['count_in_template'] = ntotal_smoothed
try:
    outfile['median_sigma'] = fits['median_sigma'][:]
except KeyError:
    logging.info('Median_sigma dataset not present in input file')

for param, vals, slog in zip(args.fit_param, parvals, args.log_param):
    if slog in ['false', 'False', 'FALSE']:
        outfile[param] = vals
    elif slog in ['true', 'True', 'TRUE']:
        outfile[param] = numpy.exp(vals)

if args.output_fits_by_template:
    outfile.create_group('fit_by_template')
    for k in ['count_above_thresh', 'fit_coeff', 'count_in_template']:
        outfile['fit_by_template'][k] = fits[k][:]

# Add metadata, some is inherited from template level fit
outfile.attrs['ifo'] = ifo
outfile.attrs['stat_threshold'] = fits.attrs['stat_threshold']
if 'analysis_time' in fits.attrs:
    outfile.attrs['analysis_time'] = fits.attrs['analysis_time']

# Add a magic file attribute so that coinc_findtrigs can parse it
outfile.attrs['stat'] = ifo + '-fit_coeffs'
logging.info('Done!')
