#!/usr/bin/python

# Copyright (C) 2023 Veronica Villa-Ortega
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.

"""
Plot histograms of triggers around gated times.
"""

import h5py
import numpy
import logging
import argparse
from pycbc.events import ranking
from matplotlib import use; use('Agg')
import matplotlib.pyplot as plt


parser = argparse.ArgumentParser()
parser.add_argument('--verbose', action='store_true')
parser.add_argument('--single-trigger-files', nargs='+',
                    help='HDF format single detector merged trigger file(s) path. '
                         'All triggers are combined in the histogram')
parser.add_argument('--ifo-tag', default='',
                    help='String to label plots, eg H1, H1L1')
parser.add_argument('--gating-type', choices=['auto', 'detchar'],
                    help='Type of gate to analyze.')
parser.add_argument('--window', default=5, type=float,
                    help='Time in seconds around the gate to plot (default 5s).')
parser.add_argument('--snr-cut-type', choices=['snr', 'newsnr'],
                    help='Type of snr threshold to apply.')
parser.add_argument('--snr-cut-vals', nargs='+', type=float,
                    help='Threshold snr/newsnr value(s) for histogram(s).')
parser.add_argument('--template-duration', nargs='+',
                    help='Triggers with a template duration greater than '
                         '(GT) or less than (LT) some time value in seconds. '
                         'Usage: --template-duration GT/LT time value')
parser.add_argument('--bin-duration', type=float, default=0.25,
                    help='Time in seconds represented by each bin in the histogram '
                          '(default 0.25s).')
parser.add_argument('--log-y', action='store_true',
                    help='Set y-axis scale logarithmic.')
parser.add_argument('--output-file',
                    help='Plot file name (optional).')
args = parser.parse_args()

if args.verbose:
    log_level = logging.INFO
else:
    log_level = logging.WARN
logging.basicConfig(format='%(asctime)s : %(message)s', level=log_level)

if args.gating_type:
    if args.gating_type == 'auto':
        gating_type = 'auto'
    elif args.gating_type == 'detchar':
        gating_type = 'file'
else:
    raise ValueError('A gating type argument must be provided.')

if args.snr_cut_type:
    if args.snr_cut_vals:
        snr_cut_vals = sorted(args.snr_cut_vals)
    else:
        raise ValueError('At least one SNR cut value must be provided!')
else:
    snr_cut_vals = numpy.array([0])

if args.template_duration:
    args.duration_ltgt = args.template_duration[0]
    assert args.duration_ltgt in ['LT', 'GT']
    args.duration_cut = float(args.template_duration[1])

# Binning along trigger time
nbins = int(2 * args.window / args.bin_duration + 1)

ifo_trigger_times = [[] for thr in snr_cut_vals]
n_gates = 0

for single_file in args.single_trigger_files:
    logging.info("Reading trigger file %s...", single_file.split('/')[-1])
    file_ = h5py.File(single_file, "r")
    ifo = str(list(file_.keys())[0])
    data = file_[ifo]

    if gating_type in data['gating'].keys():
        logging.info("Extracting gated times...")
        gate_time = numpy.unique(data[f'gating/{gating_type}/time'][:])
        n_gates += len(gate_time)
        logging.info("%i %s gates found", len(gate_time), args.gating_type)
    else:
        logging.info('Trigger file does not contain %s gates. '
                     'Skipping file...', args.gating_type)
        continue
    logging.info("Extracting trigger times...")
    times = data['end_time'][:]

    if args.template_duration:
        logging.info("Cutting on template duration...")
        template_duration = data['template_duration'][:]
        if args.duration_ltgt == 'LT':
            times = times[template_duration <= args.duration_cut]
        else:
            times = times[template_duration >= args.duration_cut]

    if args.snr_cut_type:
        logging.info("Cutting on %s...", args.snr_cut_type)
        if args.snr_cut_type == 'snr':
            snr_val = data['snr'][:]
        if args.snr_cut_type == 'newsnr':
            rchisq = data['chisq'][:] / (2 * data['chisq_dof'][:] - 2)
            if len(rchisq) > 0:
                snr_val = ranking.newsnr(data['snr'][:], rchisq)
            else:
                snr_val = numpy.array([])
            del rchisq
        if args.template_duration:
            if args.duration_ltgt == 'LT':
                snr_val = snr_val[template_duration <= args.duration_cut]
            else:
                snr_val = snr_val[template_duration >= args.duration_cut]

    logging.info("Starting loop over gates...")
    time_triggers = [[] for thr in snr_cut_vals]
    for i, thr in enumerate(snr_cut_vals):
        if args.snr_cut_type:
            times = times[snr_val >= thr]
            snr_val = snr_val[snr_val >= thr]
        for gt in gate_time:
            idx = numpy.logical_and(times < gt + args.window, times > gt - args.window)
            time_trigger = times[idx] - gt
            time_triggers[i].extend(time_trigger)
        ifo_trigger_times[i].extend(time_triggers[i])

    file_.close()

if args.output_file:
    out_name = args.output_file
else:
    out_name = '%s_%s' % (str(args.gating_type).upper(), args.ifo_tag)

    if args.template_duration:
        templ_cut_name = (
            str("{:.1f}".format(args.duration_cut)).replace('.', '-') 
                if '.' in args.template_duration[1]
                else args.template_duration[1]
        )
        out_name = '_'.join([out_name, 'TD-%s-%s' % (args.duration_ltgt, templ_cut_name)])

    if args.snr_cut_type:
        snr_cut_name = (args.snr_cut_type).upper()
        snr_cuts = [
            str("{:.1f}".format(cut_val)).replace('.', '-') if '.' in str(cut_val)
            else str(cut_val) for cut_val in snr_cut_vals
        ]
        snr_cut_value = '_'.join(snr_cuts)
        out_name = '_'.join([out_name, '%s-%s' % (snr_cut_name, snr_cut_value)])

    window_name = (str("{:.1f}".format(args.window)).replace('.', '-') if '.' in
                   str(args.window) else str(args.window))
    out_name = '_'.join([out_name, 'WINDOW-%s' % window_name])

    if args.log_y:
        out_name = '_'.join([out_name, 'LOG'])

    out_name = out_name + '.png'  # Filenames via command line may have any extension

if n_gates > 0:
    for i, thr in enumerate(snr_cut_vals):
        plt.hist(
            ifo_trigger_times[i],
            bins=numpy.linspace(-args.window, args.window, nbins),
            label=['%s > %.1f' % (args.snr_cut_type, thr) if args.snr_cut_type else '']
        )
    if args.log_y:
        plt.yscale('log')
    plt.grid(True)
    plt.xlabel('Time relative to gate centre (s)')
    plt.ylabel('Number of triggers')
    plt.title(  # Complicated recipe
         '%s - %i %s GATES %s %s' % 
         (args.ifo_tag, n_gates, str(args.gating_type).upper(),
          'IN %i CHUNKS' % len(args.single_trigger_files) if \
          len(args.single_trigger_files) > 1 else '',
          '\n Triggers with template duration %s %s s' % 
          ('<' if args.duration_ltgt == 'LT' else '>', str(args.duration_cut))
          if args.template_duration else '')
    )
    if args.snr_cut_type:
        plt.legend()

else:  # Nothing to plot, make an empty figure
    fig = plt.figure()
    ax = fig.add_subplot(111)
    output_message = "Sorry, no gates to plot!"
    ax.text(0.5, 0.5, output_message, horizontalalignment='center',
            verticalalignment='center')

logging.info('Saving histogram...')
plt.savefig(out_name)
plt.close()

logging.info("Done")
