#!/bin/env python
# Copyright (C) 2015 Alexander Harvey Nitz, Ian Harry
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
""" Followup foreground events
"""
import os, argparse, logging
import numpy
from ligo.lw import table
from ligo.lw import utils as ligolw_utils
from pycbc.results import layout
from pycbc.events import select_segments_by_definer
import pycbc.workflow.minifollowups as mini
import pycbc.version
import pycbc.workflow as wf
import pycbc.events
from pycbc.workflow.core import resolve_url_to_file
from pycbc.events import stat
from pycbc.io import hdf

def add_wiki_row(outfile, cols):
    """
    Adds a wiki-formatted row to an output file from a list or a numpy array.
    """
    with open(outfile, 'a') as f:
        f.write('||%s||\n' % '||'.join(map(str,cols)))

parser = argparse.ArgumentParser(description=__doc__[1:])
parser.add_argument('--version', action='version', version=pycbc.version.git_verbose_msg)
parser.add_argument('--bank-file',
                    help="HDF format template bank file")
parser.add_argument('--single-detector-file',
                    help="HDF format merged single detector trigger files")
parser.add_argument('--instrument', help="Name of interferometer e.g. H1")
parser.add_argument('--veto-file',
    help="The veto file to be used if vetoing triggers (optional).")
parser.add_argument('--veto-segment-name',
    help="If using veto file must also provide the name of the segment to use "
         "as a veto.")
parser.add_argument('--inspiral-segments',
                    help="xml segment file containing the inspiral analysis "
                         "times")
parser.add_argument('--inspiral-data-read-name',
                    help="Name of inspiral segmentlist containing data read in "
                         "by each analysis job.")
parser.add_argument('--inspiral-data-analyzed-name',
                    help="Name of inspiral segmentlist containing data "
                         "analyzed by each analysis job.")
parser.add_argument('--min-snr', type=float, default=6.5,
                    help="Minimum SNR to consider for loudest triggers")
parser.add_argument('--non-coinc-time-only', default=False,
                    action='store_true',
                    help="If given remove (veto) single-detector triggers "
                         "that occur during a time when at least one other "
                         "instrument is taking science data.")
parser.add_argument('--minimum-duration', default=None, type=float,
                    help="If given only consider single-detector triggers "
                         "with template duration larger than this.")
parser.add_argument('--maximum-duration', default=None, type=float,
                    help="If given only consider single-detector triggers "
                         "with template duration smaller than this.")
wf.add_workflow_command_line_group(parser)
wf.add_workflow_settings_cli(parser, include_subdax_opts=True)
stat.insert_statistic_option_group(parser,
    default_ranking_statistic='single_ranking_only')
args = parser.parse_args()

logging.basicConfig(format='%(asctime)s:%(levelname)s : %(message)s',
                    level=logging.INFO)

workflow = wf.Workflow(args)
workflow.ifos = [args.instrument]
workflow.ifo_string = args.instrument

wf.makedir(args.output_dir)

# create a FileList that will contain all output files
layouts = []

tmpltbank_file = resolve_url_to_file(os.path.abspath(args.bank_file))
sngl_file = resolve_url_to_file(
    os.path.abspath(args.single_detector_file),
    attrs={'ifos': args.instrument}
)

if args.veto_file is not None:
    veto_file = resolve_url_to_file(
        os.path.abspath(args.veto_file),
        attrs={'ifos': args.instrument}
    )
else:
    veto_file = None
insp_segs = resolve_url_to_file(os.path.abspath(args.inspiral_segments))
insp_data_seglists = select_segments_by_definer\
        (args.inspiral_segments, segment_name=args.inspiral_data_read_name,
         ifo=args.instrument)
insp_data_seglists.coalesce()

num_events = int(workflow.cp.get_opt_tags('workflow-sngl_minifollowups',
                 'num-sngl-events', ''))

# This helps speed up the processing to ignore a large fraction of triggers
snr_mask = None
if args.min_snr:
    logging.info('Calculating Prefilter')
    f = hdf.HFile(args.single_detector_file, 'r')
    idx, _ = f.select(lambda snr: snr > args.min_snr,
                      '{}/snr'.format(args.instrument),
                      return_indices=True)
    snr_mask = numpy.zeros(len(f['{}/snr'.format(args.instrument)]),
                           dtype=bool)
    snr_mask[idx] = True

trigs = hdf.SingleDetTriggers(args.single_detector_file, args.bank_file,
                              args.veto_file, args.veto_segment_name,
                              None, args.instrument, premask=snr_mask)

