#!/usr/bin/env python

# Copyright (C) 2016 Christopher M. Biwer
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

import argparse
import itertools
import logging
import sys

import numpy

from matplotlib import use
use('agg')
from matplotlib import pyplot as plt

import pycbc
from pycbc import __version__
from pycbc import (distributions, results, waveform)
from pycbc.inference.option_utils import (ParseParametersArg,
                                          prior_from_config)
from pycbc.workflow import WorkflowConfigParser
from pycbc.results.scatter_histograms import create_multidim_plot


def cartesian(arrays):
    """ Returns a cartesian product from a list of iterables.
    """
    return numpy.array([numpy.array(element)
                        for element in itertools.product(*arrays)])

# command line usage
parser = argparse.ArgumentParser(
    usage="pycbc_inference_plot_prior [--options]",
    description="Plots prior distributions.")
# add input options
parser.add_argument("--config-files", type=str, nargs="+", required=True,
                    help="A file parsable by "
                         "pycbc.workflow.WorkflowConfigParser.")
parser.add_argument("--parameters", nargs="+", action=ParseParametersArg,
                    metavar="PARAM[:LABEL]",
                    help="Only plot the given parameters. May also provide a "
                         "label for each parameter. If provided, the "
                         "parameters can be any of parameters in the "
                         "[variable_params] section, derived parameters from "
                         "them, or any function of them. Syntax for "
                         "functions is python; any math functions in the "
                         "numpy libary may be used. Can optionally also "
                         "specify a LABEL for each parameter. If no LABEL is "
                         "provided, PARAM will used as the LABEL. If LABEL "
                         "is the same as a parameter in "
                         "pycbc.waveform.parameters, the label "
                         "property of that parameter will be used. If not "
                         "provided, will plot all of the parameters in the "
                         "[variable_params] section of the config file.")
parser.add_argument("--prior-section", type=str, default="prior",
                    help="Name of section with distribution configurations. "
                         "Default is 'prior'.")
# add output options
parser.add_argument("--nsamples", type=int, default=10000,
                    help="The number of samples to draw from the prior for "
                         "plotting. Default is 10000.")
parser.add_argument("--output-file", type=str, required=True,
                    help="Path to output plot.")
# verbose option
parser.add_argument("--verbose", action="store_true", default=False,
                    help="")
parser.add_argument("--version", action="version", version=__version__,
                    help="show version number and exit")

# parse the command line
opts = parser.parse_args()

# setup log
pycbc.init_logging(opts.verbose)

# read configuration file
logging.info("Reading configuration files")
cp = WorkflowConfigParser(opts.config_files)

logging.info("Loading prior")
prior = prior_from_config(cp, prior_section=opts.prior_section)
logging.info("Drawing samples from prior")
samples = prior.rvs(opts.nsamples)

# figure out what parameters to plot
if opts.parameters is None:
    parameters = prior.variable_args
    labels = {}
    for param in parameters:
        # try to get the parameter label from the waveform module
        try:
            labels[param] = getattr(waveform.parameters, param).label
        except AttributeError:
            labels[param] = param
else:
    parameters = opts.parameters
    labels = opts.parameters_labels

logging.info("Plotting")
fig, axis_dict = create_multidim_plot(
    parameters, samples,
    labels=labels,
    plot_marginal=True,
    marginal_percentiles=[5, 50, 95],
    plot_scatter=False,
    plot_density=True,
    plot_contours=True,
    contour_percentiles=[50, 90],)


# set DPI
fig.set_dpi(200)

# save figure with meta-data
caption = """This plot shows the probability density function (PDF) from the 
prior distributions."""
title = "Prior Distribution"
results.save_fig_with_metadata(fig, opts.output_file,
                               cmd=" ".join(sys.argv),
                               title=title,
                               caption=caption)
plt.close()

# exit
logging.info("Done")
