#!/bin/env python
""" Apply a trials factor based on the number of available detector
combinations at the time of coincidence. This clusters to find the most
significant foreground, but leaves the background triggers alone.
"""

import h5py, numpy, argparse, logging, pycbc, pycbc.events, pycbc.io, lal
import pycbc.version
from ligo import segments

parser = argparse.ArgumentParser()
parser.add_argument("--version", action="version", version=pycbc.version.git_verbose_msg)
parser.add_argument('--verbose', action='store_true')
parser.add_argument('--statmap-files', nargs='+',
                    help="List of coinc files to be redistributed")
parser.add_argument('--cluster-window', type=float)
parser.add_argument('--censor-ifar-threshold', type=float, default=0.003,
    help="If provided, only window out foreground triggers with IFAR (years)"
         "above the threshold [default=0.003yr]")
parser.add_argument('--veto-window', type=float, default=0.1,
    help="Time around each zerolag trigger to window out [default=.1s]")
parser.add_argument('--output-file', help="name of output file")
args = parser.parse_args()

pycbc.init_logging(args.verbose)

files = [h5py.File(n, 'r') for n in args.statmap_files]

f = h5py.File(args.output_file, "w")

logging.info('Copying segments and attributes to %s' % args.output_file)
# Move segments information into the final file - remove some duplication
# in earlier files
for fi in files:
    for key in fi['segments']:
        if key.startswith('foreground') or key.startswith('background'):
            continue
        f['segments/%s/end' % key] = fi['segments/%s/end' % key][:]
        f['segments/%s/start' % key] = fi['segments/%s/start' % key][:]
        if 'segments/foreground_veto' in fi:
            f['segments/%s/foreground_veto/end' % key] = \
                                         fi['segments/foreground_veto/end'][:]
            f['segments/%s/foreground_veto/start' % key] = \
                                       fi['segments/foreground_veto/start'][:]
        for attr_name in fi.attrs:
            if key not in f:
                 f.create_group(key)
            f[key].attrs[attr_name] = fi.attrs[attr_name]

# Set up dictionaries to contain segments from the individual statmap files
indiv_segs = segments.segmentlistdict({})
# loop through statmap files and put segments into segmentlistdicts
for fi in files:
    key = fi.attrs['ifos'].replace(' ','')
    # get analysed segments from individual statmap files
    starts = fi['segments/{}/start'.format(key)][:]
    ends = fi['segments/{}/end'.format(key)][:]
    indiv_segs[key] = pycbc.events.veto.start_end_to_segments(starts, ends)

# Convert segmentlistdict to a list ('seglists') of segmentlists
# then numpy.sum(seglists, axis=0) does seglists[0] + seglists[1] + ...
foreground_segs = numpy.sum(list(indiv_segs.values()), axis=0)

# Total zerolag analysis time
f.attrs['foreground_time'] = abs(foreground_segs)

# obtain list of all ifos involved in the coinc_statmap files
all_ifos = numpy.unique([ifo for fi in files for ifo in \
                         fi.attrs['ifos'].split(' ')])
f.attrs['ifos'] = ' '.join(sorted(all_ifos))

logging.info('Generating list of datasets in input files')
key_set = pycbc.io.name_all_datasets(files)

logging.info('Copying foreground non-ifo-specific data')
# copy and concatenate all the columns in the foreground group
# from all files /except/ the IFO groups
for key in key_set:
    if key.startswith('foreground') and not any([ifo in key for ifo in all_ifos]):
        pycbc.io.combine_and_copy(f, files, key)

logging.info('Collating triggers into single structure')

all_trig_times = {}
all_trig_ids = {}
for ifo in all_ifos:
    all_trig_times[ifo] = numpy.array([], dtype=numpy.uint32)
    all_trig_ids[ifo] = numpy.array([], dtype=numpy.uint32)

