#! /usr/bin/env python

# Copyright (C) 2017 Christopher M. Biwer, Alexander Harvey Nitz, Collin Capano
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
""" Creates a DAX for a parameter estimation injection study.
"""

import argparse
import logging
import os
import shlex
import numpy
import Pegasus.DAX3 as dax
import pycbc.version
import socket
import sys
from pycbc import __version__
from pycbc import results
from pycbc.results import layout
from pycbc.results import metadata
from pycbc.results import versioning
from pycbc.workflow import configuration
from pycbc.workflow import core
from pycbc.workflow.jobsetup import (PycbcCreateInjectionsExecutable,
                                     PycbcInferenceExecutable)
from pycbc.workflow import inference_followups as inffu
from pycbc.workflow import plotting
from pycbc.inject import InjectionSet
from pycbc.io import FieldArray


def config_from_config(cp, section, skip_opts=None):
    """Loads a config parser from the given section."""
    # create a dummy command-line parser for getting the config files and
    # options
    cfparser = argparse.ArgumentParser()
    configuration.add_workflow_command_line_group(cfparser)
    cli = cp.section_to_cli(section, skip_opts=skip_opts)
    opts = cfparser.parse_args(shlex.split(cli))
    return configuration.WorkflowConfigParser.from_cli(opts)


def read_inference_settings_from_config(cp, section):
    """Loads the config parser and gets the number of inference runs to do."""
    if cp.has_option(section, 'nruns'):
        nruns = int(cp.get(section, 'nruns'))
    else:
        nruns = 1
    return config_from_config(cp, section, skip_opts=['nruns']), nruns
    

def symlink_path(f, path):
    """ Symlinks a path.
    """
    if f is None:
        return
    try:
        os.symlink(f.storage_path, os.path.join(path, f.name))
    except OSError:
        pass


# set command line parser
parser = argparse.ArgumentParser(description=__doc__[1:])

# injection options: either specify a number to create, or use the given file
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument("--num-injections", type=int,
                   help="The number of injections to create.")
group.add_argument("--injection-file", type=str, nargs="+", 
                   help="Analyze injections in the given file(s) instead of "
                        "creating them.")
# add option groups
configuration.add_workflow_command_line_group(parser)
# add workflow group
core.add_workflow_settings_cli(parser, include_subdax_opts=True)
parser.add_argument("--seed", type=int, default=0,
                    help="Starting to seed to use. This will be incremented "
                         "one for each injection analyzed. Default is 0.")
# version option
parser.add_argument("--version", action="version", version=__version__,
                    help="Prints version information.")

# parser command line
opts = parser.parse_args()

# configuration files
config_file_tmplt = 'inference-{}.ini'
config_file_dir = 'config_files'
# the directory we'll store samples files to
samples_file_dir = 'samples_files'
# the directory we'll store posterior files to
posterior_file_dir = 'posterior_files'
# the directory we'll store the injection files to
injection_file_dir = 'injection_files'

# make data output directory
if opts.output_dir is None:
    opts.output_dir = opts.workflow_name + '_output'
core.makedir(opts.output_dir)
core.makedir('{}/{}'.format(opts.output_dir, config_file_dir))
core.makedir('{}/{}'.format(opts.output_dir, posterior_file_dir))
core.makedir('{}/{}'.format(opts.output_dir, injection_file_dir))

# set the dax file name
dax_file = opts.dax_file
if dax_file is None:
    dax_file = "{}.dax".format(opts.workflow_name)

# log to terminal until we know where the path to log output file
log_format = "%(asctime)s:%(levelname)s : %(message)s"
logging.basicConfig(format=log_format, level=logging.INFO)

# create workflow and sub-workflows
workflow = core.Workflow(opts, opts.workflow_name)
finalize_workflow = core.Workflow(opts, opts.workflow_name + "-finalization")

# load the inference config file
inference_cp, nruns = read_inference_settings_from_config(
    workflow.cp, 'workflow-inference')

# change working directory to the output
origdir = os.path.abspath(os.curdir)
os.chdir(opts.output_dir)

# if an injection file was provided, write each injection as a separate file
# in the injections directory
if opts.injection_file:
    injection_files = core.FileList([])
    n_injections = 0
    for injection_file in opts.injection_file:
        injections = InjectionSet('{}/{}'.format(origdir, injection_file))
        local_numinj = len(injections.table)
        st_args = {k: injections.table[k][0] 
                   for k in injections._injhandler.static_args}
        for ii in range(local_numinj):
            samples = {f: injections.table[ii][f] 
                       for f in injections.table.fieldnames
                       if f not in injections._injhandler.static_args}
            outfile = '{}/injection{}.hdf'.format(injection_file_dir,
                                                  n_injections)
            injections.write(
                 outfile, FieldArray.from_kwargs(**samples),
                 write_params=[k for k in injections.table.fieldnames 
                               if k not in injections._injhandler.static_args],
                 static_args=st_args)
            injection_files.append(core.File.from_path(outfile))
            n_injections += 1
