#!/usr/bin/env python
#
# Copyright (C) 2019 Gino Contestabile, Francesco Pannarale
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.


"""
Produces signal consistency plots of the form standard/bank/auto chi-square vs coherent SNR.
"""

# =============================================================================
# Preamble
# =============================================================================
import sys
import os
import logging
import numpy
from matplotlib import pyplot as plt
from matplotlib import rc
import pycbc.version
from pycbc import init_logging
from pycbc.results import pygrb_postprocessing_utils as ppu
from pycbc.results import pygrb_plotting_utils as plu

plt.switch_backend('Agg')
rc('font', size=14)

__author__ = "Francesco Pannarale <francesco.pannarale@ligo.org>"
__version__ = pycbc.version.git_verbose_msg
__date__ = pycbc.version.date
__program__ = "pycbc_pygrb_plot_chisq_veto"


# =============================================================================
# Functions
# =============================================================================
# Format single detector chi-square data as numpy array and floor at 0.005
def format_single_chisqs(trig_ifo_cs, ifos):
    """Format single IFO chi-square data as numpy array and floor at 0.005"""

    for ifo in ifos:
        trig_ifo_cs[ifo] = numpy.asarray(trig_ifo_cs[ifo])
        numpy.putmask(trig_ifo_cs[ifo], trig_ifo_cs[ifo] == 0, 0.005)

    return trig_ifo_cs


# Function to load necessary SNR and chi-squared data from a trigger/injection file
def load_data(input_file, vetoes, ifos, opts, injections=False):
    """Build a dictionary containing SNR and chi-squared data extracted from a
    trigger/injection file"""

    # Keys of the quantities to be extracted
    snr_type = 'snr'
    if opts.use_sngl_ifo_snr:
        snr_type += '_sngl'
    veto_type = opts.y_variable
    if opts.use_sngl_ifo_veto:
        veto_type += '_sngl'

    # Inizialize the dictionary
    data = {}
    data[snr_type] = None
    data[veto_type] = None
    data['chisq_dof'] = None
    data['bank_chisq_dof'] = None
    data['cont_chisq_dof'] = None

    # Fill the dictionary in with required data
    if input_file:
        if injections:
            trigs_or_injs = ppu.load_injections(input_file, vetoes)
        else:
            trigs_or_injs = ppu.load_triggers(input_file, vetoes)
        # Case with single SNR
        if snr_type == 'snr_sngl':
            data[snr_type] = trigs_or_injs.get_sngl_snr(opts.ifo)
        # Coherent SNR
        else:
            data[snr_type] = numpy.asarray(trigs_or_injs.get_column('snr'))
        # Chi-squared data
        prefix_dict = {'standard': '', 'bank': 'bank_', 'auto': 'cont_'}
        if '_sngl' not in veto_type:
            data[veto_type] = \
                numpy.asarray(trigs_or_injs.get_column(prefix_dict[veto_type]+'chisq'))
        elif veto_type == 'standard_sngl':
            data[veto_type] = trigs_or_injs.get_sngl_chisqs(ifos)
        # Format single IFO chi-squared data (keep only data of interest)
        if opts.use_sngl_ifo_veto:
            data[veto_type] = format_single_chisqs(data[veto_type], ifos)[opts.ifo]
        # Remove BestNR = 0 points when dealing with standard chi-square
        if 'standard' in veto_type:
            null_thresh = list(map(float, opts.null_snr_threshold.split(',')))
            data['reweighted_snr'] = ppu.get_bestnrs(trigs_or_injs,
                                                     q=opts.chisq_index,
                                                     n=opts.chisq_nhigh,
                                                     null_thresh=null_thresh,
                                                     snr_threshold=opts.snr_threshold,
                                                     sngl_snr_threshold=opts.sngl_snr_threshold,
                                                     chisq_threshold=opts.newsnr_threshold,
                                                     null_grad_thresh=opts.null_grad_thresh,
                                                     null_grad_val=opts.null_grad_val)
            data[snr_type] = data[snr_type][data['reweighted_snr'] != 0]
            data[veto_type] = data[veto_type][data['reweighted_snr'] != 0]
        # Gather data on degrees of freedom
        for cs in ['', 'bank_', 'cont_']:
            dof = cs+'chisq_dof'
            data[dof] = numpy.unique(trigs_or_injs.get_column(dof))

    return data


