#!/bin/env python
# Copyright (C) 2015 Alexander Harvey Nitz
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
""" Followup foreground events
"""
import os, sys, argparse, logging, re, h5py, pycbc.workflow as wf
from pycbc.results import layout
from pycbc.types import MultiDetOptionAction
from pycbc.events import select_segments_by_definer, coinc
from pycbc.io import get_all_subkeys
import pycbc.workflow.minifollowups as mini
import pycbc.workflow.pegasus_workflow as wdax
import pycbc.version

def to_file(path, ifo=None):
    fil = wdax.File(os.path.basename(path))
    fil.ifo = ifo
    path = os.path.abspath(path)
    fil.PFN(path, 'local')
    return fil

parser = argparse.ArgumentParser(description=__doc__[1:])
parser.add_argument('--version', action='version', version=pycbc.version.git_verbose_msg)
parser.add_argument('--workflow-name', default='my_unamed_run')
parser.add_argument("-d", "--output-dir", default=None,
                    help="Path to output directory.")
parser.add_argument('--bank-file',
                    help="HDF format template bank file")
parser.add_argument('--statmap-file',
                    help="HDF format clustered coincident trigger result file")
parser.add_argument('--single-detector-triggers', nargs='+', action=MultiDetOptionAction,
                    help="HDF format merged single detector trigger files")
parser.add_argument('--inspiral-segments',
                    help="xml segment files containing the inspiral analysis times")
parser.add_argument('--inspiral-data-read-name',
                    help="Name of inspiral segmentlist containing data read in "
                         "by each analysis job.")
parser.add_argument('--inspiral-data-analyzed-name',
                    help="Name of inspiral segmentlist containing data "
                         "analyzed by each analysis job.")
parser.add_argument('--analysis-category', type=str, required=False,
                    default='background_exc',
                    choices = ['foreground', 'background', 'background_exc'],
                    help='Designates whether to look at foreground triggers '
                         'background triggers (including "little dogs") '
                         'or background triggers (with "little dogs" removed)')
parser.add_argument('--sort-variable', default='ifar',
                    help='Which subgroup of --analysis-category to use for '
                         'sorting. Default=ifar')
parser.add_argument('--sort-order', default='descending',
                    choices=['ascending','descending'],
                    help='Which direction to use when sorting on '
                         '--sort-variable. Default=descending')
parser.add_argument('--output-map', type=str, required=False,
                    help='Supply the location of the pegasus output map file. '
                         'Only needed if running this job within a workflow.')
parser.add_argument('--output-file', type=str, required=True,
                    help='Name of the output dax file to be generated.')
parser.add_argument('--transformation-catalog', type=str, required=False,
                    help='Supply the location of the pegasus executable '
                         'transformation catalog file. '
                         'Only needed if running this job within a workflow.')
parser.add_argument('--tags', nargs='+', default=[])
wf.add_workflow_command_line_group(parser)
args = parser.parse_args()

logging.basicConfig(format='%(asctime)s:%(levelname)s : %(message)s',
                    level=logging.INFO)

workflow = wf.Workflow(args, args.workflow_name)

wf.makedir(args.output_dir)

# create a FileList that will contain all output files
layouts = []

tmpltbank_file = to_file(args.bank_file)
coinc_file = to_file(args.statmap_file)
insp_segs = to_file(args.inspiral_segments)

single_triggers = []
fsdt = {}
insp_data_seglists = {}
insp_analysed_seglists = {}
for ifo in args.single_detector_triggers:
    fname = args.single_detector_triggers[ifo]
    single_triggers.append(to_file(fname, ifo=ifo))
    fsdt[ifo] = h5py.File(args.single_detector_triggers[ifo], 'r')
    insp_data_seglists[ifo] = select_segments_by_definer(
        args.inspiral_segments,
        segment_name=args.inspiral_data_read_name,
        ifo=ifo)
    insp_analysed_seglists[ifo] = select_segments_by_definer(
        args.inspiral_segments,
        segment_name=args.inspiral_data_analyzed_name,
        ifo=ifo)
    # NOTE: make_singles_timefreq needs a coalesced set of segments. If this is
    #       being used to determine command-line options for other codes,
    #       please think if that code requires coalesced, or not, segments.
    insp_data_seglists[ifo].coalesce()
    insp_analysed_seglists[ifo].coalesce()

