#!/usr/bin/env python

# Copyright (C) 2014 Alex Nitz
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""
Find multi-detector gravitational wave triggers and calculate the
coherent SNRs and related statistics.
"""

import logging
import time
from collections import defaultdict
import argparse
import numpy as np
from pycbc import (
    detector, fft, init_logging, inject, opt, psd, scheme, strain, vetoes,
    waveform, DYN_RANGE_FAC
)
from pycbc.events import ranking, coherent as coh, EventManagerCoherent
from pycbc.filter import MatchedFilterControl
from pycbc.types import TimeSeries, zeros, float32, complex64
from pycbc.types import MultiDetOptionAction
from pycbc.vetoes import sgchisq
time_init = time.time()
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("-V", "--verbose", action="store_true",
                    help="print extra debugging information",
                    default=False)
parser.add_argument("--output", type=str)
parser.add_argument("--instruments", nargs="+", type=str, required=True,
                    help="List of instruments to analyze.")
parser.add_argument("--bank-file", type=str)
parser.add_argument("--snr-threshold", type=float,
                    help="SNR threshold for trigger generation")
parser.add_argument("--newsnr-threshold", type=float, metavar='THRESHOLD',
                    help="Cut triggers with NewSNR less than THRESHOLD")
parser.add_argument("--low-frequency-cutoff", type=float,
                    help="The low frequency cutoff to use for filtering "
                         "(Hz)")
# add approximant arg
waveform.bank.add_approximant_arg(parser)
parser.add_argument("--order", type=str,
                    help="The integer half-PN order at which to generate the "
                         "approximant.")
parser.add_argument("--taper-template",
                    choices=["start", "end", "startend"],
                    help="For time-domain approximants, taper the start "
                         "and/or end of the waveform before FFTing.")
parser.add_argument("--cluster-method", choices=["template", "window"],
                    default="window",
                    help="Method to use when clustering triggers. 'window' - "
                         "cluster within a fixed time window defined by the "
                         "cluster-window option (default); or 'template' - "
                         "cluster within windows defined by each template's "
                         "chirp length.")
parser.add_argument("--cluster-window", type=float, default=0,
                    help="Length of clustering window in seconds.")
parser.add_argument("--bank-veto-bank-file", type=str)
parser.add_argument("--chisq-bins", default=0)
# Commenting out options which are not yet implemented
# parser.add_argument("--chisq-threshold", type=float, default=0) 
# parser.add_argument("--chisq-delta", type=float, default=0)
parser.add_argument("--autochi-number-points", type=int, default=0)
parser.add_argument("--autochi-stride", type=int, default=0)
parser.add_argument("--autochi-onesided", action='store_true',
                    default=False)
parser.add_argument("--downsample-factor", type=int, default=1,
                    help="Factor that determines the interval between the "
                         "initial SNR sampling. If not set (or 1) no sparse "
                         "sample is created, and the standard full SNR is "
                         "calculated.")
parser.add_argument("--upsample-threshold", type=float,
                    help="The fraction of the SNR threshold to check the "
                         "sparse SNR sample.")
parser.add_argument("--upsample-method", choices=["pruned_fft"],
                    default='pruned_fft',
                    help="The method to find the SNR points between the "
                         "sparse SNR sample.")
parser.add_argument("--user-tag", type=str, metavar="TAG",
                    help="This is used to identify FULL_DATA jobs for "
                         "compatibility with pipedown post-processing. Option "
                         "will be removed when no longer needed.")
# Arguments added for the coherent stuff
parser.add_argument("--ra", type=float, help="Right ascension, in radians")
parser.add_argument("--dec", type=float, help="Declination, in radians")
parser.add_argument("--coinc-threshold", type=float, default=0.0,
                    help="Triggers with coincident/coherent snr below this "
                         "value will be discarded.")
parser.add_argument("--do-null-cut", action='store_true',
                    help="Apply a cut based on null SNR.")
parser.add_argument("--null-min", type=float, default=5.25,
                    help="Triggers with null_snr above this value will be"
                         " discarded.")
parser.add_argument("--null-grad", type=float, default=0.2,
                    help="The gradient of the line defining the null cut "
                         "after the null step.")
parser.add_argument("--null-step", type=float, default=20.,
                    help="Triggers with coherent snr above null_step will "
                         "be cut according to the null_grad and null_min.")
parser.add_argument("--trigger-time", type=int,
                    help="Time of the GRB, used to set the antenna patterns.")
parser.add_argument("--projection", default="standard",
                    choices=["standard", "left", "right", "left+right"],
                    help="Choice of projection matrix. 'left' and 'right' " 
                         "correspond to face-away and face-on")
# Add options groups
strain.insert_strain_option_group_multi_ifo(parser)
strain.StrainSegments.insert_segment_option_group_multi_ifo(parser)
psd.insert_psd_option_group_multi_ifo(parser)
scheme.insert_processing_option_group(parser)
fft.insert_fft_option_group(parser)
opt.insert_optimization_option_group(parser)
inject.insert_injfilterrejector_option_group_multi_ifo(parser)
sgchisq.SingleDetSGChisq.insert_option_group(parser)
args = parser.parse_args()
init_logging(args.verbose)
# Set GRB time variable for convenience
t_gps = args.trigger_time
# Put the ifos in alphabetical order so they are always called in
# the same order.
args.instruments.sort()
strain.verify_strain_options_multi_ifo(args, parser, args.instruments)
strain.StrainSegments.verify_segment_options_multi_ifo(
        args, parser, args.instruments)
psd.verify_psd_options_multi_ifo(args, parser, args.instruments)
scheme.verify_processing_options(args, parser)
fft.verify_fft_options(args, parser)
inj_filter_rejector = inject.InjFilterRejector.from_cli_multi_ifos(
    args, args.instruments)
strain_dict = strain.from_cli_multi_ifos(
    args, args.instruments, inj_filter_rejector, dyn_range_fac=DYN_RANGE_FAC)
strain_segments_dict = strain.StrainSegments.from_cli_multi_ifos(
    args, strain_dict, args.instruments)
ctx = scheme.from_cli(args)
with ctx:
    fft.from_cli(args)
    # Set some often used variables for easy access
    flow = args.low_frequency_cutoff
    flow_dict = defaultdict(lambda : flow)
    for count, ifo in enumerate(args.instruments):
        if count == 0:
            sample_rate = strain_dict[ifo].sample_rate
            sample_rate_dict = defaultdict(lambda : sample_rate)
            flen = strain_segments_dict[ifo].freq_len
            flen_dict = defaultdict(lambda : flen)
            tlen = strain_segments_dict[ifo].time_len
            tlen_dict = defaultdict(lambda : tlen)
            delta_f = strain_segments_dict[ifo].delta_f
            delta_f_dict = defaultdict(lambda : delta_f)
        else:
            try:
                assert(sample_rate == strain_dict[ifo].sample_rate)
                assert(flen == strain_segments_dict[ifo].freq_len)
                assert(tlen == strain_segments_dict[ifo].time_len)
                assert(delta_f == strain_segments_dict[ifo].delta_f)
            except:
                err_msg = "Sample rate, frequency length and time length "
                err_msg += "must all be consistent across ifos."
                raise ValueError(err_msg)
    logging.info("Making frequency-domain data segments")
    segments = {
        ifo: strain_segments_dict[ifo].fourier_segments()
        for ifo in args.instruments
        }
    del strain_segments_dict
    psd.associate_psds_to_multi_ifo_segments(
        args, segments, strain_dict, flen, delta_f, flow, args.instruments,
        dyn_range_factor=DYN_RANGE_FAC, precision='single')
    # Currently we are using the same matched-filter parameters for all
    # ifos. Therefore only one MatchedFilterControl needed. Maybe this can
    # change if needed. Segments is only used to get tlen etc. which is
    # same for all ifos, so just send the first ifo
    template_mem = zeros(tlen, dtype=complex64)
    # Calculate time delay to each detector
    time_delay_idx = {}
    for ifo in args.instruments:
        dt = detector.Detector(ifo).time_delay_from_earth_center(
            args.ra, args.dec, t_gps)
        time_delay_idx[ifo] = int(round(dt * sample_rate))
    # Matched filter each ifo. Don't cluster here for a coherent search.
    # Clustering happens at the end of the template loop.
    # FIXME: The single detector SNR threshold should not necessarily be
    #        applied to every IFO (usually only 2 most sensitive in
    #        network)
    matched_filter = {
        ifo: MatchedFilterControl(
            args.low_frequency_cutoff, None, args.snr_threshold, tlen, delta_f,
            complex64, segments[ifo], template_mem, use_cluster=False,
            downsample_factor=args.downsample_factor,
            upsample_threshold=args.upsample_threshold,
            upsample_method=args.upsample_method, cluster_function='symmetric')
        for ifo in args.instruments}
    logging.info("Initializing signal-based vetoes.")
    # The existing SingleDetPowerChisq can calculate the single detector
    # chisq for multiple ifos, so just use that directly.
    power_chisq = vetoes.SingleDetPowerChisq(args.chisq_bins)
    # The existing SingleDetBankVeto can calculate the single detector
    # bank veto for multiple ifos, so we just use it directly.
    bank_chisq = vetoes.SingleDetBankVeto(
        args.bank_veto_bank_file, flen, delta_f, flow, complex64,
        phase_order=args.order, approximant=args.approximant)
    # Same here
    autochisq = vetoes.SingleDetAutoChisq(
        args.autochi_stride, args.autochi_number_points,
        onesided=args.autochi_onesided)
    logging.info("Overwhitening frequency-domain data segments")
    for ifo in args.instruments:
        for seg in segments[ifo]:
            seg /= seg.psd
    ifo_out_types = {
        'time_index': int,
        'ifo': int, # IFO is stored as an int internally!
        'snr': complex64,
        'chisq': float32,
        'chisq_dof': int,
        'bank_chisq': float32,
        'bank_chisq_dof': int,
        'cont_chisq': float32,
        }
    ifo_out_vals = {
        'time_index': None,
        'ifo': None,
        'snr': None,
        'chisq': None,
        'chisq_dof': None,
        'bank_chisq': None,
        'bank_chisq_dof': None,
        'cont_chisq': None,
        }
    ifo_names = sorted(ifo_out_vals.keys())
    network_out_types = {
        'latitude': float32,
        'longitude': float32,
        'time_index': int,
        'coherent_snr': float32,
        'null_snr': float32,
        'nifo': int,
        'my_network_chisq': float32,
        'reweighted_snr': float32
        }
    network_out_vals = {
        'latitude': None,
        'longitude': None,
        'time_index': None,
        'coherent_snr': None,
        'null_snr': None,
        'nifo': None,
        'my_network_chisq': None, 
        'reweighted_snr': None
        }
    network_names = sorted(network_out_vals.keys())
    event_mgr = EventManagerCoherent(
        args, args.instruments, ifo_names,
        [ifo_out_types[n] for n in ifo_names], network_names,
        [network_out_types[n] for n in network_names])
    logging.info("Read in template bank")
    bank = waveform.FilterBank(
        args.bank_file, flen, delta_f, complex64, low_frequency_cutoff=flow,
        phase_order=args.order, taper=args.taper_template,
        approximant=args.approximant, out=template_mem)
    # Use injfilterrejector to reduce the bank to only those templates that
    # might actually find something
    n_bank = len(bank)
    nfilters = 0
    logging.info("Full template bank size: %d", n_bank)
    for ifo in args.instruments:
        bank.template_thinning(inj_filter_rejector[ifo])
    if not len(bank) == n_bank:
        n_bank = len(bank)
        logging.info("Template bank size after thinning: %d", n_bank)
    # Antenna patterns
    fp = {}
    fc = {}
    for ifo in args.instruments:
        fp[ifo], fc[ifo] = detector.Detector(ifo).antenna_pattern(
            args.ra, args.dec, polarization=0, t_gps=t_gps)
    # Loop over templates
    for t_num, template in enumerate(bank):
        # Loop over segments
        for s_num,stilde in enumerate(segments[args.instruments[0]]):
            stilde = {ifo : segments[ifo][s_num] for ifo in args.instruments}
            # Filter check checks the 'inj_filter_rejector' options to
            # determine whether to filter this template/segment 
            # if injections are present.
            analyse_segment = True
            for ifo in args.instruments:
                if not inj_filter_rejector[ifo].template_segment_checker(
                        bank, t_num, stilde[ifo]):
                    logging.info(
                        "Skipping segment %d/%d with template %d/%d as no "
                        "detectable injection is present",
                        s_num + 1, len(segments[ifo]), t_num + 1, n_bank)
                    analyse_segment = False
            # Find detector sensitivities (sigma) and make array of
            # normalised
            sigmasq = {
                ifo : template.sigmasq(segments[ifo][s_num].psd)
                for ifo in args.instruments}
            sigma = {ifo : np.sqrt(sigmasq[ifo]) for ifo in args.instruments}
            # Every time s_num is zero or we skip the segment, we run new 
            # template to increment the template index
            if s_num==0:
                event_mgr.new_template(tmplt=template.params, sigmasq=sigmasq)
            if not analyse_segment: continue
            logging.info(
                "Analyzing segment %d/%d", s_num + 1, len(segments[ifo]))
            snr_dict = dict.fromkeys(args.instruments)
            norm_dict = dict.fromkeys(args.instruments)
            corr_dict = dict.fromkeys(args.instruments)
            idx = dict.fromkeys(args.instruments)
            snrv_dict = dict.fromkeys(args.instruments)
            snr = dict.fromkeys(args.instruments)
            ifo_list = args.instruments[:]
            for ifo in args.instruments:
                logging.info(
                    "Filtering template %d/%d, ifo %s", t_num + 1, n_bank, ifo)
                # No clustering in the coherent search until the end.
                # The correlation vector is the FFT of the snr (so inverse
                # FFT it to get the snr). 
                snr_ts, norm, corr, ind, snrv = \
                     matched_filter[ifo].matched_filter_and_cluster(
                         s_num, template.sigmasq(stilde[ifo].psd), window=0)
                snr_dict[ifo] = (
                        snr_ts[matched_filter[ifo].segments[s_num].analyze]
                        * norm)
                norm_dict[ifo] = norm
                corr_dict[ifo] = corr.copy()
                idx[ifo] = ind.copy()
                snrv_dict[ifo] = snrv.copy()
                snr[ifo] = snrv * norm
            # Move onto next segment if there are no triggers.
            if len(ifo_list)==0: continue
            # Save the indexes of triggers (if we have any)
            # Even if we have none, need to keep an empty dictionary.
            # Only do this is idx doesn't get time shifted out of the time
            # we are looking at.
            idx_dict = {
                ifo: idx[ifo][np.logical_and(
                    idx[ifo] > time_delay_idx[ifo],
                    idx[ifo] - time_delay_idx[ifo] < len(snr_dict[ifo])
                    )
                ]
                for ifo in args.instruments}
            # Find triggers that are coincident (in geocent time) in
            # multiple ifos. If a single ifo analysis then just use the
            # indexes from that ifo.
            if len(args.instruments) > 1:
                coinc_idx = coh.get_coinc_indexes(idx_dict, time_delay_idx)
            else:
                coinc_idx = (
                    idx_dict[args.instruments[0]]
                    - time_delay_idx[args.instruments[0]]
                    )
            logging.info("Found %d coincident triggers", len(coinc_idx))
            # Number of ifos
            nifo = len(args.instruments)
            for ifo in args.instruments:
                # Move on if this segment has no data
                if len(snr_dict[ifo])==0:
                    raise RuntimeError(
                        'The SNR triggers dictionary is empty. This '
                        'should not be possible.')
                # Time delay is applied to indices
                coinc_idx_det_frame = {
                    ifo: coinc_idx + time_delay_idx[ifo]
                    for ifo in args.instruments}
            # Calculate the coincident and coherent snr
            # Check we have data before we try to compute the coherent snr
            if len(coinc_idx) != 0 and nifo > 1:
                # Find coinc snr at trigger times and apply coinc snr
                # threshold
                rho_coinc, coinc_idx, coinc_triggers = coh.coincident_snr(
                    snr_dict, coinc_idx, args.coinc_threshold, time_delay_idx)
                logging.info(
                    "%d points above coincident SNR threshold", len(coinc_idx))
                if len(coinc_idx) != 0:
                    logging.info("Max coincident SNR = %.2f", max(rho_coinc))
            # If there is only one ifo, then coinc_triggers is just the
            # triggers from ifo
            elif len(coinc_idx) != 0 and nifo == 1:
                coinc_triggers = {
                    args.instruments[0]: snr[args.instruments[0]][
                        coinc_idx_det_frame[args.instruments[0]]
                        ]
                    }
            else:
                coinc_triggers = {}
                logging.info("No triggers above coincident SNR threshold")
            # If we have triggers above coinc threshold and more than 2
            # ifos then calculate the coherent statistics
            if len(coinc_idx) != 0 and nifo > 2:
                if args.projection=='left+right':
                    project_l = coh.get_projection_matrix(
                        fp, fc, sigma, projection='left')
                    rho_coh_l, coinc_idx_l, coinc_triggers_l, rho_coinc_l = \
                        coh.coherent_snr(
                            coinc_triggers, coinc_idx, args.coinc_threshold,
                            project_l, rho_coinc)
                    project_r = coh.get_projection_matrix(
                        fp, fc, sigma, projection='right')
                    rho_coh_r, coinc_idx_r, coinc_triggers_r, rho_coinc_r = \
                        coh.coherent_snr(
                            coinc_triggers, coinc_idx, args.coinc_threshold,
                            project_r, rho_coinc)
                    max_idx = np.argmax([rho_coh_l, rho_coh_r], axis=0)
                    rho_coh = np.where(max_idx==0, rho_coh_l, rho_coh_r)
                    coinc_idx = np.where(
                        max_idx==0, coinc_idx_l, coinc_idx_r)
                    coinc_triggers = {
                        ifo: np.where(
                            max_idx==0, coinc_triggers_l[ifo],
                            coinc_triggers_r[ifo])
                        for ifo in coinc_triggers_l}
                    rho_coinc = np.where(
                        max_idx==0, rho_coinc_l, rho_coinc_r)
                else:
                    project = coh.get_projection_matrix(
                        fp, fc, sigma, projection=args.projection)
                    rho_coh, coinc_idx, coinc_triggers, rho_coinc = \
                        coh.coherent_snr(
                            coinc_triggers, coinc_idx, args.coinc_threshold,
                            project, rho_coinc)
                logging.info(
                    "%d points above coherent threshold", len(rho_coh))
                if len(coinc_idx) != 0:
                    logging.info("Max coherent SNR = %.2f", max(rho_coh))
                    #Find the null snr
                    null, rho_coh, rho_coinc, coinc_idx, coinc_triggers = \
                        coh.null_snr(
                            rho_coh, rho_coinc, snrv=coinc_triggers,
                            index=coinc_idx, apply_cut=args.do_null_cut)
                    if len(coinc_idx) != 0:
                        logging.info("Max null SNR = %.2f", max(null))
                    logging.info("%d points above null threshold", len(null))
            # We are now going to find the individual detector chi2 values.
            # To do this it is useful to find the indexes of coinc triggers
            # in the detector frame.
            if len(coinc_idx) != 0:
                # coinc_idx_det_frame is redefined to account for the cuts
                # to coinc_idx above
                coinc_idx_det_frame = {
                    ifo: coinc_idx + time_delay_idx[ifo]
                    for ifo in args.instruments}
                coherent_ifo_trigs = {
                    ifo: snr_dict[ifo][coinc_idx_det_frame[ifo]]
                    for ifo in args.instruments}
                # Calculate the power and autochi2 values for the coinc
                # indexes this uses the snr timeseries before the time
                # delay, so we need to undo it. Same for normalisation)
                chisq = {}
                chisq_dof = {}
                for ifo in args.instruments:
                    chisq[ifo], chisq_dof[ifo] = power_chisq.values(
                        corr_dict[ifo],
                        coherent_ifo_trigs[ifo] / norm_dict[ifo],
                        norm_dict[ifo], stilde[ifo].psd,
                        coinc_idx_det_frame[ifo] + stilde[ifo].analyze.start,
                        template)
                # Calculate network chisq value
                network_chisq_dict = coh.network_chisq(
                    chisq, chisq_dof, coherent_ifo_trigs)
                # Calculate chisq reweighted SNR
                if nifo > 2:
                    reweighted_snr = ranking.newsnr(
                        rho_coh, network_chisq_dict)
                    # Calculate null reweighted SNR
                    reweighted_snr = coh.reweight_snr_by_null(
                        reweighted_snr, null, rho_coh)
                elif nifo == 2:
                    reweighted_snr = ranking.newsnr(
                        rho_coinc, network_chisq_dict)
                else:
                    rho_sngl = abs(
                        snr[args.instruments[0]][
                            coinc_idx_det_frame[args.instruments[0]]
                            ]
                        )
                    reweighted_snr = ranking.newsnr(
                        rho_sngl, network_chisq_dict)
                # Cut based on reweighted snr and the given newsnr_threshold parameter
                reweighted_snr = coh.reweightedsnr_cut(reweighted_snr, args.newsnr_threshold)
                # Need all out vals to be the same length. This means the
                # entries that are single values need to be repeated once
                # per event.
                num_events = len(reweighted_snr)
                # the output will only be possible if
                # len(networkchi2) == num_events
                for ifo in args.instruments:
                    (ifo_out_vals['bank_chisq'],
                     ifo_out_vals['bank_chisq_dof']) = bank_chisq.values(
                        template, stilde[ifo].psd, stilde[ifo],
                        coherent_ifo_trigs[ifo] / norm_dict[ifo],
                        norm_dict[ifo],
                        coinc_idx_det_frame[ifo] + stilde[ifo].analyze.start)
                    ifo_out_vals['cont_chisq'] = autochisq.values(
                        snr[ifo] / norm_dict[ifo], coinc_idx_det_frame[ifo],
                        template, stilde[ifo].psd, norm_dict[ifo],
                        stilde=stilde[ifo], low_frequency_cutoff=flow)
                    ifo_out_vals['chisq'] = chisq[ifo]
                    ifo_out_vals['chisq_dof'] = chisq_dof[ifo]
                    ifo_out_vals['time_index'] = (
                        coinc_idx_det_frame[ifo] + stilde[ifo].cumulative_index
                        )
                    ifo_out_vals['snr'] = coherent_ifo_trigs[ifo]
                    # IFO is stored as an int
                    ifo_out_vals['ifo'] = (
                        [event_mgr.ifo_dict[ifo]] * num_events)
                    event_mgr.add_template_events_to_ifo(
                        ifo, ifo_names, [ifo_out_vals[n] for n in ifo_names])
                if nifo>2:
                    network_out_vals['coherent_snr'] = np.real(rho_coh)
                    network_out_vals['null_snr'] = np.real(null)
                elif nifo==2:
                    network_out_vals['coherent_snr'] = np.real(rho_coinc)
                else:
                    network_out_vals['coherent_snr'] = (
                        abs(snr[args.instruments[0]][
                            coinc_idx_det_frame[args.instruments[0]]
                            ])
                        )
                network_out_vals['reweighted_snr'] = reweighted_snr
                network_out_vals['my_network_chisq'] = (
                    np.real(network_chisq_dict))
                network_out_vals['time_index'] = (
                    coinc_idx + stilde[ifo].cumulative_index)
                network_out_vals['nifo'] = [nifo] * num_events
                network_out_vals['ra'] = [args.ra] * num_events
                network_out_vals['dec'] = [args.dec] * num_events
                event_mgr.add_template_events_to_network(
                    network_names,
                    [network_out_vals[n] for n in network_names])
        # Now cluster
        if args.cluster_method == "window":
            cluster_window = int(args.cluster_window * sample_rate)
        elif args.cluster_method == "template":
            cluster_window = int(template.chirp_length * sample_rate)
        event_mgr.cluster_template_network_events(
            "time_index", "coherent_snr", cluster_window)
        event_mgr.finalize_template_events()
event_mgr.write_events(args.output)
logging.info("Finished")
logging.info("Time to complete analysis: %.0f s", time.time() - time_init)
