#!/usr/bin/env python
#
# Copyright (C) 2019 Gino Contestabile, Francesco Pannarale
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.


"""
Plot single IFO SNR vs coherent SNR for a PyGRB run.
"""

# =============================================================================
# Preamble
# =============================================================================
import sys
import os
import logging
import collections
import operator
from matplotlib import pyplot as plt
from matplotlib import rc
import numpy
import scipy
import h5py
import pycbc.version
from pycbc.detector import Detector
from pycbc import init_logging
from pycbc.results import save_fig_with_metadata
from pycbc.results import pygrb_postprocessing_utils as ppu
from pycbc.results import pygrb_plotting_utils as plu

plt.switch_backend('Agg')
rc('font', size=14)

__author__ = "Francesco Pannarale <francesco.pannarale@ligo.org>"
__version__ = pycbc.version.git_verbose_msg
__date__ = pycbc.version.date
__program__ = "pycbc_pygrb_plot_coh_ifosnr"


# =============================================================================
# Functions
# =============================================================================
# Function to load trigger data
def load_data(input_file, vetoes, ifos, opts, injections=False):
    """Load data from a trigger/injection file"""

    # Initialize the dictionary
    data = {}
    data['coherent'] = None
    data['single'] = dict((ifo, None) for ifo in ifos)
    data['f_resp_mean'] = dict((ifo, None) for ifo in ifos)
    data['sigma_mean'] = dict((ifo, None) for ifo in ifos)
    data['sigma_max'] = None
    data['sigma_min'] = None

    if input_file:
        if injections:
            logging.info("Loading injections...")
            # This will eventually become load_injections
            trigs_or_injs = ppu.load_triggers(input_file, vetoes)
        else:
            logging.info("Loading triggers...")
            trigs_or_injs = ppu.load_triggers(input_file, vetoes)

        num_trigs = len(trigs['network/end_time_gc'][:])

        # Load SNR data
        data['coherent'] = trigs['network/coherent_snr'][:]

        # Get single ifo SNR data
        for ifo in ifos:
            att = ifo[0].lower()
            data['single'][ifo] = trigs['%s/snr_%s' % (ifo, att)][:]

        # Get sigma for each ifo
        for ifo in ifos:
            sigma = dict((ifo, list(
                trigs['%s/sigmasq' % ifo])) for ifo in ifos)

        # Create array for sigma_tot
        sigma_tot = numpy.zeros(num_trigs)

        # Get parameters necessary for antenna responses
        longitude = numpy.degrees(trigs['network/longitude'])
        latitude = numpy.degrees(trigs['network/latitude'])
        time = trigs['network/end_time_gc'][:]

        # Get antenna response based parameters
        for ifo in ifos:
            antenna = Detector(ifo)
            ifo_f_resp = \
                ppu.get_antenna_responses(antenna, longitude,
                                          latitude, time)

            # Get the average for f_resp_mean and calculate sigma_tot
            data['f_resp_mean'][ifo] = ifo_f_resp.mean()
            sigma_tot += (sigma[ifo] * ifo_f_resp)

        for ifo in ifos:
            try:
                sigma_norm = sigma[ifo]/sigma_tot
                data['sigma_mean'][ifo] = sigma_norm.mean()
                if ifo == opts.ifo:
                    data['sigma_max'] = sigma_norm.max()
                    data['sigma_min'] = sigma_norm.min()
            except ValueError:
                data['sigma_mean'][ifo] = 0
                if ifo == opts.ifo:
                    data['sigma_max'] = 0
                    data['sigma_min'] = 0

        logging.info("%d triggers found.", num_trigs)

    return data

# Plot lines representing deviations based on non-central chi-square
def plot_deviation(percentile, snr_grid, y, ax, style):
    """Plot deviations based on non-central chi-square"""

    # ncx2: non-central chi-squared; ppf: percent point function
    #ax.plot(snr_grid, scipy.stats.ncx2.ppf(percentile, 2, y*y)**0.5, style)

    # Using interpolation to work around "saturation" given by the
    # original code line (commented out above)
    y_vals = scipy.stats.ncx2.ppf(percentile, 2, y*y)**0.5
    y_vals = numpy.unique(y_vals)
    x_vals = snr_grid[0:len(y_vals)]
    n_vals = int(len(y_vals)/2)
    f = scipy.interpolate.interp1d(x_vals[0:n_vals], y_vals[0:n_vals],
                                   bounds_error=False, fill_value="extrapolate")
    ax.plot(snr_grid, f(snr_grid), style)


# =============================================================================
# Main script starts here
# =============================================================================
parser = ppu.pygrb_initialize_plot_parser(description=__doc__,
                                          version=__version__)
parser.add_argument("-t", "--trig-file", action="store",
                    default=None, required=True,
                    help="The location of the trigger file")
parser.add_argument("-z", "--zoom-in", default=False, action="store_true",
                    help="Output file a zoomed in version of the plot.")
opts = parser.parse_args()

init_logging(opts.verbose, format="%(asctime)s: %(levelname)s: %(message)s")

