#!/usr/bin/env python

# Copyright (C) 2020 Collin Capano
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""Makes a plot of MCMC parameters saved over checkpoint history.."""

import logging
import argparse

import numpy
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot

import pycbc
from pycbc.inference import io

parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('--input-file', type=str, required=True,
                    help='Path to the input HDF file.')
parser.add_argument('--output-file', type=str, required=True,
                    help='Name of the output plot.')
parser.add_argument('-t', '--plot-checkpoint-dt', action='store_true',
                    help='Plot the wall-clock time between checkpoints.')
parser.add_argument('-a', '--plot-act', action='store_true',
                    help='Plot ACT vs checkpoint iteration.')
parser.add_argument('-n', '--plot-effective-nsamples', action='store_true',
                    help='Plot the number of effective samples versus '
                         'checkpoint iteration.')
parser.add_argument('-b', '--plot-nchains-burned-in', action='store_true',
                    help='Plot the number of chains that were burned in '
                         'versus checkpoint iteration. Note that for '
                         'ensemble samplers, this will be all or nothing.')
opts = parser.parse_args()

pycbc.init_logging(True)

nplots = sum([opts.plot_act, opts.plot_effective_nsamples,
              opts.plot_nchains_burned_in, opts.plot_checkpoint_dt])
if nplots == 0:
    raise ValueError("nothing to plot")

# load the data
logging.info('Loading data')

fp = io.loadfile(opts.input_file, 'r')
history = fp['sampler_info/checkpoint_history']
iterations = history['niterations'][()]

if opts.plot_checkpoint_dt:
    checkpoint_dt = history['checkpoint_dt'][()]

if opts.plot_act:
    try:
        raw_acts = history['act'][()]
    except KeyError:
        raise ValueError("ACT history was not saved")
    if raw_acts.ndim == 2:
        # separate acts for each chain, calculate mean (of finite at each point)
        acts = numpy.full(raw_acts.shape[1], numpy.inf)
        for ii in range(raw_acts.shape[1]):
            aa = raw_acts[:, ii]
            aa = aa[numpy.isfinite(aa)]
            if aa.size > 0:
                acts[ii] = aa.mean()
    else:
        acts = raw_acts

if opts.plot_effective_nsamples:
    nsamples = history['effective_nsamples'][()]

if opts.plot_nchains_burned_in:
    try:
        burn_in_iter = history['burn_in_iteration'][()]
    except KeyError:
        raise ValueError("Burn-in history not saved")
    nchains_burned_in = numpy.zeros(iterations.size, dtype=int)
    nchains = fp.nchains
    for ii in range(nchains_burned_in.size):
        if burn_in_iter.ndim == 1:
            # ensemble sampler; all or none
            nchains_burned_in[ii] = nchains*(burn_in_iter[ii] > 0)
        else:
            nchains_burned_in[ii] = (burn_in_iter[:, ii] > 0).sum()
fp.close()

# plot
logging.info("Plotting")
fig, axes = pyplot.subplots(nrows=nplots, figsize=(8, 3*nplots))
if nplots == 1:
    axes = [axes]
pi = -1
xmin = iterations.min() - 1
xmax = iterations.max() + 1

if opts.plot_checkpoint_dt:
    pi += 1
    ax = axes[pi]
    ax.plot(iterations, checkpoint_dt/60., lw=2)
    ax.set_ylabel('wallclock dt (m)')
    ax.set_xlim(xmin, xmax)

if opts.plot_act:
    pi += 1
    ax = axes[pi]
    ax.plot(iterations, acts, lw=2, zorder=1)
    if raw_acts.ndim == 2:
        # plot each of the chains separately
        for ii in range(raw_acts.shape[0]):
            ax.plot(iterations, raw_acts[ii, :], lw=1, color='C1',
                    alpha=0.3,
                    zorder=0)
    ax.set_ylabel('ACT')
    ax.set_xlim(xmin, xmax)

if opts.plot_effective_nsamples:
    pi += 1
    ax = axes[pi]
    ax.plot(iterations, nsamples, lw=2, zorder=1)
    ax.set_ylabel(r'eff. N samples')
    ax.set_xlim(xmin, xmax)

if opts.plot_nchains_burned_in:
    pi += 1
    ax = axes[pi]
    ax.plot(iterations, nchains_burned_in, lw=2, zorder=1)
    # put a horizontal line at the total number of chains
    ax.axhline(nchains, ls='--', color='C3', zorder=0)
    ax.set_ylabel(r'N chains burned in')
    ax.set_xlim(xmin, xmax)

ax.set_xlabel('iteration')

# common settings
for ii, ax in enumerate(axes):
    ax.grid(ls=':', zorder=-1)
    # turn off x ticks for all but the bottom
    if ii < len(axes) - 1:
        ax.set_xticklabels([])

fig.savefig(opts.output_file, bbox_inches='tight')
