#!/usr/bin/env python

# Copyright (C) 2015 Andrew R. Williamson
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""
Make workflow for the archival, targeted, coherent inspiral pipeline.
"""

import pycbc.version

__author__ = "Andrew Williamson <andrew.williamson@ligo.org>"
__version__ = pycbc.version.git_verbose_msg
__date__ = pycbc.version.date
__program__ = "pycbc_make_offline_grb_workflow"

import shutil
import sys
import os
import argparse
import logging
import pycbc.workflow as _workflow
from ligo.segments import segment, segmentlist, segmentlistdict
import matplotlib
matplotlib.use('agg')
from pycbc.results.legacy_grb import make_grb_segments_plot

workflow_name = "pygrb_offline"
logging.basicConfig(format="%(asctime)s:%(levelname)s : %(message)s",
                    level=logging.INFO)

# Parse command line options and instantiate pycbc workflow object
parser = argparse.ArgumentParser()
parser.add_argument("--version", action="version", version=__version__)
_workflow.add_workflow_command_line_group(parser)
args = parser.parse_args()
wflow = _workflow.Workflow(args, workflow_name)
all_files = _workflow.FileList([])
tags = []
initDir = os.getcwd()

logging.info("Generating %s workflow" % workflow_name)

# Setup run directory
if wflow.cp.has_option("workflow", "output-directory"):
    baseDir = wflow.cp.get("workflow", "output-directory")
else:
    baseDir = os.getcwd()
triggername = str(wflow.cp.get("workflow", "trigger-name"))
runDir = os.path.join(baseDir, "GRB%s" % triggername)
logging.info("Workflow will be generated in %s" % runDir)
if not os.path.exists(runDir):
    os.makedirs(runDir)
os.chdir(runDir)

# SEGMENTS
triggertime = int(wflow.cp.get("workflow", "trigger-time"))
start = triggertime - int(wflow.cp.get("workflow-exttrig_segments",
                                       "max-duration"))
end = triggertime + int(wflow.cp.get("workflow-exttrig_segments",
                                     "max-duration"))
wflow.cp = _workflow.set_grb_start_end(wflow.cp, start, end)

# Retrieve science segments
currDir = os.getcwd()
segDir = os.path.join(currDir, "segments")
sciSegs, segsFileList = _workflow.setup_segment_generation(wflow, segDir)


if wflow.cp.has_option("inspiral", "segment-start-pad"):
    pad_data = int(wflow.cp.get("inspiral", "pad-data"))
    start_pad = int(wflow.cp.get("inspiral", "segment-start-pad"))
    end_pad = int(wflow.cp.get("inspiral", "segment-end-pad"))
    wflow.cp.set("workflow-exttrig_segments", "min-before",str(start_pad+pad_data))
    wflow.cp.set("workflow-exttrig_segments", "min-after",str(end_pad+pad_data))
elif wflow.cp.has_option("inspiral", "analyse-segment-end"):
    safety = 1
    deadtime = int(wflow.cp.get("inspiral", "segment-length")) / 2
    spec_len = int(wflow.cp.get("inspiral", "inverse-spec-length")) / 2
    wflow.cp.set("workflow-exttrig_segments", "min-before",
                 str(deadtime - spec_len - safety))
    wflow.cp.set("workflow-exttrig_segments", "min-after",
                 str(spec_len + safety))
else:
    deadtime = int(wflow.cp.get("inspiral", "segment-length")) / 4
    wflow.cp.set("workflow-exttrig_segments", "min-before", str(deadtime))
    wflow.cp.set("workflow-exttrig_segments", "min-after", str(deadtime))

# Do checks for no/single IFO case
single_ifo = wflow.cp.has_option("workflow", "allow-single-ifo-search")
if len(sciSegs.keys()) == 0:
    plot_met = make_grb_segments_plot(wflow, segmentlistdict(), triggertime,
            triggername, segDir)
    logging.warning("No science segments available.")
    sys.exit()
elif len(sciSegs.keys()) < 2 and not single_ifo:
    plot_met = make_grb_segments_plot(wflow, segmentlistdict(sciSegs),
            triggertime, triggername, segDir)
    msg = "Science segments exist only for %s. " % tuple(sciSegs.keys())[0]
    msg += "If you wish to enable single IFO running add the option "
    msg += "'allow-single-ifo-search' to the [workflow] section of your "
    msg += "configuration file."
    logging.warning(msg)
    sys.exit()
else:
    onSrc, offSrc = _workflow.generate_triggered_segment(wflow, segDir,
                                                         sciSegs)

sciSegs = segmentlistdict(sciSegs)
if onSrc is None:
    plot_met = make_grb_segments_plot(wflow, sciSegs, triggertime, triggername,
            segDir, fail_criterion=offSrc)
    logging.info("Making segment plot and exiting.")
    sys.exit()
else:
    plot_met = make_grb_segments_plot(wflow, sciSegs, triggertime, triggername,
            segDir, coherent_seg=offSrc[tuple(offSrc.keys())[0]][0])
    segs_plot = _workflow.File(plot_met[0], plot_met[1], plot_met[2],
                               file_url=plot_met[3])
    segs_plot.PFN(segs_plot.cache_entry.path, site="local")
    sciSegs = offSrc
    all_files.extend(_workflow.FileList([segs_plot]))

if len(sciSegs) == 1:
    logging.info("Generating a single IFO search.")
    mf_tag = "sngl"
elif len(sciSegs) > 1:
    mf_tag = "coherent"

# Update analysis time after coherent segment calculation
ifo = tuple(sciSegs.keys())[0]
wflow.cp = _workflow.set_grb_start_end(wflow.cp, int(sciSegs[ifo][0][0]),
                                       int(sciSegs[ifo][0][1]))

padding = int(wflow.cp.get("inspiral", "pad-data"))
if wflow.cp.has_option("workflow-condition_strain", "do-gating"):
    gate_pad = int(wflow.cp.get("condition_strain", "pad-data"))
    padding += gate_pad

if wflow.cp.has_option("inspiral", "segment-start-pad"):
    start_pad = int(wflow.cp.get("inspiral", "segment-start-pad"))
    end_pad = int(wflow.cp.get("inspiral", "segment-end-pad"))
    wflow.analysis_time = segment(int(sciSegs[ifo][0][0]) + \
                                                    start_pad + padding,
                                  int(sciSegs[ifo][0][1]) - \
                                                      padding - end_pad)
elif wflow.cp.has_option("inspiral", "analyse-segment-end"):
    wflow.analysis_time = segment(int(sciSegs[ifo][0][0]) + deadtime - \
                                      spec_len + padding - safety,
                                  int(sciSegs[ifo][0][1]) - spec_len - \
                                      padding - safety)
else:
    wflow.analysis_time = segment(int(sciSegs[ifo][0][0]) + deadtime + padding,
                                  int(sciSegs[ifo][0][1]) - deadtime - padding)

ext_file = None

# DATAFIND
dfDir = os.path.join(currDir, "datafind")
datafind_files, _, sciSegs, _ = _workflow.setup_datafind_workflow(wflow,
        sciSegs, dfDir, segsFileList)
if wflow.cp.has_option("workflow-condition_strain", "do-gating"):
    new_seg = segment(sciSegs[ifo][0][0] + gate_pad,
                      sciSegs[ifo][0][1] - gate_pad)
    for iifo in sciSegs:
        sciSegs[iifo][0] = new_seg
    wflow.cp = _workflow.set_grb_start_end(wflow.cp, int(sciSegs[ifo][0][0]),
                                           int(sciSegs[ifo][0][1]))

ifos = sorted(sciSegs.keys())
ifo = ifos[0]
wflow.ifos = ifos
datafind_veto_files = _workflow.FileList([])

# GATING
if wflow.cp.has_option("workflow-condition_strain", "do-gating"):
    logging.info("Creating gating jobs.")
    wflow.cp = _workflow.set_grb_start_end(wflow.cp, int(sciSegs[ifo][0][0]),
                                           int(sciSegs[ifo][0][1]))
    gating_nodes, gated_files = _workflow.make_gating_node(wflow,
            datafind_files, outdir=dfDir)
    gating_method = wflow.cp.get("workflow-condition_strain",
                                 "gating-method")
    for gating_node in gating_nodes:
        if gating_method == "IN_WORKFLOW":
            wflow.add_node(gating_node)
        elif gating_method == "AT_RUNTIME":
            logging.info("Executing gating node...")
            wflow.execute_node(gating_node)
        else:
            msg = "[workflow-condition_strain] option 'gating-method' can "
            msg += "only have one of the values 'IN_WORKFLOW' or 'AT_RUNTIME'. "
            msg += "You have provided the value %s." % gating_method
            logging.error(msg)
            sys.exit()
    datafind_files = _workflow.FileList([])
    for ifo in ifos:
        gated_frames = _workflow.FileList([gated_frame for gated_frame in \
                gated_files if gated_frame.ifo == ifo])
        gated_cache = _workflow.File(ifo, "gated",
                segment(int(wflow.cp.get("workflow", "start-time")),
                        int(wflow.cp.get("workflow", "end-time"))),
                extension="lcf", directory=dfDir)
        gated_cache.PFN(gated_cache.cache_entry.path, site="local")
        gated_frames.convert_to_lal_cache().tofile(\
                            open(gated_cache.storage_path, "w"))
        datafind_files.extend(_workflow.FileList([gated_cache]))

datafind_veto_files.extend(datafind_files)
ifo_list = sorted(sciSegs.keys())
ifo = ifo_list[0]
ifos = ''.join(ifo_list)
wflow.ifos = ifos

# Is this an IPN GRB?
if wflow.cp.has_option("workflow-inspiral", "ipn-search-points") \
        and wflow.cp.has_option("workflow-injections", "ipn-sim-points"):
    wflow.cp.set("injections", "ipn-gps-time",
            wflow.cp.get("workflow", "trigger-time"))
    IPN = True
elif wflow.cp.has_option("workflow-inspiral", "ipn-search-points") \
        or wflow.cp.has_option("workflow-injections", "ipn-sim-points"):
    msg = "You have provided only one of 'ipn-search-points' under "
    msg += "[workflow-inspiral] and 'ipn-sim-points' under "
    msg += "[workflow-injections] in your configuration files. If this is an "
    msg += "IPN GRB please provide both, otherwise provide neither."
    logging.error(msg)
    sys.exit()
else:
    IPN = False

# If using pycbc_multi_inspiral we need bank_veto_bank.xml
if (os.path.basename(wflow.cp.get("executables", "inspiral")) == "pycbc_multi_inspiral"):
    bank_veto_file = _workflow.get_coh_PTF_files(wflow.cp, ifos, runDir,
                                                 bank_veto=True)
    datafind_veto_files.extend(bank_veto_file)

    if IPN:
        search_pts_file = _workflow.get_ipn_sky_files(wflow,
                wflow.cp.get("workflow-inspiral", "ipn-search-points"),
                tags=["SEARCH"])
        datafind_veto_files.extend(_workflow.FileList([search_pts_file]))

    # Make ExtTrig xml file (needed for lalapps_inspinj and summary pages)
    ext_file = _workflow.make_exttrig_file(wflow.cp, ifos, sciSegs[ifo][0],
                                           baseDir)
    all_files.extend(_workflow.FileList([ext_file]))

all_files.extend(datafind_veto_files)

# TEMPLATE BANK AND SPLIT BANK
bank_files = _workflow.setup_tmpltbank_workflow(wflow, sciSegs,
                                                datafind_files, dfDir)
splitbank_files = _workflow.setup_splittable_workflow(wflow, bank_files, dfDir,
                                                      tags=["inspiral"])
all_files.extend(bank_files)
all_files.extend(splitbank_files)

# INJECTIONS
injs = None
inj_tags = []
inj_files = None
inj_caches = None
inj_insp_files = None
inj_insp_caches = None
if wflow.cp.has_section("workflow-injections"):
    injDir = os.path.join(currDir, "injections")
    inj_caches = _workflow.FileList([])
    inj_insp_caches = _workflow.FileList([])

    # Generate injection files
    if IPN:
        sim_pts_file = _workflow.get_ipn_sky_files(wflow,
                wflow.cp.get("workflow-injections", "ipn-sim-points"),
                tags=["SIM"])
        all_files.extend(_workflow.FileList([sim_pts_file]))
        inj_files, inj_tags = _workflow.setup_injection_workflow(wflow, injDir,
                exttrig_file=sim_pts_file)
    else:
        inj_files, inj_tags = _workflow.setup_injection_workflow(wflow, injDir,
                exttrig_file=ext_file)
    all_files.extend(inj_files)
    injs = inj_files

    # Either split template bank for injections jobs or use same split banks
    # as for standard matched filter jobs
    if wflow.cp.has_section("workflow-splittable-injections"):
        inj_splitbank_files = _workflow.setup_splittable_workflow(wflow,
                bank_files, injDir, tags=["injections"])
        for inj_split in inj_splitbank_files:
            split_str = [s for s in inj_split.tagged_description.split("_") \
                         if ("BANK" in s and s[-1].isdigit())]
            if len(split_str) != 0:
                inj_split.tagged_description += "%s_%d" % (inj_split.tag_str,
                       int(split_str[0].replace("BANK", "")))
        all_files.extend(inj_splitbank_files)
    else:
        inj_splitbank_files = _workflow.FileList([])
        inj_splitbank_files.extend(splitbank_files)

    # Split the injection files
    if wflow.cp.has_section("workflow-splittable-split_inspinj"):
        inj_split_files = _workflow.FileList([])
        for inj_file, inj_tag in zip(inj_files, inj_tags):
            file = _workflow.FileList([inj_file])
            inj_splits = _workflow.setup_splittable_workflow(wflow, file,
                    injDir, tags=["split_inspinj", inj_tag])
            for inj_split in inj_splits:
                split_str = [s for s in \
                             inj_split.tagged_description.split("_") \
                             if ("SPLIT" in s and s[-1].isdigit())]
                if len(split_str) != 0:
                    new = inj_split.tagged_description.replace(split_str[0],
                            "SPLIT_%s" % split_str[0].replace("SPLIT", ""))
                    inj_split.tagged_description = new
            inj_split_files.extend(inj_splits)
        all_files.extend(inj_split_files)
        injs = inj_split_files

    # Generate injection matched filter workflow
    inj_insp_files = _workflow.setup_matchedfltr_workflow(wflow, sciSegs,
            datafind_veto_files, inj_splitbank_files, injDir, injs,
            tags=[mf_tag + "_injections"])
    for inj_insp_file in inj_insp_files:
        split_str = [s for s in inj_insp_file.name.split("_") \
                     if ("SPLIT" in s and s[-1].isdigit())]
        if len(split_str) != 0:
            num = split_str[0].replace("SPLIT", "_")
            inj_insp_file.tagged_description += num

    # Make cache files (needed for post-processing)
    for inj_tag in inj_tags:
        files = _workflow.FileList([file for file in injs \
                                    if inj_tag in file.tag_str])
        inj_cache = _workflow.File(ifos, "injections", sciSegs[ifo][0],
                                   extension="lcf", directory=injDir,
                                   tags=[inj_tag])
        inj_cache.PFN(inj_cache.cache_entry.path, site="local")
        inj_caches.extend(_workflow.FileList([inj_cache]))
        inj_cache_entries = files.convert_to_lal_cache()
        inj_cache_entries.tofile(open(inj_cache.storage_path, "w"))

        files = _workflow.FileList([file for file in inj_insp_files \
                                    if inj_tag in file.tag_str])
        inj_insp_cache = _workflow.File(ifos, "inspiral_injections",
                                        sciSegs[ifo][0], extension="lcf",
                                        directory=injDir, tags=[inj_tag])
        inj_insp_cache.PFN(inj_insp_cache.cache_entry.path, site="local")
        inj_insp_caches.extend(_workflow.FileList([inj_insp_cache]))
        inj_insp_cache_entries = files.convert_to_lal_cache()
        inj_insp_cache_entries.tofile(open(inj_insp_cache.storage_path, "w"))

    all_files.extend(inj_caches)
    all_files.extend(inj_insp_files)
    all_files.extend(inj_insp_caches)

# MAIN MATCHED FILTERING
inspDir = os.path.join(currDir, "inspiral")
inspiral_files = _workflow.setup_matchedfltr_workflow(wflow, sciSegs,
        datafind_veto_files, splitbank_files, inspDir,
        tags=[mf_tag + "_no_injections"])
all_files.extend(inspiral_files)
inspiral_cache = _workflow.File(ifos, "inspiral", sciSegs[ifo][0],
                                extension="lcf", directory=inspDir)
inspiral_cache.PFN(inspiral_cache.cache_entry.path, site="local")
all_files.extend(_workflow.FileList([inspiral_cache]))
inspiral_cache_entries = inspiral_files.convert_to_lal_cache()
inspiral_cache_entries.tofile(open(inspiral_cache.storage_path, "w"))

# LONG TIME SLIDES
long_slides = True if wflow.cp.has_option("workflow", "do-long-slides") else False
if long_slides:
    tsDir = os.path.join(currDir, "timeslides")
    if wflow.cp.has_option("workflow", "num-long-slides"):
        num_slides = int(wflow.cp.get("workflow", "num-long-slides"))
    else:
        num_slides = int((int(wflow.cp.get("workflow", "end-time")) - \
                         int(wflow.cp.get("workflow", "start-time"))) / \
                         float(int(wflow.cp.get("inspiral", "segment-length")) * \
                               (len(ifo_list) - 1)))
    logging.info("Doing {} long slides".format(num_slides))
    inspiral_ts_files = _workflow.FileList([])
    for slide in range(num_slides):
        inspiral_slide_files = _workflow.setup_matchedfltr_workflow(wflow,
                sciSegs, datafind_veto_files, splitbank_files, tsDir,
                tags=[mf_tag + "_no_injections", "slide{}".format(slide)])
        inspiral_ts_files.extend(inspiral_slide_files)
        all_files.extend(inspiral_slide_files)
else:
    inspiral_ts_files=None

# POST-PROCESSING
ppDir = os.path.join(currDir, "post_processing")
post_proc_method = wflow.cp.get_opt_tags("workflow-postproc",
                                         "postproc-method", tags)

if post_proc_method in ["COH_PTF_WORKFLOW", "COH_PTF_OFFLINE",
                        "COH_PTF_ONLINE"]:
    from pylal import pygrb_cohptf_pp
    # Add parsed config file so it can be linked from summary page
    cp_file_name = workflow_name + ".ini"
    cp_file_url = "file://localhost%s/%s" % (runDir, cp_file_name)
    cp_file = _workflow.File(ifos, cp_file_name, sciSegs[ifo][0],
                             file_url=cp_file_url)
    cp_file.PFN(cp_file.cache_entry.path, site="local")

    # Generate post-processing workflow
    html_dir = wflow.cp.get("workflow", "html-dir")
    pp_files = pygrb_cohptf_pp.setup_coh_PTF_post_processing(wflow,
            inspiral_files, inspiral_cache, ppDir, segDir,
            injection_trigger_files=inj_insp_files, injection_files=injs,
            injection_trigger_caches=inj_insp_caches,
            timeslide_trigger_files=inspiral_ts_files,
            injection_caches=inj_caches, config_file=cp_file, web_dir=html_dir,
            segments_plot=segs_plot, ifos=ifos, inj_tags=inj_tags)

    # Retrieve style files for webpage
    summary_files = _workflow.get_coh_PTF_files(wflow.cp, ifos, ppDir,
                                                summary_files=True)

    pp_files.extend(_workflow.FileList([cp_file]))
    pp_files.extend(summary_files)

all_files.extend(pp_files)

# COMPILE WORKFLOW AND WRITE DAX
wflow.save()
logging.info("Written dax.")

