#!/usr/bin/env python

# Copyright (C) 2021 Francesco Pannarale
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.


"""
Plot found/missed injection properties for the triggered search (PyGRB).'
"""

# =============================================================================
# Preamble
# =============================================================================
import sys
import os
import logging
import matplotlib.pyplot as plt
from matplotlib import rc
import numpy as np
import pycbc.version
from pycbc import init_logging
from pycbc.results import save_fig_with_metadata
from pycbc.results import pygrb_postprocessing_utils as ppu
try:
    from glue.ligolw import lsctables
except ImportError:
    pass

plt.switch_backend('Agg')
rc("image")

__author__ = "Francesco Pannarale <francesco.pannarale@ligo.org>"
__version__ = pycbc.version.git_verbose_msg
__date__ = pycbc.version.date
__program__ = "pycbc_pygrb_plot_injs_results"


# =============================================================================
# Functions
# =============================================================================
def process_var_strings(qty):
    """Add underscores to match xml column name conventions"""

    qty = qty.replace('effdist', 'eff_dist')
    qty = qty.replace('effsitedist', 'eff_site_dist')
    qty = qty.replace('skyerror', 'sky_error')
    qty = qty.replace('cos', 'cos_')
    qty = qty.replace('abs', 'abs_')

    return qty


# Function to extract inclination related data from a trigger/injection file
def load_incl_data(raw_data, qty):
    """Extract data related to inclination from raw trigger/injection data"""

    local_dict = {}

    # Whether the user requests incl, |incl|, cos(incl), or cos(|incl|)
    # the following information is needed
    local_dict['incl'] = np.asarray(raw_data.get_column('inclination'))

    # Requesting |incl| or cos(|incl|)
    if 'abs_' in qty:
        local_dict['abs_incl'] = 0.5*np.pi - abs(local_dict['incl'] - 0.5*np.pi)

    # Requesting cos(incl) or cos(|incl|): take cosine
    if 'cos_' in qty:
        angle = qty.replace('cos_', '')
        angle_data = local_dict[angle]
        data = np.cos(angle_data)
    # Requesting incl or abs_incl: convert to degrees
    else:
        data = np.rad2deg(local_dict[qty])

    return data


# Function to extract spin related data from a trigger/injection file
def load_spin_data(raw_data, qty):
    """Extract data related to spin from raw trigger/injection data"""

    local_dict = {}

    # Grab the components
    for component in ['x', 'y', 'z']:
        key = qty+component
        local_dict[key] = np.asarray(raw_data.get_column(key))

    # Calculate the modulus
    data = np.sqrt(local_dict[qty+'x']**2 +
                   local_dict[qty+'y']**2 +
                   local_dict[qty+'z']**2)

    return data


# Function to extract mass ratio or total mass data from a trigger/injection file
def load_mtot_q_data(raw_data, qty):
    """Extract data related to mass ratio or total mass from raw
    trigger/injection data"""

    local_dict = {}

    # Grab individual mass components
    for mi in ['mass1', 'mass2']:
        local_dict[mi] = np.asarray(raw_data.get_column(mi))

    if qty == 'mtotal':
        data = local_dict['mass1'] + local_dict['mass2']
    else:
        data = local_dict['mass2']/local_dict['mass1']
        data = np.where(data > 1, 1./data, data)

    return data


# Function to extract mass ratio or total mass data from a trigger/injection file
def load_effdist_data(raw_data, qty, opts, sites):
    """Extract data related to effective distances from raw trigger/injection data"""

    local_dict = {}

    if qty == 'eff_site_dist':
        data = raw_data.get_column('eff_dist_%s' % opts.ifo[0].lower())
    else:
        local_dict['eff_site_dist'] =\
            dict((ifo, raw_data.get_column('eff_dist_%s' % ifo.lower()))
                 for ifo in sites)
        # Effective distance (inverse sum of inverse effective distances)
        data = np.power(np.power(np.asarray(
            list(local_dict['eff_site_dist'].values())), -1).sum(0), -1)

    return data


