#!/usr/bin/env python
#
# Copyright (C) 2019 Gino Contestabile, Francesco Pannarale
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.


"""
Plot single IFO/coherent/reweighted/null SNR timeseries for a PyGRB run.
"""

# =============================================================================
# Preamble
# =============================================================================
import sys
import os
import logging
import numpy
import h5py
import matplotlib.pyplot as plt
from matplotlib import rc
import pycbc.version
from pycbc import init_logging
from pycbc.results import pygrb_postprocessing_utils as ppu
from pycbc.results import pygrb_plotting_utils as plu

plt.switch_backend('Agg')
rc('font', size=14)

__author__ = "Francesco Pannarale <francesco.pannarale@ligo.org>"
__version__ = pycbc.version.git_verbose_msg
__date__ = pycbc.version.date
__program__ = "pycbc_pygrb_plot_snr_timeseries"


# =============================================================================
# Functions
# =============================================================================
# Load trigger data
def load_data(input_file, vetoes, injections=False):
    """Load data from a trigger/injection file"""
    
    trigs_or_injs = None
    if input_file:
        if injections:
            logging.info("Loading injections...")
            # This will eventually become load_injections
            trigs_or_injs = ppu.load_triggers(input_file, vetoes)
        else:
            logging.info("Loading triggers...")
            trigs_or_injs = ppu.load_triggers(input_file, vetoes)
        
    return trigs_or_injs


# Find start and end times of trigger/injecton data relative to a given time
def get_start_end_times(data_time, central_time):
    """Determine padded start and end times of data relative to central_time"""

    #start = int(min(data.time)) - central_time
    #end = int(max(data.time)) - central_time
    start = int(min(data_time)) - central_time
    end = int(max(data_time)) - central_time
    duration = end - start
    start -= duration*0.05
    end += duration*0.05

    return start, end


# Reset times so that t=0 is corresponds to the given trigger time
def reset_times(data_time, trig_time):
    """Reset times in data so that t=0 corresponds to the trigger time provided"""

    #data.time = [t-trig_time for t in data.time]
    data_time = [t-trig_time for t in data_time]

    return data_time


# =============================================================================
# Main script starts here
# =============================================================================
parser = ppu.pygrb_initialize_plot_parser(description=__doc__,
                                          version=__version__)
parser.add_argument("-t", "--trig-file", action="store",
                    default=None, required=True,
                    help="The location of the trigger file")
parser.add_argument('--central-time', type=float, default=None,
                    help="Center plot at the given GPS time. If omitted, "+
                    "use the GRB trigger time")
parser.add_argument("-y", "--y-variable", default=None,
                    choices=['coherent', 'single', 'reweighted', 'null'],
                    help="Quantity to plot on the vertical axis.")
ppu.pygrb_add_bestnr_opts(parser)
opts = parser.parse_args()

init_logging(opts.verbose, format="%(asctime)s: %(levelname)s: %(message)s")

# Check options
trig_file = os.path.abspath(opts.trig_file)
found_file = os.path.abspath(opts.found_file) if opts.found_file else None
snr_type = opts.y_variable
ifo = opts.ifo
if snr_type == 'single' and ifo is None:
    err_msg = "Please specify an interferometer for a single IFO plot"
    parser.error(err_msg)

logging.info("Imported and ready to go.")

# Set output directories
outdir = os.path.split(os.path.abspath(opts.output_file))[0]
if not os.path.isdir(outdir):
    os.makedirs(outdir)

# Extract IFOs and vetoes
ifos, vetoes = ppu.extract_ifos_and_vetoes(trig_file, opts.veto_files,
                                        opts.veto_category)

# Load trigger data
trig_data = load_data(trig_file, vetoes)

# Load (or initialize) injection data
inj_data = load_data(found_file, vetoes, injections=True)

# Specify HDF file keys for x quantity (time) and y quantity (SNR)
if snr_type == 'single':
    x_key = opts.ifo + '/' + 'end_time'
    y_key = opts.ifo + '/' + 'snr_' + opts.ifo.lower()
else:
    x_key = 'network/end_time_gc'
    y_key = 'network/' + snr_type + '_snr'

# Obtain times
trig_data_time = trig_data[x_key][:]
inj_data_time = inj_data[x_key][:] if found_file else None

# Obtain SNRs
trig_data_snr = trig_data[y_key][:]
inj_data_snr = inj_data[y_key][:] if found_file else None

# Determine the central time (t=0): default is the GRB trigger time
central_time = opts.central_time if opts.central_time is not None else\
    ppu.get_grb_time(opts.seg_files)

# Determine trigger data start and end times relative to the central time
start, end = get_start_end_times(trig_data_time, central_time)

# Reset trigger and injection times
trig_data_time = reset_times(trig_data_time, central_time)
if found_file:
    inj_data_time = reset_times(inj_data_time, central_time)

# Generate plots
logging.info("Plotting...")

# Determine what goes on the vertical axis
y_labels = {'coherent': "Coherent SNR",
            'single': "%s SNR" % ifo,
            'null': "Null SNR",
            'reweighted': "Reweighted SNR"}
y_label = y_labels[snr_type]

# Determine title and caption
if opts.plot_title is None:
    opts.plot_title = y_label + " vs Time"
if opts.plot_caption is None:
    opts.plot_caption = ("Blue crosses: background triggers.  ")
    if found_file:
        opts.plot_caption += ("Red crosses: injections triggers.")

# Single IFO SNR versus time plots
xlims = [start, end]
if opts.x_lims:
    xlims = opts.x_lims
    xlims = map(float, xlims.split(','))
plu.pygrb_plotter([trig_data_time, trig_data_snr],
                  [inj_data_time, inj_data_snr],
                  "Time since %.3f (s)" % (central_time), y_label,
                  opts, cmd=' '.join(sys.argv))
