#!/usr/bin/env python

# Copyright (C) 2021 Francesco Pannarale, Viviana Caceres
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.


"""
Plot found/missed injection properties for the triggered search (PyGRB).'
"""

import h5py, logging, os.path, argparse, sys
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import pycbc.pnutils, pycbc.results, pycbc.version
from pycbc import init_logging
from pycbc.results import pygrb_postprocessing_utils as ppu

plt.switch_backend('Agg')
matplotlib.rc("image")

__author__ = "Francesco Pannarale <francesco.pannarale@ligo.org>"
__version__ = pycbc.version.git_verbose_msg
__date__ = pycbc.version.date
__program__ = "pycbc_pygrb_plot_injs_results"


# =============================================================================
# Functions
# =============================================================================
def process_var_strings(qty):
    """Add underscores to match HDF column name conventions"""

    qty = qty.replace('effdist', 'eff_dist')
    qty = qty.replace('effsitedist', 'eff_site_dist')
    qty = qty.replace('skyerror', 'sky_error')
    qty = qty.replace('cos', 'cos_')
    qty = qty.replace('abs', 'abs_')
    qty = qty.replace('coaphase', 'coa_phase')
    qty = qty.replace('endtime', 'end_time')

    return qty


def load_incl_data(injs, qty):
    """Extract data related to inclination from raw injection data"""

    local_dict = {}

    # Whether the user requests incl, |incl|, cos(incl), or cos(|incl|)
    # the following information is needed
    local_dict['incl'] = injs['injections/inclination'][:]

    # Requesting |incl| or cos(|incl|)
    if 'abs_' in qty:
        local_dict['abs_incl'] = 0.5*np.pi - abs(local_dict['incl'] - 0.5*np.pi)

    # Requesting cos(incl) or cos(|incl|): take cosine
    if 'cos_' in qty:
        angle = qty.replace('cos_', '')
        angle_data = local_dict[angle]
        data = np.cos(angle_data)
    # Requesting incl or abs_incl: convert to degrees
    else:
        data = np.rad2deg(local_dict[qty])

    return data

# Function to extract mass ratio or total mass data from a injection file
def load_mass_data(injs, qty):
    """Extract data related to mass ratio or total mass from raw
    injection data"""

    if qty == 'mtotal':
        data = injs['injections/mass1'][:] + injs['injections/mass2'][:]
    elif qty == 'mchirp':
        data, _ = pycbc.pnutils.mass1_mass2_to_mchirp_eta(
            injs['injections/mass1'][:], injs['injections/mass2'][:])
    else:
        data = injs['injections/mass2'][:]/injs['injections/mass1'][:]
        data = np.where(data > 1, 1./data, data)

    return data


# Function to extract mass ratio or total mass data from a injection file
def load_effdist_data(injs, qty, opts, ifos):
    """Extract data related to effective distances from raw injection data"""

    local_dict = {}

    if qty == 'eff_site_dist':
        data = injs['injections/eff_dist_%s' % opts.ifo[0].lower()][:]
    else:
        local_dict['eff_site_dist'] =\
            dict((ifo, injs['injections/eff_dist_%s' % ifo[0].lower()][:])
                 for ifo in ifos)
        # Effective distance (inverse sum of inverse effective distances)
        data = np.power(np.power(
            local_dict['eff_site_dist'], -1).sum(0), -1)

    return data

# Function to extract spin related data from a injection file
def load_spin_modulus_data(injs, qty):
    """Extract data related to spin modulus from raw injection data"""

    # Calculate the modulus
    data = np.sqrt(injs['injections/%sx' % qty][:]**2 +
                   injs['injections/%sy' % qty][:]**2 +
                   injs['injections/%sz' % qty][:]**2)

    return data


