#!/usr/bin/env python
#previous path was usr/bin/python
""" This program adds single detector hdf trigger files together.
"""
import numpy, argparse, h5py, logging
import pycbc.version
from numpy import unique

def changes(arr):
    l = numpy.where(arr[:-1] != arr[1:])[0]
    l = numpy.concatenate(([0], l+1, [len(arr)]))
    return unique(l)

def collect(key, files):
    data = []
    for fname in files:
        fin = h5py.File(fname)
        if key in fin:
            data += [fin[key][:]]
        fin.close()
    return numpy.concatenate(data)

def region(f, key, boundaries, ids):
    dset = f[key]
    refs = []
    for j in range(len(boundaries) - 1):
        l, r = boundaries[ids[j]], boundaries[ids[j]+1]
        refs.append(dset.regionref[l:r]) 
    f.create_dataset(key+'_template', data=refs,
                     dtype=h5py.special_dtype(ref=h5py.RegionReference))

parser = argparse.ArgumentParser()
parser.add_argument('--version', action='version', version=pycbc.version.git_verbose_msg)
parser.add_argument('--trigger-files', nargs='+')
parser.add_argument('--output-file')
parser.add_argument('--bank-file')
parser.add_argument('--verbose', '-v', action='count')
args = parser.parse_args()

logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.INFO) 

f = h5py.File(args.output_file, 'w')

logging.info("getting the list of columns from a representative file")
trigger_columns = []
for fname in args.trigger_files:
    try:
        f2 = h5py.File(fname, 'r')
    except IOError as e:
        logging.error("Cannot open %s" % fname)
        raise e
    ifo = tuple(f2.keys())[0]
    if len(f2[ifo].keys()) > 0:
        k = f2[ifo].keys()
        trigger_columns = list(f2[ifo].keys())
        if not 'template_hash' in trigger_columns:
            f2.close()
            continue
        trigger_columns.remove('search')
        trigger_columns.remove('template_hash')
        if 'gating' in trigger_columns:
            trigger_columns.remove('gating')
        f2.close()
        break
    f2.close()

for col in trigger_columns:
    logging.info("trigger column: %s" % col)

logging.info('reading the metadata from the files')
start = numpy.array([], dtype=numpy.float64)
end = numpy.array([], dtype=numpy.float64)
tpc = numpy.array([], dtype=numpy.float64)
frpc = numpy.array([], dtype=numpy.float64)
stf = numpy.array([], dtype=numpy.float64)
rtime = numpy.array([], dtype=numpy.float64)
gating = {}
for filename in args.trigger_files:
    try:
        data = h5py.File(filename, 'r')
    except IOError as e:
        logging.error('Cannot open %s' % filename)
        raise e
    ifo_data = data[ifo]
    s, e = ifo_data['search/start_time'][:], ifo_data['search/end_time'][:]
    start, end = numpy.append(start, s), numpy.append(end, e)

    if 'templates_per_core' in ifo_data['search'].keys():
        tpc = numpy.append(tpc, ifo_data['search/templates_per_core'][:])
    if 'filter_rate_per_core' in ifo_data['search'].keys():
        frpc = numpy.append(frpc, ifo_data['search/filter_rate_per_core'][:])
    if 'setup_time_fraction' in ifo_data['search'].keys():
        stf = numpy.append(stf, ifo_data['search/setup_time_fraction'][:])
    if 'run_time' in ifo_data['search'].keys():
        rtime = numpy.append(rtime, ifo_data['search/run_time'][:])

    if 'gating' in ifo_data:
        gating_keys = []
        ifo_data['gating'].visit(gating_keys.append)
        for gk in gating_keys:
            gk_data = ifo_data['gating/' + gk]
            if isinstance(gk_data, h5py.Dataset):
                if not gk in gating:
                    gating[gk] = numpy.array([], dtype=numpy.float64)
                gating[gk] = numpy.append(gating[gk], gk_data[:])
    data.close()    
f['%s/search/start_time' % ifo], f['%s/search/end_time' % ifo] = start, end

if len(tpc) > 0:
    f['%s/search/templates_per_core' % ifo] = tpc
if len(frpc) > 0:
    f['%s/search/filter_rate_per_core' % ifo] = frpc
if len(stf) > 0:
    f['%s/search/setup_time_fraction' % ifo] = stf
if len(rtime) > 0:
    f['%s/search/run_time' % ifo] = rtime

for gk, gv in gating.items():
    f[ifo + '/gating/' + gk] = gv

logging.info('set up sorting of triggers and template ids')
# For fast lookup we need the templates in hash order
hashes = h5py.File(args.bank_file, 'r')['template_hash'][:]
bank_tids = hashes.argsort()
unsort = bank_tids.argsort()
hashes = hashes[bank_tids]

trigger_hashes = collect('%s/template_hash' % ifo, args.trigger_files)
trigger_sort = trigger_hashes.argsort()
trigger_hashes = trigger_hashes[trigger_sort]
template_boundaries = changes(trigger_hashes)

template_ids = bank_tids[numpy.searchsorted(hashes, trigger_hashes[template_boundaries[:-1]])]

full_boundaries = numpy.searchsorted(trigger_hashes, hashes)
full_boundaries = numpy.concatenate([full_boundaries, [len(trigger_hashes)]])
# get the full boundaries in hash order 
del trigger_hashes

idlen = (template_boundaries[1:] - template_boundaries[:-1])
f.create_dataset('%s/template_id' % ifo, data=numpy.repeat(template_ids, idlen), 
                 compression='gzip', shuffle=True, compression_opts=9)
f['%s/template_boundaries' % ifo] = full_boundaries[unsort]

logging.info('reading the trigger columns from the input files')
for col in trigger_columns:
    key = '%s/%s' % (ifo, col)
    logging.info('reading %s' % col)
    data = collect(key, args.trigger_files)[trigger_sort]
    logging.info('writing %s to file' % col)
    dset = f.create_dataset(key, data=data, compression='gzip',
                                 compression_opts=9, shuffle=True)
    del data
    region(f, key, full_boundaries, unsort) 
f.close()
logging.info('done')
