#!/usr/bin/env python
""" Plot variation in PSD
"""
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import h5py, numpy, argparse, sys, math
import pycbc.results, pycbc.types, pycbc.version, pycbc.waveform, pycbc.filter

from pycbc.fft.fftw import set_measure_level
set_measure_level(0)

parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--version", action='version', version=pycbc.version.git_verbose_msg)
parser.add_argument("--psd-files", nargs='+', help='HDF file of psds')
parser.add_argument("--output-file", help='output file name')
parser.add_argument("--min_mtot", nargs="+", help="Minimum total mass for range", type=float)
parser.add_argument("--max_mtot", nargs="+", help="Maximum total mass for range", type=float)
parser.add_argument("--d_mtot", nargs="+", help="Delta total mass for range ", type=float)
parser.add_argument("--approximant", nargs="+", help="approximant to use for range")
args = parser.parse_args()

canonical_snr = 8.0

fig = plt.figure(0)
plt.xlabel('Total Mass (M$_{\odot}$)')
plt.ylabel('Inspiral Range (Mpc)')
plt.grid() 

for psd_file in args.psd_files:
    f = h5py.File(psd_file, 'r')
    ifo = tuple(f.keys())[0]
    flow = f.attrs['low_frequency_cutoff']
    keys = f[ifo + '/psds'].keys()
    start, end = f[ifo + '/start_time'][:], f[ifo + '/end_time'][:]
    seglen = numpy.subtract(end, start)
    tott = sum(seglen)
    f.close()
    ranges = {}
    avg_range, rangerr, mbin = [], [], []

    for i in range(len(keys)):
        name = ifo + '/psds/' + str(i)
        psd = pycbc.types.load_frequencyseries(psd_file, group=name)
        delta_t = 1.0 / ((len(psd) - 1) * 2 * psd.delta_f)
        out = pycbc.types.zeros(len(psd), dtype=numpy.complex64)

        for mi, mf, dm, apx in zip(args.min_mtot, args.max_mtot, args.d_mtot, args.approximant):
            for M in numpy.arange(mi, mf, dm):
                htilde = pycbc.waveform.get_waveform_filter(out,
                                     mass1=M/2.,mass2=M/2., approximant=apx,
                                     f_lower=flow, delta_f=psd.delta_f,
                                     delta_t=delta_t, 
                                     distance = 1.0/pycbc.DYN_RANGE_FAC)
                htilde = htilde.astype(numpy.complex64)
                sigma = pycbc.filter.sigma(htilde, psd=psd,
                                           low_frequency_cutoff=flow)
                horizon_distance = sigma / canonical_snr 
                inspiral_range = horizon_distance / 2.26

                if M in ranges:
                    ranges[M].append(inspiral_range)
                else:
                    ranges[M] = [inspiral_range]

    for M in numpy.arange(mi, mf, dm):
        mean = numpy.average(ranges[M], weights=seglen)
        variance = numpy.average((ranges[M]-mean)**2, weights=seglen)
        stddev = math.sqrt(variance)
        avg_range.append(mean), rangerr.append(stddev), mbin.append(M)

    for apx in args.approximant:
        label = '%s-%s' % (ifo, apx)
        plt.errorbar(mbin, avg_range, yerr=rangerr, ecolor=pycbc.results.ifo_color(ifo), label=label, fmt='none')  
        plt.plot(mbin, avg_range, color=pycbc.results.ifo_color(ifo))

plt.legend(loc="upper left")

pycbc.results.save_fig_with_metadata(fig, args.output_file,
    title = "Inspiral Range",
    caption = "The canonical sky-averaged inspiral range for a single "
              "detector at SNR 8 vs total mass:equal mass binary",
    cmd = ' '.join(sys.argv),
    fig_kwds={'dpi':200}
    )
