#!/usr/bin/python

# Copyright 2020 Gareth S. Davies
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.

import numpy as np
import pycbc
from pycbc import bin_utils
from pycbc.events import triggers, trigger_fits as trstats
from pycbc.io import DictArray
from pycbc.events import ranking
import argparse, logging, operator, os, sys, re, h5py
import matplotlib
matplotlib.use('agg')
from matplotlib import pyplot as plt

def inequality_string_to_function(inequ):
    if inequ == "==":
        eq_out = operator.eq
    elif inequ == "<=":
        eq_out = operator.le
    elif inequ == "<":
        eq_out = operator.lt
    elif inequ == ">=":
        eq_out = operator.ge
    elif inequ == ">":
        eq_out = operator.gt
    elif inequ == "!=":
        eq_out = operator.ne
    return eq_out

parser = argparse.ArgumentParser(usage="",
    description="Plot histograms of triggers split over various parameters")
parser.add_argument("--verbose", action="store_true",
                    help="Print extra debugging information", default=False)
parser.add_argument("--ifos", nargs="+", required=True,
                    help="Which ifo are we fitting the triggers for? "
                         "Required")
parser.add_argument("--top-directory", metavar='PATH', required=True,
                    help="Directory containing trigger files, top directory, "
                         "will contain subdirectories for each day of data. "
                         "Required.")
parser.add_argument("--analysis-date", required=True,
                    help="Date of the analysis, format YYYY_MM_DD. Required")
parser.add_argument("--file-identifier", default="H1L1V1-Live",
                    help="String required in filename to be considered for "
                         "analysis. Default: 'H1L1V1-Live'.")
parser.add_argument("--fit-function", default="exponential",
                    choices=["exponential", "rayleigh", "power"],
                    help="Functional form for the maximum likelihood fit. "
                         "Choose from exponential, rayleigh or power. "
                         "Default: exponential")
parser.add_argument("--fit-param", default="snr",
                    choices=["new_snr", "snr"],
                    help="Which parameter to use for fitting. "
                         "Choose from new_snr or snr. "
                         "Default: new_snr")
parser.add_argument("--cut-params", nargs="+", default=None,
                    help="Parameters for cuts. Default = None.")
parser.add_argument("--cut-thresholds", nargs="+", default=None, type=float,
                    help="Parameters for cuts. Default = None.")
parser.add_argument("--cut-inequalities", nargs="+", default=None,
                    help='Operators for cuts, usage: keep thresholds (op) thresh '
                         'choose from > >= < <= == !=. Default: None',
                         choices='> >= < <= == !='.split(' '))
parser.add_argument("--duration-bin-start", type=int, default=5,
                    help="Shortest duration to use for duration bins. "
                         "Default = 5")
parser.add_argument("--duration-bin-end", type=int, default=152,
                    help="Longest duration to use for duration bins. "
                         "Default = 152 (current longest in template bank)")
parser.add_argument("--num-duration-bins", type=int, default=6,
                    help="How many template duration bins to split the bank into "
                         " before fitting. Default = 6")
parser.add_argument("--fit-threshold", type=float, default=5,
                    help="Lower threshold used in fitting the triggers."
                         "Default 5.")
parser.add_argument("--duration-bin-spacing", default='log',
                    choices=['linear','log'],
                    help="How to set spacing for bank split "
                         " before fitting. Default = log")
parser.add_argument("--cluster", default=None,
                    help="which parameter to use maximum when clustering. default=None")
parser.add_argument("--output-directory", required=True,
                    help="Directory in which to save the output file.")
#parser.add_argument("--", default="", help="")


#Add some input sanitisation
args = parser.parse_args()

pycbc.init_logging(args.verbose)

files = [f for f in os.listdir(os.path.join(args.top_directory, args.analysis_date))
         if args.file_identifier in f]

logging.info("{} files found".format(len(files)))

args.cut_params.append('template_duration')
args.cut_inequalities.append("<")
args.cut_thresholds.append(args.duration_bin_end)

args.cut_params.append('template_duration')
args.cut_inequalities.append(">")
args.cut_thresholds.append(args.duration_bin_start)