# Function to extract desired data from a injection file
def load_data(input_file, keys, opts, ifos):
    """Create a dictionary containing the data specified by the
    list of keys extracted from a injection file"""

    injs = h5py.File(input_file, 'r')
    data_dict = {}

    easy_keys = ['coa_phase', 'distance', 'mass1', 'mass2',
                 'polarization', 'spin1x', 'spin1y', 'spin1z',
                 'spin2x', 'spin2y', 'spin2z']

    for qty in keys:
        if qty in easy_keys:
            data_dict[qty] = injs['injections/%s' % qty][:]
        elif qty == 'end_time':
            data_dict[qty] = injs['injections/end_time'][:]
            grb_time = ppu.get_grb_time(opts.seg_files)
            data_dict[qty] -= grb_time
        elif qty in ['latitude', 'longitude']:
            data_dict[qty] = np.rad2deg(injs['injections/%s' % qty][:])
        elif qty in ['mtotal', 'q', 'mchirp']:
            data_dict[qty] = load_mass_data(injs, qty)
        elif qty in ['eff_site_dist', 'eff_dist']:
            data_dict[qty] = load_effdist_data(injs, qty, opts, ifos)
        elif 'incl' in qty:
            data_dict[qty] = load_incl_data(injs, qty)
        # This handles spin1 and spin2, i.e. spin magnitudes, as components
        # are dealt with in easy_keys (first if)
        elif 'spin' in qty:
            data_dict[qty] = load_spin_modulus_data(injs, qty)

    return data_dict


def load_trig_data(input_file, vetoes):
    """Load data from a trigger file"""

    logging.info("Loading triggers...")
    trigs = ppu.load_triggers(input_file, vetoes)

    return trigs


# Function to cherry-pick a subset of full_data specified by fap_mask.
# Reweighted SNR values are picked separately since they are extracted
# in different manners
def grab_injs_subset(full_data, fap_mask):
    """Separate out a subset of full_data based on fap_mask"""

    data_subset = {}

    for qty in full_data.keys():
        data_subset[qty] = full_data[qty][fap_mask]

    return data_subset


# =============================================================================
# Main script starts here
# =============================================================================
parser = ppu.pygrb_initialize_plot_parser(description=__doc__,
                                          version=__version__)
parser.add_argument('--injection-file',
                    help="The hdf injection file to plot", required=True)
parser.add_argument('--offsource-file',
                    help="The hdf offsource trigger file", required=True)            
admitted_vars = ['coa_phase', 'distance', 'latitude', 'longitude','mass1',
                 'mass2', 'polarization', 'spin1x', 'spin1y', 'spin1z',
                 'spin2x', 'spin2y', 'spin2z', 'end_time', 'mtotal', 'q',
                 'mchirp', 'eff_site_dist', 'eff_dist', 'incl', 'cos_incl',
                 'abs_incl', 'cos_abs_incl', 'spin1', 'spin2', 'sky_error']
parser.add_argument("-x", "--x-variable", default=None, required=True,
                    choices=admitted_vars,
                    help="Quantity to plot on the horizontal axis. " +
                    "(Underscores may be omitted in specifying this option).")
parser.add_argument("--x-log", action="store_true",
                    help="Use log horizontal axis")
parser.add_argument("-y", "--y-variable", default=None, required=True,
                    choices=admitted_vars,
                    help="Quantity to plot on the vertical axis. " +
                    "(Underscores may be omitted in specifying this option).")
parser.add_argument("--y-log", action="store_true",
                    help="Use log vertical axis")
parser.add_argument('--colormap',default='cividis_r',
                   help="Type of colormap to be used for the plots.")
parser.add_argument('--gradient-far', action='store_true',
                    help="Show far of found injections as a gradient")
parser.add_argument('--far-type', choices=('inclusive', 'exclusive'),
                    default='inclusive',
                    help="Type of far to plot for the color. Choices are "
                         "'inclusive' or 'exclusive'. Default = 'inclusive'")
parser.add_argument('--missed-on-top', action='store_true',
                    help="Plot missed injections on top of found ones and "
                         "high FAR on top of low FAR")
opts = parser.parse_args()

init_logging(opts.verbose, format="%(asctime)s: %(levelname)s: %(message)s")

# Check options
if opts.injection_file is None:
    err_msg = "Must provide injection file."
    raise RuntimeError(err_msg)

x_qty = process_var_strings(opts.x_variable)
y_qty = process_var_strings(opts.y_variable)

if 'eff_site_dist' in [x_qty, y_qty] and opts.ifo is None:
    err_msg = "A value for --ifo must be provided for "
    err_msg += "site specific effective distance"
    parser.error(err_msg)

# Store options used multiple times in local variables
outfile = opts.output_file
trig_file = os.path.abspath(opts.offsource_file)
f = h5py.File(opts.injection_file, 'r')
grb_time = ppu.get_grb_time(opts.seg_files)

# Set output directory
logging.info("Setting output directory.")
outdir = os.path.split(os.path.abspath(outfile))[0]
if not os.path.isdir(outdir):
    os.makedirs(outdir)