# For each file, append the trigger time and id data for each ifo
# If an ifo does not participate in any given coinc then fill with -1 values
for f_in in files:
    for ifo in all_ifos:
        if ifo in f_in['foreground']:
            all_trig_times[ifo] = numpy.concatenate( [all_trig_times[ifo],
                                  f_in['foreground/{}/time'.format(ifo)][:]] )
            all_trig_ids[ifo] = numpy.concatenate( [all_trig_ids[ifo],
                            f_in['foreground/{}/trigger_id'.format(ifo)][:]] )
        else:
            all_trig_times[ifo] = numpy.concatenate( [all_trig_times[ifo],
                               -1 * numpy.ones_like(f_in['foreground/fap'][:],
                                                    dtype=numpy.uint32)] )
            all_trig_ids[ifo] = numpy.concatenate( [all_trig_ids[ifo],
                               -1 * numpy.ones_like(f_in['foreground/fap'][:],
                                                    dtype=numpy.uint32)] )
    f_in.close()

logging.info('Clustering triggers for loudest ifar value')

for ifo in all_ifos:
    f['foreground/{}/time'.format(ifo)] = all_trig_times[ifo]
    f['foreground/{}/trigger_id'.format(ifo)] = all_trig_ids[ifo]

ifar_stat = numpy.core.records.fromarrays([f['foreground/ifar'][:],
                                           f['foreground/stat'][:]],
                                          names='ifar,stat')

# all_times is a tuple of trigger time arrays
all_times = (f['foreground/%s/time' % ifo][:] for ifo in all_ifos)

def argmax(v):
    return numpy.argsort(v)[-1]

# Currently only clustering zerolag, i.e. foreground, so set all timeslide_ids
# to zero
cidx = pycbc.events.cluster_coincs_multiifo(ifar_stat, all_times,
                                            numpy.zeros(len(ifar_stat)), 0,
                                            args.cluster_window, argmax)

def filter_dataset(h5file, name, idx):
    # Dataset needs to be deleted and remade as it is a different size
    filtered_dset = h5file[name][:][idx]
    del h5file[name]
    h5file[name] = filtered_dset
    return idx

# Downsample the foreground columns to only the loudest ifar between the
# multiple files
for key in f['foreground'].keys():
    if key not in all_ifos:
        id = filter_dataset(f, 'foreground/%s' % key, cidx)
    else:  # key is an ifo
        for k in f['foreground/%s' % key].keys():
            id = filter_dataset(f, 'foreground/{}/{}'.format(key, k), cidx)

logging.info('Applying trials factor')

# Recalculate event times after clustering for trials factor calculation
clustered_times = (f['foreground/%s/time' % ifo][:] for ifo in all_ifos)

# Trials factor is how many possible 2+IFO combinations are 'on'
#  at the time of the coincidence
trials_factors = numpy.zeros_like(f['foreground/ifar'][:])
test_times = numpy.array([pycbc.events.mean_if_greater_than_zero(tc)[0]
                          for tc in zip(*clustered_times)])

# Iterate over different ifo combinations
for key in f['segments']:
    if key.startswith('foreground') or key.startswith('background'):
        continue
    end_times = numpy.array(f['segments/%s/end' % key][:])
    start_times = numpy.array(f['segments/%s/start' % key][:])
    idx_within_segment = pycbc.events.indices_within_times(test_times,
                                                           start_times,
                                                           end_times)
    trials_factors[idx_within_segment] += numpy.ones_like(idx_within_segment)

f['foreground/ifar'][:] = f['foreground/ifar'][:] / trials_factors
f['foreground/ifar_exc'][:] = f['foreground/ifar_exc'][:] / trials_factors

f.attrs['foreground_time_exc'] = f.attrs['foreground_time']

# Construct the foreground censor veto from the clustered candidate times
# above the ifar threshold
thr = test_times[f['foreground/ifar'][:] > args.censor_ifar_threshold]
vstart = thr - args.veto_window
vend = thr + args.veto_window
vtime = segments.segmentlist([segments.segment(s, e)
                              for s, e in zip(vstart, vend)])
logging.info('Censoring %.2f seconds', abs(vtime))
f.attrs['foreground_time_exc'] -= abs(vtime)
f['segments/foreground_veto/start'] = vstart
f['segments/foreground_veto/end'] = vend

#TODO: add in background combinations
# If there is a background set (full_data as opposed to injection run), then
# recalculate the values for its triggers as well

f.close()
logging.info('done')