# Function to extract desired data from a trigger/injection file
def load_data(raw_data, keys, opts, sites):
    """Create a dictionary containing the data specified by the
    list of keys extracted from a trigger/injection file"""

    data_dict = {}

    easy_keys = ['mchirp', 'mass1', 'mass2', 'distance',
                 'spin1x', 'spin1y', 'spin1z',
                 'spin2x', 'spin2y', 'spin2z',
                 'ra', 'dec']
    col_names = ['mchirp', 'mass1', 'mass2', 'distance',
                 'spin1x', 'spin1y', 'spin1z',
                 'spin2x', 'spin2y', 'spin2z',
                 'longitude', 'latitude']
    col_dict = dict(zip(easy_keys, col_names))

    for qty in keys:
        if qty in easy_keys:
            data_dict[qty] = np.asarray(raw_data.get_column(col_dict[qty]))
        elif qty == 'time':
            col_name = 'geocent_end_time'
            data_dict[qty] = np.asarray(raw_data.get_column(col_name)) +\
                             np.asarray(raw_data.get_column(col_name+'_ns')*\
                                        10**-9)
            grb_time = ppu.get_grb_time(opts.seg_files)
            data_dict[qty] -= grb_time
        elif qty in ['mtotal', 'q']:
            data_dict[qty] = load_mtot_q_data(raw_data, qty)
        elif qty in ['eff_site_dist', 'eff_dist']:
            data_dict[qty] = load_effdist_data(raw_data, qty, opts, sites)
        elif 'incl' in qty:
            data_dict[qty] = load_incl_data(raw_data, qty)
        # This handles spin1 and spin2, i.e. spin magnitudes, as components
        # are dealt with in easy_keys (first if)
        elif 'spin' in qty:
            data_dict[qty] = load_spin_data(raw_data, qty)

    return data_dict


# Function to cherry-pick a subset of full_data specified by fap_mask.
# The BestNR values are also picked up.
def grab_triggers_subset(full_data, fap_mask, found_trigs):
    """Separate out a subset of full_data based on fap_mask and include
    the detection statistics values"""

    data_subset = {}

    for qty in full_data.keys():
        data_subset[qty] = full_data[qty][fap_mask]

    data_subset['bestnr'] = found_trigs['bestnr'][fap_mask]

    return data_subset


# =============================================================================
# Main script starts here
# =============================================================================
parser = ppu.pygrb_initialize_plot_parser(description=__doc__,
                                          version=__version__)
parser.add_argument("-F", "--offsource-file", action="store", required=True,
                    default=None, help="Location of off-source trigger file")
admitted_vars = ['mchirp', 'mtotal', 'q', 'distance',
                 'eff_dist', 'eff_site_dist', 'time',
                 'sky_error', 'ra', 'dec', 'incl',
                 'cos_incl', 'abs_incl', 'cos_abs_incl',
                 'spin1', 'spin1x', 'spin1y', 'spin1z',
                 'spin2', 'spin2x', 'spin2y', 'spin2z',
                 'effdist', 'effsitedist', 'skyerror',
                 'cosincl', 'absincl', 'cosabsincl']
parser.add_argument("-x", "--x-variable", default=None, required=True,
                    choices=admitted_vars,
                    help="Quantity to plot on the horizontal axis. "+
                    "(Underscores may be omitted in specifying this option).")
parser.add_argument("-y", "--y-variable", default=None, required=True,
                    choices=admitted_vars,
                    help="Quantity to plot on the vertical axis. "+
                    "(Underscores may be omitted in specifying this option).")
parser.add_argument("--x-log", action="store_true",
                    help="Use log horizontal axis")
parser.add_argument("--y-log", action="store_true",
                    help="Use log vertical axis")
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument("--found-missed", action="store_true",
                   help="Plot found injections over missed injections.")
group.add_argument("--missed-found", action="store_false",
                   help="Plot missed injections over found injections.")
