#!/usr/bin/env python

# Copyright (C) 2016 Ian W. Harry, Y Ddraig Goch
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# in the knowledge that it will probably not be useful, and you'll moan at me,
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""
Workflow generator to create diagnosis plots and figures of merit for an input
template bank.
"""

#imports
from __future__ import division
import os
import argparse

from ligo import segments

import pycbc.version
import pycbc.workflow as wf
import pycbc.workflow.pegasus_workflow as pwf
from pycbc.results import create_versioning_page, static_table, layout
from pycbc.workflow import LalappsInspinjExecutable, PycbcDarkVsBrightInjectionsExecutable
from pycbc.workflow import setup_splittable_dax_generated

# Boiler-plate stuff
__author__  = "Ian Harry <ian.harry@ligo.org>"
__version__ = pycbc.version.git_verbose_msg
__date__    = pycbc.version.date
__program__ = "pycbc_make_bank_verifier_workflow"

# Some new executable classes. These can be moved into modules if needed
class BanksimExecutable(wf.Executable):
    """Class for running pycbc_banksim
    """
    # This can be altered if you don't always want to store output files
    current_retention_level = wf.Executable.ALL_TRIGGERS

    file_input_options = ['--psd-file', '--asd-file']

    def create_node(self, analysis_time, inj_file, bank_file, extra_tags=None):
        if extra_tags is None:
            extra_tags = []
        node = wf.Executable.create_node(self)
        node.add_input_opt('--signal-file', inj_file)
        node.add_input_opt('--template-file', bank_file)
        node.new_output_file_opt(analysis_time, '.dat', '--match-file',
                                 tags=self.tags + extra_tags)
        return node

class BanksimBankCombineExecutable(wf.Executable):
    """Class for running pycbc_banksim_combine_banks
    """
    # This can be altered if you don't always want to store output files
    current_retention_level = wf.Executable.ALL_TRIGGERS

    def create_node(self, analysis_time, inp_files, extra_tags=None):
        if extra_tags is None:
            extra_tags = []
        node = wf.Executable.create_node(self)
        node.add_input_list_opt('--input-files', inp_files)
        node.new_output_file_opt(analysis_time, '.dat', '--output-file',
                                 tags=self.tags + extra_tags)
        return node

class BanksimMatchCombineExecutable(wf.Executable):
    """Class for running pycbc_banksim_match_combine
    """
    # This can be altered if you don't always want to store output files
    current_retention_level = wf.Executable.FINAL_RESULT

    file_input_options = ['--filter-func-file']

    def create_node(self, analysis_time, match_files, inj_files, bank_files,
                    extra_tags=None):
        if extra_tags is None:
            extra_tags = []
        node = wf.Executable.create_node(self)
        node.add_input_list_opt('--match-files', match_files)
        for curr_file in inj_files:
            node._add_input(curr_file)
        for curr_file in bank_files:
            node._add_input(curr_file)
        node.new_output_file_opt(analysis_time, '.h5', '--output-file',
                                 tags=self.tags + extra_tags)
        return node

class BanksimPlotFittingFactorsExecutable(wf.Executable):
    """Class for running pycbc_banksim_plot_fitting_factors
    """
    # This can be altered if you don't always want to store output files
    current_retention_level = wf.Executable.FINAL_RESULT

    def create_node(self, analysis_time, input_file, extra_tags=None):
        if extra_tags is None:
            extra_tags = []
        node = wf.Executable.create_node(self)
        node.add_input_opt('--input-file', input_file)
        node.new_output_file_opt(analysis_time, '.png', '--output-file',
                                 tags=self.tags + extra_tags)
        return node

class BanksimPlotEffFittingFactorsExecutable(wf.Executable):
    """Class for running pycbc_banksim_plot_eff_fitting_factor
    """
    # This can be altered if you don't always want to store output files
    current_retention_level = wf.Executable.FINAL_RESULT

    def create_node(self, analysis_time, input_files, extra_tags=None):
        if extra_tags is None:
            extra_tags = []
        node = wf.Executable.create_node(self)
        node.add_input_list_opt('--input-files', input_files)
        node.new_output_file_opt(analysis_time, '.png', '--output-file',
                                 tags=self.tags + extra_tags)
        return node

class BanksimTablePointInjsExecutable(wf.Executable):
    """Class for running pycbc_banksim_table_point_injs
    """
    # This can be altered if you don't always want to store output files
    current_retention_level = wf.Executable.FINAL_RESULT

    def create_node(self, analysis_time, input_files, relative_dirs,
                    extra_tags=None):
        if extra_tags is None:
            extra_tags = []
        node = wf.Executable.create_node(self)
        node.add_input_list_opt('--input-files', input_files)
        node.add_list_opt('--directory-links', relative_dirs)
        node.new_output_file_opt(analysis_time, '.html', '--output-file',
                                 tags=self.tags + extra_tags)
        return node


# Argument parsing and setup of workflow

# Use the standard workflow command-line parsing routines. Things like a 
# configuration file are specified within the "workflow command line group"
# so run this with --help to see what options are added.
_desc = __doc__[1:]
parser = argparse.ArgumentParser(description=_desc)
parser.add_argument('--version', action='version', version=__version__)
parser.add_argument("--workflow-name", type=str, default='bank_verifier',
                    help="Descriptive name of the analysis.")
parser.add_argument("-d", "--output-dir", default=None,
                    help="Path to output directory.")
wf.add_workflow_command_line_group(parser)
args = parser.parse_args()

# Create the workflow object
workflow = wf.Workflow(args, args.workflow_name)

wf.makedir(args.output_dir)
os.chdir(args.output_dir)
args.output_dir = '.'

rdir = layout.SectionNumber('results', ['point_injection_sets',
                                        'broad_injection_sets',
                                        'workflow'])
wf.makedir(rdir.base)
wf.makedir(rdir['workflow'])

# Input bank file
inp_bank = workflow.cp.get('workflow', 'input-bank')
inp_bank = wf.File.from_path(inp_bank)
inp_bank.tags = []
inp_bank.description='TEMPLATEBANK'
inp_bank.ifo_list=(['H1','L1','V1'])
inp_bank.segment = workflow.analysis_time

# Inspinj job
inspinj_job = LalappsInspinjExecutable(workflow.cp, 'injection', out_dir='.',
                                       ifos='HL', tags=[])

def add_banksim_set(workflow, file_tag, num_injs, curr_tags, split_banks):
    """Add a group of jobs that does a complete banksim.
    """
    inspinj_job.update_current_tags(curr_tags)
    t_seg = segments.segment([1000000000, 1000000000+int(num_injs)])
    node = inspinj_job.create_node(t_seg)
    workflow += node
    inj_file = node.output_file
    # Here we apply the em-bright criterion
    if workflow.cp.has_option('workflow-injections', 'em-bright-only'):
        # Job to carry on with em-bright injections only
        em_filter_job = PycbcDarkVsBrightInjectionsExecutable(workflow.cp,
                                       'em_bright_filter',
                                       out_dir='.', ifos='HL', tags=curr_tags)
        node = em_filter_job.create_node(inj_file, t_seg, curr_tags)
        workflow += node
        inj_file = node.output_files[0]

    split_injs = setup_splittable_dax_generated(workflow, [inj_file],
                                                'splitinjfiles', curr_tags)
    # Banksim job
    banksim_job = BanksimExecutable(workflow.cp, 'banksim',
                                    out_dir=file_tag+'match',
                                    ifos='HL', tags=[file_tag])
    bscombine_job = \
        BanksimBankCombineExecutable(workflow.cp, 'banksim_bank_combine',
                                     out_dir=file_tag+'match', ifos='HL',
                                     tags=[file_tag])
    mcombine_job = \
        BanksimMatchCombineExecutable(workflow.cp, 'banksim_match_combine',
                                      out_dir=file_tag+'match', ifos='HL',
                                      tags=[file_tag])
    banksim_files = wf.FileList([])

    for inj_idx, split_inj in enumerate(split_injs):
        inj_tag = 'INJ{}'.format(inj_idx)
        currinj_banksim_files = wf.FileList([])
        for bank_idx, split_bank in enumerate(split_banks):
            bank_tag = 'BANK{}'.format(bank_idx)
            inj_tag = 'INJ{}'.format(inj_idx)
            node = banksim_job.create_node(workflow.analysis_time, split_inj,
                                           split_bank,
                                           extra_tags=[bank_tag,inj_tag])
            workflow+=node
            currinj_banksim_files.append(node.output_file)
        curr_node = bscombine_job.create_node(workflow.analysis_time,
                                              currinj_banksim_files,
                                              extra_tags=[inj_tag])
        workflow += curr_node
        banksim_files.append(curr_node.output_file)
    curr_node = mcombine_job.create_node(workflow.analysis_time, banksim_files,
                                         split_injs, split_banks)
    workflow += curr_node
    return curr_node.output_file

# Set up the actual banksims
curr_tags = ['shortinjbanksplit']
split_banks = setup_splittable_dax_generated(workflow, [inp_bank],
                                             'splitbankfiles', curr_tags)

output_pointinjs = {}
for file_tag, num_injs in workflow.cp.items('workflow-pointinjs'):
    curr_tags = ['shortinjs', file_tag]
    curr_file = add_banksim_set(workflow, file_tag, num_injs, curr_tags,
                                split_banks)
    output_pointinjs[file_tag] = curr_file

curr_tags = ['broadinjbanksplit']
split_banks = setup_splittable_dax_generated(workflow, [inp_bank], 
                                             'splitbankfiles', curr_tags)

output_broadinjs = {}
for file_tag, num_injs in workflow.cp.items('workflow-broadinjs'):
    curr_tags = ['broadinjs', file_tag]
    curr_file = add_banksim_set(workflow, file_tag, num_injs, curr_tags,
                                split_banks)
    output_broadinjs[file_tag] = curr_file

plotting_nodes = []

out_dir = rdir.base
point_injs_table_exe = BanksimTablePointInjsExecutable\
    (workflow.cp, 'banksim_table_point_injs',
     out_dir=rdir['point_injection_sets'], ifos='HL')
eff_fitting_facs_exe = BanksimPlotEffFittingFactorsExecutable\
    (workflow.cp, 'banksim_plot_eff_fitting_fac', out_dir=out_dir, ifos='HL')
plot_fitting_facs_exe = BanksimPlotFittingFactorsExecutable\
    (workflow.cp, 'banksim_plot_fitting_factors',
     out_dir=rdir['point_injection_sets'], ifos='HL')

summary_page_files = []
# Add files to point_inj_summ_files in pairs. A tuple of one entry will span
# the full column. A pair exists together.
point_inj_summ_files = []
# Nothing in this yet, so not sure how to set this.
broad_inj_summ_files = []

# Set the point injection names
for f in sorted(output_pointinjs):
    rdir['point_injection_sets/{}'.format(f)]

curr_node = point_injs_table_exe.create_node\
    (workflow.analysis_time,
     [output_pointinjs[f] for f in sorted(output_pointinjs)],
     ['../' + rdir.name['point_injection_sets/{}'.format(f)]
      for f in sorted(output_pointinjs)])
workflow += curr_node
plotting_nodes.append(curr_node)
point_inj_summ_files.append((curr_node.output_file,))

secs = workflow.cp.get_subsections('banksim_plot_eff_fitting_fac')
for tag in secs:
    eff_fitting_facs_exe.update_current_tags([tag])
    curr_node = eff_fitting_facs_exe.create_node\
        (workflow.analysis_time,
         [output_pointinjs[f] for f in output_pointinjs])
    workflow += curr_node
    plotting_nodes.append(curr_node)
    summary_page_files.append(curr_node.output_file)

# Set up layouts
layout.group_layout(rdir.base, summary_page_files)
layout.two_column_layout(rdir['point_injection_sets'], point_inj_summ_files)
# Also add broad_injs when ready

secs = workflow.cp.get_subsections('banksim_plot_fitting_factors')
# Note a sorted(dict) returns a list of sorted *keys*. Works in python 2.4+
# and python 3 (dict.keys is removed in python 3)
for tag in sorted(output_broadinjs):
    curr_outs = []
    curr_file = output_broadinjs[tag]
    plot_fitting_facs_exe.update_output_directory\
        (rdir['broad_injection_sets/{}'.format(tag)])
    for tag2 in secs:
        plot_fitting_facs_exe.update_current_tags([tag,tag2])
        curr_node = plot_fitting_facs_exe.create_node\
            (workflow.analysis_time, curr_file)
        workflow += curr_node
        plotting_nodes.append(curr_node)
        curr_outs.append((curr_node.output_file,))
    # Other outputs could go here, before running the layout
    layout.two_column_layout(rdir['broad_injection_sets/{}'.format(tag)],
                             curr_outs)

for tag in sorted(output_pointinjs):
    curr_outs = []
    curr_file = output_pointinjs[tag]
    plot_fitting_facs_exe.update_output_directory\
        (rdir['point_injection_sets/{}'.format(tag)])
    for tag2 in secs:
        plot_fitting_facs_exe.update_current_tags([tag,tag2])
        curr_node = plot_fitting_facs_exe.create_node\
            (workflow.analysis_time, curr_file)
        workflow += curr_node
        plotting_nodes.append(curr_node)
        curr_outs.append((curr_node.output_file,))
    # Other outputs could go here, before running the layout
    layout.two_column_layout(rdir['point_injection_sets/{}'.format(tag)],
                             curr_outs)

# Create versioning information
create_versioning_page(rdir['workflow/version'], workflow.cp)

wf.make_results_web_page(workflow, os.path.join(os.getcwd(), rdir.base),
                         explicit_dependencies=plotting_nodes)

workflow.save()
