#!/usr/bin/env python
""" Plot variation in PSD
"""
import matplotlib; matplotlib.use('Agg');
import h5py, numpy, argparse, pylab, sys
import pycbc.results, pycbc.types, pycbc.version, pycbc.waveform, pycbc.filter

from pycbc.fft.fftw import set_measure_level
set_measure_level(0)

parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--version", action='version', version=pycbc.version.git_verbose_msg)
parser.add_argument("--psd-files", nargs='+', help='HDF file of psds')
parser.add_argument("--output-file", help='output file name')
parser.add_argument("--mass1", nargs="+", help="Mass of first component in solar masses", type=float)
parser.add_argument("--mass2", nargs="+", help="Mass of second component in solar masses", type=float)
parser.add_argument("--approximant", nargs="+", help="approximant to use for range")
args = parser.parse_args()

canonical_snr = 8.0

fig = pylab.figure(0)
pylab.xlabel('Time (s)')
pylab.ylabel('Inspiral Range (Mpc)')
pylab.grid()

for psd_file in args.psd_files:
    f = h5py.File(psd_file, 'r')
    ifo = tuple(f.keys())[0]
    flow = f.attrs['low_frequency_cutoff']
    keys = list(f[ifo + '/psds'].keys())
    start, end = f[ifo + '/start_time'][:], f[ifo + '/end_time'][:]
    f.close()
    ranges = {}
    for i in range(len(keys)):
        name = ifo + '/psds/' + str(i)
        psd = pycbc.types.load_frequencyseries(psd_file, group=name)
        delta_t = 1.0 / ((len(psd) - 1) * 2 * psd.delta_f)
        out = pycbc.types.zeros(len(psd), dtype=numpy.complex64)

        for m1, m2, apx in zip(args.mass1, args.mass2, args.approximant):
            htilde = pycbc.waveform.get_waveform_filter(out,
                                     mass1=m1,mass2=m2, approximant=apx,
                                     f_lower=flow, delta_f=psd.delta_f,
                                     delta_t=delta_t,
                                     distance = 1.0/pycbc.DYN_RANGE_FAC)
            htilde = htilde.astype(numpy.complex64)
            sigma = pycbc.filter.sigma(htilde, psd=psd, low_frequency_cutoff=flow)
            horizon_distance = sigma / canonical_snr
            inspiral_range = horizon_distance / 2.26

            wf_key = (m1, m2, apx)
            if wf_key in ranges:
                ranges[wf_key].append(inspiral_range)
            else:
                ranges[wf_key] = [inspiral_range]

    for m1, m2, apx in zip(args.mass1, args.mass2, args.approximant):
        if len(args.approximant) > 1:
            label = '%s: $%sM_{\odot}-%sM_{\odot}$ (%s)' % (ifo, m1, m2, apx)
        else:
            label = str(ifo)
        wf_key = (m1, m2, apx)
        pylab.errorbar((start+end)/2, ranges[wf_key], xerr=(end-start)/2,
                       ecolor=pycbc.results.ifo_color(ifo), label=label,
                       fmt='none')

pylab.legend(loc="best", fontsize='small')

if len(args.approximant) == 1:
    fig.suptitle('$%sM_{\odot}-%sM_{\odot}$ %s' % (m1, m2, apx))

pycbc.results.save_fig_with_metadata(fig, args.output_file,
    title = "Inspiral Range",
    caption = "The canonical sky- and orientation-averaged inspiral range for a single "
              "detector at SNR 8. This range is comparable to the SenseMon range and a factor of 2.26 smaller than the horizon distance.",
    cmd = ' '.join(sys.argv),
    fig_kwds={'dpi':200}
    )
