#!/usr/bin/env python

# Copyright (C) 2021 Francesco Pannarale
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""Determine efficiency and exclusion distances of a PyGRB run."""

# =============================================================================
# Preamble
# =============================================================================
import sys
import os
import logging
import matplotlib.pyplot as plt
from matplotlib import rc
import numpy as np
import scipy
from scipy import stats
import pycbc.version
from pycbc.detector import Detector
from pycbc import init_logging
from pycbc.results import save_fig_with_metadata
from pycbc.results import pygrb_postprocessing_utils as ppu
try:
    from glue.ligolw import lsctables
except ImportError:
    pass

plt.switch_backend('Agg')
rc("image")

__author__ = "Francesco Pannarale <francesco.pannarale@ligo.org>"
__version__ = pycbc.version.git_verbose_msg
__date__ = pycbc.version.date
__program__ = "pycbc_pygrb_efficiency"


def efficiency_with_errs(found_bestnr, num_injections, num_mc_injs=0):
    """Function to calculate the fraction of recovered injections and its
    error bars (used for efficiency/sensitive distance plots)."""

    if not isinstance(num_mc_injs, int):
        err_msg = "The parameter num_mc_injs is the number of Monte-Carlo "
        err_msg += "injections.  It must be an integer."
        raise TypeError(err_msg)

    only_found_injs = found_bestnr[:-1]
    all_injs = num_injections[:-1]
    fraction = only_found_injs / all_injs

    # Divide by Monte-Carlo iterations
    if num_mc_injs:
        only_found_injs = only_found_injs / num_mc_injs
        all_injs = all_injs / num_mc_injs

    err_common = all_injs * (2 * only_found_injs + 1)
    err_denom = 2 * all_injs * (all_injs + 1)
    err_vary = 4 * all_injs * only_found_injs * (all_injs - only_found_injs) \
                + all_injs**2
    err_vary = err_vary**0.5
    err_low = (err_common - err_vary)/err_denom
    err_low_mc = fraction - err_low
    err_high = (err_common + err_vary)/err_denom
    err_high_mc = err_high - fraction

    return err_low_mc, err_high_mc, fraction


# =============================================================================
# Main script starts here
# =============================================================================
parser = ppu.pygrb_initialize_plot_parser(description=__doc__,
                                          version=__version__)
parser.add_argument("-F", "--offsource-file", action="store", required=True,
                    default=None, help="Location of off-source trigger file.")
# As opposed to offsource-file and trig-file, this only contains onsource
parser.add_argument("--onsource-file", action="store", default=None,
                    help="Location of on-source trigger file.")
parser.add_argument("--background-output-file", default=None, required=True,
                    help="Detection efficiency output file.")
parser.add_argument("--onsource-output-file", default=None, required=True,
                    help="Exclusion distance output file.")
parser.add_argument("-g", "--glitch-check-factor", action="store",
                    type=float, default=1.0, help="When deciding " +
                    "exclusion efficiencies this value is multiplied " +
                    "to the offsource around the injection trigger to " +
                    "determine if it is just a loud glitch.")
parser.add_argument("-C", "--cluster-window", action="store", type=float,
                    default=0.1, help="The cluster window used " +
                    "to cluster triggers in time.")
ppu.pygrb_add_missed_injs_input_opt(parser)
ppu.pygrb_add_injmc_opts(parser)
ppu.pygrb_add_bestnr_opts(parser)
opts = parser.parse_args()

init_logging(opts.verbose, format="%(asctime)s: %(levelname)s: %(message)s")

# Check options
if (opts.found_file is None) and (opts.missed_file is None):
    do_injections = False
elif (opts.found_file) and opts.missed_file:
    do_injections = True
else:
    err_msg = "Must provide both found and missed file if running injections."
    parser.error(err_msg)

if not opts.newsnr_threshold:
    opts.newsnr_threshold = opts.snr_threshold

# Store options used multiple times in local variables
outdir = os.path.split(os.path.abspath(opts.background_output_file))[0]
trig_file = opts.offsource_file
onsource_file = opts.onsource_file
found_file = opts.found_file
missed_file = opts.missed_file
inj_set_name = os.path.split(os.path.abspath(missed_file))[1].split('INSPINJ')[1].split('_')[1]
chisq_index = opts.chisq_index
chisq_nhigh = opts.chisq_nhigh
wf_err = opts.waveform_error
cal_errs = {}
cal_errs['G1'] = opts.g1_cal_error
cal_errs['H1'] = opts.h1_cal_error
cal_errs['K1'] = opts.k1_cal_error
cal_errs['L1'] = opts.l1_cal_error
cal_errs['V1'] = opts.v1_cal_error
cal_dc_errs = {}
cal_dc_errs['G1'] = opts.g1_dc_cal_error
cal_dc_errs['H1'] = opts.h1_dc_cal_error
cal_dc_errs['K1'] = opts.k1_dc_cal_error
cal_dc_errs['L1'] = opts.l1_dc_cal_error
cal_dc_errs['V1'] = opts.v1_dc_cal_error
snr_thresh = opts.snr_threshold
sngl_snr_thresh = opts.sngl_snr_threshold
new_snr_thresh = opts.newsnr_threshold
null_grad_thresh = opts.null_grad_thresh
null_grad_val = opts.null_grad_val
null_thresh = list(map(float, opts.null_snr_threshold.split(',')))
upper_dist = opts.upper_inj_dist
lower_dist = opts.lower_inj_dist
num_bins = opts.num_bins
wav_err = opts.waveform_error
cluster_window = opts.cluster_window
glitch_check_fac = opts.glitch_check_factor
num_mc_injs = opts.num_mc_injections
# Initialize random number generator
np.random.seed(opts.seed)
logging.info("Setting random seed to %d.", opts.seed)

# Set output directory
logging.info("Setting output directory.")
if not os.path.isdir(outdir):
    os.makedirs(outdir)

# Extract IFOs and vetoes
ifos, vetoes = ppu.extract_ifos_and_vetoes(trig_file, opts.veto_files,
                                           opts.veto_category)

# Load triggers, time-slides, and segment dictionary
logging.info("Loading triggers.")
trigs = ppu.load_xml_table(trig_file, lsctables.MultiInspiralTable.tableName)
logging.info("%d triggers loaded.", len(trigs))
logging.info("Loading timeslides.")
slide_dict = ppu.load_time_slides(trig_file)
logging.info("Loading segments.")
segment_dict = ppu.load_segment_dict(trig_file)

# Construct trials
logging.info("Constructing trials.")
trial_dict = ppu.construct_trials(opts.seg_files, segment_dict,
                                  ifos, slide_dict, vetoes)
total_trials = sum([len(trial_dict[slide_id]) for slide_id in slide_dict])
logging.info("%d trials generated.", total_trials)

# Extract basic trigger properties and store as dictionaries
trig_time, trig_snr, trig_bestnr = \
    ppu.extract_basic_trig_properties(trial_dict, trigs, slide_dict, segment_dict, opts)

# Calculate BestNR values and maximum
time_veto_max_bestnr = {}

for slide_id in slide_dict:
    num_slide_segs = len(trial_dict[slide_id])
    time_veto_max_bestnr[slide_id] = np.zeros(num_slide_segs)

for slide_id in slide_dict:
    for j, trial in enumerate(trial_dict[slide_id]):
        trial_cut = (trial[0] <= trig_time[slide_id])\
                          & (trig_time[slide_id] < trial[1])
        if not trial_cut.any():
            continue
        # Max BestNR
        time_veto_max_bestnr[slide_id][j] = \
                        max(trig_bestnr[slide_id][trial_cut])

logging.info("SNR and bestNR maxima calculated.")

# Output details of loudest offsouce triggers
offsource_trigs = []
sorted_trigs = ppu.sort_trigs(trial_dict, trigs, slide_dict, segment_dict)
for slide_id in slide_dict:
    offsource_trigs.extend(zip(trig_bestnr[slide_id], sorted_trigs[slide_id]))
offsource_trigs.sort(key=lambda element: element[0])
offsource_trigs.reverse()

# ==========================
# Print loudest SNRs to file
# THIS OUTPUT FILE IS CURRENTLY UNUSED - MAYBE DELETE?
# Note: the only new info from above is the median SNR, bestnr
# and loudest SNR, so could just add this to the above's caption.
# ==========================
max_bestnr, _, full_time_veto_max_bestnr =\
    ppu.max_median_stat(slide_dict, time_veto_max_bestnr, trig_bestnr, total_trials)


# =======================
# Load on source triggers
# =======================
if onsource_file:

    # Get trigs
    on_trigs = ppu.load_xml_table(onsource_file, "multi_inspiral")
    logging.info("%d onsource triggers loaded.", len(on_trigs))

    # Separate off chirp mass column
    on_mchirp = on_trigs.get_column('mchirp')

    # Set loudest event arrays
    #loud_on_bestnr_trigs = None
    loud_on_bestnr = 0
    #loud_on_fap = 1

    # Calculate BestNRs and record loudest trig by BestNR
    if on_trigs:
        on_trigs_bestnrs = ppu.get_bestnrs(on_trigs,
                                           q=chisq_index, n=chisq_nhigh,
                                           null_thresh=null_thresh,
                                           snr_threshold=snr_thresh,
                                           sngl_snr_threshold=sngl_snr_thresh,
                                           chisq_threshold=new_snr_thresh,
                                           null_grad_thresh=null_grad_thresh,
                                           null_grad_val=null_grad_val)
        loud_on_bestnr_trigs, loud_on_bestnr = \
            np.asarray([[x, y] for y, x in sorted(zip(on_trigs_bestnrs, on_trigs),
                                                  key=lambda elem: elem[0],
                                                  reverse=True)])[0]

    # If the loudest event has bestnr = 0, there is no event at all!
    if loud_on_bestnr == 0:
        loud_on_bestnr_trigs = None
        loud_on_fap = 1

    logging.info("Onsource analysed.")

    if loud_on_bestnr_trigs:
        trig = loud_on_bestnr_trigs
        num_trials_louder = 0
        tot_off_snr = np.array([])
        for slide_id in slide_dict:
            num_trials_louder += sum(time_veto_max_bestnr[slide_id] > \
                                     loud_on_bestnr)
            tot_off_snr = np.concatenate([tot_off_snr,\
                                          time_veto_max_bestnr[slide_id]])
        #fap_test = sum(tot_off_snr > loud_on_bestnr)/total_trials
        loud_on_fap = num_trials_louder/total_trials

else:
    tot_off_snr = np.array([])
    for slide_id in slide_dict:
        tot_off_snr = np.concatenate([tot_off_snr,\
                                      time_veto_max_bestnr[slide_id]])
    med_snr = np.median(tot_off_snr)
    #loud_on_fap = sum(tot_off_snr > med_snr)/total_trials

# =======================
# Post-process injections
# =======================
if do_injections:

    sites = [ifo[0] for ifo in ifos]

    # Triggers and injections recovered in some form: discard ones at vetoed times
    found_trigs = ppu.load_injections(found_file, vetoes)
    found_injs = ppu.load_injections(found_file, vetoes, sim_table=True)

    logging.info("Missed/found injections/triggers loaded.")

    # Extract columns of found injections
    found_inj = {}
    found_inj['ra'] = np.asarray(found_injs.get_column('longitude'))
    found_inj['dec'] = np.asarray(found_injs.get_column('latitude'))
    found_inj['time'] = np.asarray(found_injs.get_column('geocent_end_time')) +\
                        np.asarray(found_injs.get_column('geocent_end_time_ns')*\
                                   10**-9)
    found_inj['dist'] = np.asarray(found_injs.get_column('distance'))

    # Extract columns of triggers
    found_trig = {}
    found_trig['mchirp'] = np.asarray(found_trigs.get_column('mchirp'))
    found_trig['bestnr'] = ppu.get_bestnrs(found_trigs, q=chisq_index,
                                           n=chisq_nhigh,
                                           null_thresh=null_thresh,
                                           snr_threshold=snr_thresh,
                                           sngl_snr_threshold=sngl_snr_thresh,
                                           chisq_threshold=new_snr_thresh,
                                           null_grad_thresh=null_grad_thresh,
                                           null_grad_val=null_grad_val)

    # Construct conditions for injection:
    # 1) found louder than background,
    zero_fap = np.zeros(len(found_injs)).astype(bool)
    zero_fap_cut = found_trig['bestnr'] > max_bestnr
    zero_fap = zero_fap | (zero_fap_cut)

    # 2) found (bestnr > 0) but not louder than background (non-zero FAP)
    nonzero_fap = ~zero_fap & (found_trig['bestnr'] != 0)

    # 3) missed after being recovered (i.e., vetoed) are not used here
    # missed = (~zero_fap) & (~nonzero_fap)

    # Non-zero FAP triggers (g_ifar)
    g_ifar = {}
    g_ifar['bestnr'] = found_trig['bestnr'][nonzero_fap]
    g_ifar['stat'] = np.zeros([len(g_ifar['bestnr'])])
    for ix, (mc, bestnr) in \
                enumerate(zip(found_trig['mchirp'][nonzero_fap], g_ifar['bestnr'])):
        g_ifar['stat'][ix] = (full_time_veto_max_bestnr > bestnr).sum()
    g_ifar['stat'] = g_ifar['stat'] / total_trials

    # Set the sigma values
    inj_sigma = found_trigs.get_sigmasqs()
    # If the sigmasqs are not populated, we can still do calibration errors,
    # but only in the 1-detector case
    for ifo in ifos:
        if sum(inj_sigma[ifo] == 0):
            logging.info("%s: sigmasq not set for at least one trigger.", ifo)
        if sum(inj_sigma[ifo] != 0) == 0:
            logging.info("%s: sigmasq not set for any trigger.", ifo)
            if len(ifos) == 1:
                msg = "Since this is a single ifo analysis, sigmasq will be "
                msg += "set to unity for all triggers in order to build the "
                msg += "calibration errors."
                logging.info(msg)
                inj_sigma[ifo][:] = 1.

    f_resp = {}
    for ifo in ifos:
        antenna = Detector(ifo)
        f_resp[ifo] = ppu.get_antenna_responses(antenna,
                                                found_inj['ra'], found_inj['dec'],
                                                found_inj['time'])

    inj_sigma_mult = (np.asarray(list(inj_sigma.values())) *\
                      np.asarray(list(f_resp.values())))

    inj_sigma_tot = inj_sigma_mult[0, :]
    for i in range(1, len(ifos)):
        inj_sigma_tot += inj_sigma_mult[i, :]

    inj_sigma_mean = {}
    for ifo in ifos:
        inj_sigma_mean[ifo] = ((inj_sigma[ifo]*f_resp[ifo])/inj_sigma_tot).mean()

    logging.info("%d found injections analysed.", len(found_injs))

    # Missed injections (ones not recovered at all)
    missed_injs = ppu.load_injections(missed_file, vetoes, sim_table=True)

    # Process missed injections 'missed_inj'
    missed_inj = {}
    missed_inj['dist'] = np.asarray(missed_injs.get_column('distance'))

    logging.info("%d missed injections analysed.", len(missed_injs))

    # Create new set of injections for efficiency calculations
    total_injs = len(found_injs) + len(missed_injs)
    long_inj = {}
    long_inj['dist'] = stats.uniform.rvs(size=total_injs) * (upper_dist-lower_dist) +\
                       upper_dist

    logging.info("%d long distance injections created.", total_injs)

    # Set distance bins and data arrays
    dist_bins = zip(np.arange(lower_dist, upper_dist + (upper_dist-lower_dist),\
                             (upper_dist-lower_dist)/num_bins),\
                    np.arange(lower_dist, upper_dist + (upper_dist-lower_dist),\
                             (upper_dist-lower_dist)/num_bins) +\
                             (upper_dist-lower_dist)/num_bins)
    dist_bins = list(dist_bins)

    num_dist_bins_plus_one = len(dist_bins) + 1
    num_injections = {}
    found_max_bestnr = {}
    found_on_bestnr = {}
    for key in ['mc', 'no_mc']:
        num_injections[key] = np.zeros(num_dist_bins_plus_one)
        found_max_bestnr[key] = np.zeros(num_dist_bins_plus_one)
        found_on_bestnr[key] = np.zeros(num_dist_bins_plus_one)

    # Construct FAP list for all found injections
    inj_fap = np.zeros(len(found_injs))
    inj_fap[nonzero_fap] = g_ifar['stat']

    # Calculate the amplitude error
    # Begin by calculating the components from each detector
    cal_error = 0
    for ifo in ifos:
        cal_error += cal_errs[ifo]**2 * inj_sigma_mean[ifo]**2
    cal_error = cal_error**0.5

    max_dc_cal_error = max(cal_dc_errs.values())

    # Calibration phase uncertainties are neglected
    logging.info("Calibration amplitude uncertainty calculated.")

    # Now create the numbers for the efficiency plots; these include calibration
    # and waveform errors. These are incorporated by running over each injection
    # num_mc_injs times, where each time we draw a random value of distance

    # Distribute injections
    # NOTE: the loop on num_mc_injs would fill up the *_inj['dist_mc']'s at the
    # same time, so filling them up sequentially will vary the numbers a little
    # (this is an MC, order of operations matters!)
    found_inj['dist_mc'] = ppu.mc_cal_wf_errs(num_mc_injs, found_inj['dist'],
                                              cal_error, wav_err, max_dc_cal_error)
    missed_inj['dist_mc'] = ppu.mc_cal_wf_errs(num_mc_injs, missed_inj['dist'],
                                               cal_error, wav_err, max_dc_cal_error)
    long_inj['dist_mc'] = ppu.mc_cal_wf_errs(num_mc_injs, long_inj['dist'],
                                             cal_error, wav_err, max_dc_cal_error)

    logging.info("MC injection set distributed with %d iterations.",\
                 num_mc_injs)

    # Check injections against on source
    more_sig_than_onsource = np.ndarray(len(found_injs))
    if onsource_file:
        more_sig_than_onsource = (inj_fap <= loud_on_fap)
    else:
        more_sig_than_onsource = (inj_fap <= 0.5)

    distance_count = np.zeros(len(dist_bins))

    found_trig['max_bestnr'] = np.empty(len(found_trig['mchirp']))
    found_trig['max_bestnr'].fill(max_bestnr)

    max_bestnr_cut = (found_trig['bestnr'] > found_trig['max_bestnr'])

    # Check louder than on source
    found_trig['loud_on_bestnr'] = np.empty(len(found_trig['mchirp']))
    if onsource_file:
        found_trig['loud_on_bestnr'].fill(loud_on_bestnr)
    else:
        found_trig['loud_on_bestnr'].fill(med_snr)
    on_bestnr_cut = found_trig['bestnr'] > found_trig['loud_on_bestnr']

    # Check whether injection is found for the purposes of exclusion
    # distance calculation.
    # Found: if louder than all on source
    # Missed: if not louder than loudest on source
    found_excl = on_bestnr_cut & (more_sig_than_onsource) & \
                (found_trig['bestnr'] != 0)
    # If not missed, double check bestnr against nearby triggers
    near_test = np.zeros((found_excl).sum()).astype(bool)
    for j, (t, bestnr) in enumerate(zip(found_inj['time'][found_excl],\
                                        found_trig['bestnr'][found_excl])):
        # 0 is the zero-lag timeslide
        near_bestnr = trig_bestnr[0]\
                      [np.abs(trig_time[0]-t) < cluster_window]
        near_test[j] = ~((near_bestnr * glitch_check_fac > bestnr).any())
    # Apply the local test
    c = 0
    for z, b in enumerate(found_excl):
        if found_excl[z]:
            found_excl[z] = near_test[c]
            c += 1

    # Loop over each random instance of the injection set
    for k in range(num_mc_injs+1):
        # Loop over the distance bins
        for j, dist_bin in enumerate(dist_bins):
            # Construct distance cut
            found_dist_cut = (dist_bin[0] <= found_inj['dist_mc'][k, :]) &\
                             (found_inj['dist_mc'][k, :] < dist_bin[1])
            missed_dist_cut = (dist_bin[0] <= missed_inj['dist_mc'][k, :]) &\
                              (missed_inj['dist_mc'][k, :] < dist_bin[1])
            long_dist_cut = (dist_bin[0] <= long_inj['dist_mc'][k, :]) &\
                            (long_inj['dist_mc'][k, :] < dist_bin[1])

            # Count all injections in this distance bin
            num_found_pass = (found_dist_cut).sum()
            num_missed_pass = (missed_dist_cut).sum()
            num_long_pass = long_dist_cut.sum() or 0
            # Count only zero FAR injections
            num_zero_far = (found_dist_cut & max_bestnr_cut).sum()
            # Count number found for exclusion
            num_excl = (found_dist_cut & (found_excl)).sum()

            # Record number of injections, number found for exclusion
            # and number of zero FAR
            if k == 0:
                key = 'no_mc'
            else:
                key = 'mc'
            num_pass = num_found_pass + num_missed_pass + num_long_pass
            num_injections[key][j] += num_pass
            num_injections[key][-1] += num_pass
            found_max_bestnr[key][j] += num_zero_far
            found_max_bestnr[key][-1] += num_zero_far
            found_on_bestnr[key][j] += num_excl
            found_on_bestnr[key][-1] += num_excl

    logging.info("Found/missed injection efficiency calculations completed.")

# Post-processing of injections ends here

# ==========
# Make plots
# ==========

if do_injections:
    # Calculate distances (horizontal axis) as means
    dist_plot_vals = [np.asarray(dist_bin).mean() for dist_bin in dist_bins]

    # Calculate error bars for efficiency/distance plots and datafiles
    # using max BestNR of background
    yerr_low_mc, yerr_high_mc, fraction_mc = efficiency_with_errs(\
         found_max_bestnr['mc'], num_injections['mc'], num_mc_injs=num_mc_injs)
    yerr_low_no_mc, yerr_high_no_mc, fraction_no_mc = efficiency_with_errs(\
                               found_max_bestnr['no_mc'], num_injections['no_mc'])

    # Calculate and save to disk the 50% sensitive distance
    eff_low = fraction_no_mc
    eff_idx = np.where(eff_low < 0.5)[0]
    if eff_idx.size == 0:
        sens_dist = -1
        err_msg = "Efficiency does not drop below 50%!"
        logging.error(err_msg)
    elif eff_idx[0] == 0:
        sens_dist = 0
        err_msg = "Efficiency does not drop below 90%!"
        logging.error(err_msg)
    else:
        i = eff_idx[0]
        d = dist_plot_vals[i]
        d_low = dist_plot_vals[i-1]
        e = eff_low[i]
        e_low = eff_low[i-1]
        # TODO: insert this in the output pages
        sens_dist = d + (e - 0.5) * (d - d_low) / (e_low - e)

    # Plot efficiency using loudest background
    fig = plt.figure()
    ax = fig.gca()
    ax.plot(dist_plot_vals, (fraction_no_mc), 'g-',
            label='No marginalisation')
    ax.errorbar(dist_plot_vals, (fraction_no_mc),
                yerr=[yerr_low_no_mc, yerr_high_no_mc], c='green')
    marg_eff = fraction_mc
    if not np.isnan(marg_eff.sum()):
        ax.plot(dist_plot_vals, marg_eff, 'r-', label='Marginalised')
        ax.errorbar(dist_plot_vals, marg_eff, yerr=[yerr_low_mc, yerr_high_mc],
                    c='red')
    ax.legend()
    ax.grid()
    ax.set_ylim([0, 1])
    ax.set_xlim(0, 2.*upper_dist - lower_dist)
    ax.set_ylabel("Fraction of injections found louder than loudest background")
    ax.set_xlabel("Distance (Mpc)")
    plot_title = "Detection efficiency - "+inj_set_name
    plot_caption = "Injection recovery efficiency using "
    plot_caption += "BestNR as detection statistic.  "
    plot_caption += "Injections louder than loudest background trigger."
    fig_path = opts.background_output_file
    save_fig_with_metadata(fig, fig_path, cmd=' '.join(sys.argv),
                           title=plot_title, caption=plot_caption)
    plt.close()

    # Calculate error bars for efficiency/distance plots and datafiles
    # using max BestNR of foreground
    yerr_low_no_mc, yerr_high_no_mc, fraction_no_mc = efficiency_with_errs(\
                                found_on_bestnr['no_mc'], num_injections['no_mc'])
    yerr_low, yerr_high, fraction_mc = efficiency_with_errs(found_on_bestnr['mc'],\
                                     num_injections['mc'], num_mc_injs=num_mc_injs)

    # Marginalized efficiency (isf = inverse survival function)
    red_efficiency = (fraction_mc) - (yerr_low) * scipy.stats.norm.isf(0.1)

    # Calculate and save to disk 50% and 90% exclusion distances
    for percentile in [50, 90]:
        eff_idx = np.where(red_efficiency < (percentile / 100.))[0]
        if eff_idx.size == 0:
            green_efficiency = (fraction_no_mc)
            excl_efficiency = green_efficiency
            eff_idx = np.where(green_efficiency < (percentile / 100.))[0]
        else:
            excl_efficiency = red_efficiency
        if eff_idx.size and eff_idx[0] != 0:
            i = eff_idx[0]
            d = dist_plot_vals[i]
            d_low = dist_plot_vals[i-1]
            e = excl_efficiency[i]
            e_low = excl_efficiency[i-1]
            excl_dist = d + (e - (percentile / 100.)) * (d - d_low) /\
                    (e_low - e)
        else:
            excl_dist = 0
            err_msg = "Efficiency below %d%% in first bin!" % (percentile)
            logging.error(err_msg)
        # TODO: include percentile, excl_dist on output pages

    # Plot efficiency using loudest foreground
    fig = plt.figure()
    ax = fig.gca()
    ax.grid()
    ax.plot(dist_plot_vals, (fraction_no_mc), 'g-',
            label='No marginalisation')
    ax.errorbar(dist_plot_vals, (fraction_no_mc),
                yerr=[yerr_low_no_mc, yerr_high_no_mc], c='green')
    marg_eff = fraction_mc
    if not np.isnan(marg_eff.sum()):
        ax.plot(dist_plot_vals, marg_eff, 'r-', label='Marginalised')
        ax.errorbar(dist_plot_vals, marg_eff, yerr=[yerr_low, yerr_high], c='red')
    if not np.isnan(red_efficiency.sum()):
        ax.plot(dist_plot_vals, red_efficiency, 'm-',
                label='Inc. counting errors')
    ax.set_ylim([0, 1])
    ax.grid()
    ax.legend()
    ax.get_legend().get_frame().set_alpha(0.5)
    ax.grid()
    ax.set_ylim([0, 1])
    ax.set_xlim(0, 2.*upper_dist - lower_dist)
    ax.set_ylabel("Fraction of injections found louder than "+\
                  "loudest foreground")
    ax.set_xlabel("Distance (Mpc)")
    ax.plot([excl_dist], [0.9], 'gx')
    ax.set_ylim([0, 1])
    ax.set_xlim(0, 2.*upper_dist - lower_dist)
    plot_title = "Exclusion distance - "+inj_set_name
    plot_caption = "Injection recovery efficiency using "
    plot_caption += "BestNR as detection statistic.  "
    plot_caption += "Injections louder than loudest foreground trigger."
    fig_path = opts.onsource_output_file
    save_fig_with_metadata(fig, fig_path, cmd=' '.join(sys.argv),
                           title=plot_title, caption=plot_caption)
    plt.close()

logging.info("Plots complete.")
