#!/usr/bin/env python

# Copyright (C) 2021 Francesco Pannarale
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.


"""
Plot distribution of BestNR/SNR/SNR after cuts of triggers in a PyGRB run.
"""

# =============================================================================
# Preamble
# =============================================================================
import os
import logging
import sys
import matplotlib.pyplot as plt
from matplotlib import rc
import numpy as np
import pycbc.version
from pycbc import init_logging
from pycbc.results import save_fig_with_metadata
from pycbc.results import pygrb_postprocessing_utils as ppu
try:
    from glue.ligolw import lsctables
except ImportError:
    pass

plt.switch_backend('Agg')
rc("image", cmap="cividis_r")

__author__ = "Francesco Pannarale <francesco.pannarale@ligo.org>"
__version__ = pycbc.version.git_verbose_msg
__date__ = pycbc.version.date
__program__ = "pycbc_pygrb_plot_stats_distribution"


# =============================================================================
# Main script starts here
# =============================================================================
parser = ppu.pygrb_initialize_plot_parser(description=__doc__,
                                          version=__version__)
parser.add_argument("-F", "--offsource-file", action="store", required=True,
                    default=None, help="Location of off-source trigger file")
parser.add_argument("-x", "--x-variable", default=None, required=True,
                    choices=["bestnr", "snr", "snruncut"],
                    help="Quantity to plot on the horizontal axis.")
#parser.add_argument("-S", "--seed", action="store", type=int,
#                    default=1234, help="Seed to initialize Monte Carlo.")
ppu.pygrb_add_bestnr_opts(parser)
opts = parser.parse_args()

init_logging(opts.verbose, format="%(asctime)s: %(levelname)s: %(message)s")

# Check options
if not opts.newsnr_threshold:
    opts.newsnr_threshold = opts.snr_threshold

# Store options used multiple times in local variables
trig_file = opts.offsource_file
chisq_nhigh = opts.chisq_nhigh
stat = opts.x_variable

# Set output directory
logging.info("Setting output directory.")
outdir = os.path.split(os.path.abspath(opts.output_file))[0]
if not os.path.isdir(outdir):
    os.makedirs(outdir)

# Extract IFOs and vetoes
ifos, vetoes = ppu.extract_ifos_and_vetoes(trig_file, opts.veto_files,
                                           opts.veto_category)

# Load triggers, time-slides, and segment dictionary
logging.info("Loading triggers.")
trigs = ppu.load_xml_table(trig_file, lsctables.MultiInspiralTable.tableName)
logging.info("%d triggers loaded.", len(trigs))
logging.info("Loading timeslides.")
slide_dict = ppu.load_time_slides(trig_file)
logging.info("Loading segments.")
segment_dict = ppu.load_segment_dict(trig_file)

# Construct trials
logging.info("Constructing trials.")
trial_dict = ppu.construct_trials(opts.seg_files, segment_dict,
                                  ifos, slide_dict, vetoes)
total_trials = sum([len(trial_dict[slide_id]) for slide_id in slide_dict])
logging.info("%d trials generated.", total_trials)

# Extract basic trigger properties and store as dictionaries
trig_time, trig_snr, trig_bestnr = \
    ppu.extract_basic_trig_properties(trial_dict, trigs, slide_dict, segment_dict, opts)

# Calculate SNR and BestNR values and maxima
time_veto_max_snr = {}
time_veto_max_bestnr = {}
time_veto_max_snr_uncut = {}

for slide_id in slide_dict:
    num_slide_segs = len(trial_dict[slide_id])
    time_veto_max_snr[slide_id] = np.zeros(num_slide_segs)
    time_veto_max_bestnr[slide_id] = np.zeros(num_slide_segs)
    time_veto_max_snr_uncut[slide_id] = np.zeros(num_slide_segs)