# Function that produces the contrours to be plotted
def calculate_contours(trig_data, opts, veto_type, new_snrs=None):
    """Generate the contours for the veto plots"""

    # Read off the degrees of freedom
    prefix_dict = {'standard': '', 'bank': 'bank_', 'auto': 'cont_'}
    dof = trig_data[prefix_dict[veto_type]+'chisq_dof'][0]

    # Add the new SNR threshold contour to the list if necessary
    # and keep track of where it is
    if new_snrs is None:
        new_snrs = [5.5, 6, 6.5, 7, 8, 9, 10, 11]
    try:
        cont_value = new_snrs.index(opts.newsnr_threshold)
    except ValueError:
        new_snrs.append(opts.newsnr_threshold)
        cont_value = -1

    # Get SNR values for contours
    snr_low_vals = numpy.arange(4, 30, 0.1)
    snr_high_vals = numpy.arange(30, 500, 1)
    snr_vals = numpy.asarray(list(snr_low_vals) + list(snr_high_vals))

    # Initialise contours
    contours = numpy.zeros([len(new_snrs), len(snr_vals)],
                           dtype=numpy.float64)
    # Loop over SNR values and calculate chisq variable needed
    for j, snr in enumerate(snr_vals):
        for i, new_snr in enumerate(new_snrs):
            contours[i][j] = plu.new_snr_chisq(snr, new_snr, dof,
                                               opts.chisq_index,
                                               opts.chisq_nhigh)

    return contours, snr_vals, cont_value


# Function that produces the contrours to be plotted
def contour_colors(opts, new_snrs=None):
    """Define the colours with which contours are plotted"""

    # Add the new SNR threshold contour to the list if necessary
    # and keep track of where it is
    if new_snrs is None:
        new_snrs = [5.5, 6, 6.5, 7, 8, 9, 10, 11]
    new_snr_thresh = opts.newsnr_threshold
    colors = ["k-" if snr == new_snr_thresh else
              "y-" if snr == int(snr) else
              "y--" for snr in new_snrs]

    return colors


# =============================================================================
# Main script starts here
# =============================================================================
parser = ppu.pygrb_initialize_plot_parser(description=__doc__,
                                          version=__version__)
parser.add_argument("-t", "--trig-file", action="store",
                    default=None, required=True,
                    help="The location of the trigger file")
parser.add_argument("-z", "--zoom-in", default=False, action="store_true",
                    help="Output file a zoomed in version of the plot.")
parser.add_argument("-y", "--y-variable", required=True,
                    choices=['standard', 'bank', 'auto'],
                    help="Quantity to plot on the vertical axis.")
parser.add_argument("--use-sngl-ifo-snr", default=False,
                    action="store_true", help="Plots are vs single IFO " +
                    "SNR, rather than coherent SNR")
parser.add_argument("--use-sngl-ifo-veto", default=False,
                    action="store_true", help="Single IFO veto values " +
                    "plotted, rather than network ones")
ppu.pygrb_add_bestnr_opts(parser)
opts = parser.parse_args()

init_logging(opts.verbose, format="%(asctime)s: %(levelname)s: %(message)s")

# Check options
trig_file = os.path.abspath(opts.trig_file)
found_file = os.path.abspath(opts.found_file) if opts.found_file else None
zoom_in = opts.zoom_in
veto_type = opts.y_variable
ifo = opts.ifo
snr_type = 'snr'
# If this is false, coherent SNR is used on the horizontal axis
# otherwise the single IFO SNR is used
use_sngl_ifo_snr = opts.use_sngl_ifo_snr
if use_sngl_ifo_snr:
    if ifo is None:
        err_msg = "--ifo must be given when using --use-sngl-ifo-snr"
        parser.error(err_msg)
    snr_type += '_sngl'
# If this is false, network vetoes are used on the vertical axis
# otherwise the veto is intended as a single IFO quantity
use_sngl_ifo_veto = opts.use_sngl_ifo_veto
if use_sngl_ifo_veto:
    if ifo is None:
        err_msg = "--ifo must be given when using --use-sngl-ifo-snr-veto"
        parser.error(err_msg)
    elif veto_type != 'standard':
        err_msg = "Single IFO values are available only for the standard chi-square"
        parser.error(err_msg)
    else:
        veto_type += '_sngl'

