#!/usr/bin/env python
# Copyright (C) 2015 Alexander Harvey Nitz
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
""" Followup foreground events """

import os, argparse, numpy, logging, h5py, copy
import pycbc.workflow as wf
from pycbc.types import MultiDetOptionAction
from pycbc.events import select_segments_by_definer, coinc
from pycbc.results import layout
from pycbc.detector import Detector
import pycbc.workflow.minifollowups as mini
import pycbc.workflow.pegasus_workflow as wdax
import pycbc.version, pycbc.pnutils
from pycbc.io.hdf import SingleDetTriggers


def to_file(path, ifo=None):
    fil = wdax.File(os.path.basename(path))
    fil.ifo = ifo
    path = os.path.abspath(path)
    fil.PFN(path, 'local')
    return fil

parser = argparse.ArgumentParser(description=__doc__[1:])
parser.add_argument('--version', action='version', version=pycbc.version.git_verbose_msg)
parser.add_argument('--workflow-name', default='my_unamed_run')
parser.add_argument("-d", "--output-dir", default=None,
                    help="Path to output directory.")
parser.add_argument('--bank-file',
                    help="HDF format template bank file")
parser.add_argument('--injection-file',
                    help="HDF format injection results file")
parser.add_argument('--injection-xml-file',
                    help="XML format injection file")
parser.add_argument('--single-detector-triggers', nargs='+', action=MultiDetOptionAction,
                    help="HDF format merged single detector trigger files")
parser.add_argument('--inspiral-segments',
                    help="xml segment files containing the inspiral analysis times")
parser.add_argument('--inspiral-data-read-name',
                    help="Name of inspiral segmentlist containing data read in "
                         "by each analysis job.")
parser.add_argument('--inspiral-data-analyzed-name',
                    help="Name of inspiral segmentlist containing data "
                         "analyzed by each analysis job.")
parser.add_argument('--inj-window', type=int, default=0.5,
                    help="Time window in which to look for injection triggers")
parser.add_argument('--ifar-threshold', type=float, default=None,
                    help="If given also followup injections with ifar smaller "
                         "than this threshold.")
parser.add_argument('--maximum-decisive-snr', type=float, default=None,
                    help="If given, only followup injections where the "
                         "decisive SNR is smaller than this value.")
parser.add_argument('--output-map')
parser.add_argument('--output-file')
parser.add_argument('--transformation-catalog')
parser.add_argument('--tags', nargs='+', default=[])
wf.add_workflow_command_line_group(parser)
args = parser.parse_args()

logging.basicConfig(format='%(asctime)s:%(levelname)s : %(message)s',
                    level=logging.INFO)

workflow = wf.Workflow(args, args.workflow_name)

wf.makedir(args.output_dir)

# create a FileList that will contain all output files
layouts = []

tmpltbank_file = to_file(args.bank_file)
injection_file = to_file(args.injection_file)
injection_xml_file = to_file(args.injection_xml_file)
insp_segs = to_file(args.inspiral_segments)

single_triggers = []
insp_data_seglists = {}
insp_analysed_seglists = {}
for ifo in args.single_detector_triggers:
    fname = args.single_detector_triggers[ifo]
    single_triggers.append(to_file(fname, ifo=ifo))
    insp_data_seglists[ifo] = select_segments_by_definer\
        (args.inspiral_segments, segment_name=args.inspiral_data_read_name,
         ifo=ifo)
    insp_analysed_seglists[ifo] = select_segments_by_definer\
        (args.inspiral_segments, segment_name=args.inspiral_data_analyzed_name,
         ifo=ifo)
    # NOTE: make_singles_timefreq needs a coalesced set of segments. If this is
    #       being used to determine command-line options for other codes,
    #       please think if that code requires coalesced, or not, segments.
    insp_data_seglists[ifo].coalesce()
    insp_analysed_seglists[ifo].coalesce()

f = h5py.File(args.injection_file, 'r')
missed = f['missed/after_vetoes'][:]
if args.ifar_threshold is not None:
    try:  # injections may not have (inclusive) IFAR present
        ifars = f['found_after_vetoes']['ifar'][:]
    except KeyError:
        ifars = f['found_after_vetoes']['ifar_exc'][:]
        logging.warn('Inclusive IFAR not found, using exclusive')
    lgc_arr = ifars < args.ifar_threshold
    missed = numpy.append(missed,
                          f['found_after_vetoes']['injection_index'][lgc_arr])