# Check options
trig_file = os.path.abspath(opts.trig_file)
found_file = os.path.abspath(opts.found_file) if opts.found_file else None
zoom_in = opts.zoom_in
ifo = opts.ifo
if ifo is None:
    err_msg = "Please specify an interferometer"
    parser.error(err_msg)

if opts.plot_title is None:
    opts.plot_title = '%s SNR vs Coherent SNR' % ifo
if opts.plot_caption is None:
    opts.plot_caption = ("Blue crosses: background triggers.  ")
    if found_file:
        opts.plot_caption += ("Red crosses: injections triggers.  ")
    opts.plot_caption = opts.plot_caption +\
                         "Black line: veto line.  " +\
                         "Gray shaded region: vetoed area - The cut is " +\
                         "applied only to the two most sensitive detectors, " +\
                         "which can vary with mass and sky location.  " +\
                         "Green lines: the expected SNR for optimally " +\
                         "oriented injections (mean, min, and max).  " +\
                         "Magenta lines: 2 sigma errors.  " +\
                         "Cyan lines: 3 sigma errors."

logging.info("Imported and ready to go.")

# Set output directories
outdir = os.path.split(os.path.abspath(opts.output_file))[0]
if not os.path.isdir(outdir):
    os.makedirs(outdir)

# Extract IFOs and vetoes
ifos, vetoes = ppu.extract_ifos_and_vetoes(trig_file, opts.veto_files,
                                           opts.veto_category)

# Extract trigger data
trig_data = load_data(trig_file, vetoes, ifos, opts)

# Extract (or initialize) injection data
inj_data = load_data(found_file, vetoes, ifos, opts, injections=True)

# Generate plots
logging.info("Plotting...")

# Order the IFOs by sensitivity
ifo_senstvty = {}
for i_ifo in ifos:
    senstvty = trig_data['f_resp_mean'][i_ifo]*trig_data['sigma_mean'][i_ifo]
    ifo_senstvty.update({i_ifo: senstvty})
ifo_senstvty = collections.OrderedDict(sorted(ifo_senstvty.items(),
                                              key=operator.itemgetter(1),
                                              reverse=True))
loudness_labels = ['Loudest', 'Second loudest', 'Third loudest']

# Determine the maximum coherent SNR value we are dealing with
x_max = plu.axis_max_value(trig_data['coherent'], inj_data['coherent'], found_file)
max_snr = x_max
if x_max < 50.:
    max_snr = 50.

# Determine the maximum auto veto value we are dealing with
y_max = plu.axis_max_value(trig_data['single'][opts.ifo], inj_data['single'][opts.ifo], found_file)

# Setup the plots
x_label = "Coherent SNR"
y_label = "%s sngl SNR" % ifo
fig = plt.figure()
ax = fig.gca()
# Plot trigger data
ax.plot(trig_data['coherent'], trig_data['single'][opts.ifo], 'bx')
ax.grid()
# Plot injection data
if found_file:
    ax.plot(inj_data['coherent'], inj_data['single'][opts.ifo], 'r+')
# Sigma-mean, min, max
y_data = {'mean': trig_data['sigma_mean'][ifo],
          'min': trig_data['sigma_min'],
          'max': trig_data['sigma_max']}
# Calculate: zoom-snr * sqrt(response * sigma-mean, min, max)
#snr_grid = numpy.arange(0.01, max_snr, 0.01)
snr_grid = numpy.arange(0.01, max_snr, 1)
#y_data = [snr_grid*(trig_data['f_resp_mean'][ifo]*x)**0.5 for x in y_data]
y_data = dict((key, snr_grid*(trig_data['f_resp_mean'][ifo]*y_data[key])**0.5) for key in y_data)
for key in y_data:
    ax.plot(snr_grid, y_data[key], 'g-')
# 2 sigma (0.9545)
plot_deviation(0.02275, snr_grid, y_data['min'], ax, 'm-')
plot_deviation(1-0.02275, snr_grid, y_data['max'], ax, 'm-')
# 3 sigma (0.9973)
plot_deviation(0.00135, snr_grid, y_data['min'], ax, 'c-')
plot_deviation(1-0.00135, snr_grid, y_data['max'], ax, 'c-')
# Non-zoomed plot
ax.plot([0, max_snr], [4, 4], 'k-')
ax.set_xlabel(x_label)
ax.set_ylabel(y_label)
ax.set_xlim([0, 1.01*x_max])
ax.set_ylim([0, 1.20*y_max])
# Veto applies to the two most sensitive IFOs, so shade them
loudness_index = list(ifo_senstvty.keys()).index(ifo)
if loudness_index < 2:
    limy = ax.get_ylim()[0]
    polyx = [0, max_snr]
    polyy = [4, 4]
    polyx.extend([max_snr, 0])
    polyy.extend([limy, limy])
    ax.fill(polyx, polyy, color='#dddddd')
opts.plot_title = opts.plot_title + " (%s SNR)" % loudness_labels[loudness_index]
# Zoom in if asked to do so
if opts.zoom_in:
    ax.set_xlim([6, 50])
    ax.set_ylim([0, 20])
# Save plot
save_fig_with_metadata(fig, opts.output_file, cmd=' '.join(sys.argv),
                       title=opts.plot_title, caption=opts.plot_caption)
plt.close()