num_events = int(workflow.cp.get_opt_tags('workflow-minifollowups', 'num-events', ''))
f = h5py.File(args.statmap_file, 'r')
file_val = args.analysis_category
stat = f['{}/stat'.format(file_val)][:]

if args.sort_variable not in f[file_val]:
    all_datasets = [re.sub(file_val, '', ds).strip('/') for ds in get_all_subkeys(f, file_val)]
    raise KeyError('Sort variable {0} not in {1}: sort choices in '
                   '{1} are {2}'.format(args.sort_variable, file_val,
                                        ', '.join(all_datasets)))

if len(stat) == 0:
    # There are no triggers, make no-op job and exit
    noop_node = mini.create_noop_node()
    workflow += noop_node
    workflow.save(filename=args.output_file, output_map_path=args.output_map,
                  transformation_catalog_path=args.transformation_catalog)
    sys.exit(0)

events_to_read = num_events * 100

if len(stat) < num_events:
    num_events = len(stat)
if len(stat) < events_to_read:
    events_to_read = len(stat)

sorting = f[file_val + '/' + args.sort_variable][:].argsort()
if args.sort_order == 'descending':
    sorting = sorting[::-1]
event_idx = sorting[0:events_to_read]
stat = stat[event_idx]

times = {}
tids = {}
# Check if using input files from old (2-ifo) coinc code
if 'detector_1' in f.attrs:
    times = {f.attrs['detector_1']:
               f['{}/time1'.format(file_val)][:][event_idx],
             f.attrs['detector_2']:
               f['{}/time2'.format(file_val)][:][event_idx],
             }
    tids = {f.attrs['detector_1']:
          f['{}/trigger_id1'.format(file_val)][:][event_idx],
            f.attrs['detector_2']:
          f['{}/trigger_id2'.format(file_val)][:][event_idx],
            }
else:
    # Version used for multi-ifo coinc code
    ifo_list = f.attrs['ifos'].split(' ')
    for ifo in ifo_list:
        times[ifo] = f['{}/{}/time'.format(file_val,ifo)][:][event_idx]
        tids[ifo] = f['{}/{}/trigger_id'.format(file_val, ifo)][:][event_idx]

bank_data = h5py.File(args.bank_file, 'r')

