#!/usr/bin/env python
#
# Copyright (C) 2019 Gino Contestabile, Francesco Pannarale
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.


"""
Plot null statistic or coincident SNR vs coherent SNR for a PyGRB run.
"""

# =============================================================================
# Preamble
# =============================================================================
import sys
import os
import logging
import numpy
from matplotlib import rc
import matplotlib.pyplot as plt
import pycbc.version
from pycbc import init_logging
from pycbc.results import pygrb_postprocessing_utils as ppu
from pycbc.results import pygrb_plotting_utils as plu

plt.switch_backend('Agg')

__author__ = "Francesco Pannarale <francesco.pannarale@ligo.org>"
__version__ = pycbc.version.git_verbose_msg
__date__ = pycbc.version.date
__program__ = "pycbc_pygrb_plot_null_stats"


# =============================================================================
# Functions
# =============================================================================
# Function to load necessary SNR data from a trigger/injection file
def load_data(input_file, vetoes, opts, injections=False, sim_table=False):
    """Build a dictionary containing SNR data extracted from a
    trigger/injection file"""

    # Inizialize the dictionary
    data = {}
    data['snr'] = None
    null_stat_type = opts.y_variable
    data[null_stat_type] = None

    # Fill the dictionary in with required data
    if input_file:
        if injections:
            trigs_or_injs = ppu.load_injections(input_file, vetoes, sim_table=sim_table)
        else:
            trigs_or_injs = ppu.load_triggers(input_file, vetoes)
        data['snr'] = numpy.asarray(trigs_or_injs.get_column('snr'))
        if null_stat_type == 'coincident':
            data[null_stat_type] = (trigs_or_injs.get_column('coinc_snr'))
        elif null_stat_type == 'nullstat':
            data[null_stat_type] = numpy.asarray(trigs_or_injs.get_column('null_statistic'))
        elif null_stat_type == 'overwhitened':
            data[null_stat_type] = numpy.asarray(trigs_or_injs.get_null_snr())

    return data


# Function that produces the contrours to be plotted
def calculate_contours(opts, new_snrs=None):
    """Generate the contours to plot"""

    # Add the new SNR threshold contour to the list if necessary
    if new_snrs is None:
        new_snrs = [5.5, 6, 6.5, 7, 8, 9, 10, 11]
    new_snr_thresh = opts.newsnr_threshold
    if new_snr_thresh not in new_snrs:
        new_snrs.append(new_snr_thresh)

    # Get SNR values for contours
    snr_low_vals = numpy.arange(4, 30, 0.1)
    snr_high_vals = numpy.arange(30, 500, 1)
    snr_vals = numpy.asarray(list(snr_low_vals) + list(snr_high_vals))

    # Determine contour
    null_cont = []
    null_thresh = []
    for val in map(float, opts.null_snr_threshold.split(',')):
        null_thresh.append(val)
    null_thresh = null_thresh[-1]
    null_grad_snr = opts.null_grad_thresh
    null_grad_val = opts.null_grad_val
    for snr in snr_vals:
        if snr > null_grad_snr:
            null_cont.append(null_thresh + (snr-null_grad_snr)*null_grad_val)
        else:
            null_cont.append(null_thresh)
    null_cont = numpy.asarray(null_cont)

    return null_cont, snr_vals


# =============================================================================
# Main script starts here
# =============================================================================
parser = ppu.pygrb_initialize_plot_parser(description=__doc__,
                                          version=__version__)
parser.add_argument("-t", "--trig-file", action="store",
                    default=None, required=True,
                    help="The location of the trigger file")
parser.add_argument("-z", "--zoom-in", default=False, action="store_true",
                    help="Output file a zoomed in version of the plot.")
parser.add_argument("-y", "--y-variable", default=None,
                    choices=['coincident', 'nullstat', 'overwhitened'],
                    help="Quantity to plot on the vertical axis.")
