#!/usr/bin/env python
"""Plots the recovered versus injected parameter values from a population
of injections.
"""

import sys
import argparse
import logging
import matplotlib as mpl; mpl.use("Agg")
import matplotlib.colorbar as cbar
import matplotlib.pyplot as plt
import numpy
import pycbc
import pycbc.version
from matplotlib import cm
from pycbc import inject
from pycbc import __version__
from pycbc.inference import (option_utils, io)
from pycbc.results import save_fig_with_metadata

# parse command line
parser = io.ResultsArgumentParser(description=__doc__)

parser.add_argument("--version", action="version", version=__version__,
                    help="Prints version information.")
parser.add_argument("--output-file", required=True, type=str,
                    help="Path to save output plot.")
parser.add_argument("--verbose", action="store_true",
                    help="Allows print statements.")
parser.add_argument("--percentiles", nargs=2, type=float, default=[5, 95],
                    help="Percentiles to use as limits.")
option_utils.add_scatter_option_group(parser)
option_utils.add_injsamples_map_opt(parser)
opts = parser.parse_args()

# set logging
pycbc.init_logging(opts.verbose)

# read results
fp, parameters, labels, samples = io.results_from_cli(opts)

# only plot one parameter
assert len(parameters) == 1
parameter = parameters[0] 
label = labels[parameters[0]] 

# create figure
fig = plt.figure(figsize=(6,6))
ax = fig.add_subplot(111)

# typecast to list for iteratation
samples = [samples] if not isinstance(samples, list) else samples
fp = [fp] if not isinstance(fp, list) else fp
injs = io.injections_from_cli(opts)

# if user wants a colorbar

# if a z-arg is specified, load samples for it
if opts.z_arg is not None:
    logging.info("Getting samples for colorbar")
    zlbl = opts.z_arg_labels[opts.z_arg]
    zvals = []
    for fp in fps:
        zsamples = fp.samples_from_cli(opts, parameters=opts.z_arg)
        vals = zsamples[opts.z_arg]
        zvals.append(numpy.median(vals))

        # update range of colorbar
        min_zval = vals.min() if i == 0 else min(min_zval, vals.min())
        max_zval = vals.max() if i == 0 else max(max_zval, vals.max())

    # create colormap
    cmap = cm.get_cmap(opts.scatter_cmap)
    vmin = opts.vmin if opts.vmin else min_zval
    vmax = opts.vmax if opts.vmax else max_zval
    norm = mpl.colors.Normalize(vmin, vmax)

# loop over input files and its samples
logging.info("Plotting")
for i, (input_file, input_fp, input_samples) in enumerate(zip(opts.input_file,
                                                              fp, samples)):
    # get paramter values
    sampled_vals = input_samples[parameter]
    injected_vals = injs[parameter][i]
    # compute percentiles of sampled results
    percentiles = numpy.array([numpy.percentile(sampled_vals, p)
                             for p in opts.percentiles])

    # get median and lowest and highest quntiles for plotting
    med = numpy.median(sampled_vals)
    high = percentiles.max()
    low = percentiles.min()

    # get color
    if opts.z_arg:
        color = cmap(norm(zvals[i]))
    else:
        color = "black"

    # plot a point for each injection
    ax.errorbar([injected_vals],
                [med - injected_vals],
                yerr=[[(med - low)], [(high - med)]],
                ecolor=color, linestyle="None", zorder=10)

# create a colorbar
if opts.z_arg:
    cax, _ = cbar.make_axes(ax)
    cb2 = cbar.ColorbarBase(cax, cmap=cmap, norm=norm)
    cb2.set_label(r"Recovered Median " + zlabel)

# set labels
ax.set_ylabel(r"Recovered " + label + r"- Injected " + label)
ax.set_xlabel(r"Injected " + r"{}".format(label))

# add grid to plot
ax.grid()

# add 1:1 line to plot
ax.axhline(0, linestyle="dashed", color="gray", zorder=9)

# save plot
caption = ('Difference in recovered value vs injected value vs injected '
           'value. The vertical lines show the width of the {} to {} '
           'percentile credible interval.'
           .format(opts.percentiles[0], opts.percentiles[1]))
save_fig_with_metadata(fig, opts.output_file,
                       caption=caption,
                       cmd=' '.join(sys.argv),
                       fig_kwds={'bbox_inches': 'tight'})

# done
logging.info("Done")

