#!/usr/bin/env python
# Copyright (C) 2017 Christopher M. Biwer
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
""" Plots the Gelman-Rubin convergence diagnositic statistic.
"""

import argparse
import logging
import matplotlib as mpl; mpl.use("Agg")
import matplotlib.pyplot as plt
import sys
from pycbc import results
from pycbc import __version__ 
from pycbc.inference import (gelman_rubin, io, option_utils)

# add options to command line
parser = io.ResultsArgumentParser(skip_args=['walkers'])

# verbose option
parser.add_argument("--verbose", action="store_true", default=False,
                    help="Print logging info.")
parser.add_argument('--version', action='version', version=__version__,
                    help='show version number and exit')

# output options
parser.add_argument("--output-file", type=str, required=True,
                    help="Path to output plot.")
parser.add_argument("--walkers", type=int, nargs="+", default=None,
                    help="Specific walkers to plot. Default is plot "
                         "all walkers.")
parser.add_argument("--hline", type=float, default=None,
                    help="Plots a horizontal line.")
parser.add_argument("--segment-start", type=int, required=True,
                    help="Index in chain to start calculation.")
parser.add_argument("--segment-end", type=int, required=True,
                    help="Index in chain to end calculation.")
parser.add_argument("--segment-step", type=int, required=True,
                    help="Step size in chain to next calculation.")


# parse the command line
opts = parser.parse_args()

# setup log
if opts.verbose:
    log_level = logging.DEBUG
else:
    log_level = logging.WARN
logging.basicConfig(format="%(asctime)s : %(message)s", level=log_level)

# enfore that this is not a single iteration
if opts.iteration is not None:
    raise ValueError("Cannot use --iteration")

# load the results
fp, params, labels, _ = io.results_from_cli(
                                         opts, load_samples=False,
                                         walkers=None)

# if use wants specific walkers
walkers = range(fp.nwalkers) if opts.walkers is None else opts.walkers

# create Figure
fig = plt.figure()

# loop over parameters
stats = []
for param, label in zip(params, labels):
    logging.info("Plotting parameter %s", param)

    # get samples for each chain
    chains = [fp.read_samples(param, walkers=i,
                              thin_start=opts.thin_start,
                              thin_interval=opts.thin_interval,
                              thin_end=opts.thin_end)[param] for i in walkers]

    # calculate the Gelman-Rubin convergence diagnostic statistic
    chains = [chain.reshape(1, len(chain)) for chain in chains]
    starts, ends, stats = gelman_rubin.walk(chains, opts.segment_start,
                                            opts.segment_end,
                                            opts.segment_step)

    # plot
    plt.plot(ends, stats[0, :], label=label)

# format plot
plt.ylabel("Potential Scale Reduction Factor")
plt.xlabel("Iteration")
plt.legend(labelspacing=0.2)

# plot horizontal line
if opts.hline:
    plt.hlines(opts.hline, 0, ends[-1], linestyles="dashed")

# save figure with meta-data
caption_kwargs = {
    "parameters" : ", ".join(labels),
}
caption = """The Gelman-Rubin convergence diagnostic statistic for {parameters}
read from the input file.""".format(**caption_kwargs)
title = "Gelman-Rubin Convergence for {parameters}".format(**caption_kwargs)
results.save_fig_with_metadata(fig, opts.output_file,
                               cmd=" ".join(sys.argv),
                               title=title,
                               caption=caption)
plt.close()

# exit
fp.close()
logging.info("Done")
