#!/bin/env python
# Copyright (C) 2015 Alexander Harvey Nitz, Ian Harry
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
""" Followup foreground events
"""
import os, sys, argparse, logging, h5py, datetime
import numpy
from lal import GPSToUTC
from glue.ligolw import lsctables, table
from glue.ligolw import utils as ligolw_utils
from pycbc.results import layout
from pycbc.events import select_segments_by_definer
import pycbc.workflow.minifollowups as mini
import pycbc.workflow.pegasus_workflow as wdax
import pycbc.version
import pycbc.workflow as wf
import pycbc.events
from pycbc.io import hdf

def to_file(path, ifo=None):
    fil = wdax.File(os.path.basename(path))
    fil.ifo = ifo
    path = os.path.abspath(path)
    fil.PFN(path, 'local')
    return fil

def add_wiki_row(outfile, cols):
    """
    Adds a wiki-formatted row to an output file from a list or a numpy array.
    """
    with open(outfile, 'a') as f:
        f.write('||%s||\n' % '||'.join(map(str,cols)))

parser = argparse.ArgumentParser(description=__doc__[1:])
parser.add_argument('--version', action='version', version=pycbc.version.git_verbose_msg)
parser.add_argument('--workflow-name', default='my_unamed_run')
parser.add_argument("-d", "--output-dir", default=None,
                    help="Path to output directory.")
parser.add_argument('--bank-file',
                    help="HDF format template bank file")
parser.add_argument('--single-detector-file',
                    help="HDF format merged single detector trigger files")
parser.add_argument('--instrument', help="Name of interferometer e.g. H1")
parser.add_argument('--veto-file',
    help="The veto file to be used if vetoing triggers (optional).")
parser.add_argument('--veto-segment-name',
    help="If using veto file must also provide the name of the segment to use "
         "as a veto.")
parser.add_argument('--inspiral-segments',
                    help="xml segment file containing the inspiral analysis "
                         "times")
parser.add_argument('--inspiral-data-read-name',
                    help="Name of inspiral segmentlist containing data read in "
                         "by each analysis job.")
parser.add_argument('--inspiral-data-analyzed-name',
                    help="Name of inspiral segmentlist containing data "
                         "analyzed by each analysis job.")
parser.add_argument('--min-snr', type=float, default=6.5,
                    help="Minimum SNR to consider for loudest triggers")
parser.add_argument('--ranking-statistic',
                help="How to rank triggers when determining loudest triggers.")
parser.add_argument('--statistic-files', nargs='+',
                    action='append', default=[],
                    help='Use this to supply supplementary files that are '
                         'necessary for computing the ranking statistic. For '
                         'example the phase-time consistency file.')
parser.add_argument('--non-coinc-time-only', default=False,
                    action='store_true',
                    help="If given remove (veto) single-detector triggers "
                         "that occur during a time when at least one other "
                         "instrument is taking science data.")
parser.add_argument('--minimum-duration', default=None, type=float,
                    help="If given only consider single-detector triggers "
                         "with template duration larger than this.")
parser.add_argument('--maximum-duration', default=None, type=float,
                    help="If given only consider single-detector triggers "
                         "with template duration smaller than this.")
parser.add_argument('--wiki-file',
                    help="Name of file to save wiki-formatted table in")
parser.add_argument('--output-map')
parser.add_argument('--transformation-catalog')
parser.add_argument('--output-file')
parser.add_argument('--tags', nargs='+', default=[])
wf.add_workflow_command_line_group(parser)
args = parser.parse_args()
# args.statistic_files might be a list of lists. This will flatten this to a single list.
args.statistic_files = sum(args.statistic_files, [])

logging.basicConfig(format='%(asctime)s:%(levelname)s : %(message)s',
                    level=logging.INFO)

workflow = wf.Workflow(args, args.workflow_name)
workflow.ifos = [args.instrument]
workflow.ifo_string = args.instrument

wf.makedir(args.output_dir)

if args.wiki_file:
    # initialize a wiki table and add the column headers
    wiki_file = os.path.join(args.output_dir, args.wiki_file)
    add_wiki_row(wiki_file, ['GPS time', 'UTC time', 'newSNR', 'SNR',
                      'm1', 'm2', 's1', 's2', 'duration', 'omega', 'notes',
                      'veto'])

# create a FileList that will contain all output files
layouts = []

tmpltbank_file = to_file(args.bank_file)
sngl_file = to_file(args.single_detector_file, ifo=args.instrument)
if args.veto_file is not None:
    veto_file = to_file(args.veto_file, ifo=args.instrument)
else:
    veto_file = None
insp_segs = to_file(args.inspiral_segments)
insp_data_seglists = select_segments_by_definer\
        (args.inspiral_segments, segment_name=args.inspiral_data_read_name,
         ifo=args.instrument)
insp_data_seglists.coalesce()

num_events = int(workflow.cp.get_opt_tags('workflow-sngl_minifollowups',
                 'num-sngl-events', ''))

# This helps speed up the processing to ignore a large fraction of triggers
snr_mask = None
if args.min_snr:
    logging.info('Calculating Prefilter')
    f = hdf.HFile(args.single_detector_file, 'r')
    idx, _ = f.select(lambda snr: snr > args.min_snr,
                      '{}/snr'.format(args.instrument),
                      return_indices=True)
    snr_mask = numpy.zeros(len(f['{}/snr'.format(args.instrument)]),
                           dtype=bool)
    snr_mask[idx] = True

