#!/usr/bin/env python
# Copyright (C) 2018--2019 Gareth Davies
# Copyright (C) 2020 Josh Willis
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

import h5py
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import argparse
from matplotlib.pyplot import cm
from math import ceil

parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser(description="Plot ratios of VTs calculated at"
                                 " various IFARs using pycbc_page_sensitivity"
                                 " (run with --hdf-out option) for two"
                                 " comparable analyses ")
parser.add_argument('--vt-file-one', required=True,
                    help="HDF file containing VT curves, first set of data"
                    " for comparison")
parser.add_argument('--vt-file-two', required=True,
                    help="HDF file containing VT curves, second set of data"
                    " for comparison")
parser.add_argument('--outfile', type=str, required=True,
                    help='Output file to save to')
parser.add_argument('--plot-title', type=str, required=True,
                    help='(Possibly) quoted string to be title of plot')
parser.add_argument('--log-x', action='store_true', default=False,
                    help='Use logarithmic x-axis')
parser.add_argument('--log-y', action='store_true', default=False,
                    help='Use logarithmic y-axis')
args = parser.parse_args()

# Load in the two datasets
f1 = h5py.File(args.vt_file_one)
f2 = h5py.File(args.vt_file_two)

x1 = f1['xvals'][:]
x2 = f2['xvals'][:]
if not np.array_equal(x1, x2):
   raise RuntimeError("IFAR values are not the same between the two files")

xvals = x1

keys = f1['data'].keys()
# sanitise the input so that the files have the same binning parameter and bins
assert keys == f2['data'].keys(), "keys do not match for the two files"

nkeys = len(keys)
#if nkeys != 6:
#   raise RuntimeError("Only prepared for number of chirp mass bins to be six")

#fig, axs = plt.subplots(3, 2, sharex=True, sharey=True)
nrows = int(ceil(nkeys/2.0))
fig, axs = plt.subplots(nrows, 2, sharex=True, sharey=True)

stitle = fig.suptitle(args.plot_title)

color = iter(cm.rainbow(np.linspace(0, 1, nkeys)))
alpha = 0.6

# loop through each chirp mass bin:
lines = []

for n, key in zip(range(nkeys), keys):
    c = next(color)

    data1 = f1['data'][key][:]
    data2 = f2['data'][key][:]

    errhigh1 = f1['errorhigh'][key][:]
    errlow1 = f1['errorlow'][key][:]

    errhigh2 = f2['errorhigh'][key][:]
    errlow2 = f2['errorlow'][key][:]

    ys = np.divide(data1, data2)
    yerr_errlow = np.multiply(np.sqrt(np.divide(errlow1, data1)**2 +
                np.divide(errlow2, data2)**2), ys)
    yerr_errhigh = np.multiply(np.sqrt(np.divide(errhigh1, data1)**2 +
                np.divide(errhigh2, data2)**2), ys)
    ax = axs[n%nrows, n//nrows]
    ax.axhline(y=1.0, alpha=0.5, ls='--', color='k')
    ax.plot(xvals, ys, c='k')
    l, = ax.plot(xvals, ys, c=c)
    lines.append(l)
    ax.fill_between(xvals, ys-yerr_errlow, ys+yerr_errhigh,
                    facecolor=c, edgecolor=c, alpha=alpha)
    if args.log_x:
       ax.set_xscale('log')

    if args.log_y:
       ax.set_yscale('log')

    print("Bin {0}: minimum ratio {1}, maximum ratio {2}".format(key,
                                                         ys.min(), ys.max()))

for ax in axs.flat:
   ax.set(xlabel='Inverse False Alarm Rate (years)',
          ylabel='VT ratio')

for ax in axs.flat:
   ax.label_outer()


# Positioning the legend seems to be a nightmare
# First, grab the center, rightmost of the axes:
ax = axs[1, 1]
lgd = ax.legend(handles=lines, labels=keys, loc="center left",
                borderaxespad=0.1, title="Chirp mass bins",
                bbox_to_anchor=(1.05,0.5))

fig.savefig(args.outfile, bbox_extra_artists=(lgd, stitle), bbox_inches='tight')
plt.close()