ppu.pygrb_add_missed_injs_input_opt(parser)
ppu.pygrb_add_bestnr_opts(parser)
opts = parser.parse_args()

init_logging(opts.verbose, format="%(asctime)s: %(levelname)s: %(message)s")

# Check options
if (opts.found_file is None) or (opts.missed_file is None):
    err_msg = "Must provide both found and missed injections files."
    raise RuntimeError(err_msg)

x_qty = process_var_strings(opts.x_variable)
y_qty = process_var_strings(opts.y_variable)

if 'eff_site_dist' in [x_qty, y_qty] and opts.ifo is None:
    err_msg = "A value for --ifo must be provided for "
    err_msg += "site specific effective distance"
    parser.error(err_msg)
if not opts.newsnr_threshold:
    opts.newsnr_threshold = opts.snr_threshold

# Store options used multiple times in local variables
outfile = opts.output_file
fm_or_mf = "found_missed" if opts.found_missed else "missed_found"
trig_file = opts.offsource_file
found_file = opts.found_file
missed_file = opts.missed_file
chisq_index = opts.chisq_index
chisq_nhigh = opts.chisq_nhigh
snr_thresh = opts.snr_threshold
sngl_snr_thresh = opts.sngl_snr_threshold
new_snr_thresh = opts.newsnr_threshold
null_grad_thresh = opts.null_grad_thresh
null_grad_val = opts.null_grad_val
null_thresh = list(map(float, opts.null_snr_threshold.split(',')))
grb_time = ppu.get_grb_time(opts.seg_files)

# Set output directory
logging.info("Setting output directory.")
outdir = os.path.split(os.path.abspath(outfile))[0]
if not os.path.isdir(outdir):
    os.makedirs(outdir)

# Extract IFOs and vetoes
ifos, vetoes = ppu.extract_ifos_and_vetoes(trig_file, opts.veto_files,
                                           opts.veto_category)

# Load triggers, time-slides, and segment dictionary
logging.info("Loading triggers.")
trigs = ppu.load_xml_table(trig_file, lsctables.MultiInspiralTable.tableName)
logging.info("%d triggers loaded.", len(trigs))
logging.info("Loading timeslides.")
slide_dict = ppu.load_time_slides(trig_file)
logging.info("Loading segments.")
segment_dict = ppu.load_segment_dict(trig_file)

# Construct trials
logging.info("Constructing trials.")
trial_dict = ppu.construct_trials(opts.seg_files, segment_dict,
                                  ifos, slide_dict, vetoes)
total_trials = sum([len(trial_dict[slide_id]) for slide_id in slide_dict])
logging.info("%d trials generated.", total_trials)

# Extract basic trigger properties and store as dictionaries
trig_time, _, trig_bestnr = \
    ppu.extract_basic_trig_properties(trial_dict, trigs, slide_dict, segment_dict, opts)

# Calculate SNR and BestNR values and maxima
time_veto_max_bestnr = {}
for slide_id in slide_dict:
    num_slide_segs = len(trial_dict[slide_id])
    time_veto_max_bestnr[slide_id] = np.zeros(num_slide_segs)

for slide_id in slide_dict:
    for j, trial in enumerate(trial_dict[slide_id]):
        trial_cut = (trial[0] <= trig_time[slide_id])\
                          & (trig_time[slide_id] < trial[1])
        if not trial_cut.any():
            continue
        # Max BestNR
        time_veto_max_bestnr[slide_id][j] = \
                        max(trig_bestnr[slide_id][trial_cut])

max_bestnr, _, full_time_veto_max_bestnr = ppu.max_median_stat(slide_dict,
                                                               time_veto_max_bestnr,
                                                               trig_bestnr,
                                                               total_trials)

logging.info("BestNR maxima calculated.")

