#!/usr/bin/env python
# Copyright (C) 2016 Alex Nielsen, Ian Harry
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
""" Make Alex Nielsen's fitting factor plots for a single banksim run.
"""
from __future__ import division

import sys
import h5py
import argparse
import numpy
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

import pycbc.version
from pycbc import results

__author__ = "Alex Nielsen <alex.nielsen@ligo.org>, "
__author__  += "Ian Harry <ian.harry@ligo.org>"
__version__ = pycbc.version.git_verbose_msg
__date__    = pycbc.version.date
__program__ = "pycbc_banksim_plot_fitting_factors"

parser = argparse.ArgumentParser(usage='',
    description="Plot fitting factor distribution.")
parser.add_argument("--version", action="version", version=__version__)
parser.add_argument('--input-file', default=None, required=True,
                    help="List of input files.")
parser.add_argument('--output-file', default=None, required=True,
                    help="Output file.")
parser.add_argument('--filter-injections', action='store_true', default=False,
                    help="If true only consider and plot injections that are "
                         "marked in the input HDF file as passing the filter "
                         "that was supplied when generating the HDF file.")
# Plotting options
parser.add_argument('--plot-title',
                    help="If given, use this as the plot title")
parser.add_argument('--plot-caption',
                    help="If given, use this as the plot caption")

 
opt = parser.parse_args()

curr_fp = h5py.File(opt.input_file, 'r')
m1 = curr_fp['inj_params/mass1'][:]
m2 = curr_fp['inj_params/mass2'][:]
s1z = curr_fp['inj_params/spin1z'][:]
s2z = curr_fp['inj_params/spin2z'][:]
match = curr_fp['trig_params/match'][:]
if opt.filter_injections:
    bool_arr = curr_fp['filtered_points'][:]
    m1 = m1[bool_arr]
    m2 = m2[bool_arr]
    s1z = s1z[bool_arr]
    s2z = s2z[bool_arr]
    match = match[bool_arr]
mtot = m1 + m2
eta = m1*m2 / (mtot*mtot)
effspin = (s1z*m1 + s2z*m2) / mtot
curr_fp.close()


cmap = plt.get_cmap('magma')

def adjust_axes(ax_instance):
    """Adjust axes so that there is always a range of >= 0.1 on the plots."""
    low, upp = ax_instance.get_xlim()
    if upp - low < 0.1:
        cax.set_xlim([low-0.1, upp+0.1])
    low, upp = ax_instance.get_ylim()
    if upp - low < 0.1:
        cax.set_ylim([low-0.1, upp+0.1])


pointsize=4
fig, axarr = plt.subplots(nrows=3, ncols=2, figsize=(10,8))
fig.subplots_adjust(left=0.1, bottom=0.1, right=0.8, top=0.96, wspace=0.3,
                     hspace=0.3)
cax = axarr[0, 0]
cax.autoscale_view('tight')
cax.scatter(s1z, s2z, c=match, zorder=10, s=pointsize, linewidth=0, cmap=cmap)
cax.set_xlabel('Injected spin1')
cax.set_ylabel('Injected spin2')
adjust_axes(cax)

cax = axarr[0, 1]
cax.autoscale_view('tight')
cax.scatter(eta, effspin, c=match, zorder=10, s=pointsize, linewidth=0,
            cmap=cmap)
cax.set_xlabel('Injected eta')
cax.set_ylabel('Injected effective spin')
adjust_axes(cax)

cax = axarr[1,0]
cax.autoscale_view('tight')
cax.scatter(m1, m2, c=match, zorder=10, s=pointsize, linewidth=0, cmap=cmap)
cax.set_xlabel('Injected mass1')
cax.set_ylabel('Injected mass2')
adjust_axes(cax)

cax = axarr[1,1]
cax.autoscale_view('tight')
image = cax.scatter(mtot, effspin, c=match, zorder=10, s=pointsize,
                    linewidth=0, cmap=cmap)
cax.set_xlabel('Injected total mass')
cax.set_ylabel('Injected effective spin')
adjust_axes(cax)

cax = axarr[2,0]
cax.autoscale_view('tight')
hist, bins = numpy.histogram(match, bins=50)
width = 1.0 * (bins[1] - bins[0])
center = (bins[:-1] + bins[1:]) / 2
cax.bar(center, hist, align='center', width=width, edgecolor="none")
cax.set_xlabel('fitting factor')
cax.set_ylabel('number')
cax.set_yscale('log')
cax.set_ylim([0.9,1000])

cax = axarr[2,1]
cax.autoscale_view('tight')
if len(m1):
    normed_hist = hist / hist.sum()
else:
    normed_hist = hist
cumulative = numpy.cumsum(normed_hist)
cax.bar(center, cumulative, align='center', width=width, edgecolor="none")
cax.set_xlabel('fitting factor')
cax.set_ylabel('cumulative fraction')
if len(m1) > 50:
    cax.set_yscale('log')
cax.set_ylim([0,1])

fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.85, 0.05, 0.05, 0.91])
if len(mtot):
    cbar = fig.colorbar(image, cax=cbar_ax)
    cbar.set_label('Recovered fitting factor')

if opt.plot_title is None:
    opt.plot_title = 'Fitting factor plots'
if opt.plot_caption is None:
    opt.plot_caption = ("A sequence of plots showing the fitting factor of "
                        "the input injections as a function of both spins and "
                        "masses (top 2 rows), as well as showing cumulative "
                        "and non-cumulative distributions of fitting factor.")

fig_kwds = {}
if '.png' in opt.output_file:
    fig_kwds['dpi'] = 200

results.save_fig_with_metadata(fig, opt.output_file,
                               fig_kwds=fig_kwds, title=opt.plot_title,
                               cmd=' '.join(sys.argv),
                               caption=opt.plot_caption)