ppu.pygrb_add_bestnr_opts(parser)
opts = parser.parse_args()

init_logging(opts.verbose, format="%(asctime)s: %(levelname)s: %(message)s")

# Check options
trig_file = os.path.abspath(opts.trig_file)
found_file = os.path.abspath(opts.found_file) if opts.found_file else None
zoom_in = opts.zoom_in
null_stat_type = opts.y_variable

# Prepare plot title and caption
y_labels = {'nullstat': "Null statistic",
            'overwhitened': "Overwhitened null statistic",
            'coincident': "Coincident SNR"}
if opts.plot_title is None:
    opts.plot_title = y_labels[null_stat_type] + " vs Coherent SNR"
if opts.plot_caption is None:
    opts.plot_caption = ("Blue crosses: background triggers.  ")
    if found_file:
        opts.plot_caption = opts.plot_caption +\
                            ("Red crosses: injections triggers.  ")

    if null_stat_type == 'coincident':
        opts.plot_caption += ("Green line: coincident SNR = coherent SNR.")
    else:
        opts.plot_caption = opts.plot_caption +\
                             "Black line: veto line.  " +\
                             "Green line: above this triggers have reduced " +\
                             "detection statistic.  " +\
                             "Magenta line: in this line the statistic is " +\
                             "reduced by a factor of two."

logging.info("Imported and ready to go.")

# Set output directories
outdir = os.path.split(os.path.abspath(opts.output_file))[0]
if not os.path.isdir(outdir):
    os.makedirs(outdir)

# Extract IFOs and vetoes
_, vetoes = ppu.extract_ifos_and_vetoes(trig_file, opts.veto_files,
                                        opts.veto_category)

# Extract trigger data
trig_data = load_data(trig_file, vetoes, opts)

# Extract (or initialize) injection data
inj_data = load_data(found_file, vetoes, opts, injections=True, sim_table=False)

# Generate plots
logging.info("Plotting...")

# Contours
snr_vals = None
cont_colors = None
shade_cont_value = None
x_max = None
# Coincident SNR plot case: we want a coinc=coh diagonal line on the plot
if null_stat_type == 'coincident':
    cont_colors = ['g-']
    x_max = plu.axis_max_value(trig_data['snr'], inj_data['snr'], found_file)
    snr_vals = [4, x_max]
    null_stat_conts = [[4, x_max]]
# Overwhitened null stat (null SNR) and null stat  cases: newSNR contours
else:
    cont_colors = ['k-', 'g-', 'm-']
    null_cont, snr_vals = calculate_contours(opts, new_snrs=None)
    null_stat_conts = [null_cont]
    if zoom_in:
        null_thresh = list(map(float, opts.null_snr_threshold.split(',')))
        null_thresh_width = null_thresh[1] - null_thresh[0]
        null_stat_conts.append(numpy.asarray(null_cont) - null_thresh_width)
        if null_thresh_width > 1:
            null_stat_conts.append(numpy.asarray(null_cont) - null_thresh_width + 1)
    shade_cont_value = 0

# Overwhitened null stat (null SNR), null stat or coincident SNR vs
# Coherent SNR plot
if not opts.x_lims and zoom_in:
    opts.x_lims = '6,30'
if not opts.y_lims and zoom_in:
    opts.y_lims = '0,30'
# Get rcParams
rc('font', size=14)
# Set color for out-of-range values
# Determine y-axis values of triggers and injections
y_label = y_labels[null_stat_type]
trigs = [trig_data['snr'], trig_data[null_stat_type]]
injs = [inj_data['snr'], inj_data[null_stat_type]]
plu.pygrb_plotter(trigs, injs, "Coherent SNR", y_label, opts,
                  snr_vals=snr_vals, conts=null_stat_conts,
                  shade_cont_value=shade_cont_value,
                  colors=cont_colors, vert_spike=True,
                  cmd=' '.join(sys.argv))