# =======================
# Post-process injections
# =======================
# Triggers, missed injections (i.e., not recovered at all), and injections
# recovered in some form. Trigs/injs at vetoed times are discarded.
found_trigs = ppu.load_injections(found_file, vetoes, label="triggers")
missed_injs = ppu.load_injections(missed_file, vetoes, sim_table=True, label="missed injections")
found_injs = ppu.load_injections(found_file, vetoes, sim_table=True, label="found injections")
logging.info("Triggers and missed/found injections loaded.")

# Extract the detection statistic of found triggers
found_trig = {}
found_trig['bestnr'] = ppu.get_bestnrs(found_trigs, q=chisq_index, n=chisq_nhigh,
                                       null_thresh=null_thresh,
                                       snr_threshold=snr_thresh,
                                       sngl_snr_threshold=sngl_snr_thresh,
                                       chisq_threshold=new_snr_thresh,
                                       null_grad_thresh=null_grad_thresh,
                                       null_grad_val=null_grad_val)

# Extract the necessary data from the missed injections for the plot
sites = [ifo[0] for ifo in ifos]
missed_inj = load_data(missed_injs, [x_qty, y_qty], opts, sites)
logging.info("%d missed injections analysed.", len(missed_injs))

# Extract the necessary data from the found injections for the plot
found_inj = load_data(found_injs, [x_qty, y_qty], opts, sites)

# Handle separately the special case of plotting the sky_error: this
# quantity is not defined for *missed* injections
if 'sky_error' in [x_qty, y_qty]:
    found_inj['ra'] = np.asarray(found_injs.get_column('longitude'))
    found_inj['dec'] = np.asarray(found_injs.get_column('latitude'))
    found_trig['ra'] = np.asarray(found_trigs.get_column('ra'))
    found_trig['dec'] = np.asarray(found_trigs.get_column('dec'))
    found_inj['sky_error'] = np.arccos(np.cos(found_inj['dec'] - found_trig['dec']) -\
                                       np.cos(found_inj['dec']) * np.cos(found_trig['dec']) *\
                                       (1 - np.cos(found_inj['ra'] - found_trig['ra'])))
    missed_inj['sky_error'] = np.array([])

# Construct conditions for injection:
# 1) found louder than background,
zero_fap = np.zeros(len(found_injs)).astype(bool)
zero_fap_cut = found_trig['bestnr'] > max_bestnr
zero_fap = zero_fap | (zero_fap_cut)
# 2) found (bestnr > 0) but not louder than background (non-zero FAP)
nonzero_fap = ~zero_fap & (found_trig['bestnr'] != 0)
# 3) missed after being recovered (i.e., vetoed)
missed = (~zero_fap) & (~nonzero_fap)

# Separate triggers into:
# 1) Zero FAP (zero_fap) 'g_found'
g_found = grab_triggers_subset(found_inj, zero_fap, found_trig)
# 2) Non-zero FAP (nonzero_fap) 'g_ifar'
g_ifar = grab_triggers_subset(found_inj, nonzero_fap, found_trig)
g_ifar['stat'] = np.zeros([len(g_ifar['bestnr'])])
for ix, bestnr in enumerate(g_ifar['bestnr']):
    g_ifar['stat'][ix] = (full_time_veto_max_bestnr > bestnr).sum()
g_ifar['stat'] = g_ifar['stat'] / total_trials
# 3) Missed due to vetoes (missed) 'g_missed2'
g_missed2 = grab_triggers_subset(found_inj, missed, found_trig)

# Statistics: missed-found
MF = np.argsort(g_ifar['stat'])
# Statistics: found-missed
FM = MF[::-1]

logging.info("%d found injections analysed.", len(found_injs))

# Post-processing of injections ends here

# ==========
# Make plots
# ==========

# Plot results of injection campaign with found ones on top of missed ones
# (found-missed) or vice-versa (missed-found)

# Info for site-specific plots
sitename = {'G1':'GEO', 'H1':'Hanford', 'L1':'Livingston', 'V1':'Virgo',
            'K1':'KAGRA'}

