#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2019 Duncan Macleod
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""Cluster triggers from a pyGRB run
"""

from __future__ import print_function

import argparse
import os
import shutil
import sys

import tqdm

import numpy

import h5py

from gwdatafind.utils import filename_metadata

from pycbc import __version__

__author__ = "Duncan Macleod <duncan.macleod@ligo.org>"

TQDM_BAR_FORMAT = ("{desc}: |{bar}| "
                   "{n_fmt}/{total_fmt} {unit} ({percentage:3.0f}%) "
                   "[{elapsed} | ETA {remaining}]{postfix}")
TQDM_KW = {
    "ascii": " -=#",
    "bar_format": TQDM_BAR_FORMAT,
    "smoothing": 0.05,
}


# -- utilities ----------------------------------

def slice_hdf5(inputfile, outfile, include, verbose=False):
    """Create a new HDF5 file containing a slice of the network events

    Here ``include`` should be an index array
    """
    include = numpy.array(include, copy=False)
    if include.dtype == numpy.bool_:
        nevents = include.sum()
    else:
        nevents = include.size

    with h5py.File(inputfile, "r") as h5in:
        ifos = [k for k in h5in.keys() if k != "network"]

        # find which single-ifo events to keep
        ifo_index = {
            ifo: numpy.unique(
                h5in["network/{}_event_id".format(ifo)][:][include],
            ) for ifo in ifos
        }

        nsets = sum(isinstance(item, h5py.Dataset) for
                    group in h5in.values() for
                    item in group.values())
        msg = "Slicing {} network events into new file".format(nevents)
        bar = tqdm.tqdm(total=nsets, desc=msg, disable=not verbose,
                        unit="datasets", **TQDM_KW)
        with h5py.File(outfile, "w") as h5out:
            for old in h5in["network"].values():
                if isinstance(old, h5py.Dataset):
                    h5out.create_dataset(
                        old.name,
                        data=old[:][include],
                        compression=old.compression,
                        compression_opts=old.compression_opts,
                    )
                    bar.update()
            for ifo in ifos:
                idx = numpy.in1d(h5in[ifo]["event_id"][()], ifo_index[ifo])
                for old in h5in[ifo].values():
                    if isinstance(old, h5py.Dataset):
                        h5out.create_dataset(
                            old.name,
                            data=old[:][idx],
                            compression=old.compression,
                            compression_opts=old.compression_opts,
                        )
                        bar.update()
        bar.close()
        if verbose:
            print("Slice written to {}".format(outfile))


# -- parse command line -------------------------

parser = argparse.ArgumentParser(
    description=__doc__,
)

parser.add_argument(
    "-v",
    "--verbose",
    action="store_true",
    default=False,
    help="verbose output with microsecond timer (default: %(default)s)",
)
parser.add_argument(
    "-V",
    "--version",
    action="version",
    version=__version__,
    help="show version number and exit",
)

# clustering
parser.add_argument(
    "-W",
    "--time-window",
    type=float,
    default=0.1,
    help="the cluster time window (default %(default)s)",
)
parser.add_argument(
    "-c",
    "--rank-column",
    default="coherent_snr",
    help="column over which to rank events (default: %(default)s)",
)

# input/output
parser.add_argument(
    "-t",
    "--trig-file",
    required=True,
    help="path to input trigger file",
)
parser.add_argument(
    "-o",
    "--output-dir",
    default=os.getcwd(),
    help="output directory (default: %(default)s)",
)

args = parser.parse_args()

win = args.time_window

ifotag, filetag, segment = filename_metadata(args.trig_file)
start, end = segment
outfile = os.path.join(
    args.output_dir,
    "{}-{}_CLUSTERED-{}-{}.h5".format(
        ifotag,
        filetag,
        start,
        end - start,
    ),
)

# -- generate clustering bins -------------------

nbins = int((end - start) // win + 1)
bins = [[] for i in range(nbins)]
loudsnr = numpy.zeros(nbins)
loudtime = numpy.zeros(nbins)
clusters = []

# -- cluster ------------------------------------

with h5py.File(args.trig_file, "r") as h5f:
    time = h5f["network"]["end_time_gc"][()]
    snr = h5f["network"][args.rank_column][()]

# empty file (no triggers), so just copy the file
if not time.size:
    shutil.copyfile(args.trig_file, outfile)
    if args.verbose:
        print("trigger file is empty")
        print("copied input file to {}".format(outfile))
    sys.exit(0)

# find loudest trigger in each bin
for i in tqdm.tqdm(range(time.size), desc="Initialising bins",
                   disable=not args.verbose, total=time.size, unit='triggers',
                   **TQDM_KW):
    t, s = time[i], snr[i]
    idx = int(float(t - start) // win)
    bins[idx].append(i)
    if s > loudsnr[idx]:
        loudsnr[idx] = s
        loudtime[idx] = t

prev = -1
nxt_ = 1
first = True
last = False
add_cluster = clusters.append
nclusters = 0

# cluster
bar = tqdm.tqdm(bins, desc="Clustering bins",
                disable=not args.verbose, total=nbins, unit='bins',
                postfix=dict(nclusters=0), **TQDM_KW)
for i, bin_ in enumerate(bar):
    if not bin_:  # empty
        continue

    for idx in bin_:
        t, s = time[idx], snr[idx]

        if s < loudsnr[i]:  # not loudest in own bin
            continue

        # check loudest event in previous bin
        if not first:
            prevt = loudtime[prev]
            if prevt and abs(prevt - t) < win and s < loudsnr[prev]:
                continue

        # check loudest event in next bin
        if not last:
            nextt = loudtime[nxt_]
            if nextt and abs(nextt - t) < win and s < loudsnr[nxt_]:
                continue

        loudest = True

        # check all events in previous bin
        if not first and prevt and abs(prevt - t) < win:
            for id2 in bins[prev]:
                if abs(time[id2] - t) < win and s < snr[id2]:
                    loudest = False
                    break

        # check all events in next bin
        if loudest and not last and nextt and abs(nextt - t) < win:
            for id2 in bins[nxt_]:
                if abs(time[id2] - t) < win and s < snr[id2]:
                    loudest = False
                    break

        # this is loudest in its vicinity, keep it
        if loudest:
            add_cluster(idx)
            nclusters += 1
            bar.set_postfix(nclusters=nclusters)

    # update things for next time
    first = False
    last = i == nbins - 1
    prev += 1
    nxt_ += 1

    bar.update()

slice_hdf5(
    args.trig_file,
    outfile,
    numpy.asarray(clusters),
    verbose=args.verbose,
)