trigs = hdf.SingleDetTriggers(args.single_detector_file, args.bank_file,
                              args.veto_file, args.veto_segment_name,
                              None, args.instrument, premask=snr_mask)

if args.non_coinc_time_only:
    from glue.ligolw.ligolw import LIGOLWContentHandler as h
    lsctables.use_in(h)
    segs_doc = ligolw_utils.load_filename(args.inspiral_segments,
                                          contenthandler=h)
    seg_def_table = table.get_table(segs_doc, 'segment_definer')
    def_ifos = seg_def_table.getColumnByName('ifos')
    def_ifos = [str(ifo) for ifo in def_ifos]
    ifo_list = list(set(def_ifos))
    ifo_list.remove(args.instrument)
    for ifo in ifo_list:
        curr_veto_mask, segs = pycbc.events.veto.indices_outside_segments(
            trigs.end_time, [args.inspiral_segments],
            ifo=ifo, segment_name=args.inspiral_data_analyzed_name)
        curr_veto_mask.sort()
        trigs.apply_mask(curr_veto_mask)

if args.minimum_duration is not None:
    logging.info('applying minimum duration')
    durations = trigs.template_duration
    lgc_mask = durations > args.minimum_duration
    trigs.apply_mask(lgc_mask)
    logging.info('remaining triggers: %s', trigs.mask.sum())

if args.maximum_duration is not None:
    logging.info('applying maximum duration')
    durations = trigs.template_duration
    lgc_mask = durations < args.maximum_duration
    trigs.apply_mask(lgc_mask)
    logging.info('remaining triggers: %s', trigs.mask.sum())

if len(trigs.snr) == 0:
    # There are no triggers, make no-op job and exit
    noop_node = mini.create_noop_node()
    workflow += noop_node
    workflow.save(filename=args.output_file, output_map_path=args.output_map,
                  transformation_catalog_path=args.transformation_catalog)
    sys.exit(0)

logging.info('finding loudest clustered events')
trigs.mask_to_n_loudest_clustered_events\
    (n_loudest=num_events, ranking_statistic=args.ranking_statistic,
     statistic_files=args.statistic_files)
if len(trigs.stat) < num_events:
    num_events = len(trigs.stat)

times = trigs.end_time
tids = trigs.template_id

# loop over number of loudest events to be followed up
order = trigs.stat.argsort()[::-1]
for rank, num_event in enumerate(order):
    logging.info('Processing event: %s', num_event)

    files = wf.FileList([])
    time = times[num_event]
    ifo_time = '%s:%s' %(args.instrument, str(time))
    tid = trigs.mask[num_event]
    ifo_tid = '%s:%s' %(args.instrument, str(tid))

    layouts += (mini.make_sngl_ifo(workflow, sngl_file, tmpltbank_file,
                                   tid, args.output_dir, args.instrument,
                                   tags=args.tags + [str(rank)]),)
    files += mini.make_trigger_timeseries(workflow, [sngl_file],
                              ifo_time, args.output_dir, special_tids=ifo_tid,
                              tags=args.tags + [str(rank)])
    curr_params = {}
    curr_params['mass1'] = trigs.mass1[num_event]
    curr_params['mass2'] = trigs.mass2[num_event]
    curr_params['spin1z'] = trigs.spin1z[num_event]
    curr_params['spin2z'] = trigs.spin2z[num_event]
    curr_params['f_lower'] = trigs.f_lower[num_event]
    curr_params[args.instrument + '_end_time'] = time
    # don't require precessing template info if not present
    try:
        curr_params['spin1x'] = trigs.spin1x[num_event]
        curr_params['spin2x'] = trigs.spin2x[num_event]
        curr_params['spin1y'] = trigs.spin1y[num_event]
        curr_params['spin2y'] = trigs.spin2y[num_event]
        curr_params['inclination'] = trigs.inclination[num_event]
    except KeyError:
        pass
    try:
        # Only present for precessing search
        curr_params['u_vals'] = trigs.u_vals[num_event]
    except:
        pass

    if wiki_file:
        add_wiki_row(wiki_file, [time,
                          str(datetime.datetime(*GPSToUTC(int(time))[0:6])),
                          trigs.stat[num_event], trigs.snr[num_event],
                          curr_params['mass1'], curr_params['mass2'],
                          curr_params['spin1z'], curr_params['spin2z'],
                          trigs.template_duration[num_event], ' ',' ',' '])

    files += mini.make_single_template_plots(workflow, insp_segs,
                            args.inspiral_data_read_name,
                            args.inspiral_data_analyzed_name, curr_params,
                            args.output_dir,
                            tags=args.tags+[str(rank)])

    files += mini.make_plot_waveform_plot(workflow, curr_params,
                                        args.output_dir, [args.instrument],
                                        tags=args.tags + [str(rank)])


    files += mini.make_singles_timefreq(workflow, sngl_file, tmpltbank_file,
                            time, args.output_dir,
                            data_segments=insp_data_seglists,
                            tags=args.tags + [str(rank)])

    files += mini.make_qscan_plot(workflow, args.instrument, time,
                                  args.output_dir,
                                  data_segments=insp_data_seglists,
                                  tags=args.tags + [str(rank)])

    layouts += list(layout.grouper(files, 2))

workflow.save(filename=args.output_file, output_map_path=args.output_map,
              transformation_catalog_path=args.transformation_catalog)
layout.two_column_layout(args.output_dir, layouts)