num_events = int(workflow.cp.get_opt_tags('workflow-injection_minifollowups', 'num-events', ''))

try:
    if 'optimal_snr_1' in f['injections'].keys():  # 2-ifo pipeline case
        optimal_snr = [f['injections/optimal_snr_1'][:][missed],
                       f['injections/optimal_snr_2'][:][missed]]
        # 'decisive' opt SNR is 2nd largest
        dec_snr = numpy.minimum(optimal_snr[0], optimal_snr[1])
    else:
        optstrings = [os for os in f['injections'].keys() if \
                      os.startswith('optimal_snr_')]
        optimal_snr = [f['injections/%s' % os][:][missed] for os in optstrings]
        # 2nd largest opt SNR
        dec_snr = [sorted(snrs)[-2] for snrs in zip(*optimal_snr)]
        dec_snr = numpy.array(dec_snr)
    
    if args.maximum_decisive_snr is not None:
        # By setting to 0, these injections will not be considered
        dec_snr[dec_snr > args.maximum_decisive_snr] = 0
    sorting = dec_snr.argsort()
    sorting = sorting[::-1]  # descending order of dec opt SNR
except KeyError:
    # Fall back to effective distance if optimal SNR not available
    eff_dist = {'H1': f['injections/eff_dist_h'][:][missed],
                'L1': f['injections/eff_dist_l'][:][missed],
                'V1': f['injections/eff_dist_v'][:][missed]
                }
    dec_dist = numpy.maximum(eff_dist[single_triggers[0].ifo],
                             eff_dist[single_triggers[1].ifo])
    mchirp, eta = pycbc.pnutils.mass1_mass2_to_mchirp_eta(\
                                              f['injections/mass1'][:][missed],
                                              f['injections/mass2'][:][missed])
    dec_chirp_dist = pycbc.pnutils.chirp_distance(dec_dist, mchirp)
    sorting = dec_chirp_dist.argsort()  # ascending order of dec chirp distance

if len(missed) < num_events:
    num_events = len(missed)

