#!/usr/bin/env python

"""
Splits injections created using pycbc_create_injections into smaller sets.
Split sets are organized to maximize time between injections.
"""

import argparse
from pycbc.inject import InjectionSet
import h5py
import numpy as np
import pycbc.version


# Parse command line
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--version", action="version",
                    version=pycbc.version.git_verbose_msg)
parser.add_argument("-f", "--output-files", nargs='*', required=True,
                    help="Names of output files")
parser.add_argument("-i", "--input-file", required=True,
                    help="Injection file to be split")

args = parser.parse_args()

# Read in input file as both an hdf file and an InjectionSet object
inj_file = h5py.File(args.input_file, 'r')
inj_set = InjectionSet(args.input_file)

# Define table of injection info
inj_table = inj_set.table

# InjectionSet.write() requires static params as a dictionary,
# so get that from file object.
# Ignore the "static_params" copy.
static_params = {key : inj_file.attrs[key] for key in inj_file.attrs
                 if key != 'static_args'}

# Also get the names of variable params as write_args
write_args = [arg for arg in inj_table.fieldnames
              if arg not in static_params]

num_injs = len(inj_table)
num_splits = len(args.output_files)

# Ideally, the number of injections is divsible by number of splits
# with no remainder, but that is not always true
ideal_split = num_injs // num_splits
remainder = num_injs % num_splits
injs_per_split = np.zeros(num_splits, dtype=int) + ideal_split

# Handle the remainder if it exists
if remainder > 0:
    for i in range(remainder):
        injs_per_split[i] += 1

# Sanity check: did we get account for all injs?
assert sum(injs_per_split) == num_injs, "Not all injections were accounted for!"

# Sort injections by time
inj_table.sort(order='tc')

# Split injections into a list of smaller sets
for i in range(num_splits):
    # Number of injections in this split
    injs_in_split = injs_per_split[i]
    # Spread injections by time so they don't overlap
    injs_to_get = [i+(num_splits*j) for j in range(injs_in_split)]
    # Write to file
    InjectionSet.write(args.output_files[i], inj_table[injs_to_get],
                       write_args, static_params)
