#!/bin/env python

# Copyright 2020 Gareth S Davies
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
"""
Script to plot the output of pycbc_multiifo_dtphase in a corner plot
"""

import matplotlib, h5py, numpy as np, copy
import itertools, sys, logging, argparse
matplotlib.use('agg')
from matplotlib import pyplot as plt
from pycbc.events import coinc_rate
from pycbc import init_logging, version
from pycbc.results import save_fig_with_metadata

def marginalise_pdf(pdf, dimensions_to_keep):
    ndims = len(pdf.shape)
    # Copy the original PDF to ensure we don't change it
    mrg_pdf = copy.deepcopy(pdf)
    # Need to count backwards in dimensions, otherwise the axis number
    # may not match the dimension number after first sum.
    for d in range(ndims)[::-1]:
        if d in dimensions_to_keep: continue
        mrg_pdf = mrg_pdf.sum(axis=d)
    return mrg_pdf

parser = argparse.ArgumentParser()
parser.add_argument('--version', action="version",
                    version=version.git_verbose_msg)
parser.add_argument('--verbose', action='store_true')
parser.add_argument('--input-file', required=True,
                    help="Input phasetd histogram file, made using "
                         "pycbc_multiifo_dtphase")
parser.add_argument('--reference-ifo',
                    help="IFO to use as reference for P/T/A differences. "
                         "Default = first IFO in 'ifos' attribute of "
                         "input file.")
parser.add_argument('--norm', choices=['pdf','max'], default='pdf',
                    help="Normalization of the histogram. Choose from pdf "
                         "or so that the maximum (unmarginalized) bin height "
                         "is unity. Default=pdf")
parser.add_argument('--shared-color-scale', action='store_true',
                    help="Color scale shared between the plots? Default=No")
parser.add_argument('--output-file', required=True,
                    help="Output png file")


args = parser.parse_args()
init_logging(args.verbose)

f = h5py.File(args.input_file, 'r')

# Hard-code some lists to refer to later
width_names = ['twidth', 'pwidth', 'swidth']
plot_lab = ["Time difference", "Phase difference", "SNR ratio"]
plot_leg = ["", "0-2pi phase diff limits", "SNR ratio limits"]

ifos = f.attrs['ifos']
ref_ifo = args.reference_ifo or ifos[0]

if ref_ifo not in ifos:
    raise RuntimeError("Reference IFO is not in file")

logging.info("IFOs: %s", " ".join(ifos))
logging.info("Reference IFO: %s", ref_ifo)

# The param_bin indices will have named datafields
dtns = f[ref_ifo + '/param_bin'].dtype.names
ndims = len(dtns)

# Work out the bin widths in each dimension
widths = [f.attrs[width_names[i % 3]] for i in range(ndims)]

# lims gives the physically allowed differences for GW time-of-flight
# between IFOs, 2-pi phase difference range, and the limits on SNR
# ratio from the dtphase file setup. These need to be reconstructed
# from the file attributes.
slim = [f.attrs['srbmin'] * f.attrs['swidth'],
        (f.attrs['srbmax'] + 1) * f.attrs['swidth']]
lims = [[], [0, 2 * np.pi], slim]
tdiffs = [coinc_rate.multiifo_noise_coincident_area([ref_ifo, ifo], 0)
          for ifo in f.attrs['ifos'] if not ifo == ref_ifo][::-1]

# Which IFOs are the differences between?
ifostrs = ["%s-%s" % (ifo, ref_ifo) for ifo in f.attrs['ifos']
           if not ifo == ref_ifo]

