#!/usr/bin/env python

# Copyright (C) 2019 Ian Harry
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""
Reduce a MERGE triggers file to a reduced template bank
"""

from __future__ import print_function
from __future__ import division
import logging
import imp
import argparse
import numpy
import h5py
import pycbc
import pycbc.version

parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--version", action="version",
                    version=pycbc.version.git_verbose_msg)
parser.add_argument("-V", "--verbose", action="store_true",
                    help="print extra debugging information", default=False)
parser.add_argument("--input-file", required=True,
                    help="Input merge triggers HDF file.")
parser.add_argument("--output-file", required=True,
                    help="Output merge triggers HDF file.")
parser.add_argument("--full-template-bank", required=True,
                    help="The original full template bank HDF file.")
parser.add_argument("--filter-func-file", required=True,
                    help="This can be provided to give a function to define "
                         "which points are covered by the template bank "
                         "bounds, and which are not. The file should contain "
                         "a function called filter_tmpltbank, which should "
                         "take as call profile the template bank HDF object "
                         "and return a boolean (accept=1/reject=0) array.")


opt = parser.parse_args()
pycbc.init_logging(opt.verbose)

bank_fd = h5py.File(opt.full_template_bank, 'r')

modl = imp.load_source('filter_func', opt.filter_func_file)
func = modl.filter_tmpltbank
bool_arr = func(bank_fd)

logging.info("Downselecting templates. Started with %d templates, now have "
             "%d after downselecting.", len(bool_arr), bool_arr.sum())

tids = numpy.arange(len(bool_arr))[bool_arr]

hashes = bank_fd['template_hash'][:]
bank_tids = hashes.argsort()
unsort = bank_tids.argsort()

copy_params = ['bank_chisq', 'bank_chisq_dof', 'chisq', 'chisq_dof',
               'coa_phase', 'cont_chisq', 'cont_chisq_dof','end_time',
               'sg_chisq', 'sigmasq', 'snr', 'template_duration']

ifd = h5py.File(opt.input_file, 'r')
ifos = list(ifd.keys())
assert(len(ifos) == 1)
ifo = ifos[0]
ofd = h5py.File(opt.output_file, 'w')
ofd.create_group(ifo)
new_boundaries = []
old_boundaries = []
tid_count = 0
for tid in tids:
    tid_count += 1
    # WHICH TEMPLATE DO WE HAVE
    if not tid_count % 1000:
        logging.info("Processing template %d of %d", tid_count, len(tids))
    # Where is it's lower boundary
    boundary1 = ifd[ifo+'/template_boundaries'][tid]
    # Upper boundary is harder
    # Position in sorted hashed list
    pos = unsort[tid]
    if pos == len(bool_arr) - 1:
        # If it's the last one, then go to the end
        boundary2 = len(ifd[ifo+'/template_duration'])
    else:
        # Otherwise find the next template boundary, which is tricksy
        boundary2 = ifd[ifo+'/template_boundaries'][bank_tids[pos+1]]
    # Check this is sane
    test_tids = ifd[ifo+'/template_id'][boundary1:boundary2]
    if (test_tids - tid).any():
        raise ValueError()
    old_boundaries.append((boundary1,boundary2))
    if new_boundaries:
        new_boundaries.append((new_boundaries[-1][1],
                               new_boundaries[-1][1]+boundary2-boundary1))
    else:
        new_boundaries.append((0,boundary2-boundary1))
template_boundaries = [tmpx[0] for tmpx in new_boundaries]
ofd[ifo]['template_boundaries'] = template_boundaries

for c in copy_params:
    logging.info("Copying parameter " + c)
    currdtype=ifd[ifo][c][:2].dtype
    temp_array=numpy.zeros([new_boundaries[-1][1]], dtype=currdtype)
    for i in range(len(old_boundaries)):
        old_bound = old_boundaries[i]
        new_bound = new_boundaries[i]
        curr_data = ifd[ifo][c][old_bound[0]:old_bound[1]]
        temp_array[new_bound[0]:new_bound[1]] = curr_data

    ofd[ifo][c] = temp_array

    refs = []
    for i in range(len(new_boundaries)):
        new_bound = new_boundaries[i]
        refs.append(ofd[ifo][c].regionref[new_bound[0]:new_bound[1]])
    ofd[ifo].create_dataset\
        (c + '_template', data=refs,
         dtype=h5py.special_dtype(ref=h5py.RegionReference))

logging.info("Updating template IDs")
c = 'template_id'
currdtype=ifd[ifo][c][:2].dtype
temp_array=numpy.zeros([new_boundaries[-1][1]], dtype=currdtype)
temp_array2=numpy.zeros([new_boundaries[-1][1]], dtype=currdtype)
for i in range(len(old_boundaries)):
    old_bound = old_boundaries[i]
    new_bound = new_boundaries[i]
    curr_data = ifd[ifo][c][old_bound[0]:old_bound[1]]
    temp_array2[new_bound[0]:new_bound[1]] = curr_data
    temp_array[new_bound[0]:new_bound[1]] = i

ofd[ifo][c] = temp_array
ofd[ifo][c+'_orig'] = temp_array2
refs = []
refs2 = []
for i in range(len(new_boundaries)):
    new_bound = new_boundaries[i]
    refs.append(ofd[ifo][c].regionref[new_bound[0]:new_bound[1]])
    refs2.append(ofd[ifo][c+'_orig'].regionref[new_bound[0]:new_bound[1]])

ofd[ifo].create_dataset\
    (c + '_template', data=refs,
     dtype=h5py.special_dtype(ref=h5py.RegionReference))
ofd[ifo].create_dataset\
    (c + '_orig_template', data=refs2,
     dtype=h5py.special_dtype(ref=h5py.RegionReference))

# Copy some of the unchanged groups
ifd.copy(ifo+'/gating', ofd[ifo])
ifd.copy(ifo+'/search', ofd[ifo])

# Copy attributes
logging.info("Copying attributes")
for key, value in ifd[ifo].attrs:
    ofd[ifo].attrs[key] = value
ofd.close()
