#!/usr/bin/env python
"""
Make a table of foreground coincidences information. 
This code is meant for the multiifo statmap files as an
alternative to pycbc_page_foreground.
"""
import argparse, h5py, numpy as np, pycbc.results, sys
import pycbc.pnutils, pycbc.events
import pycbc.version
from pycbc import init_logging
import logging
from itertools import combinations


parser = argparse.ArgumentParser()
parser.add_argument("--version", action="version",
                    version=pycbc.version.git_verbose_msg)
parser.add_argument('--trigger-file', required=True,
                    help='HDF File containing found trigger information')
parser.add_argument('--single-detector-triggers', nargs='+', default=None)
parser.add_argument('--bank-file', required=True)
parser.add_argument('--use-hierarchical-level', type=int, default=None,
                    help='Indicate which FARs to write to the table '
                         'based on the number of hierarchical removals done. '
                         'Choosing None defaults to giving the FARs after '
                         'all hierarchical removals were done depending on '
                         'previous configuration choices. Choosing 0 selects '
                         'writing the FARs prior to any hierarchical '
                         'removals. Choosing 1 means writing the FARs after '
                         'doing 1 hierarchical removal. The program will '
                         'fail if the user selects a number above the number '
                         'of hierarchical removals done. [default=None]')
parser.add_argument('--sort-param', default="ifar",
                    help='What to use to sort table, choose ifar, ifar_exc, '
                         'stat.')
parser.add_argument('--verbose', action='count')
parser.add_argument('--output-file')
args = parser.parse_args()

init_logging(args.verbose)

logging.info("Opening results file")
f = h5py.File(args.trigger_file,'r')

h_num_rm = args.use_hierarchical_level

try:
    h_iterations = f.attrs['hierarchical_removal_iterations']
except KeyError:
    h_iterations = None

if h_num_rm is not None and (h_iterations is None or h_num_rm > h_iterations):
    logging.info("Hierarchical removal requested is more than performed")
    col_one = np.array([h_num_rm])
    col_two = np.array([h_iterations])
    columns = [col_one, col_two]
    names = ["Hierarchical Removals Requested",
             "Hierarchical Removals Performed"]
    format_strings = ['#', '#']
    html_table = pycbc.results.html_table(columns, names,
                                   format_strings=format_strings, page_size=10)

    kwds = { 'title' :
        'No more events louder than all background at this removal level.',
        'cmd' :' '.join(sys.argv), }
    pycbc.results.save_fig_with_metadata(str(html_table), args.output_file,
                                         **kwds)
    sys.exit(0)

ifos = f.attrs['ifos'].split(' ')

title = "Loudest Event Table"
if h_num_rm:
    logging.info("Using hierarchical removal level %d", h_num_rm)
    foreground = f['foreground_h%d' % h_num_rm]
else:
    foreground = f['foreground']

logging.info("Sorting on %s", args.sort_param)
ii_sort = np.argsort(foreground[args.sort_param][:])[::-1]
logging.info("Getting coinc statistic and significance details")
cols = [foreground['ifar'][:][ii_sort], foreground['ifar_exc'][:][ii_sort],
        foreground['fap'][:][ii_sort],  foreground['fap_exc'][:][ii_sort],
        foreground['stat'][:][ii_sort]]
names = ['Inc. IFAR (YR)', 'Exc. IFAR (YR)', 'Inc. FAP', 'Exc. FAP ',
         'Ranking Statistic']
formats = ['#.###E0', '#.###E0', '#.##E0', '#.##E0', '##.###']

detectors = f.attrs['ifos'].split(' ')

logging.info("Getting event time and number of IFOs")
tc = [foreground[ifo + '/time'][:][ii_sort] for ifo in detectors]
n_ifos = [np.count_nonzero(np.array(ct) > 0) for ct in zip(*tc)]
mean_time = [np.mean(np.array(ct)[np.array(ct) > 0]) for ct in zip(*tc)]

# Add GPS event time and n_ifos to the table
cols.append(mean_time)
names.append("End Time (GPS)")
formats.append("#.##")

