#!/usr/bin/env python
""" Make a table of found injection information
"""
import argparse, h5py, numpy as np, pycbc.results, pycbc.detector, sys
from pycbc.types import MultiDetOptionAction
import pycbc.pnutils, pycbc.events
import pycbc.version
from itertools import combinations


parser = argparse.ArgumentParser()
parser.add_argument("--version", action="version", version=pycbc.version.git_verbose_msg)
parser.add_argument('--injection-file', help='HDF File containing the matched injections')
parser.add_argument('--single-trigger-files', nargs='*', help="HDF format single detector trigger files", action=MultiDetOptionAction)
parser.add_argument('--verbose', action='count')
parser.add_argument('--show-missed', action='store_true')
parser.add_argument('--output-file')
args = parser.parse_args()

f = h5py.File(args.injection_file,'r')
inj = f['injections']
found_cols, found_names, found_formats = [], [], []

if('detector_1' in f.attrs):
    ifos = [f.attrs['detector_1'], f.attrs['detector_2']]
else:
    ifos = f.attrs['ifos'].split(' ')

if args.show_missed:
    title = "Missed Injections"
    idx = f['missed/after_vetoes'][:]
else:
    title = "Found Injections"
    found = f['found_after_vetoes']
    idx = found['injection_index'][:]
    if('detector_1' in f.attrs):
        tdiff = (found['time1'][:] - found['time2'][:]) * 1000
        tdiff_str = '%s - %s time (ms)' % (f.attrs['detector_1'], f.attrs['detector_2'])
        ids = {f.attrs['detector_1']: found['trigger_id1'][:],
               f.attrs['detector_2']: found['trigger_id2'][:],}

        found_cols = [found['stat'], found['ifar'], found['ifar_exc'], tdiff]
        found_names = ['Ranking Stat.', 'Inc. IFAR (yrs)', 'Exc. IFAR', tdiff_str]
        found_formats =  ['##.##', '##.##', '##.##', '##.##']

    else:
        detectors = f.attrs['ifos'].split(' ')
        keys = f['found_after_vetoes'].keys()
        detectors_used = []
        found = f['found_after_vetoes']
        for det in detectors:
            if(det in keys):
                detectors_used.append(det)
        det_two_combo= np.array(list(combinations(detectors_used,2)))
        tdiff = []
        tdiff_str = []
        tdiff_format =[]
        for i in range(len(det_two_combo)):
            time_1 = np.array(found[det_two_combo[i,0]+'/time'][:])
            time_2 = np.array(found[det_two_combo[i,1]+'/time'][:])
            tdiff_vals = (time_1 - time_2) * 1000
            tdiff_vals[np.logical_or(time_1 < 0, time_2 < 0)] = np.nan
            tdiff_1 = ['%.2f' % td if not np.isnan(td) else ' ' for td in tdiff_vals]
            tdiff.append(tdiff_1)
            tdiff_head= '%s - %s time (ms)' % (det_two_combo[i,0], det_two_combo[i,1])
            tdiff_str.append(tdiff_head)
            tdiff_format.append('##.##')
        ids = {detector:found[detector+'/trigger_id'][:] for detector in detectors_used}
    
        found_cols = [found['stat'], found['ifar_exc']] + tdiff
        found_names = ['Ranking Stat.', 'Exc. IFAR'] + tdiff_str
        found_formats =  ['##.##', '##.##'] + tdiff_format


    if args.single_trigger_files:
        for ifo in args.single_trigger_files:
            f = h5py.File(args.single_trigger_files[ifo], 'r')[ifo]
            ids_ifo = np.array(ids[ifo])
            ids_na = np.argwhere(ids_ifo == -1)
            snr_vals = f['snr'][:][ids_ifo]
            snr_vals[ids_ifo == -1] = np.nan
            chisq_vals = f['chisq'][:][ids_ifo] / (2 * f['chisq_dof'][:][ids_ifo] - 2)
            chisq_vals[ids_ifo == -1] = np.nan
            newsnr_vals = pycbc.events.ranking.newsnr(snr_vals, chisq_vals)
            snr = ['%.2f' % s if not np.isnan(s) else ' ' for s in snr_vals]
            chisq = ['%.2f' % c if not np.isnan(c) else ' ' for c in chisq_vals]
            newsnr = ['%.2f' % s if not np.isnan(s) else ' ' for s in newsnr_vals]
            
            found_names += [ifo + " SNR", ifo + " CHISQ", ifo + " NewSNR"]
            if('detector_1' in f.attrs):
                found_cols += [snr_vals, chisq_vals, newsnr_vals]
            else:
                found_cols += [snr, chisq, newsnr]
            found_formats += ['##.##', '##.##', '##.##']

eff_dist = {'eff_dist_%s' % i[0].lower() : 'Eff Dist (%s)' % i for i in ifos}

keys = inj.keys()
eff_dist_str = []
eff_distance = []
eff_dist_format = []
for dist in eff_dist :
    if(dist in keys):
        eff_distance.append(inj[dist][:][idx])
        eff_dist_str.append(eff_dist[dist])
        eff_dist_format.append('##.##')
                
dec_dist = np.max(eff_distance,0)
m1, m2 = inj['mass1'][:][idx], inj['mass2'][:][idx]
mchirp, eta = pycbc.pnutils.mass1_mass2_to_mchirp_eta(m1, m2)
dec_chirp_dist = pycbc.pnutils.chirp_distance(dec_dist, mchirp)

columns = [dec_chirp_dist, inj['end_time'][:][idx], m1, m2, mchirp, eta,
           inj['spin1x'][:][idx], inj['spin1y'][:][idx], inj['spin1z'][:][idx],
           inj['spin2x'][:][idx], inj['spin2y'][:][idx], inj['spin2z'][:][idx],
           inj['distance'][:][idx]] + eff_distance + found_cols
         
names = ['DChirp Dist', 'Inj Time', 'Mass1', 'Mass2', 'Mchirp', 'Eta',
         's1x', 's1y', 's1z',
         's2x', 's2y', 's2z', 
         'Dist']  + eff_dist_str + found_names 

format_strings = ['##.##', '##.##', '##.##', '##.##', '##.##', '##.##',
                  '##.##', '##.##', '##.##',
                  '##.##', '##.##', '##.##',
                  '##.##'] + eff_dist_format +  found_formats 
columns = [np.array(col) for col in columns]
html_table = pycbc.results.html_table(columns, names,
                                 format_strings=format_strings,
                                 page_size=20)

kwds = { 'title' : title, 
        'caption' : "A table of %s and their coincident statistic information." % title.lower(),
        'cmd' :' '.join(sys.argv), }
pycbc.results.save_fig_with_metadata(str(html_table), args.output_file, **kwds)
