#!/usr/bin/env python
#
# Copyright (C) 2019 Gino Contestabile, Francesco Pannarale
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.


"""
Produces signal consistency plots of the form standard/bank/auto
chi-square vs single IFO/coherent/reweighted/null/coinc SNR.
"""

# =============================================================================
# Preamble
# =============================================================================
import sys
import os
import logging
import numpy
import h5py
from matplotlib import pyplot as plt
from matplotlib import rc
import pycbc.version
from pycbc import init_logging
from pycbc.results import pygrb_postprocessing_utils as ppu
from pycbc.results import pygrb_plotting_utils as plu

plt.switch_backend('Agg')
rc('font', size=14)

__author__ = "Francesco Pannarale <francesco.pannarale@ligo.org>"
__version__ = pycbc.version.git_verbose_msg
__date__ = pycbc.version.date
__program__ = "pycbc_pygrb_plot_chisq_veto"


# =============================================================================
# Functions
# =============================================================================
# Functions to load trigger data
def load_data(input_file, vetoes, ifos, opts, injections=False):
    """Load data from a trigger/injection file"""
    
    snr_type = opts.snr_type
    veto_type = opts.y_variable

    # Initialize the dictionary
    data = {}
    data[snr_type] = None
    data[veto_type] = None
    data['dof'] = None

    if input_file:
        if injections:
            logging.info("Loading injections...")
            # This will eventually become load_injections
            trigs_or_injs = ppu.load_triggers(input_file, vetoes)
        else:
            logging.info("Loading triggers...")
            trigs_or_injs = ppu.load_triggers(input_file, vetoes)
        
        num_trigs_or_injs = len(trigs_or_injs['network/end_time_gc'][:])
        
        if snr_type in ['coherent', 'null', 'reweighted']:
            data[snr_type] = trigs_or_injs['network/%s_snr' % snr_type][:]
        elif snr_type == 'single':
            att = opts.ifo[0].lower()
            data[snr_type] = trigs_or_injs['%s/snr_%s' % (opts.ifo, att)][:]

        # Calculate coincident SNR
        elif snr_type == 'coincident':
            data[snr_type] = ppu.get_coinc_snr(trigs_or_injs, ifos)

        # Tags to find vetoes in HDF files
        veto_tags = {'standard': 'chisq',
                     'bank': 'bank_chisq',
                     'auto': 'cont_chisq'}

        data[veto_type] = \
            trigs_or_injs['%s/%s' % (opts.ifo, veto_tags[veto_type])][:]

        # Floor single IFO chi-square at 0.005
        numpy.putmask(data[veto_type], data[veto_type] == 0, 0.005)

        data['dof'] = numpy.unique(
            trigs_or_injs['%s/%s_dof' % (opts.ifo, veto_tags[veto_type])][:])

        logging.info("%d triggers found.", num_trigs_or_injs)

    return data


# Function that produces the contrours to be plotted
def calculate_contours(trig_data, opts, veto_type, new_snrs=None):
    """Generate the contours for the veto plots"""

    # Read off the degrees of freedom
    dof = trig_data['dof'][0]

    # Add the new SNR threshold contour to the list if necessary
    # and keep track of where it is
    if new_snrs is None:
        new_snrs = [5.5, 6, 6.5, 7, 8, 9, 10, 11]
    try:
        cont_value = new_snrs.index(opts.newsnr_threshold)
    except ValueError:
        new_snrs.append(opts.newsnr_threshold)
        cont_value = -1

    # Get SNR values for contours
    snr_low_vals = numpy.arange(4, 30, 0.1)
    snr_high_vals = numpy.arange(30, 500, 1)
    snr_vals = numpy.asarray(list(snr_low_vals) + list(snr_high_vals))

    # Initialise contours
    contours = numpy.zeros([len(new_snrs), len(snr_vals)],
                           dtype=numpy.float64)
    # Loop over SNR values and calculate chisq variable needed
    for j, snr in enumerate(snr_vals):
        for i, new_snr in enumerate(new_snrs):
            contours[i][j] = plu.new_snr_chisq(snr, new_snr, dof,
                                               opts.chisq_index,
                                               opts.chisq_nhigh)

    return contours, snr_vals, cont_value


# Function that produces the contrours to be plotted
def contour_colors(opts, new_snrs=None):
    """Define the colours with which contours are plotted"""

    # Add the new SNR threshold contour to the list if necessary
    # and keep track of where it is
    if new_snrs is None:
        new_snrs = [5.5, 6, 6.5, 7, 8, 9, 10, 11]
    new_snr_thresh = opts.newsnr_threshold
    colors = ["k-" if snr == new_snr_thresh else
              "y-" if snr == int(snr) else
              "y--" for snr in new_snrs]

    return colors


# =============================================================================
# Main script starts here
# =============================================================================
parser = ppu.pygrb_initialize_plot_parser(description=__doc__,
                                          version=__version__)
parser.add_argument("-t", "--trig-file", action="store",
                    default=None, required=True,
                    help="The location of the trigger file")
parser.add_argument("-z", "--zoom-in", default=False, action="store_true",
                    help="Output file a zoomed in version of the plot.")
