#!/usr/bin/env python

# Copyright (C) 2015 Tito Dal Canton
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""
Plot PyCBC's single-detector triggers over the search parameter space.
"""

import logging
import argparse
import numpy as np
import matplotlib
matplotlib.use('agg')
import pylab as pl
from matplotlib.colors import LogNorm
from matplotlib.ticker import LogLocator
import h5py
import pycbc.pnutils
import pycbc.events
import pycbc.results
import pycbc.io
import sys
import pycbc.version
from pycbc.events.stat import sngl_statistic_dict

parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--version", action="version",
                    version=pycbc.version.git_verbose_msg)
parser.add_argument('--single-trig-file', required=True,
                    help='Path to file containing single-detector triggers in '
                         'HDF5 format. Required')
parser.add_argument('--bank-file', required=True,
                    help='Path to file containing template bank in HDF5 format'
                         '. Required')
parser.add_argument('--veto-file', type=str,
                    help='Optional path to file containing veto segments')
parser.add_argument('--segment-name', default=None, type=str,
                    help='Optional, name of segment list to use for vetoes')
parser.add_argument('--filter-string', default=None, type=str,
                    help='Optional, boolean expression for filtering triggers')
parser.add_argument('--min-snr', default=0., type=float,
                    help='Only plot triggers above the given SNR')
parser.add_argument('--output-file', type=str, required=True,
                    help='Destination path for plot')
parser.add_argument('--x-var', required=True,
                    choices=pycbc.io.SingleDetTriggers.get_param_names(),
                    help='Parameter to plot on the x-axis. Required')
parser.add_argument('--y-var', required=True,
                    choices=pycbc.io.SingleDetTriggers.get_param_names(),
                    help='Parameter to plot on the y-axis. Required')
parser.add_argument('--z-var', required=True,
                    choices=['density'] + list(sngl_statistic_dict.keys()),
                    help='Quantity to plot on the color scale. Required')
parser.add_argument('--detector', required=True,
                    help='Detector. Required')
parser.add_argument('--grid-size', type=int, default=80,
                    help='Bin resolution (larger = smaller bins)')
parser.add_argument('--log-x', action='store_true',
                    help='Use log scale for x-axis')
parser.add_argument('--log-y', action='store_true',
                    help='Use log scale for y-axis')
parser.add_argument('--min-x', type=float, help='Optional minimum x value')
parser.add_argument('--max-x', type=float, help='Optional maximum x value')
parser.add_argument('--min-y', type=float, help='Optional minimum y value')
parser.add_argument('--max-y', type=float, help='Optional maximum y value')
parser.add_argument('--min-z', type=float, help='Optional minimum z value')
parser.add_argument('--max-z', type=float, help='Optional maximum z value')
opts = parser.parse_args()

logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO)

snr_filter = '(self.snr>%f)' % (opts.min_snr) if opts.min_snr > 0. else None 
filts = [f for f in [snr_filter, opts.filter_string] if f is not None]
filter_func = ' & '.join(filts) if filts else None

trigs = pycbc.io.SingleDetTriggers(opts.single_trig_file, opts.bank_file,
                  opts.veto_file, opts.segment_name, filter_func, opts.detector)

x = getattr(trigs, opts.x_var)
y = getattr(trigs, opts.y_var)

mask = np.ones(len(x), dtype=bool)
if opts.min_x is not None:
    mask = np.logical_and(mask, x >= opts.min_x)
if opts.max_x is not None:
    mask = np.logical_and(mask, x <= opts.max_x)
if opts.min_y is not None:
    mask = np.logical_and(mask, y >= opts.min_y)
if opts.max_y is not None:
    mask = np.logical_and(mask, y <= opts.max_y)
x = x[mask]
y = y[mask]

hexbin_style = {
    'gridsize': opts.grid_size,
    # hexbin shows bins with *less* than mincnt as blank
    'mincnt': 0,
    'linewidths': 0.03
}
if opts.log_x:
    hexbin_style['xscale'] = 'log'
if opts.log_y:
    hexbin_style['yscale'] = 'log'
if opts.min_z is not None:
    hexbin_style['vmin'] = opts.min_z
if opts.max_z is not None:
    hexbin_style['vmax'] = opts.max_z

logging.info('Plotting')
fig = pl.figure()
ax = fig.gca()

if opts.z_var == 'density':
    norm = LogNorm()
    hb = ax.hexbin(x, y, norm=norm, vmin=1, **hexbin_style)
    fig.colorbar(hb, ticks=LogLocator(subs=range(10)))
elif opts.z_var in sngl_statistic_dict:
    cb_style = {}
    z = getattr(trigs, opts.z_var) 

    z = z[mask]
    min_z = z.min() if opts.min_z is None else opts.min_z
    max_z = z.max() if opts.max_z is None else opts.max_z
    if max_z / min_z > 10:
        hexbin_style['norm'] = LogNorm()
        cb_style['ticks'] = LogLocator(subs=range(10))
    hb = ax.hexbin(x, y, C=z, reduce_C_function=max, **hexbin_style)
    fig.colorbar(hb, **cb_style)
else:
    raise RuntimeError('z_var = %s is not recognized!' % (opts.z_var))

ax.set_xlabel(opts.x_var)
ax.set_ylabel(opts.y_var)
ax.set_title(opts.z_var.title() + ' of %s triggers ' % (opts.detector))
title = '%s of %s triggers over %s and %s' % (opts.z_var.title(),
                         opts.detector, opts.x_var.title(), opts.y_var.title())
fig_caption = ("This plot shows the %s of single detector triggers for the %s "
               "detector. %s is shown on the colorbar axis against %s and %s "
               "on the x- and y-axes." % (opts.z_var, opts.detector,
                                   opts.z_var.title(), opts.x_var, opts.y_var))
pycbc.results.save_fig_with_metadata(fig, opts.output_file, title=title,
                                     caption=fig_caption, cmd=' '.join(sys.argv),
                                     fig_kwds={'dpi': 200})

logging.info('Done')
