#!/usr/bin/env python

import numpy, h5py, operator, argparse, logging
from pycbc import init_logging
import pycbc.conversions as convert
from pycbc import libutils
akde = libutils.import_optional('awkde')
kf = libutils.import_optional('sklearn.model_selection')

parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('--signal-file')
parser.add_argument('--template-file', required=True)
parser.add_argument('--min-mass', type=float, default=None, required=True)
parser.add_argument('--output-file', required=True)
parser.add_argument('--make-signal-kde', action='store_true')
parser.add_argument('--make-template-kde', action='store_true')
parser.add_argument('--verbose', action='count')
args = parser.parse_args()
init_logging(args.verbose)

def kde_awkde(x, x_grid, alp=0.5, gl_bandwidth='silverman', ret_kde=False):
    kde = akde.GaussianKDE(glob_bw=gl_bandwidth, alpha=alp, diag_cov=True)
    kde.fit(x)
    if isinstance(x_grid, (list, tuple, numpy.ndarray)) == False:
        y = kde.predict(x_grid)
    else:
        y = kde.predict(x_grid)
    if ret_kde == True:
        return kde, y
    return y

def kfcv_awkde(sample, bwchoice, alphachoice, k=2):
    if bwchoice not in ['silverman', 'scott']:
        bwchoice = float(bwchoice)
    fomlist = []
    kfold = kf.KFold(n_splits=k, shuffle=True, random_state=None )
    for train_index, test_index in kfold.split(sample):
        train, test = sample[train_index], sample[test_index]
        y = kde_awkde(train, test, alp=alphachoice, gl_bandwidth=bwchoice)
        fomlist.append(numpy.sum(numpy.log(y)))
    return numpy.mean(fomlist)

def optimizedparam(sampleval, nfold=2):
    bwgrid = ['scott', 'silverman']+numpy.logspace(-2,0,10).tolist()
    alphagrid = [1]
    FOM= {}
    for gbw in bwgrid:
        for alphavals in alphagrid:
            FOM[(gbw, alphavals)] = kfcv_awkde(sampleval, gbw, alphavals, k=nfold)
    optval = max(FOM.items(), key=operator.itemgetter(1))[0]
    optbw, optalpha  = optval[0], optval[1]
    maxFOM = FOM[(optbw, optalpha)]
    return optbw, optalpha

if args.make_signal_kde and args.make_template_kde:
    raise Exception("Choose only one option out of --make-signal-kde and --make-template-kde")

#calling template data and constructing template_kde_hdf file over template params
template_file = h5py.File(args.template_file, 'r')
mass1_n = template_file['mass1'][:]
mass2_n = template_file['mass2'][:]
spin1z_n = template_file['spin1z'][:]
spin2z_n = template_file['spin2z'][:]
chi_eff_n = convert.chi_eff(mass1_n, mass2_n, spin1z_n, spin2z_n)
mchirp_n = convert.mchirp_from_mass1_mass2(mass1_n, mass2_n)
eta_n = convert.eta_from_mass1_mass2(mass1_n, mass2_n)

if args.make_template_kde:
    sample_n = numpy.vstack((numpy.log(mchirp_n), eta_n, chi_eff_n)).T
    grid_pts_n = sample_n
    logging.info('Starting optimization of template KDE parameters')
    optbw_n, optalpha_n = optimizedparam(sample_n, nfold=2)
    logging.info('Evaluating template KDE')
    kde_data_n = kde_awkde(sample_n, grid_pts_n, alp=optalpha_n, gl_bandwidth=optbw_n)    

    hdf = h5py.File(args.output_file, 'w')
    hdf.create_dataset('kde_template', data=kde_data_n)
    hdf.attrs['bank_kde'] = 'kde_template'
    hdf.create_dataset('mass1', data=mass1_n)
    hdf.create_dataset('mass2', data=mass2_n)
    hdf.create_dataset('spin1z', data=spin1z_n)
    hdf.create_dataset('spin2z', data=spin2z_n)
    hdf.create_dataset('mchirp', data=mchirp_n)
    hdf.create_dataset('eta', data=eta_n)
    hdf.create_dataset('chi_eff', data=chi_eff_n)
    hdf.create_dataset('grid_pts', data=grid_pts_n)
    hdf.close()

#calling signal data and constructing signal_kde_hdf file over template params
if args.make_signal_kde:
    data_signal = numpy.genfromtxt(args.signal_file, dtype = float, delimiter=',', names = True)
    mass2 = data_signal['mass2']
    N_original = len(mass2)
    if args.min_mass:
        idx = mass2 > args.min_mass
        mass2 = mass2[idx]
        logging.info('%i triggers out of %i with MASS2 > %s' %
                         (len(mass2), N_original, str(args.min_mass)))
    else:
        idx = numpy.full(N_original, True)
    mass2_s = mass2
    mass1_s = data_signal['mass1'][idx]
    assert min(mass1_s - mass2_s) > 0
    eta_s = data_signal['eta'][idx]
    chi_eff_s  = data_signal['chi_eff'][idx]
    mchirp_s = data_signal['mchirp'][idx]

    sample_s = numpy.vstack((numpy.log(mchirp_s), eta_s, chi_eff_s)).T
    grid_pts_s = numpy.array((numpy.log(mchirp_n), eta_n, chi_eff_n)).T
    logging.info('Starting optimization of signal KDE parameters')
    optbw_s, optalpha_s = optimizedparam(sample_s, nfold=2)
    logging.info('Evaluating signal KDE')
    kde_data_s = kde_awkde(sample_s, grid_pts_s, alp=optalpha_s, gl_bandwidth=optbw_s)
 
    hdf = h5py.File(args.output_file, 'w')
    hdf.create_dataset('kde_sig', data=kde_data_s)
    hdf.attrs['signal_kde'] = 'kde_sig'
    hdf.create_dataset('mass1', data=mass1_s)
    hdf.create_dataset('mass2', data=mass2_s)
    hdf.create_dataset('mchirp', data=mchirp_s)
    hdf.create_dataset('eta', data=eta_s)
    hdf.create_dataset('chi_eff', data=chi_eff_s)
    hdf.create_dataset('grid_pts', data=grid_pts_s)
    hdf.close()
