#!/usr/bin/env python
#
# Copyright (C) 2019 Gino Contestabile, Francesco Pannarale
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.


"""
Plot null statistic or coincident SNR vs coherent SNR for a PyGRB run.
"""

# =============================================================================
# Preamble
# =============================================================================
import sys
import os
import logging
import numpy
import h5py
from matplotlib import rc
import matplotlib.pyplot as plt
import pycbc.version
from pycbc import init_logging
from pycbc.results import pygrb_postprocessing_utils as ppu
from pycbc.results import pygrb_plotting_utils as plu

plt.switch_backend('Agg')

__author__ = "Francesco Pannarale <francesco.pannarale@ligo.org>"
__version__ = pycbc.version.git_verbose_msg
__date__ = pycbc.version.date
__program__ = "pycbc_pygrb_plot_null_stats"


# =============================================================================
# Functions
# =============================================================================
# Function to load trigger data
def load_data(input_file, vetoes, ifos, opts, injections=False):
    """Load data from a trigger/injection file"""

    null_stat_type = opts.y_variable

    # Initialize the dictionary
    data = {}
    data['coherent'] = None
    data[null_stat_type] = None

    if input_file:
        if injections:
            logging.info("Loading injections...")
            # This will eventually become ppu.load_injections()
            trigs_or_injs = ppu.load_triggers(input_file, vetoes)
        else:
            logging.info("Loading triggers...")
            trigs_or_injs = ppu.load_triggers(input_file, vetoes)
            
        num_trigs_or_injs = len(trigs_or_injs['network/end_time_gc'][:])

        # Coherent SNR is always used
        data['coherent'] = trigs_or_injs['network/coherent_snr'][:]

        if null_stat_type == 'null':
            data[null_stat_type] = trigs_or_injs['network/%s_snr' % null_stat_type][:]
        elif null_stat_type == 'single':
            att = opts.ifo[0].lower()
            data[null_stat_type] = trigs_or_injs['%s/snr_%s' % (opts.ifo, att)][:]
        elif null_stat_type == 'coincident':
            data[null_stat_type] = ppu.get_coinc_snr(trigs_or_injs, ifos)

        if injections:
            logging.info("%d injections found.", num_trigs_or_injs)
        else:
            logging.info("%d triggers found.", num_trigs_or_injs)

    return data

# Function that produces the contrours to be plotted
def calculate_contours(opts, new_snrs=None):
    """Generate the contours to plot"""

    # Add the new SNR threshold contour to the list if necessary
    if new_snrs is None:
        new_snrs = [5.5, 6, 6.5, 7, 8, 9, 10, 11]
    new_snr_thresh = opts.newsnr_threshold
    if new_snr_thresh not in new_snrs:
        new_snrs.append(new_snr_thresh)

    # Get SNR values for contours
    snr_low_vals = numpy.arange(4, 30, 0.1)
    snr_high_vals = numpy.arange(30, 500, 1)
    snr_vals = numpy.asarray(list(snr_low_vals) + list(snr_high_vals))

    # Determine contour
    null_cont = []
    null_thresh = []
    for val in map(float, opts.null_snr_threshold.split(',')):
        null_thresh.append(val)
    null_thresh = null_thresh[-1]
    null_grad_snr = opts.null_grad_thresh
    null_grad_val = opts.null_grad_val
    for snr in snr_vals:
        if snr > null_grad_snr:
            null_cont.append(null_thresh + (snr-null_grad_snr)*null_grad_val)
        else:
            null_cont.append(null_thresh)
    null_cont = numpy.asarray(null_cont)

    return null_cont, snr_vals


# =============================================================================
# Main script starts here
# =============================================================================
parser = ppu.pygrb_initialize_plot_parser(description=__doc__,
                                          version=__version__)
parser.add_argument("-t", "--trig-file", action="store",
                    default=None, required=True,
                    help="The location of the trigger file")
parser.add_argument("-z", "--zoom-in", default=False, action="store_true",
                    help="Output file a zoomed in version of the plot.")
