#!/usr/bin/env python

import numpy, h5py, argparse, logging
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm

parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('--signal-file')
parser.add_argument('--template-file', required=True)
parser.add_argument('--param', nargs='+', required=True,
                    help='Specify one parameter name for a kde_vs_param plot, or '
                         'two parameter names for a param_vs_param plot. Param '
                         'names must exist as datasets in the input files')
parser.add_argument('--log-axis', nargs='+', choices=['True', 'False'], required=True,
                    help='For each parameter, specify True for a log axis and False '
                         'for a linear axis')
parser.add_argument('--plot-type', choices=['kde_vs_param', 'param_vs_param'])
parser.add_argument('--which-kde', choices=['signal_kde', 'template_kde', 'ratio_kde'])
parser.add_argument('--plot-dir', required=True)
parser.add_argument('--verbose', action='count')
args = parser.parse_args()

if args.plot_type == 'kde_vs_param':
    if len(args.param) != 1:
        parser.error('For kde_vs_param, give exactly one parameter name')
else:
    if len(args.param) != 2:
        parser.error('For param_vs_param, give exactly two parameter names')
if len(args.param) != len(args.log_axis):
    parser.error('Must specify either log (True) or non-log (False) for each parameter')

if args.signal_file:
    signal_data = h5py.File(args.signal_file, 'r')
    signal_kde = signal_data['signal_kde'][:]
template_data = h5py.File(args.template_file, 'r')
template_kde = template_data['template_kde'][:]
param0 = template_data[args.param[0]][:]
if len(args.param) > 1:
    param1 = template_data[args.param[1]][:]

if args.plot_type == 'kde_vs_param':
    fig, ax = plt.subplots(1, figsize=(12,7), constrained_layout=True)
    if args.which_kde == 'signal_kde':
        im = ax.scatter(signal_kde, param0, marker=".", c="r", s=5)
    elif args.which_kde == 'template_kde':
        im = ax.scatter(template_kde, param0, marker=".", c="r", s=5)
    elif args.which_kde == 'ratio_kde':
        im = ax.scatter(signal_kde / template_kde, param0, marker=".", c="r", s=5)
    else:
        raise RuntimeError('Unknown value of which_kde', args.which_kde)
    ax.set_xticklabels(args.which_kde, fontsize=13)
    ax.set_yticklabels(args.param[0], fontsize=13)
    ax.set_xlabel(args.which_kde, fontsize=15)
    ax.set_ylabel(args.param[0], fontsize=15)
    ax.set_xscale('log')
    if args.log_axis[0] == 'True':
        ax.set_yscale('log')
    else:
        ax.set_yscale('linear')
    plot_loc = args.plot_dir + args.which_kde + '_vs_' + args.param[0] + '.png'
    plt.savefig(plot_loc)

elif args.plot_type == 'param_vs_param':
    fig, ax = plt.subplots(1, figsize=(12,7), constrained_layout=True)
    if args.which_kde == 'signal_kde':
        im = ax.scatter(param0, param1, marker=".", c=signal_kde, cmap='RdBu_r',
                        s=5, norm=LogNorm())
    elif args.which_kde == 'template_kde':
        im = ax.scatter(param0, param1, marker=".", c=template_kde, cmap='RdBu_r',
                        s=5, norm=LogNorm())
    elif args.which_kde == 'ratio_kde':
        im = ax.scatter(param0, param1, marker=".", c=signal_kde / template_kde,
                        cmap='RdBu_r', s=5, norm=LogNorm())
    else:
        raise RuntimeError('Unknown value of which_kde', args.which_kde)
    cbar = fig.colorbar(im, ax=ax)
    ax.set_xticklabels(args.param[0], fontsize=13)
    ax.set_yticklabels(args.param[1], fontsize=13)
    ax.set_xlabel(args.param[0], fontsize=15)
    ax.set_ylabel(args.param[1], fontsize=15)
    if args.log_axis[0] == 'True':
        ax.set_xscale('log')
    else:
        ax.set_xscale('linear')
    if args.log_axis[1] == 'True':
        ax.set_yscale('log')
    else:
        ax.set_yscale('linear')
    cbar.ax.set_ylabel(args.which_kde, rotation=270, fontsize=15)
    plot_loc = args.plot_dir + args.which_kde + '_' + args.param[0] + '_vs_' + args.param[1] + '.png'
    plt.savefig(plot_loc)

else:
    raise RuntimeError('Unknown plot type!', args.plot_type)