# Set up strings for the legends and axis labels
tlegs = ["%s time of flight (s)" % ifostr for ifostr in ifostrs]
legends = [plot_leg[i % 3] for i in range(ndims)]
tlabs = ["%s time difference (s)" % ifostr for ifostr in ifostrs]
labels = [ifostrs[i // 3] + ' ' + plot_lab[i % 3] for i in range(ndims)]

# convert the limits on the P/T/A dimensions into limits for each dimension
limits = [lims[i % 3] for i in range(ndims)]

# Time dimensions have different limits
for n, d in enumerate(tdiffs):
    limits[(n - 1) * 3] = (-d / 2., d / 2.)
for n, l in enumerate(tlegs):
    legends[(n-1) * 3] = l

# Get the minimum and maximum occupied bins for each dimension
max_bin = [f[ref_ifo + '/param_bin'][dtn][:].max() for dtn in dtns]
min_bin = [f[ref_ifo + '/param_bin'][dtn][:].min() for dtn in dtns]

# Convert these bins into physical values for use in converting
hist_min = [min_bin[i] * widths[i] for i in range(ndims)]
hist_max = [(max_bin[i] + 1) * widths[i] for i in range(ndims)]
hist_range = [(1.1 * mn - 0.1 * mx, 1.1 * mx - 0.1 * mn)
              for mx, mn in zip(hist_max, hist_min)]

# Set up a zero-valued array to be filled by histogram values
arr_size = tuple(max_bin[i] - min_bin[i] + 1 for i in range(ndims))
hist_vals = np.zeros(arr_size)

logging.info("filling histogram")
for pb, w in zip(f[ref_ifo]['param_bin'][:],
                 f[ref_ifo]['weights'][:]):
    # Index is the difference between stored bin value and the
    # minimum bin value - this keeps RAM manageable
    idx = tuple(pb[i] - min_bin[i] for i, c in enumerate(dtns))
    hist_vals[idx] = w

# Replace any un-filled bins with the maximum penalty (minimum weight)
min_weight = f[ref_ifo]['weights'][:].min()
hist_vals[hist_vals < min_weight] = min_weight

# We've got everything we want from the file
f.close()

if args.norm.lower() == 'pdf':
    # Convert the histogram into a PDF
    bin_vol = np.prod(widths)
    hist_vals /= (bin_vol * hist_vals.sum())
else:
    # Normalize so that the maximum bin height is unity
    hist_vals /= hist_vals.max()

# Set up the colorbar normalization to be shared across plots
if args.shared_color_scale:
    logging.info("getting min/max plot values")
    # max/min weight will be the max after marginalization of all-but 2
    # parts of array
    dim_combos = list(itertools.combinations(np.arange(ndims), 2))
    min_mrg_weight = np.inf
    max_weight = -np.inf
    for dc in dim_combos:
        mpdf = marginalise_pdf(hist_vals, dc)
        min_mrg_weight = min(min_mrg_weight, mpdf.min())
        max_weight = max(max_weight, mpdf.max())
    colnorm = matplotlib.colors.LogNorm(vmin=min_mrg_weight/10,
                                        vmax=max_weight)
else:
    colnorm = matplotlib.colors.LogNorm()

logging.info("Plotting")
fig, axes = plt.subplots(ndims, ndims, sharex='col', sharey='row',
                         figsize=(3 * ndims, 3 * ndims))

# different types of dimension will be different colors
dcols = ['k', 'r', 'b']
# different IFO combinations will be different linestyles
ltypes = ['-',':','--']

# Add some space for the colorbar to be added
left_space = 0.1
top_space = 0.98 if args.shared_color_scale else (1 - left_space)
bot_space = 2 * left_space - 0.02 if args.shared_color_scale else left_space
fig.subplots_adjust(top=top_space, bottom=bot_space, left=left_space,
                    right=(1 - left_space), wspace=0.02, hspace=0.02)

for i in range(ndims):
    logging.info("Dimension %d: %s", i, labels[i])
    # Sum all dimensions except for this one
    marg_pdf = marginalise_pdf(hist_vals, [i])

    # Using sharex and sharey means that the 1D PDF heights will be out
    # of range. Using a twin axis here means that we can share the
    # x-axis but use it's own y-axis
    margax = axes[i,i].twinx()
    logging.info("Plotting marginalised histogram")
    # Work out the x-values of the marginalised histogram
    i_values = widths[i] * np.arange(min_bin[i], max_bin[i] + 1)
    # move the plot points to the center of the bin
    i_values += widths[i] / 2
    margax.plot(i_values, marg_pdf)
    margax.semilogy()
    # plot on the physical limits for this dimension
    yl = margax.get_ylim()
    margax.plot([limits[i][0], limits[i][0]], yl, c=dcols[i % 3],
                linestyle=ltypes[i // 3])
    margax.plot([limits[i][1], limits[i][1]], yl, c=dcols[i % 3],
                linestyle=ltypes[i // 3])
    # reiterate y limits
    margax.set_ylim(yl)
    margax.set_xlim(hist_range[i])
    margax.set_ylabel("Probability Density" if args.norm == 'pdf'
                      else "Summed Histogram Height")

    # Loop through other dimensions - stop at i-1 so we don't do the
    # same combination twice
    for j in range(i):
        logging.info("%s vs %s heatmap", labels[i], labels[j])
        twod_pdf = marginalise_pdf(hist_vals, [i,j])

        # 2-dimensional color plot of the PDF
        im = axes[i, j].imshow(np.rot90(twod_pdf), aspect='auto',
                               extent=[hist_min[j], hist_max[j],
                                       hist_min[i], hist_max[i]],
                               norm=colnorm)

        # This is probably overkill
        axes[i, j].set_ylim(hist_range[i])
        axes[i, j].set_xlim(hist_range[j])

        # Plot physical limits
        axes[i, j].plot([limits[j][0], limits[j][0]],
                        hist_range[i], c=dcols[j % 3],
                        linestyle=ltypes[j // 3])
        axes[i, j].plot([limits[j][1], limits[j][1]],
                        hist_range[i], c=dcols[j % 3],
                        linestyle=ltypes[j // 3])
        axes[i, j].plot(hist_range[j], [limits[i][0], limits[i][0]],
                        c=dcols[i % 3], linestyle=ltypes[i // 3])
        axes[i, j].plot(hist_range[j], [limits[i][1], limits[i][1]],
                        c=dcols[i % 3], linestyle=ltypes[i // 3])

    # turn off axes where we haven't plotted anything
    for j in range(i+1, ndims):
        axes[i, j].axis('off')
    # Add labels to the x axes at the bottom 
    axes[ndims - 1, i].set_xlabel(labels[i])

#Add labels to the y axes on the left hand side
for i in range(ndims):
    axes[i, 0].set_ylabel(labels[i])

if args.shared_color_scale:
    # Add a colorbar, and axis at the bottom to fill it
    # Extends from inline with first plot to the final 2d-plot
    cb_ax = fig.add_axes([left_space,
                         bot_space / 2.5,
                         (1 - 2 * left_space) * (ndims - 1.) / ndims,
                         bot_space / 8])
    # Using the last-plotted image should be safe as they share a colorbar
    cb = fig.colorbar(im, cax=cb_ax, orientation='horizontal')
    cb.set_label("Probability Density" if args.norm=='pdf'
                 else "Summed Histogram Height")

# Make some fake lines to add to the legend
zz = [0, 0]
# Add the legend in the top right axis
leg_ax = axes[0, ndims-1]
for i in range(ndims):
    leg_ax.plot(zz, zz, color=dcols[i % 3],
               label=legends[i], linestyle=ltypes[i // 3])
leg_ax.axis('off')
leg_ax.legend(loc='lower left')

# set the colorbar ticks and tick labels
logging.info("Saving with metadata")

# Set up metadata
figure_title = "Corner plot of the %s PTA histogram" % "-".join(ifos)
if args.norm == 'pdf':
    norm_statement = "Normalized to be the PDF of the parameter space. "
elif args.norm == 'max':
    norm_statement = "Normalized such that the maximum un-normalized " \
                     "histogram bin values is unity. "
caption = "Plots in various dimensions of the " + " ".join(ifos) + \
          " phase, time and SNR difference histogram. " + norm_statement + \
          " The reference IFO for these differences is " + ref_ifo + "."


save_fig_with_metadata(fig, args.output_file,
     title=figure_title,
     caption=caption,
     cmd=' '.join(sys.argv))

logging.info("Done")
