#!/usr/bin/env python

# Copyright 2023 Jam Sadiq, Praveen Kumar
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.

import numpy, h5py, operator, argparse, logging
from pycbc import init_logging
import pycbc.conversions as convert
from pycbc import libutils
from pycbc.events import triggers
akde = libutils.import_optional('awkde')
kf = libutils.import_optional('sklearn.model_selection')


parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('--signal-file', help='File with parameters of GW signals '
                    'for KDE calculation')
parser.add_argument('--template-file', required=True, help='Hdf5 file with '
                    'template masses and spins')
parser.add_argument('--min-mass', type=float, default=None,
                    help='Used only on signal masses: remove all' 
                         'signal events with mass2 < min_mass')
parser.add_argument('--nfold-signal', type=int,
                    help='Number of k-folds for signal KDE cross validation')
parser.add_argument('--nfold-template', type=int,
                    help='Number of k-folds for template KDE cross validation')
parser.add_argument('--fit-param', nargs='+', required=True,
                    help='Parameters over which KDE is calculated')
parser.add_argument('--log-param', nargs='+', choices=['True', 'False'], 
                    required=True)
parser.add_argument('--output-file', required=True, help='Name of .hdf output')
parser.add_argument('--make-signal-kde', action='store_true')
parser.add_argument('--make-template-kde', action='store_true')
parser.add_argument('--fom-plot', help='Make a FOM plot for cross-validation'
                    ' and save it as this file')
parser.add_argument('--alpha-grid', type=float, nargs="+",
                    help='Grid of choices of sensitivity parameter alpha for'
                         ' local bandwidth')
parser.add_argument('--bw-grid', type=float, nargs='+', 
                    help='Grid of choices of global bandwidth')
parser.add_argument('--verbose', action='count')
args = parser.parse_args()
init_logging(args.verbose)

assert len(args.fit_param) == len(args.log_param)
if args.make_signal_kde and args.make_template_kde:
    parser.error("Choose only one option out of --make-signal-kde and \
                 --make-template-kde")


def kde_awkde(x, x_grid, alp=0.5, gl_bandwidth=None, ret_kde=False):
    if gl_bandwidth is None:  # use default from awkde
        kde = akde.GaussianKDE(alpha=alp, diag_cov=True)
    else:
        kde = akde.GaussianKDE(glob_bw=gl_bandwidth, alpha=alp, diag_cov=True)

    kde.fit(x)
    y = kde.predict(x_grid)

    if ret_kde == True:
        return kde, y
    return y


def kfcv_awkde(sample, bwchoice, alphachoice, k=2):
    """
    Evaluate the K-fold cross validated log likelihood for an awKDE with
    specific bandwidth and sensitivity (alpha) parameters
    """
    fomlist = []
    kfold = kf.KFold(n_splits=k, shuffle=True, random_state=None)
    for train_index, test_index in kfold.split(sample):
        train, test = sample[train_index], sample[test_index]
        y = kde_awkde(train, test, alp=alphachoice, gl_bandwidth=bwchoice)
        # Figure of merit : log likelihood for training samples
        fomlist.append(numpy.sum(numpy.log(y)))

    # Return the sum over all K sets of training samples
    return numpy.sum(fomlist)


def optimizedparam(sampleval, bwgrid, alphagrid, nfold=2):
    npoints, ndim = sampleval.shape
    FOM = {}
    for gbw in bwgrid:
        for alphavals in alphagrid:
            FOM[(gbw, alphavals)] = kfcv_awkde(sampleval, gbw, alphavals, \
                                               k=nfold)
    optval = max(FOM.items(), key=operator.itemgetter(1))[0]
    optbw, optalpha = optval[0], optval[1]
    maxFOM = FOM[(optbw, optalpha)]

    # Plotting FOM parameters
    if args.fom_plot:
        import matplotlib.pyplot as plt
        fig = plt.figure(figsize=(12,8))
        ax = fig.add_subplot(111)
        for bw in bwgrid:
            FOMlist = [FOM[(bw, al)] for al in alphagrid]
            ax.plot(alphagrid, FOMlist, label='{0:.3f}'.format(bw))
        ax.plot(optalpha, maxFOM, 'ko', linewidth=10, label= \
                r'$\alpha={0:.3f},bw={1:.3f}$'.format(optalpha, optbw))
        ax.set_xlabel(r'$\alpha$', fontsize=15)
        ax.set_ylabel(r'$FOM$', fontsize=15)
        # Guess at a suitable range of FOM values to plot
        ax.set_ylim(maxFOM - 0.5 * npoints, maxFOM + 0.2 * npoints)
        ax.legend(loc='upper center', bbox_to_anchor=(0.5, 1.135), ncol=8)
        plt.savefig(args.fom_plot)
        plt.close()

    return optbw, optalpha