# loop over number of loudest events to be followed up
event_times = {}
skipped_data = []
event_count = 0
curr_idx = -1
while event_count < num_events and curr_idx < (events_to_read - 1):
    curr_idx += 1
    files = wf.FileList([])

    # Check if using input files from old (2-ifo) coinc code
    if 'detector_1' in f.attrs:
        key0, key1 = tuple(times.keys())[:2]
        ifo_times = '%s:%s %s:%s' % (key0, times[key0][curr_idx],
                                     key1, times[key1][curr_idx])
        key0, key1 = tuple(tids.keys())[:2]
        ifo_tids = '%s:%s %s:%s' % (key0, tids[key0][curr_idx],
                                    key1, tids[key1][curr_idx])
        # For background do not want to follow up 10 background events caused
        # by the same event in ifo 1 and different, quiet, events in ifo 2,
        if key0 not in event_times:
            event_times[key0] = []
        if int(times[key0][curr_idx]) in event_times[key0]:
            skipped_data.append((key0, int(times[key0][curr_idx])))
            continue
        if key1 not in event_times:
            event_times[key1] = []
        if int(times[key1][curr_idx]) in event_times[key1]:
            skipped_data.append((key1, int(times[key1][curr_idx])))
            continue
        event_times[key0].append(int(times[key0][curr_idx]))
        event_times[key1].append(int(times[key1][curr_idx]))

    else:
        # Version used for multi-ifo coinc code
        ifo_times_strings = []
        ifo_tids_strings = []
        duplicate = False
        for ifo in times:
            ifo_times_string = '%s:%s' % (ifo, times[ifo][curr_idx])
            ifo_tids_string = '%s:%s' % (ifo, tids[ifo][curr_idx])
            ifo_times_strings += [ifo_times_string]
            ifo_tids_strings += [ifo_tids_string]

            # For background do not want to follow up 10 background coincs with
            # the same event in ifo 1 and different events in ifo 2
            if ifo not in event_times:
                event_times[ifo] = []
            # Don't skip coincs in zerolag or due to the sentinel time -1
            if 'background' in args.analysis_category and \
                    times[ifo][curr_idx] != -1 and \
                    int(times[ifo][curr_idx]) in event_times[ifo]:
                skipped_data.append((ifo, int(times[ifo][curr_idx])))
                duplicate = True
                break

        if duplicate:
            continue
        for ifo in times:
            event_times[ifo].append(int(times[ifo][curr_idx]))

        ifo_times = ' '.join(ifo_times_strings)
        ifo_tids = ' '.join(ifo_tids_strings)

    event_count += 1
    if skipped_data:
        layouts += (mini.make_skipped_html(
                        workflow,
                        skipped_data,
                        args.output_dir,
                        tags=['SKIP_{}'.format(event_count)]),)
        skipped_data = []

    bank_id = f['{}/template_id'.format(file_val)][:][sorting][curr_idx]

    layouts += (mini.make_coinc_info(workflow, single_triggers, tmpltbank_file,
                              coinc_file, args.output_dir, n_loudest=curr_idx,
                              sort_order=args.sort_order, sort_var=args.sort_variable,
                              tags=args.tags + [str(event_count)]),)
    files += mini.make_trigger_timeseries(workflow, single_triggers,
                             ifo_times, args.output_dir, special_tids=ifo_tids,
                             tags=args.tags + [str(event_count)])

    params = {}
    for ifo in times:
        params['%s_end_time' % ifo] = times[ifo][curr_idx]
        try:
            # Only present for precessing case
            params['u_vals_%s'%ifo] = \
                                 fsdt[ifo][ifo]['u_vals'][tids[ifo][curr_idx]]
        except:
            pass

    params['mass1'] = bank_data['mass1'][bank_id]
    params['mass2'] = bank_data['mass2'][bank_id]
    params['spin1z'] = bank_data['spin1z'][bank_id]
    params['spin2z'] = bank_data['spin2z'][bank_id]
    params['f_lower'] = bank_data['f_lower'][bank_id]
    # don't require precessing template info if not present
    try:
        params['spin1x'] = bank_data['spin1x'][bank_id]
        params['spin1y'] = bank_data['spin1y'][bank_id]
        params['spin2x'] = bank_data['spin2x'][bank_id]
        params['spin2y'] = bank_data['spin2y'][bank_id]
        params['inclination'] = bank_data['inclination'][bank_id]
    except KeyError:
        pass

    files += mini.make_single_template_plots(workflow, insp_segs,
                                    args.inspiral_data_read_name,
                                    args.inspiral_data_analyzed_name, params,
                                    args.output_dir,
                                    tags=args.tags + [str(event_count)])

    for single in single_triggers:
        time = times[single.ifo][curr_idx]
        if time==-1:
            time = coinc.mean_if_greater_than_zero([times[sngl.ifo][curr_idx]
                                                    for sngl in single_triggers])[0]
        for seg in insp_analysed_seglists[single.ifo]:
            if time in seg:
                files += mini.make_singles_timefreq(workflow, single, tmpltbank_file,
                                time, args.output_dir,
                                data_segments=insp_data_seglists[single.ifo],
                                tags=args.tags + [str(event_count)])
                files += mini.make_qscan_plot\
                    (workflow, single.ifo, time, args.output_dir,
                     data_segments=insp_data_seglists[single.ifo],
                     tags=args.tags + [str(event_count)])
                break
        else:
            logging.info('Trigger time {} is not valid in {}, ' \
                         'skipping singles plots'.format(time, single.ifo))

    layouts += list(layout.grouper(files, 2))

workflow.save(filename=args.output_file, output_map_path=args.output_map,
              transformation_catalog_path=args.transformation_catalog)
layout.two_column_layout(args.output_dir, layouts)
