#!/usr/bin/env python

# Copyright (C) 2017 Christopher M. Biwer, Collin Capano
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
"""Plots the fraction of injections with their parameter value recovered
within a credible interval versus credible interval.
"""

import sys
import argparse
import logging
import itertools

import numpy

from scipy import stats

import matplotlib
matplotlib.use('agg')
from matplotlib import pyplot as plt
import matplotlib.style as style

import pycbc
import pycbc.results.plot
from pycbc.results import save_fig_with_metadata
from pycbc.inference import (option_utils, io)

from pycbc import __version__

# parse command line
parser = io.ResultsArgumentParser(description=__doc__)
parser.add_argument('--version', action='version', version=__version__,
                    help='show version number and exit')
parser.add_argument("--output-file", required=True, type=str,
                    help="Path to save output plot.")
parser.add_argument("--verbose", action="store_true",
                    help="Allows print statements.")
parser.add_argument("--injection-hdf-group", default="injections",
                    help="HDF group that contains injection values. "
                         "Default is 'injections'.")
parser.add_argument("--do-ks-test", action="store_true", default=False,
                    help="Perform a KS test between the percentile-percentile "
                         "plot for each parameter and the (expected) uniform "
                         "distribution. Results are printed in the legend.")
option_utils.add_injsamples_map_opt(parser)
pycbc.results.plot.add_style_opt_to_parser(parser,
                                           default='seaborn-colorblind')
opts = parser.parse_args()

# set style
pycbc.results.plot.set_style_from_cli(opts)

# set logging
pycbc.init_logging(opts.verbose)

# read results
logging.info('Loading parameters')
_, parameters, labels, samples = io.results_from_cli(opts)

# typecast to list for iteration
samples = [samples] if not isinstance(samples, list) else samples

# loop over input files and its samples
logging.info("Plotting")
measured_percentiles = {}

for input_file, input_samples in zip(opts.input_file, samples):
    # load the injections
    opts.input_file = input_file
    inj_parameters = io.injections_from_cli(opts)

    for p in parameters:
        inj_val = inj_parameters[p]
        sample_vals = input_samples[p]
        measured = stats.percentileofscore(sample_vals, inj_val, kind='weak')
        try:
            measured_percentiles[p].append(measured)
        except KeyError:
            measured_percentiles[p] = []
            measured_percentiles[p].append(measured)

# set the color and line styles; total number of unique combinations is 3 *
# the number of colors in the style's cycle (seaborn-colorblind has 6 colors,
# so the default is 18 unique combinations)
color_cycle = matplotlib.rcParams['axes.prop_cycle']
colors = itertools.cycle([x['color'] for x in color_cycle])
ncolors = len(color_cycle)
ls_cycle = ['-']*ncolors + [':']*ncolors + ['-.']*ncolors
line_styles = itertools.cycle(ls_cycle)

# create figure for plotting
fig = plt.figure(figsize=(6,6))
ax = fig.add_subplot(111)
# calculate the expected percentile for each injection and plot
for param in parameters:
    label = labels[param]
    meas = numpy.array(measured_percentiles[param])
    meas.sort()
    expected = numpy.array([stats.percentileofscore(meas, x, kind='weak')
                            for x in meas])
    # perform ks test
    if opts.do_ks_test:
        ks, p = stats.kstest(meas/100., 'uniform')
        label = '{} $D_{{KS}}$: {:.3f} p-value: {:.3f}'.format(label, ks, p)
    ax.plot(meas/100., expected/100., c=next(colors), ls=next(line_styles),
            label=label)

# set legend
breaknum = 9
if len(parameters) > 2*breaknum:
    ncols = 3
elif len(parameters) > breaknum:
    ncols = 2
else:
    ncols = 1

ax.legend(ncol=ncols)

# set labels
ax.set_ylabel(r"Fraction of Injections Recovered in Credible Interval")
ax.set_xlabel(r"Credible Interval")

# add grid to plot
ax.grid()

# add 1:1 line to plot
ax.plot([0, 1], [0, 1], linestyle="dashed", color="gray", zorder=9)

# save plot
caption = ('Percentile-percentile plot. The value of the KS statistic '
           '$D_{KS}$ is given in the legend. This gives the maximum distance '
           'between the observed line and the expected (dashed) line; i.e., '
           'it gives the maximum distance between the measured CDF and the '
           'expected (uniform) CDF. The associated two-tailed p-value gives '
           'the probability of getting a maximum distance (either above '
           'or below the expected line) larger than the observed $D_{KS}$ '
           'assuming that the measured CDF is the same as the expected. '
           'In other words, the larger (smaller) the p-value ($D_{KS}$), the '
           'more likely the measured distribution is the same as the '
           'expected.')
save_fig_with_metadata(fig, opts.output_file,
    caption=caption,
    cmd=' '.join(sys.argv),
    fig_kwds={'bbox_inches': 'tight'})

# done
logging.info("Done")