# Extract IFOs and vetoes
ifos, vetoes = ppu.extract_ifos_and_vetoes(trig_file, opts.veto_files,
                                           opts.veto_category)

# Load triggers. Time-slides not yet available
logging.info("Loading triggers.")
trig_data = load_trig_data(trig_file, vetoes)
max_reweighted_snr = max(trig_data['network/reweighted_snr'][:])

# =======================
# Post-process injections
# =======================
# Triggers, missed injections (i.e., not recovered at all), and injections
# recovered in some form. Trigs/injs at vetoed times are discarded.

# Indices of all found injections
found = f['found/injection_index'][:]
# Indices of all missed injections
missed = f['missed/all'][:]
# Indices of injections found surviving vetoes
found_after_vetoes = f['found_after_vetoes/injection_index'][:]
# Indices of found but vetoed injections
missed_after_vetoes = list(set(found) - set(found_after_vetoes))
missed_after_vetoes = np.sort(missed_after_vetoes).astype(int)
# Indices of injections found but not louder than background
# are populated further down

inj_data = load_data(opts.injection_file, [x_qty, y_qty], opts, ifos)
logging.info("Triggers and missed/found injections loaded.")

# Handle separately the special case of plotting the sky_error: this
# quantity is not defined for *missed* injections
found_trig = {}
found_inj = {}
if 'sky_error' in [x_qty, y_qty]:
    found_inj['ra'] = np.rad2deg(f['injections/longitude'][found])
    found_inj['dec'] = np.rad2deg(f['injections/latitude'][found])
    found_trig['ra'] = np.rad2deg(trig_data['network/longitude'][:])
    found_trig['dec'] = np.rad2deg(trig_data['network/latitude'][:])
    found_inj['sky_error'] = np.arccos(np.cos(found_inj['dec'] - found_trig['dec']) -\
                                       np.cos(found_inj['dec']) * np.cos(found_trig['dec']) *\
                                       (1 - np.cos(found_inj['ra'] - found_trig['ra'])))

    # Define inj_data only for found indices. Missed injections are assigned null values
    inj_data['sky_error'] = np.full(len(f['injections/longitude'][:]), None)
    inj_data['sky_error'][found] = found_inj['sky_error']

# Extract the necessary data from the missed injections for the plot
missed_inj = {}
for qty in [x_qty, y_qty]:
    missed_inj[qty] = inj_data[qty][missed]
logging.info("%d missed injections analysed.", len(missed))

# Extract the necessary data from the found injections for the plot
found_inj = {x_qty: inj_data[x_qty][found],
             y_qty: inj_data[y_qty][found]}

# Extract the detection statistic of injections found after vetoes
found_after_vetoes_stat = f['found_after_vetoes/stat'][:]

# Separate triggers into:
# 1) Found louder than background
louder_mask = found_after_vetoes_stat > max_reweighted_snr
louder_indices = found_after_vetoes[louder_mask]
found_louder = grab_injs_subset(inj_data, louder_indices)
found_louder['reweighted_snr'] = f['found_after_vetoes/stat'][louder_mask]

# 2) Found quieter than background
quieter_mask = (found_after_vetoes_stat <= max_reweighted_snr) & (found_after_vetoes_stat != 0)
# Indices of injections found (bestnr > 0) but not louder than background (non-zero FAP)
quieter_indices = found_after_vetoes[quieter_mask]
found_quieter = grab_injs_subset(inj_data, quieter_indices)
found_quieter['reweighted_snr'] = f['found_after_vetoes/stat'][quieter_mask]

# Extract inclusive/exclusive IFAR for injections found quieter than background
ifar_string = 'found_after_vetoes/ifar' if opts.far_type == 'inclusive' \
    else 'found_after_vetoes/ifar_exc'
found_quieter['ifar'] = f[ifar_string][quieter_mask]

# 3) Missed due to vetoes
vetoed = grab_injs_subset(inj_data, missed_after_vetoes)
vetoed['reweighted_snr'] = []
for vetoed_index in missed_after_vetoes:
    found_index = list(found).index(vetoed_index)
    vetoed['reweighted_snr'].append(f['found/stat'][found_index])
    
# Statistics: found on top (found-missed)
FM = np.argsort(found_quieter['ifar'])
# Statistics: missed on top (missed-found)
MF = FM[::-1]

logging.info("%d found injections analysed.", len(found))

# Post-processing of injections ends here

