#!/usr/bin/env python
#
# Copyright (C) 2019 Gino Contestabile, Francesco Pannarale
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

# =============================================================================
# Preamble
# =============================================================================

from __future__ import division

import sys
import glob
import os
import copy
import logging
from glue.ligolw import lsctables
import pycbc.version
from pycbc.results.pygrb_plotting_utils import *
plt.switch_backend('Agg')

__author__  = "Francesco Pannarale <francesco.pannarale@ligo.org>"
__version__ = pycbc.version.git_verbose_msg
__date__ = pycbc.version.date
__program__ = "pycbc_pygrb_plot_snr_timeseries"


# =============================================================================
# Main script starts here
# =============================================================================

description = 'pycbc_pygrb_plot_snr_timeries produces SNR timeseries plots.'
usage = __program__ + ' [--options]'
opts = pygrb_plot_opts_parser(usage=usage, description=description, version=__version__)

if opts.verbose:
    level = logging.INFO
else:
    level = logging.WARNING
logging.basicConfig(format="%(asctime)s:%(levelname)s : %(message)s",
                    level=level)

trig_file  = os.path.abspath(opts.trig_file)
inj_file   = None
if opts.inj_file:
    inj_file = os.path.abspath(opts.inj_file)
outfile = opts.output_file
seg_dir = opts.segment_dir
veto_files = []
if opts.veto_directory:
    veto_string = ','.join([str(i) for i in range(2,opts.veto_category+1)])
    veto_files = glob.glob(opts.veto_directory +'/*CAT[%s]*.xml' %(veto_string))
snr_type = opts.variable
ifo = opts.ifo
if snr_type == 'single' and ifo is None:
    msg = "Please specify an interferometer for a single IFO plot"
    logging.error(msg)
    sys.exit()

title_dict = {'coherent': "Coherent SNR vs Time",
              'single': "%s SNR vs Time" % ifo,
              'null': "Null SNR vs Time",
              'reweighted': "Reweighted SNR vs Time"}
if opts.plot_title is None:
    opts.plot_title = title_dict[snr_type]
if opts.plot_caption is None:
    opts.plot_caption = ("Blue crosses: Background triggers\n" +
                         "Red crosses: Injections triggers")

logging.info("Imported and ready to go.")

# Set output directories
outdirs = [os.path.split(os.path.abspath(outfile))[0]]
for outdir in outdirs:
    if not os.path.isdir(outdir):
        os.makedirs(outdir)

# Extract IFOs and vetoes
ifos = extract_ifos(trig_file)

# Extract IFOs and vetoes
vetoes = extract_vetoes(trig_file, ifos)

# Load triggers
trigs = load_triggers(trig_file, vetoes, ifos)

# Extract trigger data
trig_data = PygrbFilterOutput(trigs, ifos,
                              lsctables.MultiInspiralTable.loadcolumns,
                              "triggers", opts)

# Load injections
injs = None
if inj_file:
    injs = load_injections(inj_file, vetoes)

# Extract injection data
inj_data = PygrbFilterOutput(injs, ifos,
                             lsctables.MultiInspiralTable.loadcolumns,
                             "injections", opts)

# Generate plots
logging.info("Plotting...")

# Reset trigger and injection times
grb_time, start, end, trig_data, inj_data = reset_times(seg_dir, trig_data,
                                                        inj_data, inj_file)

# Determine what goes on the vertical axis
y_labels = {'coherent': "Coherent SNR",
            'single': "%s SNR" % ifo,
            'null': "Null SNR",
            'reweighted': "Reweighted SNR"}
y_label = y_labels[snr_type]
trig_snr_data_dict = {'coherent': trig_data.snr,
                      'null': trig_data.null_stat,
                      'reweighted': trig_data.reweighted_snr}
if ifo:
    trig_snr_data_dict.update({'single': trig_data.ifo_snr[ifo]})
trig_snr_data = trig_snr_data_dict[snr_type]
inj_snr_data = None
if inj_file:
    inj_snr_data_dict = {'coherent': inj_data.snr,
                         'null': inj_data.null_stat,
                         'reweighted': inj_data.reweighted_snr}
    if ifo:
        inj_snr_data_dict.update({'single': inj_data.ifo_snr[ifo]})
    inj_snr_data = inj_snr_data_dict[snr_type]

# Single IFO SNR versus time plots
pygrb_shared_plot_setups()
xlims = [start, end]
pygrb_plotter(trig_data.time, trig_snr_data,
              inj_data.time, inj_snr_data, inj_file,
              "Time since %s" % (str(grb_time)), y_label, outfile,
              xlims=xlims, use_logs=False, cmd=' '.join(sys.argv),
              plot_title=opts.plot_title, plot_caption=opts.plot_caption)