parser.add_argument("-y", "--y-variable", default=None,
                    choices=['coincident', 'null'], #TODO: overwhitened?
                    help="Quantity to plot on the vertical axis.")
ppu.pygrb_add_bestnr_opts(parser)
opts = parser.parse_args()

init_logging(opts.verbose, format="%(asctime)s: %(levelname)s: %(message)s")

# Check options
trig_file = os.path.abspath(opts.trig_file)
found_file = os.path.abspath(opts.found_file) if opts.found_file else None
zoom_in = opts.zoom_in
null_stat_type = opts.y_variable

# Prepare plot title and caption
y_labels = {'null': "Null SNR",
            'coincident': "Coincident SNR"} #TODO: overwhitened
if opts.plot_title is None:
    opts.plot_title = y_labels[null_stat_type] + " vs Coherent SNR"
if opts.plot_caption is None:
    opts.plot_caption = ("Blue crosses: background triggers.  ")
    if found_file:
        opts.plot_caption = opts.plot_caption +\
                            ("Red crosses: injections triggers.  ")

    if null_stat_type == 'coincident':
        opts.plot_caption += ("Green line: coincident SNR = coherent SNR.")
    else:
        opts.plot_caption = opts.plot_caption +\
                             "Black line: veto line.  " +\
                             "Green line: above this triggers have reduced " +\
                             "detection statistic.  " +\
                             "Magenta line: in this line the statistic is " +\
                             "reduced by a factor of two."

logging.info("Imported and ready to go.")

# Set output directories
outdir = os.path.split(os.path.abspath(opts.output_file))[0]
if not os.path.isdir(outdir):
    os.makedirs(outdir)

# Extract IFOs and vetoes
ifos, vetoes = ppu.extract_ifos_and_vetoes(trig_file, opts.veto_files,
                                        opts.veto_category)

# Extract trigger data
trig_data = load_data(trig_file, vetoes, ifos, opts)

# Extract (or initialize) injection data
inj_data = load_data(found_file, vetoes, ifos, opts, injections=True)

# Generate plots
logging.info("Plotting...")

# Contours
snr_vals = None
cont_colors = None
shade_cont_value = None
x_max = None
# Coincident SNR plot case: we want a coinc=coh diagonal line on the plot
if null_stat_type == 'coincident':
    cont_colors = ['g-']
    x_max = plu.axis_max_value(trig_data['coherent'], inj_data['coherent'], found_file)
    snr_vals = [4, x_max]
    null_stat_conts = [[4, x_max]]
# Overwhitened null stat (null SNR) and null stat  cases: newSNR contours
else:
    cont_colors = ['k-', 'g-', 'm-']
    null_cont, snr_vals = calculate_contours(opts, new_snrs=None)
    null_stat_conts = [null_cont]
    if zoom_in:
        null_thresh = list(map(float, opts.null_snr_threshold.split(',')))
        null_thresh_width = null_thresh[1] - null_thresh[0]
        null_stat_conts.append(numpy.asarray(null_cont) - null_thresh_width)
        if null_thresh_width > 1:
            null_stat_conts.append(numpy.asarray(null_cont) - null_thresh_width + 1)
    shade_cont_value = 0

# Overwhitened null stat (null SNR), null stat or coincident SNR vs
# Coherent SNR plot
if not opts.x_lims and zoom_in:
    opts.x_lims = '6,30'
if not opts.y_lims and zoom_in:
    opts.y_lims = '0,30'
# Get rcParams
rc('font', size=14)
# Set color for out-of-range values
# Determine y-axis values of triggers and injections
y_label = y_labels[null_stat_type]
trigs = [trig_data['coherent'], trig_data[null_stat_type]]
injs = [inj_data['coherent'], inj_data[null_stat_type]]
plu.pygrb_plotter(trigs, injs, "Coherent SNR", y_label, opts,
                  snr_vals=snr_vals, conts=null_stat_conts,
                  shade_cont_value=shade_cont_value,
                  colors=cont_colors, vert_spike=True,
                  cmd=' '.join(sys.argv))
