#!/usr/bin/env python
#
# Copyright (C) 2019 Gino Contestabile, Francesco Pannarale
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.


"""
Plot single IFO SNR vs coherent SNR for a PyGRB run.
"""

# =============================================================================
# Preamble
# =============================================================================
import sys
import os
import logging
import collections
import operator
from matplotlib import pyplot as plt
from matplotlib import rc
import numpy
import scipy
import pycbc.version
from pycbc.detector import Detector
from pycbc import init_logging
from pycbc.results import save_fig_with_metadata
from pycbc.results import pygrb_postprocessing_utils as ppu
from pycbc.results import pygrb_plotting_utils as plu

plt.switch_backend('Agg')
rc('font', size=14)

__author__ = "Francesco Pannarale <francesco.pannarale@ligo.org>"
__version__ = pycbc.version.git_verbose_msg
__date__ = pycbc.version.date
__program__ = "pycbc_pygrb_plot_coh_ifosnr"


# =============================================================================
# Functions
# =============================================================================
# Function to load necessary SNR data from a trigger/injection file
def load_data(input_file, vetoes, ifos, opts, injections=False):
    """Build a dictionary containing SNR and sensitivity data extracted
    from a trigger/injection file"""

    # Inizialize the dictionary
    data = {}
    data['snr'] = None
    data['ifo_snr'] = None
    data['f_resp_mean'] = dict((ifo, None) for ifo in ifos)
    data['sigma_mean'] = dict((ifo, None) for ifo in ifos)
    data['sigma_max'] = None
    data['sigma_min'] = None

    # Fill in with required data
    if input_file:
        if injections:
            trigs_or_injs = ppu.load_injections(input_file, vetoes)
        else:
            trigs_or_injs = ppu.load_triggers(input_file, vetoes)
            sigma = trigs_or_injs.get_sigmasqs()
            sigma_tot = numpy.zeros(len(trigs_or_injs))
            for ifo in ifos:
                antenna = Detector(ifo)
                ifo_f_resp = \
                    ppu.get_antenna_responses(antenna,
                                              trigs_or_injs.get_column('ra'),
                                              trigs_or_injs.get_column('dec'),
                                              numpy.asarray(trigs_or_injs.get_end()))
                data['f_resp_mean'][ifo] = ifo_f_resp.mean()
                sigma_tot += (sigma[ifo] * ifo_f_resp)
            for ifo in ifos:
                try:
                    sigma_norm = sigma[ifo]/sigma_tot
                    data['sigma_mean'][ifo] = sigma_norm.mean()
                    if ifo == opts.ifo:
                        data['sigma_max'] = sigma_norm.max()
                        data['sigma_min'] = sigma_norm.min()
                except ValueError:
                    data['sigma_mean'][ifo] = 0
                    if ifo == opts.ifo:
                        data['sigma_max'] = 0
                        data['sigma_min'] = 0
        data['snr'] = numpy.asarray(trigs_or_injs.get_column('snr'))
        data['ifo_snr'] = trigs_or_injs.get_sngl_snr(opts.ifo)

    return data


# Plot lines representing deviations based on non-central chi-square
def plot_deviation(percentile, snr_grid, y, ax, style):
    """Plot deviations based on non-central chi-square"""

    # ncx2: non-central chi-squared; ppf: percent point function
    #ax.plot(snr_grid, scipy.stats.ncx2.ppf(percentile, 2, y*y)**0.5, style)

    # Using interpolation to work around "saturation" given by the
    # original code line (commented out above)
    y_vals = scipy.stats.ncx2.ppf(percentile, 2, y*y)**0.5
    y_vals = numpy.unique(y_vals)
    x_vals = snr_grid[0:len(y_vals)]
    n_vals = int(len(y_vals)/2)
    f = scipy.interpolate.interp1d(x_vals[0:n_vals], y_vals[0:n_vals],
                                   bounds_error=False, fill_value="extrapolate")
    ax.plot(snr_grid, f(snr_grid), style)


# =============================================================================
# Main script starts here
# =============================================================================
parser = ppu.pygrb_initialize_plot_parser(description=__doc__,
                                          version=__version__)
parser.add_argument("-t", "--trig-file", action="store",
                    default=None, required=True,
                    help="The location of the trigger file")
parser.add_argument("-z", "--zoom-in", default=False, action="store_true",
                    help="Output file a zoomed in version of the plot.")
opts = parser.parse_args()

init_logging(opts.verbose, format="%(asctime)s: %(levelname)s: %(message)s")

# Check options
trig_file = os.path.abspath(opts.trig_file)
found_file = os.path.abspath(opts.found_file) if opts.found_file else None
zoom_in = opts.zoom_in
ifo = opts.ifo
if ifo is None:
    err_msg = "Please specify an interferometer"
    parser.error(err_msg)

if opts.plot_title is None:
    opts.plot_title = '%s SNR vs Coherent SNR' % ifo