# Prepare plot title and caption
veto_labels = {'standard': "Chi Square Veto",
               'bank': "Bank Veto",
               'auto': "Auto Veto"}
if opts.plot_title is None:
    opts.plot_title = veto_labels[veto_type.replace('_sngl', '')]
    if use_sngl_ifo_veto:
        opts.plot_title = "%s %s" %(ifo, opts.plot_title)
    if use_sngl_ifo_snr:
        opts.plot_title += " vs %s SNR" %(ifo)
    else:
        opts.plot_title += " vs Coherent SNR"
if opts.plot_caption is None:
    opts.plot_caption = ("Blue crosses: background triggers.  " +
                         "Black line: veto line.  " +
                         "Gray shaded region: Vetoed area.  " +
                         "Yellow lines: contours of new SNR.")
    if found_file:
        opts.plot_caption = ("Red crosses: injections triggers.  ") +\
                            opts.plot_caption

logging.info("Imported and ready to go.")

# Set output directory
outdir = os.path.split(os.path.abspath(opts.output_file))[0]
if not os.path.isdir(outdir):
    os.makedirs(outdir)

# Extract IFOs and vetoes
ifos, vetoes = ppu.extract_ifos_and_vetoes(trig_file, opts.veto_files,
                                           opts.veto_category)

# Exit gracefully if the requested IFO is not available
if ifo and ifo not in ifos:
    err_msg = "The IFO selected with --ifo is unavailable in the data."
    raise RuntimeError(err_msg)

# Extract trigger data
trig_data = load_data(trig_file, vetoes, ifos, opts)

# Extract (or initialize) injection data
inj_data = load_data(found_file, vetoes, ifos, opts, injections=True)

# Sanity checks
if trig_data[snr_type] is None and inj_data[snr_type] is None:
    err_msg = "No data to be plotted on the x-axis was found"
    raise RuntimeError(err_msg)
if trig_data[veto_type] is None and inj_data[veto_type] is None:
    err_msg = "No data to be plotted on the y-axis was found"
    raise RuntimeError(err_msg)

# Generate plots
logging.info("Plotting...")

# Determine x-axis values of triggers and injections
# Default is coherent SNR
x_label = "Coherent SNR"
# Case with single SNR
if use_sngl_ifo_snr:
    x_label = "%s SNR" % ifo

# Determine the minumum and maximum SNR value we are dealing with
x_min = opts.sngl_snr_threshold
x_max = 1.1*plu.axis_max_value(trig_data[snr_type], inj_data[snr_type], found_file)

# Determine y-axis minimum value and label
y_label = veto_labels[veto_type.replace('_sngl', '')]
y_min = 1
if opts.use_sngl_ifo_veto:
    y_label = "%s Single %s" % (ifo, y_label)
    y_min = 0.001

# Determine the maximum bank veto value we are dealing with
y_max = plu.axis_max_value(trig_data[veto_type], inj_data[veto_type], found_file)

# Determine contours for plots
if opts.use_sngl_ifo_veto:
    conts = None
    cont_value = None
conts, snr_vals, cont_value = calculate_contours(trig_data, opts,
                                                 veto_type, new_snrs=None)

# Initialise chisq contour values and colours
colors = contour_colors(opts, new_snrs=None)

# Produce the veto vs. SNR plot
if not opts.x_lims:
    if zoom_in:
        opts.x_lims = str(x_min)+',50'
        opts.y_lims = str(y_min)+',20000'
    else:
        opts.x_lims = str(x_min)+','+str(x_max)
        opts.y_lims = str(y_min)+','+str(10*y_max)
trigs = [trig_data[snr_type], trig_data[veto_type]]
injs = [inj_data[snr_type], inj_data[veto_type]]
plu.pygrb_plotter(trigs, injs, x_label, y_label, opts,
                  snr_vals=snr_vals, conts=conts, colors=colors,
                  shade_cont_value=cont_value, vert_spike=True,
                  cmd=' '.join(sys.argv))
