#!/usr/bin/env python
import numpy, h5py, argparse, matplotlib, sys
matplotlib.use('Agg')
import pylab, pycbc.results, pycbc.version
from pycbc.events import veto
from pycbc.io import get_chisq_from_file_choice, chisq_choices

parser = argparse.ArgumentParser()
parser.add_argument('--trigger-file', help='Single ifo trigger file')
parser.add_argument('--version', action='version', version=pycbc.version.git_verbose_msg)
parser.add_argument('--veto-file', help='Optional, file of veto segments to remove triggers')
parser.add_argument('--segment-name', default=None, type=str,
                    help='Optional, name of segment list to use for vetoes')
parser.add_argument('--min-snr', type=float, help='Optional, Minimum SNR to plot')
parser.add_argument('--output-file')
parser.add_argument('--newsnr-contours', nargs='*', help="List of newsnr values to draw contours at.", default=[])
parser.add_argument('--chisq-choice', choices=chisq_choices,
                    default='traditional',
                    help='Which chisquared to plot. Default=traditional')
args = parser.parse_args()

f = h5py.File(args.trigger_file, 'r')
ifo = tuple(f.keys())[0]
f = f[ifo]
snr = f['snr'][:]
chisq = get_chisq_from_file_choice(f, args.chisq_choice)

def snr_from_chisq(chisq, newsnr, q=6.):
    snr = numpy.zeros(len(chisq)) + float(newsnr)
    ind = numpy.where(chisq > 1.)[0]
    snr[ind] = float(newsnr) / ( 0.5 * (1. + chisq[ind] ** (q/2.)) ) ** (-1./q)
    return snr

if args.veto_file:
    time = f['end_time'][:]
    locs, segs = veto.indices_outside_segments(time, [args.veto_file], 
                                       segment_name=args.segment_name, ifo=ifo)
    snr = snr[locs]
    chisq = chisq[locs]

if args.min_snr is not None:
    locs = snr > args.min_snr
    snr = snr[locs]
    chisq = chisq[locs]

fig = pylab.figure(1)

r = numpy.logspace(numpy.log(chisq.min()), numpy.log(chisq.max()), 300)
for i, cval in enumerate(args.newsnr_contours):
    snrv = snr_from_chisq(r, cval)
    pylab.plot(snrv, r, color='black', lw=0.5)
    if i == 0:
        label = "$\\hat{\\rho} = %s$" % cval
    else:
        label = "$%s$" % cval
    try:
        label_pos_idx = numpy.where(snrv > snr.max() * 0.8)[0][0]
    except IndexError:
        label_pos_idx = 0
    pylab.text(snrv[label_pos_idx], r[label_pos_idx], label, fontsize=6,
               horizontalalignment='center', verticalalignment='center',
               bbox=dict(facecolor='white', lw=0, pad=0, alpha=0.9))

pylab.hexbin(snr, chisq, gridsize=300, xscale='log', yscale='log', lw=0.04,
             mincnt=1, norm=matplotlib.colors.LogNorm())

ax = pylab.gca()
pylab.grid()   
ax.set_xscale('log')
cb = pylab.colorbar() 
pylab.xlim(snr.min(), snr.max() * 1.1)
pylab.ylim(chisq.min(), chisq.max() * 1.1)
cb.set_label('Trigger Density')
pylab.xlabel('Signal-to-Noise Ratio')
pylab.ylabel('Reduced $\\chi^2$')
pycbc.results.save_fig_with_metadata(fig, args.output_file, 
     title="%s :SNR vs Reduced %s Chisq" % (ifo, args.chisq_choice),
     caption="Distribution of SNR and %s Chisq for single detector triggers: "
             "Black lines show contours of constant NewSNR." \
              %(args.chisq_choice,),
     cmd=' '.join(sys.argv),
     fig_kwds={'dpi':300})