else:
    n_injections = opts.num_injections

# figure out what diagnostic jobs there are
diagnostics = inffu.get_diagnostic_plots(workflow)

# figure out if we're doing pp tests
do_pp_test = workflow.cp.has_option('executables', 'pp_table_summary')
if do_pp_test:
    # get the parameters to do the test
    pp_params = ["'"+s+"'" for s in
                 shlex.split(workflow.cp.get('workflow-pp_test', 'pp-params'))]
    # see if there is an injection map provided
    if workflow.cp.has_option('workflow-pp_test', 'injection-samples-map'):
        inj_samples_map = workflow.cp.get('workflow-pp_test',
                                          'injection-samples-map')
    else:
        inj_samples_map = None
    pp_section = 'percentile-percentile_test'
    pp_dir = [pp_section]
else:
    pp_dir = []

# sections for output HTML pages
rdir = layout.SectionNumber("results",
                            pp_dir +
                            ["detector_sensitivity", "priors", "posteriors"] +
                            diagnostics +
                            ["config_files", "workflow"])

# make results directories
core.makedir(rdir.base)
core.makedir(rdir["workflow"])
core.makedir(rdir["config_files"])

# create files for workflow log
log_file_txt = core.File(workflow.ifos, "workflow-log",
                         workflow.analysis_time,
                         extension=".txt", directory=rdir["workflow"])
log_file_html = core.File(workflow.ifos, "WORKFLOW-LOG",
                          workflow.analysis_time,
                          extension=".html", directory=rdir["workflow"])

# switch saving log to file
logging.basicConfig(format=log_format, level=logging.INFO,
                    filename=log_file_txt.storage_path, filemode="w")
log_file = logging.FileHandler(filename=log_file_txt.storage_path, mode="w")
log_file.setLevel(logging.INFO)
formatter = logging.Formatter(log_format)
log_file.setFormatter(formatter)
logging.getLogger("").addHandler(log_file)
logging.info("Created log file %s" % log_file_txt.storage_path)

config_files = {}
posterior_files = core.FileList([])
seed = opts.seed
# loop over number of injecionts to be analyzed
sub_workflows = []
for num_inj in range(n_injections):
    zpad = int(numpy.ceil(numpy.log10(n_injections)))
    label = 'Injection {}'.format(str(num_inj+1).zfill(zpad))
    event = label.lower().replace(' ', '_')

    # create a sub workflow for this event
    # we need to go back to the original directory to do this for all the file
    # references to work correctly
    os.chdir(origdir)
    sub_workflow = core.Workflow(opts,
                                 "{}-{}".format(opts.workflow_name, event))
    # now go back to the output
    os.chdir(opts.output_dir)

    # if the config file has a fake-strain-seed set, increment to get
    # independent noise realizations
    if inference_cp.has_option('data', 'fake-strain-seed'):
        # get the detectors
        detectors = shlex.split(inference_cp.get('data', 'instruments'))
        fake_strain_seeds = {}
        for det in detectors:
            fake_strain_seeds[det] = seed
            seed = seed + 1
        inference_cp.set('data', 'fake-strain-seed',
                         ' '.join(['{}:{}'.format(d, s)
                                   for d, s in fake_strain_seeds.items()]))

    # write the configuration file to the config files directory
    config_file = sub_workflow.save_config(config_file_tmplt.format(event),
                                           config_file_dir, inference_cp)[0]

    # create sym links to config file for results page
    base = "config_files/{}".format(event)
    layout.single_layout(rdir[base], [config_file])
    symlink_path(config_file, rdir[base])

    # add a node to create the injection if an injection file wasn't provided
    if opts.injection_file:
        injection_file = injection_files[num_inj]
    else:
        # construct Executable for creating injections
        create_injections_exe = PycbcCreateInjectionsExecutable(
            sub_workflow.cp, "create_injections", ifo=sub_workflow.ifos,
            out_dir=injection_file_dir)
        node, injection_file = create_injections_exe.create_node(
            config_file, seed=seed, tags=opts.tags+[event])
        node.add_opt("--ninjections", 1)
        sub_workflow += node
        seed = seed + 1

    # add the injection file to the inference config file
    cp = configuration.WorkflowConfigParser([config_file.storage_path])
    cp.set('data', 'injection-file', injection_file.name)
    with open(config_file.storage_path, "w") as fp:
        cp.write(fp)

    # make node(s) for running sampler
    samples_files = []
    inference_exe = PycbcInferenceExecutable(sub_workflow.cp, "inference",
                                             ifos=sub_workflow.ifos,
                                             out_dir=samples_file_dir)
    for nn in range(nruns):
        tags = opts.tags + [event]
        if nruns > 1:
            tags.append(str(nn))
        node, samples_file = inference_exe.create_node(
            config_file, seed=seed, tags=tags,
            analysis_time=sub_workflow.analysis_time)
        # declare the injection file as a needed input file
        node.add_input(injection_file)
        # add node to workflow
        sub_workflow += node
        samples_files.append(samples_file)
        # increment the seed
        seed = seed + 1

    # create the posterior file and plots
    posterior_file, summary_files, _, _ = inffu.make_posterior_workflow(
        sub_workflow, samples_files, config_file, event, rdir,
        posterior_file_dir=posterior_file_dir, tags=opts.tags)
    posterior_files.append(posterior_file)

    # create the diagnostic plots
    _ = inffu.make_diagnostic_plots(sub_workflow, diagnostics, samples_files,
                                    event, rdir, tags=opts.tags)

    # files for detector_sensitivity summary subsection
    base = "detector_sensitivity"
    psd_plot = plotting.make_spectrum_plot(
        sub_workflow, [samples_files[0]], rdir[base],
        tags=opts.tags+[event],
        hdf_group="data")

    # build the summary page
    zpad = int(numpy.ceil(numpy.log10(len(samples_files))))
    layout.two_column_layout(rdir.base, summary_files,
                             unique=str(num_inj).zfill(zpad),
                             title=label, collapse=True)

    # build the psd page
    layout.single_layout(rdir['detector_sensitivity'], [psd_plot],
                         unique=str(num_inj).zfill(zpad),
                         title=label, collapse=True)

    # add the sub workflow to the main workflow
    workflow += sub_workflow
    sub_workflows.append(sub_workflow)

