#!/bin/env python
""" Create a file containing the time, phase and amplitude
correlations between two or more detectors for signals by
doing a simple monte-carlo

Output is the relative amplitude, time, and phase as compared to a reference
IFO. A separate calculation is done for each IFO as a possible reference.
The data is stored as two vectors : one vector gives the discrete integer bin
corresponding to a particular location in 3*(Nifo-1)-dimensional 
amplitude/time/phase space, the other gives the weight assigned to that bin.
To get the signal rate this weight should be scaled by the local sensitivity
value and by the SNR of the event in the reference detector. 
"""
import argparse, h5py, numpy.random, pycbc.detector, logging, multiprocessing
from numpy.random import normal, uniform, power
from scipy.stats import norm
from copy import deepcopy

parser = argparse.ArgumentParser()
parser.add_argument('--ifos', nargs='+',
                    help="The ifos to generate a histogram for")
parser.add_argument('--sample-size', type=int,
                    help="Approximate number of independent samples to draw for the distribution")
parser.add_argument('--snr-ratio', type=float,
                    help="The SNR ratio permitted between reference ifo and all others."
                         "Ex. giving 4 permits a ratio of 0.25 -> 4")
parser.add_argument('--relative-sensitivities', nargs='+', type=float,
                    help="Numbers proportional to horizon distance or expected SNR "
                         "at fixed distance, one for each ifo")
parser.add_argument('--seed', type=int, default=124)
parser.add_argument('--output-file')
parser.add_argument('--verbose', action='store_true')
parser.add_argument('--bin-density', type=int, default=1,
                    help="Number of bins per 1 sigma uncertainty in a parameter."
                         " Higher values increase the resolution of the histogram"
                         " at the expense of storage.")
parser.add_argument('--smoothing-sigma', type=int, default=2,
                    help="Width of the smoothing kernel in sigmas")
parser.add_argument('--timing-uncertainty', type=float, default=.001,
                    help="Timing uncertainty to set bin size and smoothing interval"
                         " [default=.001s]")
parser.add_argument('--phase-uncertainty', type=float, default=0.25,
                    help="Phase uncertainty used to set bin size and smoothing")
parser.add_argument('--snr-reference', type=float, default=5,
                    help="Reference SNR to scale SNR uncertainty")
parser.add_argument('--snr-uncertainty', type=float, default=1.0,
                    help="SNR uncertainty to set bin size and smoothing")
parser.add_argument('--weight-threshold', type=float, default=1e-10,
                    help="Minimum histogram weight to store: "
                         "bins with less weight will not be stored.")
parser.add_argument('--batch-size', type=int, default=1000000)
args = parser.parse_args()

assert len(args.relative_sensitivities) == len(args.ifos)

# Approximate timing error at lower SNRs
twidth = args.timing_uncertainty / args.bin_density
# Factor of sqrt(2) as we'll be combining two SNRs with independent uncertainties
serr = args.snr_uncertainty * 2 ** 0.5
# Reference SNR for bin smoothing
sref = args.snr_reference  
swidth = serr / sref / args.bin_density
# Approximate phase error at lower SNRs
pwidth = numpy.arctan(serr / sref) / args.bin_density

srbmax = int(args.snr_ratio / swidth)
srbmin = int((1.0 / args.snr_ratio) / swidth)

# Apply a simple smoothing to help account for measurement errors for weak signals
def smooth_param(data, index):
    mref = max(data.values())
    bins = numpy.arange(-args.smoothing_sigma * args.bin_density,\
                        args.smoothing_sigma * args.bin_density + 1)
    kernel = norm.pdf(bins, scale=args.bin_density)

    nweights = {}

    # This is massively redundant as many points may be
    # recalculated, but it's straightforward.
    for i, key in enumerate(data):
        if data[key] / mref < args.weight_threshold:
            continue

        for a in bins:
            nkey = list(key)
            nkey[index] += a
            tnkey = tuple(nkey)

            if tnkey in nweights:
                continue

            weight = 0
            for b, w in zip(bins, kernel):
                wkey = list(nkey)
                wkey[index] += b
                wkey = tuple(wkey)

                if wkey in data:
                    weight += data[wkey] * w

            nweights[tnkey] = weight
    return nweights

d = {ifo: pycbc.detector.Detector(ifo) for ifo in args.ifos}

pycbc.init_logging(args.verbose)

numpy.random.seed(args.seed)
size = args.batch_size
chunks = int(args.sample_size / size) + 1