for slide_id in slide_dict:
    for j, trial in enumerate(trial_dict[slide_id]):
        trial_cut = (trial[0] <= trig_time[slide_id])\
                          & (trig_time[slide_id] < trial[1])
        if not trial_cut.any():
            continue
        # Max SNR
        time_veto_max_snr[slide_id][j] = \
                        max(trig_snr[slide_id][trial_cut])
        # Max BestNR
        time_veto_max_bestnr[slide_id][j] = \
                        max(trig_bestnr[slide_id][trial_cut])
        # Max SNR for triggers passing SBVs
        sbv_cut = trig_bestnr[slide_id] != 0
        if not (trial_cut&sbv_cut).any():
            continue
        time_veto_max_snr_uncut[slide_id][j] =\
                          max(trig_snr[slide_id][trial_cut & sbv_cut])

# This is the data that will be plotted
full_time_veto_max_snr = ppu.sort_stat(time_veto_max_snr)
full_time_veto_max_snr_uncut = ppu.sort_stat(time_veto_max_snr_uncut)
_, _, full_time_veto_max_bestnr =\
    ppu.max_median_stat(slide_dict, time_veto_max_bestnr, trig_bestnr, total_trials)
# The 0.'s here force the histograms to start at (0, 1) if no trial
# returned a no-event (i.e., BestNR = 0)
if full_time_veto_max_bestnr[0] != 0.:
    full_time_veto_max_snr = np.concatenate(([0.], full_time_veto_max_snr))
    full_time_veto_max_snr_uncut = np.concatenate(([0.], full_time_veto_max_snr_uncut))
    full_time_veto_max_bestnr = np.concatenate(([0.], full_time_veto_max_bestnr))

logging.info("SNR and bestNR maxima calculated.")


# =========
# Make plot
# =========
x_label_dict = {"bestnr": "BestNR",
                "snr": "SNR",
                "snruncut": "SNR after signal based vetoes"}
data_dict = {"bestnr": full_time_veto_max_bestnr,
             "snr": full_time_veto_max_snr,
             "snruncut": full_time_veto_max_snr_uncut}
fig = plt.figure()
ax = fig.gca()
ax.grid(True)
ax.set_yscale("log")
ax.set_xlabel(x_label_dict[stat])
ax.set_ylabel("False alarm probability")
# Some standard plot settings
ymax = 1.2
normalization = 1.
epsilon = 1e-8
num_bins = 50
# Figure out binning
min_stat = min(data_dict[stat])
max_stat = max(data_dict[stat])
bins = np.linspace(min_stat, max_stat, num_bins + 1, endpoint=True)
bins_bg = np.append(bins, float('Inf'))
dx = bins[1] - bins[0]
# Histogram each background instance separately and take stats
hist_sum = np.zeros(len(bins), dtype=float)
sq_hist_sum = np.zeros(len(bins), dtype=float)
# Make histogram
for instance in data_dict[stat]:
    y, x = np.histogram(instance, bins=bins_bg)
    x = np.delete(x, -1)
    y = y[::-1].cumsum()[::-1]
    hist_sum += y
    sq_hist_sum += y*y
# Get statistics
N = len(data_dict[stat])
means = hist_sum / N
stds = np.sqrt((sq_hist_sum - hist_sum*means) / (N - 1))
means[means <= epsilon] = epsilon
upper = means + stds
lower = means - stds
lower[lower <= epsilon] = epsilon
# Normalize
means = means*normalization
upper = upper*normalization
lower = lower*normalization
# Plot mean
ax.plot(x + dx/2, means, 'r+', markersize=12)
# 3 lines for aesthetic reasons: ensure that the width of the
# bar on the far right right is the same as all the others
x = np.append(x, x[-1]+dx/2)
upper = np.append(upper, upper[-1])
lower = np.append(lower, lower[-1])
# Shade in the 1-standard deviation area
ax.fill_between(x + dx/2, lower, upper, alpha=0.3, facecolor='y', step='mid')
# Wrap up
ax.set_xlim((0.9 * min_stat, 1.1 * max_stat))
ax.set_ylim((0.6/N, ymax))
plot_title = "Cumulative distribution of background triggers"
plot_caption = "Background cumulative distribution of the %s " %x_label_dict[stat]
fig_path = opts.output_file
save_fig_with_metadata(fig, fig_path, cmd=' '.join(sys.argv),
                       title=plot_title, caption=plot_caption)
plt.close()

logging.info("Plots complete.")
