#!/usr/bin/env python
""" Ingest and process DQ data for use in the ranking statistic. 
    Any data that can be read by PyCBC as a timeseries can 
    be read in as a DQ timeseries. 
    This DQ timeseries is then processed to 
    lower the sample rate, either by maximizing
    or calculating the BLRMS over each timestep.
"""
import logging, argparse, numpy, copy
import pycbc, pycbc.strain
from pycbc.version import git_verbose_msg as version
from pycbc.workflow import resolve_td_option
from pycbc.types.timeseries import TimeSeries
from ligo.segments import segment

parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('--version', action='version', version=version)
parser.add_argument('--verbose', action="store_true")
parser.add_argument("--dq-sample-rate", default=1, type=int,
                    help="Sample rate to calculate dq timeseries, "
                         "in Hz [default=1Hz]")
parser.add_argument("--dq-sample-method", required=True,
                    choices=['max', 'blrms'],
                    help="Method used to calculate dq timeseries. "
                         "Options are 'max' or 'blrms'")
parser.add_argument("--gps-start-time", type=int,required=True,
                    help="The gps start time of the data "
                          "(integer seconds)")
parser.add_argument("--gps-end-time", type=int,required=True,
                    help="The gps end time of the data "
                         "(integer seconds)")
parser.add_argument("--output-file", required=True,
                    help="name of output file")
parser.add_argument("--output-channel", required=False,
                    help="name of output channel")

pycbc.strain.insert_strain_option_group(parser, gps_times=False)

parser.set_defaults(pad_data=0)

args = parser.parse_args()
pycbc.init_logging(args.verbose)

seg = (args.gps_start_time,args.gps_end_time)

logging.info('Getting dq values for %.1f-%.1f (%.1f s)', seg[0],
             seg[1], abs(seg[1]-seg[0]))
argstmp = copy.deepcopy(args)
argstmp.gps_start_time = int(seg[0])
argstmp.gps_end_time = int(seg[1])
tmp_segment = segment([argstmp.gps_start_time, argstmp.gps_end_time])
argstmp.channel_name = resolve_td_option(args.channel_name, tmp_segment)

dq_data = pycbc.strain.from_cli(argstmp)

logging.info('Normalizing dq values for %.1f-%.1f (%.1f s)', seg[0],
             seg[1], abs(seg[1]-seg[0]))

dq_sample_rate = float(args.dq_sample_rate) 
dq_step_size = dq_data.sample_rate/dq_sample_rate

dq_method = args.dq_sample_method

if dq_method == "max":
    # maximize dq_data over given interval
    val_dq = numpy.array([max(dq_data.numpy()
                              [int(n*dq_step_size):
                               int((n+1)*dq_step_size)])
                          for n in numpy.arange(0, len(dq_data.numpy()) /
                                                dq_step_size)])
elif dq_method=="blrms":
    # Calculate the blrms over given interval
    val_dq = numpy.array([numpy.mean(dq_data.numpy()
                                      [int(n*dq_step_size):
                                       int((n+1)*dq_step_size)]**2)**0.5
                           for n in numpy.arange(0, len(dq_data.numpy()) /
                                                 dq_step_size)])

dq = TimeSeries(val_dq, epoch=dq_data.start_time, 
                  delta_t=1./dq_sample_rate)

dq.save(args.output_file, group=args.output_channel)

logging.info('Done!')