# Store results for each ifo as a reference. The reference ifo is the
# ifo which gets the smallest amplitude. This allows us to get the correct
# symmetries handled under ifo switch and apply more consistent treatment
# of error uncertainties.
f = h5py.File(args.output_file, 'w')
for ifo0 in args.ifos:
    other_ifos = deepcopy(args.ifos)
    other_ifos.remove(ifo0)

    l = 0
    nsamples = 0
    weights = {}
    for k in range(chunks):
        nsamples += size
        logging.info('generating %s samples' % size)

        # Choose random sky location and polarizations from
        # an isotropic population
        ra = uniform(0, 2 * numpy.pi, size=size)
        dec = numpy.arccos(uniform(-1., 1., size=size)) - numpy.pi/2
        inc = numpy.arccos(uniform(-1., 1., size=size))
        pol = uniform(0, 2 * numpy.pi, size=size)
        ip = numpy.cos(inc)
        ic = 0.5 * (1.0 + ip * ip)

        # calculate the toa, poa, and amplitude of each sample
        data = {}
        for rs, ifo in zip(args.relative_sensitivities, args.ifos):
            data[ifo] = {}
            fp, fc = d[ifo].antenna_pattern(ra, dec, pol, 0)
            sp, sc = fp * ip, fc * ic
            data[ifo]['s'] = (sp ** 2. + sc ** 2.) ** 0.5 * rs
            data[ifo]['t'] = d[ifo].time_delay_from_earth_center(ra, dec, 0)
            data[ifo]['p'] = numpy.arctan2(sc, sp)

        # Bin the data
        bind = []
        keep = None
        for ifo1 in other_ifos:
            dt = (data[ifo0]['t'] - data[ifo1]['t'])
            dp = (data[ifo0]['p'] - data[ifo1]['p']) % (2. * numpy.pi)
            sr = (data[ifo1]['s'] / data[ifo0]['s'])
            dtbin = (dt / twidth).astype(numpy.int)
            dpbin = (dp / pwidth).astype(numpy.int)
            srbin = (sr / swidth).astype(numpy.int)

            # We'll only store a limited range of ratios
            if keep is None:
                keep = (srbin < srbmax) & (srbin > srbmin)
            else:
                keep = keep & (srbin < srbmax) & (srbin > srbmin)
            bind += [dtbin, dpbin, srbin]

        # Calculate and sum the weights for each bin
        # use first ifo as reference for weights
        bind = [a[keep] for a in bind]

        w = data[ifo0]['s'][keep] ** 3.
        for i, key in enumerate(zip(*bind)):
            if key not in weights:
                weights[key] = 0
            weights[key] += w[i]

        ol = l
        l = len(weights.values())
        logging.info('%s, %s, %s, %s', l, l - ol, (l - ol) / float(size), l / float(nsamples))

    logging.info('applying smoothing')
    # apply smoothing iteratively
    for i, ifo in enumerate(range(len(args.ifos)-1)):
        logging.info('%s-phase', len(weights))
        weights = smooth_param(weights, i * 3 + 1)
        logging.info('%s-time', len(weights))
        weights = smooth_param(weights, i * 3 + 0)
        logging.info('%s-amp', len(weights))
        weights = smooth_param(weights, i * 3 + 2)
    logging.info('smoothing done: %s', len(weights))

    # Get normalizations for different amplitude ratio thresholds
    normv = {}
    for key in weights:
        rvals = tuple(list(key)[2::3])
        if rvals not in normv:
            normv[rvals] = 0
        m = max(normv[rvals], weights[key])
        normv[rvals] = m

    # save dict to hdf5 file as key + value array
    keys = numpy.array(weights.keys())
    values = numpy.array(weights.values())
    values /= values.max()

    # Threshold to keep saved bins to a reasonable amount
    l = values > args.weight_threshold
    keys = keys[l].astype(numpy.int8)
    values = values[l].astype(numpy.float32)
    logging.info('Final length: %s', len(keys))

    # presort the keys, helps with downstream use
    l = values.argsort()
    keys = keys[l]
    values = values[l]

    f.create_dataset('%s/param_bin' % ifo0, data=keys, compression='gzip', compression_opts=7)
    f.create_dataset('%s/weights' % ifo0, data=values, compression='gzip', compression_opts=7)

f.attrs['sensitivity_ratios'] = args.relative_sensitivities
f.attrs['srbmin'] = srbmin
f.attrs['srbmax'] = srbmax
f.attrs['twidth'] = twidth
f.attrs['pwidth'] = pwidth
f.attrs['swidth'] = swidth
f.attrs['ifos'] = args.ifos
f.attrs['stat'] = 'phasetd_newsnr_%s' % ''.join(args.ifos)