# create the PP and recovery plot
if do_pp_test:
    # create a sub workflow to do the PP test
    # we need to go back to the original directory to do this for all the file
    # references to work correctly
    os.chdir(origdir)
    pp_workflow = core.Workflow(
        opts, "{}-{}".format(opts.workflow_name, 'pp_test'))
    # now go back to the output
    os.chdir(opts.output_dir)
    # create the pp summary table
    pp_table = inffu.make_inference_pp_table(
        pp_workflow, posterior_files, rdir[pp_section],
        parameters=pp_params, injection_samples_map=inj_samples_map,
        analysis_seg=workflow.analysis_time, tags=opts.tags)[0]
    # now the pp plots and injection recovery
    pp_plots = []
    inj_recovery_plots = []
    for pi, param in enumerate(pp_params):
        tags = opts.tags + ['PARAM_{}'.format(pi)]
        # the pp plot
        pp_plot = inffu.make_inference_pp_plot(
            pp_workflow, posterior_files, rdir[pp_section],
            parameters=param, injection_samples_map=inj_samples_map,
            analysis_seg=workflow.analysis_time, tags=tags)[0]
        pp_plots.append(pp_plot)
        # the injection recovery plot
        injrec_plot = inffu.make_inference_inj_recovery_plot(
            pp_workflow, posterior_files, rdir[pp_section], param,
            injection_samples_map=inj_samples_map,
            analysis_seg=workflow.analysis_time, tags=tags)[0]
        inj_recovery_plots.append(injrec_plot)
    # add to the results page
    pp_files = [(pp_table,)] + list(zip(pp_plots, inj_recovery_plots))
    layout.two_column_layout(rdir[pp_section], pp_files)
    # add to the main workflow
    workflow += pp_workflow

# create versioning HTML pages
results.create_versioning_page(rdir["workflow/version"], workflow.cp)

# create node for making HTML pages
plotting.make_results_web_page(finalize_workflow,
    os.path.join(os.getcwd(), rdir.base))

# add finalize workflow to workflow and make it depend on the others
workflow += finalize_workflow
for sub_workflow in sub_workflows:
    dep = dax.Dependency(parent=sub_workflow.as_job,
                         child=finalize_workflow.as_job)
    workflow._adag.addDependency(dep)
if do_pp_test:
    dep = dax.Dependency(parent=pp_workflow.as_job,
                         child=finalize_workflow.as_job)
    workflow._adag.addDependency(dep)

# write dax
workflow.save(filename=dax_file, output_map_path=opts.output_map,
              transformation_catalog_path=opts.transformation_catalog)

# save workflow configuration file
base = rdir["workflow/configuration"]
core.makedir(base)
wf_ini = workflow.save_config("workflow.ini", base, workflow.cp)
layout.single_layout(base, wf_ini)

# close the log and flush to the html file
logging.shutdown()
with open (log_file_txt.storage_path, "r") as log_file:
    log_data = log_file.read()
log_str = """
<p>Workflow generation script created workflow in output directory: %s</p>
<p>Workflow name is: %s</p>
<p>Workflow generation script run on host: %s</p>
<pre>%s</pre>
""" % (os.getcwd(), opts.workflow_name, socket.gethostname(), log_data)
kwds = {"title" : "Workflow Generation Log",
        "caption" : "Log of the workflow script %s" % sys.argv[0],
        "cmd" : " ".join(sys.argv)}
results.save_fig_with_metadata(log_str, log_file_html.storage_path, **kwds)
layout.single_layout(rdir["workflow"], ([log_file_html]))