cut_inequalities = np.array(args.cut_inequalities)
cut_params = np.array(args.cut_params)
cut_thresholds = np.array(args.cut_thresholds)

if args.duration_bin_spacing == 'log':
    bincreator = bin_utils.LogarithmicBins
elif args.duration_bin_spacing == 'linear':
    bincreator = bin_utils.LinearBins

tbins = bincreator(args.duration_bin_start,
                   args.duration_bin_end,
                   args.num_duration_bins)


# newsnr needs special treatment as they are not
# in the saved files. Others are either in the saved files or can be
# calculated in the get_param function

cutting_newsnr = False
if 'new_snr' in cut_params:
    cutting_newsnr = True
    nsidx = np.nonzero(cut_params == 'new_snr')[0]
    # To save on chisq & newsnr calculation costs, and memory use, cut on
    # snr which, by definition, is greater than or equal to newsnr
    cut_params[nsidx] = 'snr'
    nsineq = cut_inequalities[nsidx]
    nsthresh = cut_thresholds[nsidx]

cut_params = list(cut_params)
cut_inequalities = list(cut_inequalities)
cut_thresholds = list(cut_thresholds)

# The following datasets are required for calculation of cut parameters
required_datasets = ['snr', 'chisq', 'chisq_dof', 'template_duration',
                     'mass1', 'mass2', 'spin1z', 'spin2z', 'template_id']
required_datasets += cut_params
required_datasets = list(set(required_datasets))

# Make some strings for logging /saving setting for attributes
zip_inequs = zip(cut_params, cut_inequalities, cut_thresholds)
ineq_str = ''
for p, i, v in zip_inequs:
    ineq_str += '{} {} {}, '.format(p, i, v)
ineq_str = ineq_str[:-2]

all_zip_inequs = zip(args.cut_params, args.cut_inequalities,
                     args.cut_thresholds)
ineq_str_all = ''
for p, i, v in all_zip_inequs:
    ineq_str_all += '{} {} {}, '.format(p, i, v)
ineq_str_all= ineq_str_all[:-2]
logging.info('Criteria set up as ' + ineq_str_all)

# Also calculate live time so that this fitting can be used in rate estimation
# Live time is not immediately obvious - get an approximation with 8 second
# granularity by adding 8 seconds per 'valid' file

live_time = {ifo: 0 for ifo in args.ifos}

#store as a DictArray - nice as it has the .select() method
events_dict = {k : [] for k in required_datasets}
logging.info("Getting events which meet criteria: %s", ineq_str)

# Loop through files - add events which meet the immediately gettable
# criteria
date_directory = os.path.join(args.top_directory, args.analysis_date)
events = {i: DictArray(data={rp: np.array([])
                       for rp in required_datasets}) for i in args.ifos}
files = [f for f in os.listdir(date_directory)
         if args.file_identifier in f]