# loop over loudest events to be followed up
found_inj_idxes = f['found_after_vetoes/injection_index'][:]
for num_event in range(num_events):
    files = wf.FileList([])

    injection_index = missed[sorting][num_event]
    time = f['injections/end_time'][injection_index]

    ifo_times = ''
    inj_params = {}
    for val in ['mass1', 'mass2', 'spin1z', 'spin2z', 'end_time']:
        inj_params[val] = f['injections/%s' %(val,)][injection_index]
    for single in single_triggers:
        ifo = single.ifo
        for seg in insp_data_seglists[single.ifo]:
            if time in seg:
                det = Detector(ifo)
                lon = f['injections/longitude'][injection_index]
                lat = f['injections/latitude'][injection_index]
                ifo_time = time + det.time_delay_from_earth_center(lon, lat, time)
                break
        else:
            ifo_time = -1.0

        ifo_times += ' %s:%s ' % (ifo, ifo_time)
        inj_params[ifo + '_end_time'] = ifo_time

    layouts += [(mini.make_inj_info(workflow, injection_file, injection_index, num_event,
                               args.output_dir, tags=args.tags + [str(num_event)])[0],)]
    if injection_index in found_inj_idxes:
        trig_id = numpy.where(found_inj_idxes == injection_index)[0][0]
        layouts += [(mini.make_coinc_info
                     (workflow, single_triggers, tmpltbank_file,
                      injection_file, args.output_dir, trig_id=trig_id,
                      file_substring='found_after_vetoes',
                      tags=args.tags + [str(num_event)])[0],)]
    files += mini.make_trigger_timeseries(workflow, single_triggers,
                              ifo_times, args.output_dir,
                              tags=args.tags + [str(num_event)])

    for single in single_triggers:
        checkedtime = time
        if (inj_params[single.ifo + '_end_time'] == -1.0):
            all_times = [inj_params[sngl.ifo + '_end_time'] for sngl in single_triggers]
            checkedtime = coinc.mean_if_greater_than_zero(all_times)[0]
        for seg in insp_analysed_seglists[single.ifo]:
            if checkedtime in seg:
                files += mini.make_singles_timefreq(workflow, single, tmpltbank_file,
                                checkedtime, args.output_dir,
                                data_segments=insp_data_seglists[single.ifo],
                                tags=args.tags + [str(num_event)])
                files += mini.make_qscan_plot\
                    (workflow, single.ifo, checkedtime, args.output_dir,
                     data_segments=insp_data_seglists[single.ifo],
                     injection_file=injection_xml_file,
                     tags=args.tags + [str(num_event)])
                break
        else:
            logging.info('Trigger time {} is not valid in {}, ' \
                         'skipping singles plots'.format(checkedtime,
                                                         single.ifo))

    files += mini.make_single_template_plots(workflow, insp_segs,
                            args.inspiral_data_read_name,
                            args.inspiral_data_analyzed_name, inj_params,
                            args.output_dir, inj_file=injection_xml_file,
                            tags=args.tags+['INJ_PARAMS',str(num_event)],
                            params_str='injection parameters as template, ' +\
                                       'here the injection is made as normal',
                            use_exact_inj_params=True)

    files += mini.make_single_template_plots(workflow, insp_segs,
                            args.inspiral_data_read_name,
                            args.inspiral_data_analyzed_name, inj_params,
                            args.output_dir, inj_file=injection_xml_file,
                            tags=args.tags + ['INJ_PARAMS_INVERTED',
                                              str(num_event)],
                            params_str='injection parameters as template, ' +\
                                       'here the injection is made inverted',
                            use_exact_inj_params=True)

    files += mini.make_single_template_plots(workflow, insp_segs,
                            args.inspiral_data_read_name,
                            args.inspiral_data_analyzed_name, inj_params,
                            args.output_dir, inj_file=injection_xml_file,
                            tags=args.tags + ['INJ_PARAMS_NOINJ',
                                              str(num_event)],
                            params_str='injection parameters, here no ' +\
                                       'injection was actually performed',
                            use_exact_inj_params=True)

    for curr_ifo in args.single_detector_triggers:
        single_fname = args.single_detector_triggers[curr_ifo]
        hd_sngl = SingleDetTriggers(single_fname, args.bank_file, None, None,
                                    None, curr_ifo)
        end_times = hd_sngl.end_time
        # Use SNR here or NewSNR??
        snr = hd_sngl.snr
        lgc_mask = abs(end_times - inj_params['end_time']) < args.inj_window

        if len(snr[lgc_mask]) == 0:
            continue

        snr_idx = numpy.arange(len(lgc_mask))[lgc_mask][snr[lgc_mask].argmax()]
        hd_sngl.mask = [snr_idx]
        curr_params = copy.deepcopy(inj_params)
        curr_params['mass1'] = hd_sngl.mass1[0]
        curr_params['mass2'] = hd_sngl.mass2[0]
        curr_params['spin1z'] = hd_sngl.spin1z[0]
        curr_params['spin2z'] = hd_sngl.spin2z[0]
        curr_params['f_lower'] = hd_sngl.f_lower[0]
        # don't require precessing template info if not present
        try:
            curr_params['spin1x'] = hd_sngl.spin1x[0]
            curr_params['spin2x'] = hd_sngl.spin2x[0]
            curr_params['spin1y'] = hd_sngl.spin1y[0]
            curr_params['spin2y'] = hd_sngl.spin2y[0]
            curr_params['inclination'] = hd_sngl.inclination[0]
        except KeyError:
            pass
        try:
            # Only present for precessing search
            curr_params['u_vals'] = hd_sngl.u_vals[0]
        except:
            pass

        curr_tags = ['TMPLT_PARAMS_%s' %(curr_ifo,)]
        curr_tags += [str(num_event)]
        files += mini.make_single_template_plots(workflow, insp_segs,
                                args.inspiral_data_read_name,
                                args.inspiral_data_analyzed_name, curr_params,
                                args.output_dir, inj_file=injection_xml_file,
                                tags=args.tags + curr_tags,
                                params_str='loudest template in %s' % curr_ifo )

    layouts += list(layout.grouper(files, 2))
    num_event += 1

workflow.save(filename=args.output_file, output_map_path=args.output_map,
              transformation_catalog_path=args.transformation_catalog)
layout.two_column_layout(args.output_dir, layouts)
