#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2019 Duncan Macleod
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""Combine triggers from a splitbank GRB run
"""

from __future__ import (division, print_function)

import argparse
import os
import re
from collections import defaultdict
from math import (pi as PI, radians)

import numpy

import tqdm

import h5py

from astropy.constants import M_sun

from pycbc import __version__

MSUN_SI = M_sun.si.value

__author__ = "Duncan Macleod <duncan.macleod@ligo.org>"

TQDM_BAR_FORMAT = ("{desc}: |{bar}| "
                   "{n_fmt}/{total_fmt} {unit} ({percentage:3.0f}%) "
                   "[{elapsed} | ETA {remaining}]")
TQDM_KW = {
    "ascii": " -=#",
    "bar_format": TQDM_BAR_FORMAT,
    "smoothing": 0.05,
}

INJSTRING_REGEX = re.compile(r"_[A-Z]*(\d+)(INJ)_(FOUND|MISSED)")

# -- utilties -----------------------------------


def use_file(filename, inclination):
    """Returns `True` if this file can be used to analysis the given inclination
    """
    match = INJSTRING_REGEX.search(filename)
    if match and float(match.groups()[0]) >= inclination:
        return True
    return False


def theta(h5injgroup):
    """Calculate the theta angle column for an injection group
    """
    f_lower = h5injgroup["f_lower"][()]
    inclination = h5injgroup["inclination"][()]
    mass1 = h5injgroup["mass1"][()]
    mass2 = h5injgroup["mass2"][()]
    spin1x = h5injgroup["spin1x"][()]
    spin1y = h5injgroup["spin1y"][()]
    spin1z = h5injgroup["spin1z"][()]
    spin2x = h5injgroup["spin2x"][()]
    spin2y = h5injgroup["spin2y"][()]
    spin2z = h5injgroup["spin2z"][()]
    size = f_lower.size

    mtotal = mass1 + mass2

    # conversion factor for the angular momentum
    angmomfac = (
        mass1 * mass2 / (mtotal * PI * MSUN_SI * f_lower) ** (-1/3.)
    )
    m1sq = mass1 ** 2
    m2sq = mass2 ** 2

    # compute the orbital angular momentum
    L = numpy.zeros((3, size))
    L[0] = angmomfac * numpy.sin(inclination)
    L[2] = angmomfac * numpy.cos(inclination)

    # compute the spins
    S = numpy.zeros((3, size))
    S[0] = m1sq * spin1x + m2sq * spin2x
    S[1] = m1sq * spin1y + m2sq * spin2y
    S[2] = m1sq * spin1z + m2sq * spin2z

    # and finally the total angular momentum
    J = L + S

    theta = numpy.arctan2(numpy.sqrt(J[0]*J[0] + J[1]*J[1]), J[2])
    outofbounds = theta > PI/2.
    theta[outofbounds] = PI - theta[outofbounds]

    if any(theta < 0):
        raise ValueError("negative theta error")

    return theta


def merge_injection_files(inputfiles, outputfile, verbose=False, inclination=0,
                          **compression_kw):
    """Merge several HDF5 injections files into a single file,
    sieving by inclination.

    Parameters
    ----------
    inputfiles : `list` of `str`
        the paths of the input HDF5 files to merge

    outputfile : `str`
        the path of the output HDF5 file to write
    """
    incrad = radians(inclination)
    attributes = {}
    datasets = {}
    keep = {}
    total = defaultdict(int)

    nfiles = len(inputfiles)

    def _scan_file(name, obj):
        """Record parameters for a given dataset
        """
        # if this is an injection group, determine which injections to keep
        if isinstance(obj, h5py.Group) and name in ("found", "missed"):
            keep.setdefault(obj.file.filename, {})[name] = theta(obj) < incrad
            total[name] += obj["mass1"].shape[0]

        # if this is just an individual dataset, record its parameters
        if isinstance(obj, h5py.Dataset):
            datasets[name] = {attr: getattr(obj, attr) for attr in (
                "compression",
                "compression_opts",
                "dtype",
            )}

    # loop once over all files, recording information
    for filename in tqdm.tqdm(inputfiles, desc="Scanning trigger files",
                              disable=not verbose, total=nfiles, unit="files",
                              **TQDM_KW):
        with h5py.File(filename, 'r') as h5f:
            attributes = dict(h5f.attrs)
            h5f.visititems(_scan_file)

    # print summary of what we found
    nfound = sum(keep[name]["found"].sum() for name in keep)
    nmissed = sum(keep[name]["missed"].sum() for name in keep)
    if verbose:
        print("Found {}/{} found, {}/{} missed with inclination < {}°".format(
            nfound, total["found"], nmissed, total["missed"], inclination,
        ))

    position = defaultdict(int)
    eventcount = 0

    with h5py.File(outputfile, 'w') as h5out:
        h5out.attrs.update(attributes)

        # create datasets
        for dset, params in datasets.items():
            size = nmissed if dset.startswith("missed") else nfound
            h5out.create_dataset(dset, shape=(size,), **params)

        # copy dataset contents
        for filename in tqdm.tqdm(inputfiles, desc="Merging trigger files",
                                  disable=not verbose, total=nfiles,
                                  unit="files", **TQDM_KW):
            _keep = keep[filename]
            with h5py.File(filename, 'r') as h5in:
                for dset in datasets:
                    if dset.startswith("missed"):
                        _keep = keep[filename]["missed"]
                    else:
                        _keep = keep[filename]["found"]
                    data = h5in[dset][:][_keep]
                    size = data.shape[0]
                    pos = position[dset]
                    if dset == "network/network_event_id":
                        # increment event_id by the _total_ number of events
                        # in the file (not just the ones we are using)
                        data += eventcount
                        eventcount += h5in[dset].shape[0]
                    h5out[dset][pos:pos+size] = data
                    position[dset] += size

    if verbose:
        print("Merged triggers written to {}".format(outputfile))
    return outputfile


# -- parse command line -------------------------

parser = argparse.ArgumentParser(
    description=__doc__,
)

parser.add_argument(
    "-v",
    "--verbose",
    action="store_true",
    default=False,
    help="print verbose output (default: %(default)s)",
)
parser.add_argument(
    "-V",
    "--version",
    action="version",
    version=__version__,
    help="show version number and exit",
)

# parameters
parser.add_argument(
    "-I",
    "--max-inclination",
    type=float,
    default=0,
    help="create an injection set with injections "
         "uniformly distributed out to this inc !!in degrees!!",
)

# input/output
parser.add_argument(
    "-f",
    "--input-files",
    nargs="*",
    required=True,
    metavar="INJECTION_TRIGGER_FILE",
    help="read in listed trigger files",
)
parser.add_argument(
    "-o",
    "--output-file",
    required=True,
    default=os.getcwd(),
    help="output directory (default: %(default)s)",
)

args = parser.parse_args()

# determine which files to keep
usefiles = filter(
    lambda x: use_file(x, args.max_inclination),
    args.input_files,
)

# merge injection files, filtering on-the-fly by inclination
merge_injection_files(
    list(usefiles),
    args.output_file,
    inclination=args.max_inclination,
    verbose=args.verbose,
)