# Take care of axes labels
axis_labels_dict = {'mchirp': "Chirp Mass (solar masses)",
                    'mtotal': "Total mass (solar masses)",
                    'q': "Mass ratio",
                    'distance': "Distance (Mpc)",
                    'eff_site_dist': "%s effective distance (Mpc)" % sitename.get(opts.ifo),
                    'eff_dist': "Inverse sum of effective distances (Mpc)",
                    'time': "Time since %d" % grb_time,
                    'sky_error': "Rec. sky error (radians)",
                    'ra': "Right ascension",
                    'dec': "Declination",
                    'incl': "Inclination (iota)",
                    'abs_incl': 'Magnitude of inclination (|iota|)',
                    'cos_incl': "cos(iota)",
                    'cos_abs_incl': "cos(|iota|)",
                    'spin1': "Spin on 1st binary component",
                    'spin1x': "Spin x-component of 1st binary component",
                    'spin1y': "Spin y-component of 1st binary component",
                    'spin1z': "Spin z-component of 1st binary component",
                    'spin2': "Spin on 2nd binary component",
                    'spin2x': "Spin x-component of 2nd binary component",
                    'spin2y': "Spin y-component of 2nd binary component",
                    'spin2z': "Spin z-component of 2nd binary component"}
x_label = axis_labels_dict[x_qty]
y_label = axis_labels_dict[y_qty]
fig = plt.figure()
xscale = "log" if opts.x_log else "linear"
yscale = "log" if opts.y_log else "linear"
ax = fig.gca()
ax.set_xscale(xscale)
ax.set_yscale(yscale)
ax.set_xlabel(x_label)
ax.set_ylabel(y_label)

# Define p-value colour
cmap = plt.get_cmap('cividis_r')
# Set color for out-of-range values
#cmap.set_over('g')

# Define the 'found' injection colour
fnd_col = cmap(0)
fnd_col = np.array([fnd_col])
if fm_or_mf == "found_missed":
    if missed_inj[x_qty].size and missed_inj[y_qty].size:
        ax.scatter(missed_inj[x_qty], missed_inj[y_qty], c="black", marker="x", s=10)
    if g_missed2[x_qty].size:
        ax.scatter(g_missed2[x_qty], g_missed2[y_qty], c="red", marker="x", s=10)
    if g_ifar[x_qty].size:
        p = ax.scatter(g_ifar[x_qty][FM], g_ifar[y_qty][FM], c=g_ifar['stat'][FM],
                       cmap=cmap, vmin=0, vmax=1, s=40,
                       edgecolor="w", linewidths=2.0)
        cb = plt.colorbar(p, label="p-value")
    if g_found[x_qty].size:
        ax.scatter(g_found[x_qty], g_found[y_qty], c=fnd_col, marker="+", s=30)
elif fm_or_mf == "missed_found":
    if g_found[x_qty].size:
        ax.scatter(g_found[x_qty], g_found[y_qty], c=fnd_col, marker="+", s=15)
    if g_ifar[x_qty].size:
        p = ax.scatter(g_ifar[x_qty][MF], g_ifar[y_qty][MF], c=g_ifar['stat'][MF],
                       cmap=cmap, vmin=0, vmax=1, s=40,
                       edgecolor="w", linewidths=2.0)
        cb = plt.colorbar(p, label="p-value")
    if g_missed2[x_qty].size:
        ax.scatter(g_missed2[x_qty], g_missed2[y_qty], c="red", marker="x", s=40)
    if missed_inj[x_qty].size and missed_inj[y_qty].size:
        ax.scatter(missed_inj[x_qty], missed_inj[y_qty], c="black", marker="x", s=40)
ax.grid()

# Handle axis limits when plotting spins
max_missed_inj = {}
for key in ['spin1', 'spin2']:
    for qty in [x_qty, y_qty]:
        if key in qty:
            max_missed_inj[qty] = missed_inj[qty].max()
if "spin" in x_qty:
    ax.set_xlim([0, np.ceil(10 * max(max_missed_inj[x_qty],
                                     found_inj[x_qty].max())) / 10])
