#!/usr/bin/python

# Copyright 2016 Thomas Dent
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.

from __future__ import division

import sys, h5py
import argparse, logging

import numpy as np

from pycbc.events import triggers
import pycbc.version

parser = argparse.ArgumentParser(usage="",
    description="Smooth (regress) the dependence of coefficients describing "
                "single-ifo background trigger distributions on a template "
                "parameter, to suppress random noise in the resulting "
                "background model.")

parser.add_argument("--version", action=pycbc.version.Version)
parser.add_argument("-V", "--verbose", action="store_true",
                    help="Print extra debugging information", default=False)
parser.add_argument("--template-fit-file",
                    help="Input hdf5 file containing fit coefficients for each"
                         " individual template. Required")
parser.add_argument("--bank-file", default=None,
                    help="hdf file containing template parameters. Required "
                         "unless reading param from template fit file")
parser.add_argument("--output", required=True,
                    help="Location for output file containing smoothed fit "
                         "coefficients.  Required")
parser.add_argument("--use-template-fit-param", action="store_true",
                    help="Use parameter values stored in the template fit file"
                         "as template_param for smoothing.", default=False)
parser.add_argument("--fit-param",
                    help="Parameter over which to regress the background "
                         "fit coefficients. Required. Either read from "
                         "template fit file or choose from mchirp, mtotal, "
                         "tau_0, tau_3, template_duration, a frequency "
                         "cutoff in pnutils or a frequency function in "
                         "LALSimulation.")
parser.add_argument("--approximant", default="SEOBNRv4",
                    help="Approximant for template duration. Default SEOBNRv4")
parser.add_argument("--f-lower", type=float, default=0.,
                    help="Starting frequency for calculating template "
                         "duration, if not reading from the template fit file")
parser.add_argument("--min-duration", type=float, default=0.,
                    help="Fudge factor for templates with tiny or negative "
                         "values of template_duration: add to duration values"
                         " before fitting. Units seconds.")
parser.add_argument("--log-param", action='store_true',
                    help="Take the log of the fit param before smoothing.")
parser.add_argument("--regression-method", required=True,
                    choices=["nn", "tricube"],
                    help="Method of smoothing over the chosen fit param. "
                         "Required.")
parser.add_argument("--num-neighbors", type=int, default=-1,
                    help="Number of neighbors used in nn method. Try 2500, or "
                         "1/10th the total number of templates if that is "
                         "smaller.")
parser.add_argument("--smoothing-width", type=float, required=True,
                    help="Distance in the space of fit param values (or the "
                         "logs of them) to smooth over. Required. For log "
                         "template duration, try 0.2")

args = parser.parse_args()

if args.regression_method == 'nn' and args.num_neighbors < 1:
    raise RuntimeError("Need to give a positive number of nearest neighbors!")

if args.verbose:
    log_level = logging.DEBUG
else:
    log_level = logging.WARN
logging.basicConfig(format='%(asctime)s : %(message)s', level=log_level)

fits = h5py.File(args.template_fit_file, 'r')

# get the ifo from the template-level fit
ifo = fits.attrs['ifo']

# get template id and template parameter values
tid = fits['template_id'][:]
if args.use_template_fit_param:
    logging.info('Reading %s values from the template fit file' %
                                                                args.fit_param)
    # check we're asking for the right param name
    assert fits.attrs['save_trig_param'] == args.fit_param
    parvals = fits['template_param'][:]

else:
    logging.info('Calculating template parameter values')
    bank = h5py.File(args.bank_file, 'r')
    m1, m2, s1z, s2z = triggers.get_mass_spin(bank, tid)
    parvals = triggers.get_param(args.fit_param, args, m1, m2, s1z, s2z) 

if 'count_in_template' in fits.keys():  # recently introduced extra dataset
    tcount = True
else:
    tcount = False

nabove = fits['count_above_thresh'][:]
nabove = nabove/np.mean(nabove)
if tcount: ntotal = fits['count_in_template'][:]
# for an exponential fit 1/alpha is linear in the trigger statistic values
# so taking weighted sums/averages of 1/alpha is appropriate
invalpha = 1./(fits['fit_coeff'][:])

# sort in ascending parameter order
parsort = np.argsort(parvals)
tid = tid[parsort]
parvals = parvals[parsort]
nabove = nabove[parsort]
if tcount: ntotal = ntotal[parsort]
invalpha = invalpha[parsort]

