#!/bin/env python
""" Create a file containing the phase and amplitude, 
correlations between two detectors by
doing a simple monte-carlo
"""
import argparse, h5py, numpy.random, pycbc.detector, logging, multiprocessing
from numpy.random import normal, uniform, power
from scipy.ndimage.filters import gaussian_filter

parser = argparse.ArgumentParser()
parser.add_argument('--ifos', nargs=2, help="The two ifos to generate a histogram for")
parser.add_argument('--sample-size', type=int, 
                    help="Approximant number of independent samples to draw for the distribution")
parser.add_argument('--min-snr', type=float, 
                    help="SNR cutoff to apply to the drawn SNR values.")
parser.add_argument('--max-snr', type=float, default=100,
                    help="Maximum snr to draw signals out to")
parser.add_argument('--timing-error', type=float,
                    help="Error on timing measurement in seconds")
parser.add_argument('--min-detector-ratio', type=float, default=0.5,
                    help="The minimum ratio between detector sensitivities to bin")
parser.add_argument('--snr-error', type=float,
                    help="Error on the SNR recovery.")
parser.add_argument('--coinc-threshold', default=0, type=float,
                    help="seconds to add to the TOF coinc window")
parser.add_argument('--seed', type=int, default=124)
parser.add_argument('--detector-ratio-granularity', type=float, default=0.1)
parser.add_argument('--bin-density', type=int, default=2,
                    help="Density of bins to make as a multiplicity of errors in each bin paramter")
parser.add_argument('--output-file')
parser.add_argument('--cores', default=1, type=int)
parser.add_argument('--verbose', action='store_true')
args = parser.parse_args()

d1 = pycbc.detector.Detector(str(args.ifos[0]))
d2 = pycbc.detector.Detector(str(args.ifos[1]))
maxdt = d1.light_travel_time_to_detector(d2) + args.coinc_threshold

# Calculate the edges of the bins

def bcount(left, right, error):
    bin_density = args.bin_density
    num = int((right - left) / float(error) * bin_density)
    width = (right - left) / float(num)
    bins = numpy.linspace(left, right, num=num)
    return num, width, bins

snr_error = args.snr_error * 2 ** 0.5

# Phase error is approximated for SNR ~ snr threshold
phase_error = numpy.arctan(snr_error / args.min_snr)
detector_error = args.detector_ratio_granularity * args.bin_density / 2.0

tnum, twidth, tbins = bcount(-maxdt, maxdt, args.timing_error)
pnum, pwidth, pbins = bcount(0, 2.0 * numpy.pi, phase_error)
snum, swidth, sbins = bcount(args.min_snr, args.max_snr, snr_error)
rnum, rwidth, rbins = bcount(args.min_detector_ratio, 1, detector_error)

hist_bins = (tbins, pbins, sbins, sbins, rbins)
print(tbins)
print(pbins)
print(sbins)
print(rbins)
pycbc.init_logging(args.verbose)

def generate_hist(val):
    size, seed = val
    total = 0
    data = None
    chunksize = 20000 if 20000 < size else size
    while total < size:
        total += chunksize
        chunk_hist, _ = numpy.histogramdd(generate_samples((chunksize, seed)), bins=hist_bins)
        if data is None:
            data = chunk_hist
        else:
            data += chunk_hist
    return data

def generate_samples(val):
    size, seed = val
    numpy.random.seed(seed)
    logging.info('generating %s samples' % size)
    # Choose random sky location and polarizations
    ra = uniform(0, 2 * numpy.pi, size=size)
    dec = numpy.arccos(uniform(-1., 1., size=size)) - numpy.pi/2
    inc = numpy.arccos(uniform(-1., 1., size=size))
    pol = uniform(0, 2 * numpy.pi, size=size)
    ip = numpy.cos(inc)
    ic = 0.5 * (1.0 + ip * ip)

    # Calculate the expected time offset, and fp,fc for both detectors
    fp1, fc1, fp2, fc2, td = [], [], [], [], []
    for r, d, p in zip(ra, dec, pol):
        r1, r2 = d1.antenna_pattern(r, d, p, 0)
        fp1.append(r1)
        fc1.append(r2)
        r1, r2 = d2.antenna_pattern(r, d, p, 0)
        fp2.append(r1)
        fc2.append(r2)
        
        t1 = d1.time_delay_from_earth_center(r, d, 0)
        t2 = d2.time_delay_from_earth_center(r, d, 0)
        td.append(t1 - t2)

    # Scale fp fc to a volumentric distribution of SNRs
    # add on gaussian errors in SNR
    f = 1000
    fsize = f * size
    dist = power(3, fsize) / args.min_snr
   
    r = uniform(args.min_detector_ratio, 1.0, size=len(dist))
    sp1 = numpy.resize(fp1 * ip, len(dist)) / dist * r
    sc1 = numpy.resize(fc1 * ic, len(dist)) / dist * r
    sp2 = numpy.resize(fp2 * ip, len(dist)) / dist
    sc2 = numpy.resize(fc2 * ic, len(dist)) / dist 
    td = numpy.resize(td, fsize)

    # Remove points below the SNR threshold
    t = sp1**2.0 + sc1**2.0 > args.min_snr ** 2.0
    t2 = sp2**2.0 + sc2**2.0 > args.min_snr ** 2.0
    t = numpy.logical_and(t, t2)
    sp1 = sp1[t]
    sp2 = sp2[t]
    sc1 = sc1[t]
    sc2 = sc2[t]
    r = r[t]

    s1 = (sp1**2.0 + sc1**2.0)**0.5
    s2 = (sp2**2.0 + sc2**2.0)**0.5

    td = td[t]
    phase_diff = (numpy.arctan2(sc1, sp1) - numpy.arctan2(sc2, sp2)) % (numpy.pi * 2)
    
    logging.info('keeping %s values' % len(s1))
    print(r)
    return td, phase_diff, s1, s2, r

# This just makes sure that 3 chunks are submitted to each process
fiddle = 1
core_size = int(args.sample_size / args.cores) / fiddle
chunk_data = [core_size] * args.cores * fiddle

seeds = numpy.arange(args.seed, args.seed + len(chunk_data), 1)
chunk_data = tuple(list(zip(chunk_data, seeds)))
print(chunk_data)

if args.cores == 1:
    h = map(generate_hist, chunk_data)
else:
    pool = multiprocessing.Pool(args.cores)
    h = pool.map(generate_hist, chunk_data)

h = numpy.sum(h, axis=0)
f = h5py.File(args.output_file, 'w')

# Convert the errors to units of number of histogram bins and apply using
# a gaussian filter
errors = (args.timing_error / twidth, 
          phase_error / pwidth,
          snr_error / swidth,
          snr_error / swidth,
          .0000001)
f['map'] = gaussian_filter(h, errors, mode='nearest', truncate=8).astype(numpy.float32)
f['tbins'] = tbins
f['pbins'] = pbins
f['sbins'] = sbins
f['rbins'] = rbins
f.attrs['ifo0'] = args.ifos[0]
f.attrs['ifo1'] = args.ifos[1]
f.attrs['stat'] = "phasetd_newsnr"