if "spin" in y_qty:
    ax.set_ylim([0, np.ceil(10 * max(max_missed_inj[y_qty],
                                     found_inj[y_qty].max())) / 10])

# Handle axis limits when plotting inclination
if "incl" in x_qty or "incl" in y_qty:
    max_inc = np.pi
    #max_inc = max(np.concatenate((g_found[qty], g_ifar[qty], g_missed2[qty], missed_inj[qty])))
    max_inc_deg = np.rad2deg(max_inc)
    max_inc_deg = np.ceil(max_inc_deg/10.0)*10
    max_inc = np.deg2rad(max_inc_deg)
    if x_qty == "incl":
        ax.set_xlim(0, max_inc_deg)
    elif x_qty == "abs_incl":
        ax.set_xlim(0, max_inc_deg*0.5)
    if y_qty == "incl":
        ax.set_ylim(0, max_inc_deg)
    elif y_qty == "abs_incl":
        ax.set_ylim(0, max_inc_deg*0.5)
    #if "cos_incl" in [x_qty, y_qty]:
    if "cos_" in [x_qty, y_qty]:
        #tt = np.arange(0, max_inc_deg + 10, 10)
        tt = np.asarray([0, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 180])
        tks = np.cos(np.deg2rad(tt))
        tk_labs = ['cos(%d deg)' % tk for tk in tt]
        #if x_qty == "cos_incl":
        if "cos_" in x_qty:
            plt.xticks(tks, tk_labs, fontsize=10)
            fig.autofmt_xdate()
            ax.set_xlim(np.cos(max_inc), 1)
            #ax.set_xlim(-1, 1)
        #if y_qty == "cos_incl":
        if "cos_" in y_qty:
            plt.yticks(tks, tk_labs, fontsize=10)
            fig.autofmt_xdate()
            ax.set_ylim(np.cos(max_inc), 1)
            #ax.set_ylim(-1, 1)

# Take care of caption
plot_caption = opts.plot_caption
if plot_caption is None:
    plot_caption = "Black cross indicates no trigger was found "
    plot_caption += "coincident with the injection.\n"
    plot_caption += "Red cross indicates a trigger was found "
    plot_caption += "coincident with the injection but it was vetoed.\n"
    plot_caption += "Yellow plus indicates that a trigger was found "
    plot_caption += "coincident with the injection and it was louder "
    plot_caption += "than all events in the offsource.\n"
    plot_caption += "Coloured circle indicates that a trigger was "
    plot_caption += "found coincident with the injection but it was "
    plot_caption += "not louder than all offsource events. The colour "
    plot_caption += "bar gives the p-value of the trigger."

# Take care of title
plot_title = opts.plot_title
if plot_title is None:
    title_dict = {'mchirp': "chirp mass",
                  'mtotal': "total mass",
                  'q': "mass ratio",
                  'distance': "distance (Mpc)",
                  'eff_dist': "inverse sum of effective distances",
                  'eff_site_dist': "site specific effective distance",
                  'time': "time",
                  'ra': "right ascension",
                  'dec': "declination",
                  'incl': "inclination",
                  'cos_incl': "inclination",
                  'abs_incl': "inclination",
                  'cos_abs_incl': "inclination",
                  'spin1': "spin",
                  'spin1x': "spin x-component",
                  'spin1y': "spin y-component",
                  'spin1z': "spin z-component",
                  'spin2': "spin",
                  'spin2x': "spin x-component",
                  'spin2y': "spin y-component",
                  'spin2z': "spin z-component"}
    if "sky_error" in [x_qty, y_qty]:
        plot_title = "Sky error of recovered injections"
    else:
        plot_title = "Injection recovery with respect to "
        plot_title += title_dict[x_qty]
        plot_title += " and "+ title_dict[y_qty]

# Wrap up
plt.tight_layout()
save_fig_with_metadata(fig, outfile, cmd=' '.join(sys.argv),
                       title=plot_title, caption=plot_caption)
plt.close()
logging.info("Plots complete.")