cols.append(n_ifos)
names.append("No. IFOs in coinc")
formats.append("#")

logging.info("Getting template information")
tid = foreground['template_id'][:][ii_sort]
with h5py.File(args.bank_file, 'r') as bank_f:
    mass1 = bank_f['mass1'][:][tid]
    mass2 = bank_f['mass2'][:][tid]
    spin1z = bank_f['spin1z'][:][tid]
    spin2z = bank_f['spin2z'][:][tid]
mchirp, eta = pycbc.pnutils.mass1_mass2_to_mchirp_eta(mass1, mass2)
cols += [mchirp, mass1, mass2, spin1z, spin2z]
names += ['mchirp', 'm1', 'm2', 's1z', 's2z']
formats += ['##.##', '##.##', '##.##', '##.##', '##.##']

logging.info("Getting time difference information")
det_two_combo= np.array(list(combinations(detectors,2)))
tdiff = []
tdiff_str = []
tdiff_format =[]
for i in range(len(det_two_combo)):
     time_1 = np.array(foreground[det_two_combo[i,0]+'/time'][:][ii_sort])
     time_2 = np.array(foreground[det_two_combo[i,1]+'/time'][:][ii_sort])
     tdiff_vals = (time_1 - time_2) * 1000
     tdiff_vals[np.logical_or(time_1 < 0, time_2 < 0)] = np.nan
     tdiff_1 = ['%.2f' % td if not np.isnan(td) else ' ' for td in tdiff_vals]
     tdiff.append(tdiff_1)
     tdiff_head= '%s - %s time (ms)' % (det_two_combo[i,0], det_two_combo[i,1])
     tdiff_str.append(tdiff_head)
     tdiff_format.append('##.##')

cols += tdiff
names += tdiff_str
formats += tdiff_format


ids = {detector:foreground[detector+'/trigger_id'][:][ii_sort] for detector
       in detectors}
f.close()

sngl_trig_filedict = {detector: [h5py.File(sngl_trig_fname, 'r')
                 for sngl_trig_fname in args.single_detector_triggers
                 if detector in h5py.File(sngl_trig_fname, 'r')][0]
       for detector in detectors}

for ifo in detectors:
    logging.info("Getting %s single-detector trigger information", ifo)
    ids_ifo = np.array(ids[ifo])
    trigfile_ifo = sngl_trig_filedict[ifo][ifo]
    snr_vals = trigfile_ifo['snr'][:][ids_ifo]
    snr_vals[ids_ifo == -1] = np.nan
    chisq_dof = trigfile_ifo['chisq_dof'][:][ids_ifo]
    chisq_vals = trigfile_ifo['chisq'][:][ids_ifo] / (2 * chisq_dof - 2)
    chisq_vals[ids_ifo == -1] = np.inf
    newsnr_vals = pycbc.events.ranking.newsnr(snr_vals, chisq_vals)
    chisq_vals[ids_ifo == -1] = np.nan
    newsnr_vals[ids_ifo == -1] = np.nan
    snr = ['%.2f' % s if not np.isnan(s) else ' ' for s in snr_vals]
    chisq = ['%.2f' % c if not np.isnan(c) else ' ' for c in chisq_vals]
    newsnr = ['%.2f' % s if not np.isnan(s) else ' ' for s in newsnr_vals]

    names += [ifo + " SNR", ifo + " Red. Chisq", ifo + " NewSNR"]
    cols += [snr, chisq, newsnr]
    formats += ['##.##', '##.##', '##.##']

columns = [np.array(col) for col in cols]
logging.info("Putting results into html table")
html_table = pycbc.results.html_table(columns, names,
                                 format_strings=formats,
                                 page_size=20)

kwds = { 'title' : title, 
        'caption' : "A table of %s and their coincident statistic information."
                    % title.lower(),
        'cmd' :' '.join(sys.argv), }
logging.info("Saving table with metadata")
pycbc.results.save_fig_with_metadata(str(html_table), args.output_file, **kwds)
