#!/usr/bin/env python
# Copyright (C) 2020 Josh Willis
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

from glob import glob
from os import path
import argparse
import numpy as np
from h5py import File
from pycbc.pnutils import mass1_mass2_to_mchirp_eta
import pycbc
import logging


# Globals
ratio_keys = ['ifar_exc', 'ifar', 'stat', 'minimum_single_detector_statistic']
other_keys = ['injection_index', 'template_hash']


# Functions
def update_found_one_only(filep, groupp, fileno, data_dict):
    keys = ratio_keys+other_keys
    for key in keys:
        data_dict[key].append(groupp[key][:])
    idx = groupp['injection_index'][:]
    injs = filep['injections']
    cm_arr = mass1_mass2_to_mchirp_eta(injs['mass1'][:][idx],
                                       injs['mass2'][:][idx])[0]
    data_dict['chirp_mass'].append(cm_arr)
    data_dict['file_number'].append(fileno * np.ones(len(idx), dtype=np.int32))
    return

def update_missed(filep, groupp, fileno, data_dict):
    keys = ['injection_index', 'loudest_rank']
    for key in keys:
        data_dict[key].append(groupp[key][:])
    idx = groupp['injection_index'][:]
    injs = filep['injections']
    cm_arr = mass1_mass2_to_mchirp_eta(injs['mass1'][:][idx],
                                       injs['mass2'][:][idx])[0]
    data_dict['chirp_mass'].append(cm_arr)
    data_dict['file_number'].append(fileno * np.ones(len(idx), dtype=np.int32))
    return

def update_found_both(filep, groupp, fileno, data_dict):
    for key in ratio_keys:
        carr = groupp['comparison'][key][:]
        rarr = groupp['reference'][key][:]
        darr = carr/rarr
        data_dict[key]['comparison'].append(carr)
        data_dict[key]['reference'].append(rarr)
        data_dict[key]['ratio'].append(darr)
    for run in ['comparison', 'reference']:
        data_dict['template_hash'][run].append(groupp[run]['template_hash'][:])
        idx = groupp[run]['injection_index'][:]
        data_dict['injection_index'][run].append(idx)
        injs = filep['injections']
        cm_arr = mass1_mass2_to_mchirp_eta(injs['mass1'][:][idx],
                                           injs['mass2'][:][idx])[0]
        data_dict['chirp_mass'][run].append(cm_arr)
    data_dict['file_number'].append(fileno * np.ones(len(idx), dtype=np.int32))
    return

def concat_one_only(data_dict):
    for key in data_dict.keys():
        arr = np.concatenate(data_dict[key])
        data_dict[key] = arr
    return

def concat_both(data_dict):
    for key in ratio_keys+other_keys+['chirp_mass']:
        d = data_dict[key]
        for k in d.keys():
            arr = np.concatenate(d[k])
            d[k] = arr
    arr = np.concatenate(data_dict['file_number'])
    data_dict['file_number'] = arr

def write_combined_results(outgrp, both_dict, ref_dict, com_dict):
    grp = outgrp.create_group('found_in_both')
    for key in ratio_keys+other_keys+['chirp_mass']:
        g = grp.create_group(key)
        for k in both_dict[key].keys():
            g.create_dataset(k, data = both_dict[key][k])
    grp.create_dataset('file_number', data = both_dict['file_number'])
    for nm, ddict in zip(['found_reference_only', 'found_comparison_only'],
                         [ref_dict, com_dict]):
        grp = outgrp.create_group(nm)
        for key in ddict.keys():
            grp.create_dataset(key, data=ddict[key])

def write_combined_missed(outgrp, ref_dict, com_dict):
    rgrp = outgrp.create_group('missed_only_reference')
    cgrp = outgrp.create_group('missed_only_comparison')
    for key in ['injection_index', 'loudest_rank',
                'chirp_mass', 'file_number']:
        rgrp.create_dataset(key, data=ref_dict[key])
        cgrp.create_dataset(key, data=com_dict[key])


# Initialize
long_description = """
This program combines several injection comparison files. Each will have been
created by running 'pycbc_injection_set_comparison' on one injection set that
was performed in each of two runs. Typically, in a script one would loop over
each injection set, running 'pycbc_injection_set_comparison' on each, and then
finally run this program.


'pycbc_combine_injection_comparisons' produces summary information across
injection runs, and in its output stores summary information on the ratio of
both IFAR and coincident ranking statistic for injections that were found in
both, as well as the separate IFAR and ranking statistic. For injections found
in only one run, it saves just the IFAR and ranking statistic.


If the specified category of found injections is 'found_after_vetoes', then
this program will also in its output store information about loud injections
that were found in only one of the two runs.


For every output HDF5 group, one member will be a dataset named 'file_number'.
For an entry, this integer n corresponds to the 'injfile_n' attribute at the
top-level of the file, which is a string naming the specific injection
comparison file where more detailed information about this injection may be
found.
"""
formatter = argparse.RawDescriptionHelpFormatter
parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser(formatter_class=formatter,
                                 description=long_description)
parser.add_argument("--input-files", nargs='+', required=True,
                    help="List of comparison files created by running"
                    " 'pycbc_injection_set_comparison' on several injection"
                    " sets")
parser.add_argument("--output-file", type=str, required=True,
                    help="Name of HDF output file in which to store results")
parser.add_argument("--found-type", type=str, required=True,
		    choices=['found', 'found_after_vetoes'],
                    help="Which class of found injections to collate")
parser.add_argument("--verbose", action="store_true", default=False,
                    help="Print extra debugging information")

args = parser.parse_args()
pycbc.init_logging(args.verbose)

outfp = File(args.output_file, "w")

