#!/usr/bin/env python
#
# Copyright (C) 2019 Gino Contestabile, Francesco Pannarale
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.


"""
Plot single IFO/coherent/reweighted/null SNR timeseries for a PyGRB run.
"""

# =============================================================================
# Preamble
# =============================================================================
import sys
import os
import logging
import numpy
import matplotlib.pyplot as plt
from matplotlib import rc
import pycbc.version
from pycbc import init_logging
from pycbc.results import pygrb_postprocessing_utils as ppu
from pycbc.results import pygrb_plotting_utils as plu

plt.switch_backend('Agg')
rc('font', size=14)

__author__ = "Francesco Pannarale <francesco.pannarale@ligo.org>"
__version__ = pycbc.version.git_verbose_msg
__date__ = pycbc.version.date
__program__ = "pycbc_pygrb_plot_snr_timeseries"


# =============================================================================
# Functions
# =============================================================================
# Function to load time and necessary SNR data from a trigger/injection file
def load_data(input_file, vetoes, opts, injections=False, sim_table=False):
    """Build a dictionary containing time and SNR data extracted from
    a trigger/injection file"""

    # Inizialize the dictionary
    data = {}
    data['time'] = None
    snr_type = opts.y_variable
    data[snr_type] = None

    # Fill the dictionary in with required data
    if input_file:
        if injections:
            trigs_or_injs = ppu.load_injections(input_file, vetoes,
                                                sim_table=sim_table)
        else:
            trigs_or_injs = ppu.load_triggers(input_file, vetoes)
        data['time'] = numpy.asarray(trigs_or_injs.get_end())
        if snr_type == 'coherent':
            data['coherent'] = numpy.asarray(trigs_or_injs.get_column('snr'))
        elif snr_type == 'null':
            data['null'] = numpy.asarray(trigs_or_injs.get_column('null_statistic'))
        elif snr_type == 'reweighted':
            null_thresh = list(map(float, opts.null_snr_threshold.split(',')))
            data['reweighted'] = ppu.get_bestnrs(trigs_or_injs,
                                                 q=opts.chisq_index,
                                                 n=opts.chisq_nhigh,
                                                 null_thresh=null_thresh,
                                                 snr_threshold=opts.snr_threshold,
                                                 sngl_snr_threshold=opts.sngl_snr_threshold,
                                                 chisq_threshold=opts.newsnr_threshold,
                                                 null_grad_thresh=opts.null_grad_thresh,
                                                 null_grad_val=opts.null_grad_val)
        elif snr_type == 'single':
            data['single'] = trigs_or_injs.get_sngl_snr(opts.ifo)

    return data


# Find start and end times of trigger/injecton data relative to a given time
def get_start_end_times(data, central_time):
    """Determine padded start and end times of data relative to central_time"""

    #start = int(min(data.time)) - central_time
    #end = int(max(data.time)) - central_time
    start = int(min(data['time'])) - central_time
    end = int(max(data['time'])) - central_time
    duration = end - start
    start -= duration*0.05
    end += duration*0.05

    return start, end


# Reset times so that t=0 is corresponds to the given trigger time
def reset_times(data, trig_time):
    """Reset times in data so that t=0 corresponds to the trigger time provided"""

    #data.time = [t-trig_time for t in data.time]
    data['time'] = [t-trig_time for t in data['time']]

    return data


# =============================================================================
# Main script starts here
# =============================================================================
parser = ppu.pygrb_initialize_plot_parser(description=__doc__,
                                          version=__version__)
parser.add_argument("-t", "--trig-file", action="store",
                    default=None, required=True,
                    help="The location of the trigger file")
parser.add_argument('--central-time', type=float, default=None,
                    help="Center plot at the given GPS time. If omitted, "+
                    "use the GRB trigger time")
parser.add_argument("-y", "--y-variable", default=None,
                    choices=['coherent', 'single', 'reweighted', 'null'],
                    help="Quantity to plot on the vertical axis.")
ppu.pygrb_add_bestnr_opts(parser)
opts = parser.parse_args()

init_logging(opts.verbose, format="%(asctime)s: %(levelname)s: %(message)s")

# Check options
trig_file = os.path.abspath(opts.trig_file)
found_file = os.path.abspath(opts.found_file) if opts.found_file else None
snr_type = opts.y_variable
ifo = opts.ifo
if snr_type == 'single' and ifo is None:
    err_msg = "Please specify an interferometer for a single IFO plot"
    parser.error(err_msg)

logging.info("Imported and ready to go.")

# Set output directories
outdir = os.path.split(os.path.abspath(opts.output_file))[0]
if not os.path.isdir(outdir):
    os.makedirs(outdir)

# Extract IFOs and vetoes
_, vetoes = ppu.extract_ifos_and_vetoes(trig_file, opts.veto_files,
                                        opts.veto_category)

# Load trigger data
trig_data = load_data(trig_file, vetoes, opts)

# Load (or initialize) injection data
inj_data = load_data(found_file, vetoes, opts, injections=True, sim_table=False)

# Determine the central time (t=0): default is the GRB trigger time
central_time = opts.central_time if opts.central_time is not None else\
    ppu.get_grb_time(opts.seg_files)

# Determine trigger data start and end times relative to the central time
start, end = get_start_end_times(trig_data, central_time)

# Reset trigger and injection times
trig_data = reset_times(trig_data, central_time)
if found_file:
    inj_data = reset_times(inj_data, central_time)

# Generate plots
logging.info("Plotting...")

# Determine what goes on the vertical axis
y_labels = {'coherent': "Coherent SNR",
            'single': "%s SNR" % ifo,
            'null': "Null SNR",
            'reweighted': "Reweighted SNR"}
y_label = y_labels[snr_type]

# Determine title and caption
if opts.plot_title is None:
    opts.plot_title = y_label + " vs Time"
if opts.plot_caption is None:
    opts.plot_caption = ("Blue crosses: background triggers.  ")
    if found_file:
        opts.plot_caption += ("Red crosses: injections triggers.")

# Single IFO SNR versus time plots
xlims = [start, end]
if opts.x_lims:
    xlims = opts.x_lims
    xlims = map(float, xlims.split(','))
plu.pygrb_plotter([trig_data['time'], trig_data[snr_type]],
                  [inj_data['time'], inj_data[snr_type]],
                  "Time since %.3f (s)" % (central_time), y_label,
                  opts, cmd=' '.join(sys.argv))