counter = 0
for filename in files:
    counter += 1
    if counter % 1000 == 0:
        logging.info("Processed %d files" % counter)
        for ifo in args.ifos:
            logging.info("{}: {} triggers in {}s".format(ifo,
                events[ifo].data['snr'].size, live_time[ifo]))
    f = os.path.join(date_directory, filename)
    skipping_file = False
    #If there is an IOerror with the file, don't fail, just carry on
    try:
        h5py.File(f, 'r')
    except IOError:
        logging.info('IOError with file ' + f)
        continue
    with h5py.File(f, 'r') as fin:
        # Open the file: does it have the ifo group and snr dataset?
        for ifo in args.ifos:
            if not (ifo in fin and 'snr' in fin[ifo]):
                continue
            live_time[ifo] += 8
            # all the required datasets must be present - this should
            # not happen though as if snr is present then the others should be
            if any(rd not in fin[ifo] for rd in required_datasets):
                logging.info('some datasets not present in ' + f +', skipping')
                continue
            # Skip if there are no triggers
            if not fin[ifo + '/' + required_datasets[0]].size:
                continue
            # unpack this file into a dictarray
            file_array = DictArray(data={rd : np.array(fin[ifo + '/' + rd][:])
                                         for rd in required_datasets})
            for cparam, cineq, cthresh in zip(cut_params, cut_inequalities,
                                              cut_thresholds):
                # calculate parameter if not already in file
                if cparam not in file_array.data:
                    file_vals = triggers.get_param(cparam, args,
                        file_array.data['mass1'], file_array.data['mass2'],
                        file_array.data['spin1z'], file_array.data['spin2z'])
                    file_array.data[cparam] = file_vals
                #Apply cuts to array for this file
                cineq_fn = inequality_string_to_function(cineq)
                idx_keep = np.nonzero(cineq_fn(file_array.data[cparam],
                                               cthresh))[0]
                # if nothing remains after cuts, skip file
                if not len(idx_keep):
                    skipping_file = True
                    break
                file_array = file_array.select(idx_keep)
            if args.cluster:
                if args.cluster=='new_snr':
                    #Calculating new_snr for clustering
                    file_array.data['new_snr'] = ranking.newsnr(
                                                     file_array.data['snr'],
                                                     file_array.data['chisq'])
                max_idx = file_array.data[args.cluster].argmax()
                file_array = file_array.select([max_idx])
            if skipping_file: continue
            events[ifo] = events[ifo] + file_array

logging.info("All events processed")


# calculate new_snr and cut if requested
if cutting_newsnr or args.fit_param == 'new_snr' or 'new_snr' not in events[ifo].data:
    logging.info("Calculating new_snr")
    for ifo in args.ifos:
        events[ifo].data['new_snr'] = ranking.newsnr(events[ifo].data['snr'],
                                                     events[ifo].data['chisq'])
#Note: according to Alex, the stored chisq here is already the reduced chisq

if cutting_newsnr:
    logging.info("Keeping events with new_snr %s %.3f" % (nsineq[0], nsthresh))
    ineq_fn = inequality_string_to_function(nsineq)
    for ifo in args.ifos:
        idx_keep = np.nonzero(ineq_fn(events[ifo].data['new_snr'], nsthresh))[0]
        if len(idx_keep):
            events[ifo] = events[ifo].select(idx_keep)
        else:
            logging.info("Too many cuts on the data - removed everything from " + ifo)

logging.info("Number of events which meet all criteria:")
for ifo in args.ifos:
    logging.info("{}: {} in {}s".format(ifo, len(events[ifo].data['snr']),
                                        live_time[ifo]))

# split the events into bins by template duration
logging.info('Sorting events into template duration bins')

oput_file = os.path.join(args.output_directory,
                         args.analysis_date + "-TRIGGER-FITS.hdf")

fout = h5py.File(oput_file, 'w')
# Fit the triggers within each bin
alphas = {i: np.zeros(args.num_duration_bins, dtype=np.float32) for i in args.ifos}
counts = {i: np.zeros(args.num_duration_bins, dtype=np.float32) for i in args.ifos}
for ifo in args.ifos:
    for key in events[ifo].data:
        fout['{}/triggers/{}'.format(ifo, key)] = events[ifo].data[key]
    event_bin = np.array([tbins[d] for d in events[ifo].data['template_duration']])
    for bin_num in range(args.num_duration_bins):
        inbin = event_bin == bin_num
        counts[ifo][bin_num] = np.count_nonzero(inbin)
        alphas[ifo][bin_num], _ = trstats.fit_above_thresh(args.fit_function,
                                      events[ifo].data[args.fit_param][inbin],
                                      args.fit_threshold)
    fout['{}/fit_coeff'.format(ifo)] = alphas[ifo]
    fout['{}/counts'.format(ifo)] = counts[ifo]
    fout[ifo].attrs['live_time'] = live_time[ifo]

fout['bins_upper'] = tbins.upper()
fout['bins_lower'] = tbins.lower()

fout.attrs['analysis_date'] = args.analysis_date
fout.attrs['input'] = sys.argv
fout.attrs['cuts'] = ineq_str
fout.attrs['fit_function'] = args.fit_function
fout.attrs['fit_threshold'] = args.fit_threshold

fout.close()
logging.info("Done")
