#!/bin/env python

# Copyright (C) 2021 Francesco Pannarale
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""
Temporary script to produce qscans of loudest triggers/missed injections
"""

# =============================================================================
# Preamble
# =============================================================================
import os
import argparse
import logging
import h5py
import numpy as np
from pycbc import init_logging
import pycbc.workflow as wf
from pycbc.workflow.core import resolve_url_to_file
import pycbc.workflow.minifollowups as mini
import pycbc.version
import pycbc.events
from pycbc.results import layout
from pycbc.results import pygrb_postprocessing_utils as ppu
from pycbc.workflow.plotting import PlotExecutable

__author__ = "Francesco Pannarale <francesco.pannarale@ligo.org>"
__version__ = pycbc.version.git_verbose_msg
__date__ = pycbc.version.date
__program__ = "pycbc_pygrb_minifollowupss"


# =============================================================================
# Functions
# =============================================================================
def add_wiki_row(outfile, cols):
    """
    Adds a wiki-formatted row to an output file from a list or a numpy array.
    """
    with open(outfile, 'a') as f:
        f.write('||%s||\n' % '||'.join(map(str, cols)))


def make_timeseries_plot(workflow, trig_file, snr_type, central_time,
                         shift_time, out_dir, ifo=None, tags=None):
    """Adds a node for a timeseries of PyGRB results to the workflow"""

    tags = [] if tags is None else tags

    # Initialize job node with its tags
    grb_name = workflow.cp.get('workflow', 'trigger-name')
    extra_tags = ['GRB'+grb_name]
    extra_tags += [snr_type]
    if ifo is not None:
        extra_tags += [ifo]
    node = PlotExecutable(workflow.cp, 'pygrb_plot_snr_timeseries',
                          ifos=workflow.ifos, out_dir=out_dir,
                          tags=tags+extra_tags).create_node()
    node.add_input_opt('--trig-file', trig_file)
    # Pass the veto files
    veto_files = ppu.build_veto_filelist(workflow)
    node.add_input_list_opt('--veto-files', veto_files)
    # Pass the segment files
    seg_files = ppu.build_segment_filelist(workflow)
    node.add_input_list_opt('--seg-files', seg_files)
    # Other shared tuning values
    for opt in ['chisq-index', 'chisq-nhigh', 'null-snr-threshold',
                'veto-category', 'snr-threshold', 'newsnr-threshold',
                'sngl-snr-threshold', 'null-grad-thresh', 'null-grad-val']:
        node.add_opt('--'+opt, workflow.cp.get('workflow', opt))
    node.new_output_file_opt(workflow.analysis_time, '.png',
                             '--output-file', tags=extra_tags)
    # Quantity to be displayed on the y-axis of the plot
    node.add_opt('--y-variable', snr_type)
    if ifo is not None:
        node.add_opt('--ifo', ifo)
    reset_central_time = shift_time - central_time
    # Horizontal axis range the = prevents errors with negative times
    x_lims = str(-5.+reset_central_time)+','+str(reset_central_time+5.)
    node.add_opt('--x-lims='+x_lims)
    # Plot title
    if ifo is not None:
        title_str = "'%s SNR at %.3f (s)'" %(ifo, central_time)
        node.add_opt('--central-time', central_time)
    else:
        title_str = "'%s SNR at %.3f (s)'" %(snr_type.capitalize(), central_time)
        node.add_opt('--central-time', central_time)
    node.add_opt('--plot-title', title_str)

    # Add job node to workflow
    workflow += node

    return node.output_files


# =============================================================================
# Main script starts here
# =============================================================================
parser = argparse.ArgumentParser(description=__doc__[1:])
parser.add_argument('--version', action='version', version=pycbc.version.git_verbose_msg)
parser.add_argument("-v", "--verbose", default=False, action="store_true",
                    help="Verbose output")
parser.add_argument('--trig-file',
                    help="xml file containing the triggers found by PyGRB")
parser.add_argument('--followups-file',
                    help="HDF format file containing trigger/injections to follow up")
parser.add_argument('--wiki-file',
                    help="Name of file to save wiki-formatted table in")
parser.add_argument('--veto-files', nargs='+', action="store",
                    default=None, help="The location of the CATX veto " +
                    "files provided as a list of space-separated values.")
parser.add_argument("-a", "--seg-files", nargs="+", action="store",
                    default=None, help="The location of the buffer, " +
                    "onsource and offsource segment files.")
wf.add_workflow_command_line_group(parser)
wf.add_workflow_settings_cli(parser, include_subdax_opts=True)
args = parser.parse_args()

init_logging(args.verbose, format="%(asctime)s: %(levelname)s: %(message)s")

workflow = wf.Workflow(args)

wf.makedir(args.output_dir)

# Create a FileList that will contain all output files
layouts = []

# Read the file with the triggers/injections to follow up
logging.info('Reading list of triggers/injections to followup')
fp = h5py.File(args.followups_file, "r")

# Initialize a wiki table and add the column headers
if args.wiki_file:
    wiki_file = os.path.join(args.output_dir, args.wiki_file)
    add_wiki_row(wiki_file, fp.keys())

# Establish the number of follow-ups to perform
num_events = int(workflow.cp.get_opt_tags('workflow-minifollowups',
                                          'num-events', ''))
num_events = min(num_events, len(fp['BestNR'][:]))

# Determine ifos used in the analysis
trig_file = resolve_url_to_file(os.path.abspath(args.trig_file))
ifos = ppu.extract_ifos(os.path.abspath(args.trig_file))
num_ifos = len(ifos)

# (Loudest) off/on-source events are on time-slid data so the
# try will succeed, as it finds the time shift columns.
is_injection_followup = True
try:
    time_shift = fp[ifos[0]+' time shift (s)'][0]
    is_injection_followup = False
except:
    pass


# Loop over triggers/injections to be followed up
for num_event in range(num_events):
    files = wf.FileList([])
    logging.info('Processing event: %s', num_event)
    gps_time = fp['GPS time'][num_event]
    gps_time = gps_time.astype(float)
    if wiki_file:
        row = []
        for key in fp.keys():
            row.append(fp[key][num_event])
        add_wiki_row(wiki_file, row)
    # Handle off/on-source (loudest) triggers follow-up (which are on slid data)
    if not is_injection_followup:
        for ifo in ifos:
            time_shift = fp[ifo+' time shift (s)'][num_event]
            ifo_time = gps_time - time_shift
            files += make_timeseries_plot(workflow, trig_file,
                                          'single', gps_time, ifo_time,
                                          args.output_dir, ifo=ifo,
                                          tags=args.tags + [str(num_event)])
            files += mini.make_qscan_plot(workflow, ifo, ifo_time,
                                          args.output_dir,
                                          tags=args.tags + [str(num_event)])
    # Handle injections (which are on unslid data)
    else:
        for snr_type in ['reweighted', 'coherent']:
            files += make_timeseries_plot(workflow, trig_file,
                                          snr_type, gps_time, gps_time,
                                          args.output_dir, ifo=None,
                                          tags=args.tags + [str(num_event)])
        for ifo in ifos:
            files += mini.make_qscan_plot(workflow, ifo, gps_time,
                                          args.output_dir,
                                          tags=args.tags + [str(num_event)])

    layouts += list(layout.grouper(files, 2))

workflow.save()
layout.two_column_layout(args.output_dir, layouts)