if args.log_param:
    logging.info('Using log %s to perform smoothing' % args.fit_param)
    parvals = np.log(parvals)
else:
    logging.info('Using %s to perform smoothing' % args.fit_param)

# do nearest-neighbours regression
# use Gaussian weight over fitting parameter
if args.regression_method == 'nn':
    # only import scikit-learn if and when needed
    from sklearn import neighbors
    weights = lambda d:np.exp(-0.5 * (d/args.smoothing_width)**2)
    knn = neighbors.KNeighborsRegressor(args.num_neighbors, weights=weights)

    logging.info('Smoothing nabove data')
    nabove_knn = knn.fit(parvals[:, np.newaxis], nabove)
    logging.info('Evaluating smoothed nabove')
    nabove_smoothed = [nabove_knn.predict([[p]]) for p in parvals]
    del nabove_knn

    if tcount:
        logging.info('Smoothing ntotal data')
        ntotal_knn = knn.fit(parvals[:, np.newaxis], ntotal)
        logging.info('Evaluating smoothed ntotal')
        ntotal_smoothed = [ntotal_knn.predict([[p]]) for p in parvals]
        del ntotal_knn

    logging.info('Smoothing invalpha data')
    # smooth the inverse alpha values times trig count above threshold
    invalphan = invalpha * nabove
    invalphan_knn = knn.fit(parvals[:, np.newaxis], invalphan)
    logging.info('Evaluating smoothed invalpha')
    # loop over parameter values to avoid memory error in knn predict
    invalphan_smoothed = [invalphan_knn.predict([[p]]) for p in parvals]
    del invalphan_knn
    # divide out by smoothed trig count
    invalpha_smoothed = np.array(invalphan_smoothed) / np.array(nabove_smoothed)

elif args.regression_method == 'tricube':
    # do Nadaraya-Watson kernel regression
    # i.e. weighted average with parameter-dependent weights
    invalphan = invalpha * nabove
    invalpha_smoothed = []
    nabove_smoothed = []
    if tcount: ntotal_smoothed = []
    logging.info('Evaluating smoothed invalpha and n_above')
    for i, p in enumerate(parvals):
        if not i % 100:
            logging.info('Smoothing template %i', i)
        # find parameter values that go into average
        par_diff = parvals - p
        in_kernel = abs(par_diff) < args.smoothing_width
        # tri-cube kernel
        weights = (1 - abs(par_diff[in_kernel])**3)**3
        norm = weights.sum()
        # weighted average
        n_sm = (nabove[in_kernel] * weights).sum() / norm
        if tcount: nt_sm = (ntotal[in_kernel] * weights).sum() / norm
        invalphan_smoothed = (invalphan[in_kernel] * weights).sum() / norm
        invalpha_smoothed.append(invalphan_smoothed / n_sm)
        nabove_smoothed.append(n_sm)
        ntotal_smoothed.append(nt_sm)

# store template-dependent fit output
outfile = h5py.File(args.output, 'w')
outfile.create_dataset('template_id', data=tid)
outfile.create_dataset('count_above_thresh', data=nabove_smoothed)
if tcount: outfile.create_dataset('count_in_template', data=ntotal_smoothed)
outfile.create_dataset('fit_coeff', data=1. / np.array(invalpha_smoothed))
outfile.create_dataset('param_val', data=(np.exp(parvals) if args.log_param
    else parvals))

# add metadata, some is inherited from template level fit
outfile.attrs.create('ifo', data=ifo)
outfile.attrs.create('stat_threshold', data=fits.attrs['stat_threshold'])
outfile.attrs.create('fit_param', data=args.fit_param)
outfile.attrs.create('regression_method', data=args.regression_method)
if args.num_neighbors > 0:
    outfile.attrs.create('n_neighbors', data=args.num_neighbors)
outfile.attrs.create('smoothing_width', data=args.smoothing_width)
if 'analysis_time' in fits.attrs:    
    outfile.attrs['analysis_time'] = fits.attrs['analysis_time']

# add a magic file attribute so that coinc_findtrigs can parse it
outfile.attrs.create('stat', data=ifo+'-fit_coeffs')

outfile.close()
logging.info('Done!')