# ==========
# Make plots
# ==========

# Info for site-specific plots
sitename = {'G1':'GEO', 'H1':'Hanford', 'L1':'Livingston', 'V1':'Virgo',
            'K1':'KAGRA', 'A1': 'India Aundha'}

# Take care of axes labels
axis_labels_dict = {'mchirp': "Chirp Mass (solar masses)",
                    'mtotal': "Total mass (solar masses)",
                    'q': "Mass ratio",
                    'distance': "Distance (Mpc)",
                    'eff_site_dist': "%s effective distance (Mpc)" % sitename.get(opts.ifo),
                    'eff_dist': "Inverse sum of effective distances (Mpc)",
                    'end_time': "Time since %d (s)" % grb_time,
                    'sky_error': "Rec. sky error (radians)",
                    'coa_phase': "Phase of complex SNR (radians)",
                    'latitude': "Latitude (degrees)",
                    'longitude': "Longitude (degrees)",
                    'incl': "Inclination (iota)",
                    'abs_incl': 'Magnitude of inclination (|iota|)',
                    'cos_incl': "cos(iota)",
                    'cos_abs_incl': "cos(|iota|)",
                    'mass1': "Mass of 1st binary component (solar masses)",
                    'mass2': "Mass of 2nd binary component (solar masses)",
                    'polarization': "Polarization phase (radians)",
                    'spin1': "Spin on 1st binary component",
                    'spin1x': "Spin x-component of 1st binary component",
                    'spin1y': "Spin y-component of 1st binary component",
                    'spin1z': "Spin z-component of 1st binary component",
                    'spin2': "Spin on 2nd binary component",
                    'spin2x': "Spin x-component of 2nd binary component",
                    'spin2y': "Spin y-component of 2nd binary component",
                    'spin2z': "Spin z-component of 2nd binary component"}

x_label = axis_labels_dict[x_qty]
y_label = axis_labels_dict[y_qty]

fig = plt.figure()
xscale = "log" if opts.x_log else "linear"
yscale = "log" if opts.y_log else "linear"
ax = fig.gca()
ax.set_xscale(xscale)
ax.set_yscale(yscale)
ax.set_xlabel(x_label)
ax.set_ylabel(y_label)

# Define p-value colour
cmap = plt.get_cmap('cividis_r')
# Set color for out-of-range values
#cmap.set_over('g')

# Define the 'found' injection colour
fnd_col = cmap(0)
fnd_col = np.array([fnd_col])
if not opts.missed_on_top:
    if missed_inj[x_qty].size and missed_inj[y_qty].size:
        ax.scatter(missed_inj[x_qty], missed_inj[y_qty], c="black", marker="x", s=10)
    if vetoed[x_qty].size:
        ax.scatter(vetoed[x_qty], vetoed[y_qty], c="red", marker="x", s=10)
    if found_quieter[x_qty].size:
        p = ax.scatter(found_quieter[x_qty][FM], found_quieter[y_qty][FM],
                       c=found_quieter['reweighted_snr'][FM],
                       cmap=cmap, vmin=0, vmax=1, s=40,
                       edgecolor="w", linewidths=2.0)
        cb = plt.colorbar(p, label="p-value")
    if found_louder[x_qty].size:
        ax.scatter(found_louder[x_qty], found_louder[y_qty], c=fnd_col, marker="+", s=30)
elif opts.missed_on_top:
    if found_louder[x_qty].size:
        ax.scatter(found_louder[x_qty], found_louder[y_qty], c=fnd_col, marker="+", s=15)
    if found_quieter[x_qty].size:
        p = ax.scatter(found_quieter[x_qty][MF], found_quieter[y_qty][MF],
                       c=found_quieter['reweighted_snr'][MF],
                       cmap=cmap, vmin=0, vmax=1, s=40,
                       edgecolor="w", linewidths=2.0)
        cb = plt.colorbar(p, label="p-value")
    if vetoed[x_qty].size:
        ax.scatter(vetoed[x_qty], vetoed[y_qty], c="red", marker="x", s=40)
    if missed_inj[x_qty].size and missed_inj[y_qty].size:
        ax.scatter(missed_inj[x_qty], missed_inj[y_qty], c="black", marker="x", s=40)
ax.grid()

# Handle axis limits when plotting spins
max_missed_inj = {}
for key in ['spin1', 'spin2']:
    for qty in [x_qty, y_qty]:
        if key in qty:
            max_missed_inj[qty] = missed_inj[qty].max()