# Obtaining template parameters
temp_file = h5py.File(args.template_file, 'r')
mass1 = temp_file['mass1']
tid = numpy.arange(len(mass1))  # Array of template ids
mass_spin = triggers.get_mass_spin(temp_file, tid)

f_dest = h5py.File(args.output_file, 'w')
f_dest.create_dataset("template_id", data=tid)
template_pars = []
for param, slog in zip(args.fit_param, args.log_param):
    pvals = triggers.get_param(param, args, *mass_spin)
    # Write the KDE param values to output file
    f_dest.create_dataset(param, data=pvals)
    if slog in ['False']:
        logging.info('Using param: %s', param)
        template_pars.append(pvals)
    elif slog in ['True']:
        logging.info('Using log param: %s', param)
        template_pars.append(numpy.log(pvals))
    else:
        raise ValueError("invalid log param argument, use 'True', or 'False'")

# Copy standard data to output file
f_dest.attrs['fit_param'] = args.fit_param
f_dest.attrs['log_param'] = args.log_param
with h5py.File(args.template_file, "r") as f_src:
    f_src.copy(f_src["./"], f_dest["./"], "input_template_params")
temp_samples = numpy.vstack((template_pars)).T

if args.make_template_kde:
    logging.info('Starting optimization of template KDE parameters')
    optbw, optalpha = optimizedparam(temp_samples, alphagrid=args.alpha_grid,\
                              bwgrid=args.bw_grid, nfold=args.nfold_template)
    logging.info('Bandwidth %.4f, alpha %.2f' % (optbw, optalpha))
    logging.info('Evaluating template KDE')
    template_kde = kde_awkde(temp_samples, temp_samples, alp=optalpha, \
                             gl_bandwidth=optbw)
    f_dest.create_dataset("data_kde", data=template_kde)
    f_dest.attrs['stat'] = "template-kde_file"

# Obtaining signal parameters
if args.make_signal_kde:
    signal_pars = []
    signal_file = numpy.genfromtxt(args.signal_file, dtype=float,
                                   delimiter=',', names=True)
    mass2_sgnl = signal_file['mass2']
    N_original = len(mass2_sgnl)
    if args.min_mass:
        idx = mass2_sgnl > args.min_mass
        mass2_sgnl = mass2_sgnl[idx]
        logging.info('%i triggers out of %i with MASS2 > %s' %
                         (len(mass2_sgnl), N_original, str(args.min_mass)))
    else:
        idx = numpy.full(N_original, True)
    mass1_sgnl = signal_file['mass1'][idx]
    assert min(mass1_sgnl - mass2_sgnl) > 0
 
    for param, slog in zip(args.fit_param, args.log_param):
        pvals = signal_file[param][idx]
        if slog in ['False']:
            logging.info('Using param: %s', param)
            signal_pars.append(pvals)
        elif slog in ['True']:
            logging.info('Using log param: %s', param)
            signal_pars.append(numpy.log(pvals))
        else:
            raise ValueError("invalid log param argument, use 'True', \
                             or 'False'")
    
    signal_samples = numpy.vstack((signal_pars)).T
    logging.info('Starting optimization of signal KDE parameters')  
    optbw, optalpha = optimizedparam(signal_samples, bwgrid=args.bw_grid, \
                      alphagrid=args.alpha_grid, nfold=args.nfold_signal)
    logging.info('Bandwidth %.4f, alpha %.2f' % (optbw, optalpha))
    logging.info('Evaluating signal KDE')
    signal_kde = kde_awkde(signal_samples, temp_samples, \
                           alp=optalpha, gl_bandwidth=optbw)
    f_dest.create_dataset("data_kde", data=signal_kde)
    f_dest.attrs['stat'] = "signal-kde_file"

f_dest.attrs['alpha'] = optalpha
f_dest.attrs['bandwidth'] = optbw
f_dest.close()

logging.info('Done!')
