#!/usr/bin/python
""" Histogram of templates where injections are found.
"""
import logging, h5py, argparse, numpy as np
from matplotlib import use
use('Agg')
from matplotlib import pyplot as plt
from pycbc.events import triggers

parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('--verbose', action='count')
parser.add_argument('--output', required=True)
parser.add_argument('--found-injection-files', dest='found', nargs='+',
                    help='hdf file(s) with found injections')
parser.add_argument('--inspiral-trigger-files', dest='trig', nargs='*',
                    default=[],
                    help='hdf files(s) with injection single-ifo triggers, '
                         'if supplied must be one per found injection file')
parser.add_argument('--bank-files', nargs='*', default=[],
                    help='hdf file(s) with template parameters')
parser.add_argument('--x-param', #choices=['template_duration'],
                    help='parameter to histogram over')
parser.add_argument('--num-bins', type=int, default=30)
parser.add_argument('--log-x', action='store_true',
                    help='use log bins in parameter')
parser.add_argument('--min-stat', type=float,
                    help='only plot injections above given stat value')
parser.add_argument('--min-ifar', type=float,
                    help='only plot injections above given ifar value')
args = parser.parse_args()

if args.verbose:
    log_level = logging.INFO
    logging.basicConfig(format='%(asctime)s : %(message)s', level=log_level)

# should be same number of inj and trig files
# and either 0 or 1 bank files or 1 per inj file
if not(len(args.trig) == 0 or len(args.trig) == len(args.found)):
    raise RuntimeError('If trigger files are given, must be one per injection '
                       'file!')
if not(len(args.bank_files) < 2 or len(args.bank_files) == len(args.found)):
    raise RuntimeError('If multiple bank files are used, must be one per '
                       'injection file!')

# duplicate bank file if necessary
if len(args.bank_files) == 1:
    args.bank_files = [args.bank_files[0] for f in args.found]
elif len(args.bank_files) == 0:
    args.bank_files = [None for f in args.found]
if len(args.trig) == 0:
    args.trig = [None for f in args.found]
foundstat = np.array([])
foundifar = np.array([])
foundparam = np.array([])
# cycle over injection files
for f, b, t in zip(args.found, args.bank_files, args.trig):
    logging.info('Processing %s, %s, %s' % (f, b, t))
    f = h5py.File(f, 'r')
    t = h5py.File(t, 'r') if t else None
    b = h5py.File(b, 'r') if b else None
    foundstat = np.concatenate((foundstat, f['found_after_vetoes/stat'][:]))
    foundifar = np.concatenate((foundifar, f['found_after_vetoes/ifar'][:]))
    # acrobatics to find right ifo
    ifo = tuple(t.keys())[0] if t else None
    ## each tuple is of form ('L1', 'detector_1')
    #dettuples = [(val, key) for (key, val) in f.attrs.items()
    #             if "detector" in key]
    #ifotag = [tag for (name, tag) in dettuples if name == ifo][0]
    ## get the trigger ids
    #ftrig_id = f['found_after_vetoes/trigger_id%s' % ifotag[-1]]
    fparam, found_in_ifo = triggers.get_found_param(f, b, t, args.x_param, ifo)
    #if args.x_param == 'template_duration':
    #    # duration values
    foundparam = np.concatenate((foundparam, fparam))
                                # t['%s/template_duration' % ifo][:][ftrig_id]))

# apply filters
if args.min_stat is not None:
    above = foundstat >= args.min_stat
    foundstat = foundstat[above]
    foundifar = foundifar[above]
    foundparam = foundparam[above]
if args.min_ifar is not None:
    above = foundifar >= args.min_ifar
    foundstat = foundstat[above]
    foundifar = foundifar[above]
    foundparam = foundparam[above]
logging.info('%i found injections above threshold(s)' % len(foundparam))

# plot
if args.log_x:
    # make sure to pick up last point in either direction
    hbins = np.logspace(np.log10(0.999*foundparam.min()), 
                        np.log10(1.001*foundparam.max()),
                        num=args.num_bins, endpoint=True)
    lims = 0.98 * foundparam.min(), 1.02 * foundparam.max()
else:
    extent = foundparam.max() - foundparam.min()
    hbins = np.linspace(foundparam.min() - 0.001 * extent,
                        foundparam.max() + 0.001 * extent, 
                        num=args.num_bins, endpoint=True)
    lims = foundparam.min() - 0.02 * extent, foundparam.max() + 0.02 * extent
n, bins, patches = plt.hist(foundparam, bins=hbins)
if args.verbose: print(bins, n)
if args.log_x:
   plt.semilogx()
plt.xlim(lims)
plt.xlabel(args.x_param.replace('_', ' '))
plt.ylabel('Number of found injections above threshold')
plt.savefig(args.output)
plt.close()