parser.add_argument("-y", "--y-variable", required=True,
                    choices=['standard', 'bank', 'auto'],
                    help="Quantity to plot on the vertical axis.")
parser.add_argument("--snr-type", default=None,
                    choices=['coherent', 'single', 'coincident', 'null',
                    'reweighted'], help="SNR value to plot on x-axis. Can be " +
                    "coherent, coincident, null, reweighted, or single IFO SNR")
ppu.pygrb_add_bestnr_opts(parser)
opts = parser.parse_args()

init_logging(opts.verbose, format="%(asctime)s: %(levelname)s: %(message)s")

# Check options
trig_file = os.path.abspath(opts.trig_file)
found_file = os.path.abspath(opts.found_file) if opts.found_file else None
zoom_in = opts.zoom_in
veto_type = opts.y_variable
ifo = opts.ifo
snr_type = opts.snr_type
# If this is false, coherent SNR is used on the horizontal axis
# otherwise the single IFO SNR is used
if snr_type == 'single':
    if ifo is None:
        err_msg = "--ifo must be given when using --use-sngl-ifo-snr"
        parser.error(err_msg)

# Veto is intended as a single IFO quantity. Network chisq will be obsolete.
if ifo is None:
    err_msg = "--ifo must be given to plot single IFO SNR veto"
    parser.error(err_msg)

# Prepare plot title and caption
veto_labels = {'standard': "Chi Square Veto",
               'bank': "Bank Veto",
               'auto': "Auto Veto"}
if opts.plot_title is None:
    opts.plot_title = "%s %s" %(ifo, veto_labels[veto_type])
    if snr_type == 'single':
        opts.plot_title += " vs %s SNR" %(ifo)
    elif snr_type == 'coincident':
        opts.plot_title += " vs Coincident SNR"
    else:
        opts.plot_title += " vs Coherent SNR"
if opts.plot_caption is None:
    opts.plot_caption = ("Blue crosses: background triggers.  " +
                         "Black line: veto line.  " +
                         "Gray shaded region: Vetoed area.  " +
                         "Yellow lines: contours of new SNR.")
    if found_file:
        opts.plot_caption = ("Red crosses: injections triggers.  ") +\
                            opts.plot_caption

logging.info("Imported and ready to go.")

# Set output directory
outdir = os.path.split(os.path.abspath(opts.output_file))[0]
if not os.path.isdir(outdir):
    os.makedirs(outdir)

# Extract IFOs and vetoes
ifos, vetoes = ppu.extract_ifos_and_vetoes(trig_file, opts.veto_files,
                                           opts.veto_category)

# Exit gracefully if the requested IFO is not available
if ifo and ifo not in ifos:
    err_msg = "The IFO selected with --ifo is unavailable in the data."
    raise RuntimeError(err_msg)

# Extract trigger data
trig_data = load_data(trig_file, vetoes, ifos, opts)

# Extract (or initialize) injection data
inj_data = load_data(found_file, vetoes, ifos, opts, injections=True)

# Sanity checks
if trig_data[snr_type] is None and inj_data[snr_type] is None:
    err_msg = "No data to be plotted on the x-axis was found"
    raise RuntimeError(err_msg)
if trig_data[veto_type] is None and inj_data[veto_type] is None:
    err_msg = "No data to be plotted on the y-axis was found"
    raise RuntimeError(err_msg)

# Generate plots
logging.info("Plotting...")

# Determine x-axis values of triggers and injections
# Default is coherent SNR
x_label = "Coherent SNR"
# Case with single SNR
if snr_type == 'single':
    x_label = "%s SNR" % ifo
if snr_type == 'coincident':
    x_label = "Coincident SNR"

# Determine the minumum and maximum SNR value we are dealing with
x_min = opts.sngl_snr_threshold
x_max = 1.1*plu.axis_max_value(trig_data[snr_type], inj_data[snr_type], found_file)

# Determine y-axis minimum value and label
y_label = "%s Single %s" % (ifo, veto_labels[veto_type])
y_min = 1

# Determine the maximum bank veto value we are dealing with
y_max = plu.axis_max_value(trig_data[veto_type], inj_data[veto_type], found_file)

# Determine contours for plots
conts = None
cont_value = None
conts, snr_vals, cont_value = calculate_contours(trig_data, opts,
                                                 veto_type, new_snrs=None)

# Initialise chisq contour values and colours
colors = contour_colors(opts, new_snrs=None)

# Produce the veto vs. SNR plot
if not opts.x_lims:
    if zoom_in:
        opts.x_lims = str(x_min)+',50'
        opts.y_lims = str(y_min)+',20000'
    else:
        opts.x_lims = str(x_min)+','+str(x_max)
        opts.y_lims = str(y_min)+','+str(10*y_max)
trigs = [trig_data[snr_type], trig_data[veto_type]]
injs = [inj_data[snr_type], inj_data[veto_type]]
plu.pygrb_plotter(trigs, injs, x_label, y_label, opts,
                  snr_vals=snr_vals, conts=conts, colors=colors,
                  shade_cont_value=cont_value, vert_spike=True,
                  cmd=' '.join(sys.argv))