both_dict = {
    'ifar_exc' : {'comparison' : [], 'reference' : [], 'ratio' : []},
    'ifar' : {'comparison' : [], 'reference' : [], 'ratio' : []},
    'stat' : {'comparison' : [], 'reference' : [], 'ratio' : []},
    'minimum_single_detector_statistic' : {'comparison' : [], 'reference' : [],
                                           'ratio' : []},
    'template_hash' : {'comparison' : [], 'reference' : []},
    'injection_index' : {'comparison' : [], 'reference' : []},
    'chirp_mass' : {'comparison' : [], 'reference' : []},
    'file_number' : []
    }

ref_dict = {
    'ifar_exc' : [],
    'ifar' : [],
    'stat' : [],
    'minimum_single_detector_statistic' : [],
    'injection_index' : [],
    'chirp_mass' : [],
    'file_number' : [],
    'template_hash' : []
    }

com_dict = {
    'ifar_exc' : [],
    'ifar' : [],
    'stat' : [],
    'minimum_single_detector_statistic' : [],
    'injection_index' : [],
    'chirp_mass' : [],
    'file_number' : [],
    'template_hash' : []
    }

missed_ref = {
    'injection_index' : [],
    'loudest_rank' : [],
    'file_number': [],
    'chirp_mass' : []
}

missed_com = {
    'injection_index' : [],
    'loudest_rank' : [],
    'file_number': [],
    'chirp_mass' : []
}

# Loop over the files, parsing each one into the appropriate dictionaries
# initialized above
i = 0
read_first = False

same_dict = {
    'detectors' : None,
    'single_detector_statistic' : None,
    'reference_dir' : None,
    'comparison_dir' : None
}

curr_dict = {
    'detectors' : None,
    'single_detector_statistic' : None,
    'reference_dir' : None,
    'comparison_dir' : None
}

nfiles = len(args.input_files)
outfp.attrs['nfiles'] = nfiles
for f in args.input_files:
    fp = File(f, "r")
    curr_dict['detectors'] = [fp.attrs['detector_1'], fp.attrs['detector_2']]
    for key in ['single_detector_statistic', 'reference_dir', 'comparison_dir',
                'number_missed', 'ifar_threshold']:
        curr_dict[key] = fp.attrs[key]
    # The next if/else block checks that as we read each new
    # comparison file, we are only combining them if they
    # have the same values of our parameters. This should almost
    # certainly be either done better, or skipped.  In particular,
    # checking that the path of the reference directories and the
    # path of the comparison directories do not change is inadequate.
    if not read_first:
        read_first = True
        for key in curr_dict.keys():
            same_dict[key] = curr_dict[key]
    else:
        for key in curr_dict.keys():
            cval = curr_dict[key]
            sval = same_dict[key]
            if isinstance(cval, list):
                is_same = (set(cval) == set(sval))
            else:
                is_same = (cval == sval)
            if not is_same:
                raise RuntimeError("Incompatible injection sets: file"
                                   " {0}.attrs[{1}] does not match"
                                   " previous injection sets".format(f, key))
    file_key = "injfile_{0}".format(i)
    outfp.attrs[file_key] = fp.attrs['injection_label']
    found_group = fp[args.found_type]
    update_found_both(fp, found_group['found_in_both'], i, both_dict)
    update_found_one_only(fp, found_group['found_reference_only'], i, ref_dict)
    update_found_one_only(fp, found_group['found_comparison_only'], i, com_dict)
    if args.found_type == 'found_after_vetoes':
        missed_group = fp['missed_after_vetoes']
        update_missed(fp, missed_group['missed_only_reference'], i, missed_ref)
        update_missed(fp, missed_group['missed_only_comparison'], i, missed_com)
    i += 1
    fp.close()

outfp.attrs['detector_1'] = same_dict['detectors'][0]
outfp.attrs['detector_2'] = same_dict['detectors'][1]
for key in ['single_detector_statistic', 'reference_dir', 'comparison_dir']:
    outfp.attrs[key] = same_dict[key]

# Now collapse the lists of arrays in these dictionaries into
# single arrays
concat_both(both_dict)
concat_one_only(ref_dict)
concat_one_only(com_dict)
if args.found_type == 'found_after_vetoes':
    concat_one_only(missed_ref)
    concat_one_only(missed_com)

cth = both_dict['template_hash']['comparison']
rth = both_dict['template_hash']['reference']
same = (cth == rth).sum()
nboth = len(cth)
same_pct = 100.0*float(same)/float(nboth)
nref = len(ref_dict['template_hash'])
ncom = len(com_dict['template_hash'])
ref_pct = 100.0*float(nref)/float(nboth)
com_pct = 100.0*float(ncom)/float(nboth)
# Now we print out some summary information that is useful
print("Reference run directory path: {0}".format(same_dict['reference_dir']))
print("Comparison run directory path: {0}".format(same_dict['comparison_dir']))
print("Comparing {0} distinct injection sets between"
      " the two searches\n".format(nfiles))
print("Number of triggers found in both: {0}".format(len(rth)))
print("Number from same template: {0} ({1:.2f} % of"
      " found in both)\n".format(same, same_pct))
print("Number found in reference only: {0} ({1:.2f} % of"
      " found in both)".format(nref, ref_pct))
print("Number found in comparison only: {0} ({1:.2f} % of"
      " found in both)".format(ncom, com_pct))


# Write everything and finish
outgrp = outfp.create_group(args.found_type)
write_combined_results(outgrp, both_dict, ref_dict, com_dict)
if args.found_type == 'found_after_vetoes':
    outgrp = outfp.create_group('missed_after_vetoes')
    write_combined_missed(outgrp, missed_ref, missed_com)

outfp.close()
