#!/usr/bin/python

import h5py
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import argparse

parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser(usage="",
    description="Plot ratios of VTs calculated at various IFARs using pycbc_page_sensitivity for two analyses")
parser.add_argument('--vt-file-one', required=True,
                    help='HDF file containing VT curves, first set of data for comparison')
parser.add_argument('--vt-file-two', required=True,
                    help='HDF file containing VT curves, second set of data for comparison')
parser.add_argument('--desc-one',  required=True,
                    help='descriptor tag for first set of data (short, for use in subscript)')
parser.add_argument('--desc-two', type=str, required=True,
                    help='descriptor tag for second set of data (short, for use in subscript)')
parser.add_argument('--outfile', type=str, required=True,
                    help='Output file to save to')
parser.add_argument('--ifars', type=float, required=True, nargs='+',
                    help='IFAR values to plot VT ratio for. Note that the plotted values '
                         'will be the closest values available from the VT files')
parser.add_argument('--log-x', action='store_true', default=False,
                    help='Use logarithmic x-axis')
parser.add_argument('--log-y', action='store_true', default=False,
                    help='Use logarithmic y-axis')
args = parser.parse_args()

# Load in the two datasets
ffirst_set = h5py.File(args.vt_file_one)
fsecond_set = h5py.File(args.vt_file_two)

# Find the index closest to the given IFAR value
idxs = [np.argmin(np.abs(ffirst_set['xvals'][:] - ifv)) for ifv in args.ifars]

keys = ffirst_set['data'].keys()
# sanitise the input so that the files have the same binning parameter and bins
assert keys == fsecond_set['data'].keys(), "keys do not match for the two files"

# make the plot pretty
plt.rc('axes.formatter', limits=[-3, 4])
plt.rc('figure', dpi=300)
fig_mi = plt.figure(figsize=(10, 4))
ax_mi = fig_mi.gca()

ax_mi.grid(True, zorder=1)

# read in labels for the different plotting points
labels = ['$ ' + label.split('\\in')[-1] for label in ffirst_set['data'].keys()]

# read in the splitting parameter name from the first data set
x_param = r'$' + tuple(ffirst_set['data'].keys())[0].split('\\in')[0].strip('$').strip() + r'$'

# read in the positions from the labels
xpos = np.array([float(l.split('[')[1].split(',')[0]) for l in labels])

# offset different ifars by 1/20th of the mean distance between parameters
try: 
    if args.log_x:
        xpos_logdiffmean = np.diff(np.log(xpos)).mean()
        xpos_add_dx = 0.05 * np.ones_like(xpos) * xpos_logdiffmean
    else:
        xpos_diffmean = np.diff(xpos).mean()
        xpos_add_dx = 0.05 * np.ones_like(xpos) * xpos_diffmean
except IndexError:
    #If there's only one value of xpos, then diff doesnt work
    xpos_add_dx = 0.05

# set the x ticks to be the positions given in the labels
plt.xticks(xpos, labels, rotation='horizontal')

colors = ['#7b85d4', '#f37738', '#83c995', '#d7369e', '#c4c9d8', '#859795']
count = 0

# loop through each IFAR and plot the VT ratio with error bars
for idv in idxs:
    data1 = np.array([ffirst_set['data'][key][idv] for key in keys])
    data2 = np.array([fsecond_set['data'][key][idv] for key in keys])

    errhigh1 = np.array([ffirst_set['errorhigh'][key][idv] for key in keys])
    errlow1 = np.array([ffirst_set['errorlow'][key][idv] for key in keys])

    errhigh2 = np.array([fsecond_set['errorhigh'][key][idv] for key in keys])
    errlow2 = np.array([fsecond_set['errorlow'][key][idv] for key in keys])

    ys = np.divide(data1, data2)
    yerr_errlow = np.multiply(np.sqrt(np.divide(errlow1, data1)**2 +
                np.divide(errlow2, data2)**2), ys)
    yerr_errhigh = np.multiply(np.sqrt(np.divide(errhigh1, data1)**2 +
                np.divide(errhigh2, data2)**2), ys)
    if args.log_x:
        xvals = np.exp(np.log(xpos) +
                       xpos_add_dx * (count -
                                      float(len(args.ifars) - 1) / 2.))
    else:
        xvals = xpos + xpos_add_dx * (count -
                                      float(len(args.ifars) - 1) / 2.)
    ax_mi.errorbar(xvals, ys,
        yerr=[yerr_errlow, yerr_errhigh], fmt='o', markersize=7, linewidth=5,
        label='IFAR = %d yr' % ffirst_set['xvals'][idv], capsize=5,
        capthick=2, mec='k', color=colors[count % len(colors)])
    count += 1

if args.log_x and args.log_y:
    plt.loglog()
elif args.log_x and not args.log_y:
    plt.semilogx()
elif args.log_y and not args.log_x:
    plt.semilogy()
plt.xticks(xpos, labels, rotation='horizontal')

# get the limit of the x axes, and draw a black line in order to highlight
# equal comparison
xlimits = plt.xlim()
plt.plot(xlimits, [1, 1], 'k', lw=3, zorder=0)
plt.xlim(xlimits) # reassert the x limits so that the plot doesn't expand

ax_mi.legend(bbox_to_anchor=(0.5, 1.01), ncol=len(args.ifars),
             loc='lower center')
ax_mi.get_legend().get_title().set_fontsize('14')
ax_mi.get_legend().get_frame().set_alpha(0.7)
ax_mi.set_xlabel(x_param)
ax_mi.set_ylabel(r'$\frac{VT(\mathrm{' + args.desc_one +'})}{VT(\mathrm{' +
                 args.desc_two +'})}$')
plt.tight_layout()

# write out to file, save as pdf if requested
if args.outfile.split('.')[-1] == 'pdf':
    fig_mi.savefig(args.outfile, format='PDF')
else:
    fig_mi.savefig(args.outfile)

plt.close()
