#!/bin/env python
import numpy, argparse, pycbc.version, pycbc.pnutils, logging
from pycbc.events import veto
from glue.ligolw import ligolw, table, lsctables, utils as ligolw_utils

effd = {"H1":"eff_dist_h", "L1":"eff_dist_l", "V1":"eff_dist_v"}
def remove(l, i):
    to_remove = [l[t] for t in i]
    for r in to_remove:
        l.remove(r)

# dummy class needed for loading LIGOLW files
class LIGOLWContentHandler(ligolw.LIGOLWContentHandler):
    pass
lsctables.use_in(LIGOLWContentHandler)

parser = argparse.ArgumentParser()
parser.add_argument('--version', action='version', version=pycbc.version.git_verbose_msg)
parser.add_argument('--verbose', action='store_true')
parser.add_argument('--injection-file')
parser.add_argument('--veto-file', 
                      help="File containing segments used to veto injections")
parser.add_argument('--segment-name', 
                      help="Name of segmentlist within the veto file to veto injections")
parser.add_argument('--ifos', nargs='+')
parser.add_argument('--max-effective-chirp-distance', type=float)
parser.add_argument('--output-file')
args = parser.parse_args()
pycbc.init_logging(args.verbose)

logging.info('File: %s' % args.injection_file)
indoc = ligolw_utils.load_filename(args.injection_file, False, contenthandler=LIGOLWContentHandler)
sim_table = table.get_table(indoc, 'sim_inspiral')

logging.info('%s Injections in file' % len(sim_table))

if args.veto_file:
    logging.info('Removing injections outside coincident time')
    for ifo in args.ifos:    
        inj_time = numpy.array(sim_table.get_column('geocent_end_time') + 1e-9 * sim_table.get_column('geocent_end_time_ns'), dtype=numpy.float64)
        idx, segs = veto.indices_outside_segments(inj_time, [args.veto_file], ifo, args.segment_name)
        remove(sim_table, idx)
    logging.info('We now have %s injections' % len(sim_table))

if args.max_effective_chirp_distance:
    logging.info('Removing injections that are quiet in one detector')
    for ifo in args.ifos:
        eff_distance = numpy.array(sim_table.get_column(effd[ifo]), dtype=numpy.float32)
        m1, m2 = numpy.array(sim_table.get_column('mass1')), numpy.array(sim_table.get_column('mass2'))
        mchirp, eta = pycbc.pnutils.mass1_mass2_to_mchirp_eta(m1, m2)
        chirp_distance = pycbc.pnutils.chirp_distance(eff_distance, mchirp)
        idx = numpy.where(chirp_distance > args.max_effective_chirp_distance)[0]
        remove(sim_table, idx)
    logging.info('We now have %s injections' % len(sim_table))

outdoc = ligolw.Document()
outdoc.appendChild(ligolw.LIGO_LW()).appendChild(sim_table)
ligolw_utils.write_filename(outdoc, args.output_file)
logging.info('Done')  