if "spin" in x_qty:
    ax.set_xlim([0, np.ceil(10 * max(max_missed_inj[x_qty],
                                     found_inj[x_qty].max())) / 10])
if "spin" in y_qty:
    ax.set_ylim([0, np.ceil(10 * max(max_missed_inj[y_qty],
                                     found_inj[y_qty].max())) / 10])

# Handle axis limits when plotting inclination
if "incl" in x_qty or "incl" in y_qty:
    max_inc = np.pi
    #max_inc = max(np.concatenate((g_found[qty], g_ifar[qty], g_missed2[qty], missed_inj[qty])))
    max_inc_deg = np.rad2deg(max_inc)
    max_inc_deg = np.ceil(max_inc_deg/10.0)*10
    max_inc = np.deg2rad(max_inc_deg)
    if x_qty == "incl":
        ax.set_xlim(0, max_inc_deg)
    elif x_qty == "abs_incl":
        ax.set_xlim(0, max_inc_deg*0.5)
    if y_qty == "incl":
        ax.set_ylim(0, max_inc_deg)
    elif y_qty == "abs_incl":
        ax.set_ylim(0, max_inc_deg*0.5)
    #if "cos_incl" in [x_qty, y_qty]:
    if "cos_" in [x_qty, y_qty]:
        #tt = np.arange(0, max_inc_deg + 10, 10)
        tt = np.asarray([0, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 180])
        tks = np.cos(np.deg2rad(tt))
        tk_labs = ['cos(%d deg)' % tk for tk in tt]
        #if x_qty == "cos_incl":
        if "cos_" in x_qty:
            plt.xticks(tks, tk_labs, fontsize=10)
            fig.autofmt_xdate()
            ax.set_xlim(np.cos(max_inc), 1)
            #ax.set_xlim(-1, 1)
        #if y_qty == "cos_incl":
        if "cos_" in y_qty:
            plt.yticks(tks, tk_labs, fontsize=10)
            fig.autofmt_xdate()
            ax.set_ylim(np.cos(max_inc), 1)
            #ax.set_ylim(-1, 1)

# Take care of caption
plot_caption = opts.plot_caption
if plot_caption is None:
    plot_caption = "Black cross indicates no trigger was found "
    plot_caption += "coincident with the injection.\n"
    plot_caption += "Red cross indicates a trigger was found "
    plot_caption += "coincident with the injection but it was vetoed.\n"
    plot_caption += "Yellow plus indicates that a trigger was found "
    plot_caption += "coincident with the injection and it was louder "
    plot_caption += "than all events in the offsource.\n"
    plot_caption += "Coloured circle indicates that a trigger was "
    plot_caption += "found coincident with the injection but it was "
    plot_caption += "not louder than all offsource events. The colour "
    plot_caption += "bar gives the p-value of the trigger."

# Take care of title
plot_title = opts.plot_title
if plot_title is None:
    title_dict = {'mchirp': "chirp mass",
                  'mtotal': "total mass",
                  'q': "mass ratio",
                  'distance': "distance (Mpc)",
                  'eff_dist': "inverse sum of effective distances",
                  'eff_site_dist': "site specific effective distance",
                  'end_time': "time",
                  'coa_phase': "phase of complex SNR",
                  'latitude': "latitude",
                  'longitude': "longitude",
                  'incl': "inclination",
                  'cos_incl': "inclination",
                  'abs_incl': "inclination",
                  'cos_abs_incl': "inclination",
                  'mass1': "mass",
                  'mass2': "mass",
                  'polarization': "polarization",
                  'spin1': "spin",
                  'spin1x': "spin x-component",
                  'spin1y': "spin y-component",
                  'spin1z': "spin z-component",
                  'spin2': "spin",
                  'spin2x': "spin x-component",
                  'spin2y': "spin y-component",
                  'spin2z': "spin z-component"}

    if "sky_error" in [x_qty, y_qty]:
        plot_title = "Sky error of recovered injections"
    else:
        plot_title = "Injection recovery with respect to "
        plot_title += title_dict[x_qty]
        plot_title += " and "+ title_dict[y_qty]

# Wrap up
plt.tight_layout()
pycbc.results.save_fig_with_metadata(fig, outfile, cmd=' '.join(sys.argv),
                                     title=plot_title, caption=plot_caption)
plt.close()
logging.info("Plots complete.")