if opts.plot_caption is None:
    opts.plot_caption = ("Blue crosses: background triggers.  ")
    if found_file:
        opts.plot_caption += ("Red crosses: injections triggers.  ")
    opts.plot_caption = opts.plot_caption +\
                         "Black line: veto line.  " +\
                         "Gray shaded region: vetoed area - The cut is " +\
                         "applied only to the two most sensitive detectors, " +\
                         "which can vary with mass and sky location.  " +\
                         "Green lines: the expected SNR for optimally " +\
                         "oriented injections (mean, min, and max).  " +\
                         "Magenta lines: 2 sigma errors.  " +\
                         "Cyan lines: 3 sigma errors."

logging.info("Imported and ready to go.")

# Set output directories
outdir = os.path.split(os.path.abspath(opts.output_file))[0]
if not os.path.isdir(outdir):
    os.makedirs(outdir)

# Extract IFOs and vetoes
ifos, vetoes = ppu.extract_ifos_and_vetoes(trig_file, opts.veto_files,
                                           opts.veto_category)

# Extract trigger data
trig_data = load_data(trig_file, vetoes, ifos, opts)

# Extract (or initialize) injection data
inj_data = load_data(found_file, vetoes, ifos, opts, injections=True)

# Generate plots
logging.info("Plotting...")

# Order the IFOs by sensitivity
ifo_senstvty = {}
for i_ifo in ifos:
    senstvty = trig_data['f_resp_mean'][i_ifo]*trig_data['sigma_mean'][i_ifo]
    ifo_senstvty.update({i_ifo: senstvty})
ifo_senstvty = collections.OrderedDict(sorted(ifo_senstvty.items(),
                                              key=operator.itemgetter(1),
                                              reverse=True))
loudness_labels = ['Loudest', 'Second loudest', 'Third loudest']

# Determine the maximum coherent SNR value we are dealing with
x_max = plu.axis_max_value(trig_data['snr'], inj_data['snr'], found_file)
max_snr = x_max
if x_max < 50.:
    max_snr = 50.

# Determine the maximum auto veto value we are dealing with
y_max = plu.axis_max_value(trig_data['ifo_snr'], inj_data['ifo_snr'], found_file)

# Setup the plots
x_label = "Coherent SNR"
y_label = "%s sngl SNR" % ifo
fig = plt.figure()
ax = fig.gca()
# Plot trigger data
ax.plot(trig_data['snr'], trig_data['ifo_snr'], 'bx')
ax.grid()
# Plot injection data
if found_file:
    ax.plot(inj_data['snr'], inj_data['ifo_snr'], 'r+')
# Sigma-mean, min, max
y_data = {'mean': trig_data['sigma_mean'][ifo],
          'min': trig_data['sigma_min'],
          'max': trig_data['sigma_max']}
# Calculate: zoom-snr * sqrt(response * sigma-mean, min, max)
#snr_grid = numpy.arange(0.01, max_snr, 0.01)
snr_grid = numpy.arange(0.01, max_snr, 1)
#y_data = [snr_grid*(trig_data['f_resp_mean'][ifo]*x)**0.5 for x in y_data]
y_data = dict((key, snr_grid*(trig_data['f_resp_mean'][ifo]*y_data[key])**0.5) for key in y_data)
for key in y_data:
    ax.plot(snr_grid, y_data[key], 'g-')
# 2 sigma (0.9545)
plot_deviation(0.02275, snr_grid, y_data['min'], ax, 'm-')
plot_deviation(1-0.02275, snr_grid, y_data['max'], ax, 'm-')
# 3 sigma (0.9973)
plot_deviation(0.00135, snr_grid, y_data['min'], ax, 'c-')
plot_deviation(1-0.00135, snr_grid, y_data['max'], ax, 'c-')
# Non-zoomed plot
ax.plot([0, max_snr], [4, 4], 'k-')
ax.set_xlabel(x_label)
ax.set_ylabel(y_label)
ax.set_xlim([0, 1.01*x_max])
ax.set_ylim([0, 1.20*y_max])
# Veto applies to the two most sensitive IFOs, so shade them
loudness_index = list(ifo_senstvty.keys()).index(ifo)
if loudness_index < 2:
    limy = ax.get_ylim()[0]
    polyx = [0, max_snr]
    polyy = [4, 4]
    polyx.extend([max_snr, 0])
    polyy.extend([limy, limy])
    ax.fill(polyx, polyy, color='#dddddd')
opts.plot_title = opts.plot_title + " (%s SNR)" % loudness_labels[loudness_index]
# Zoom in if asked to do so
if opts.zoom_in:
    ax.set_xlim([6, 50])
    ax.set_ylim([0, 20])
# Save plot
save_fig_with_metadata(fig, opts.output_file, cmd=' '.join(sys.argv),
                       title=opts.plot_title, caption=opts.plot_caption)
plt.close()