if args.non_coinc_time_only:
    from pycbc.io.ligolw import LIGOLWContentHandler as h

    segs_doc = ligolw_utils.load_filename(args.inspiral_segments,
                                          contenthandler=h)
    seg_def_table = table.Table.get_table(segs_doc, 'segment_definer')
    def_ifos = seg_def_table.getColumnByName('ifos')
    def_ifos = [str(ifo) for ifo in def_ifos]
    ifo_list = list(set(def_ifos))
    ifo_list.remove(args.instrument)
    for ifo in ifo_list:
        curr_veto_mask, segs = pycbc.events.veto.indices_outside_segments(
            trigs.end_time, [args.inspiral_segments],
            ifo=ifo, segment_name=args.inspiral_data_analyzed_name)
        curr_veto_mask.sort()
        trigs.apply_mask(curr_veto_mask)

if args.minimum_duration is not None:
    logging.info('applying minimum duration')
    durations = trigs.template_duration
    lgc_mask = durations > args.minimum_duration
    trigs.apply_mask(lgc_mask)
    logging.info('remaining triggers: %s', trigs.mask.sum())

if args.maximum_duration is not None:
    logging.info('applying maximum duration')
    durations = trigs.template_duration
    lgc_mask = durations < args.maximum_duration
    trigs.apply_mask(lgc_mask)
    logging.info('remaining triggers: %s', trigs.mask.sum())

logging.info('finding loudest clustered events')
rank_method = stat.get_statistic_from_opts(args, [args.instrument])
trigs.mask_to_n_loudest_clustered_events(rank_method, n_loudest=num_events)

if len(trigs.stat) < num_events:
    num_events = len(trigs.stat)

times = trigs.end_time
tids = trigs.template_id

# loop over number of loudest events to be followed up
order = trigs.stat.argsort()[::-1]
for rank, num_event in enumerate(order):
    logging.info('Processing event: %s', num_event)

    files = wf.FileList([])
    time = times[num_event]
    ifo_time = '%s:%s' %(args.instrument, str(time))
    tid = trigs.mask[num_event]
    ifo_tid = '%s:%s' %(args.instrument, str(tid))

    layouts += (mini.make_sngl_ifo(workflow, sngl_file, tmpltbank_file,
                                   tid, args.output_dir, args.instrument,
                                   tags=args.tags + [str(rank)]),)
    files += mini.make_trigger_timeseries(workflow, [sngl_file],
                              ifo_time, args.output_dir, special_tids=ifo_tid,
                              tags=args.tags + [str(rank)])
    curr_params = {}
    curr_params['mass1'] = trigs.mass1[num_event]
    curr_params['mass2'] = trigs.mass2[num_event]
    curr_params['spin1z'] = trigs.spin1z[num_event]
    curr_params['spin2z'] = trigs.spin2z[num_event]
    curr_params['f_lower'] = trigs.f_lower[num_event]
    curr_params[args.instrument + '_end_time'] = time
    # don't require precessing template info if not present
    try:
        curr_params['spin1x'] = trigs.spin1x[num_event]
        curr_params['spin2x'] = trigs.spin2x[num_event]
        curr_params['spin1y'] = trigs.spin1y[num_event]
        curr_params['spin2y'] = trigs.spin2y[num_event]
        curr_params['inclination'] = trigs.inclination[num_event]
    except KeyError:
        pass
    try:
        # Only present for precessing search
        curr_params['u_vals'] = trigs.u_vals[num_event]
    except:
        pass

    files += mini.make_single_template_plots(workflow, insp_segs,
                            args.inspiral_data_read_name,
                            args.inspiral_data_analyzed_name, curr_params,
                            args.output_dir,
                            tags=args.tags+[str(rank)])

    files += mini.make_plot_waveform_plot(workflow, curr_params,
                                        args.output_dir, [args.instrument],
                                        tags=args.tags + [str(rank)])


    files += mini.make_singles_timefreq(workflow, sngl_file, tmpltbank_file,
                            time, args.output_dir,
                            data_segments=insp_data_seglists,
                            tags=args.tags + [str(rank)])

    files += mini.make_qscan_plot(workflow, args.instrument, time,
                                  args.output_dir,
                                  data_segments=insp_data_seglists,
                                  tags=args.tags + [str(rank)])

    layouts += list(layout.grouper(files, 2))

workflow.save()
layout.two_column_layout(args.output_dir, layouts)
