#!/usr/bin/env python

# Copyright (C) 2015 Miriam Cabero Mueller
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""
Plot the variation of the amplitude spectral density (ASD) over time
using the PSD hdf file generated by the pipeline
"""

import logging
import argparse
import numpy
import h5py
import sys
from six.moves import range
import matplotlib
matplotlib.use('agg')
import pylab
import pycbc.results
from matplotlib.colors import LogNorm

parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('--psd-file', required=True,
                    help='HDF file containing the PSDs.')
parser.add_argument('--output-file', required=True,
                    help='Name for the output plot.')
parser.add_argument('--normalize', action='store_true',  
                    help='Select this option if you want to plot the amplitude '
                         'relative to median.')
parser.add_argument('--bins', type=int, default=1000,
                    help='Number of (logarithmic-spaced) frequency bins to plot'
                         ' (default %(default)d).')
parser.add_argument('--reduction-method', choices=['min','max','mean','median'],
                    default='mean',
                    help='Method of reducing the ASD data into fewer'
                         ' frequency bins to plot (default %(default)s).')
parser.add_argument('--verbose', action='store_true')

opts = parser.parse_args()
pycbc.init_logging(opts.verbose)

fig=pylab.figure()
ax=fig.gca()

logging.info('Reading %s', opts.psd_file)

f = h5py.File(opts.psd_file, 'r')
ifo = tuple(f.keys())[0]
df = f[ifo + '/psds/0'].attrs['delta_f']
keys = f[ifo + '/psds'].keys()
psds = [f[ifo + '/psds/' + str(key)][:] for key in range(len(f[ifo + '/psds']))]

flow = f.attrs['low_frequency_cutoff']
kmin = int(flow / df)

fac = 1.0 / pycbc.DYN_RANGE_FAC
freqs = numpy.arange( 0 , len(psds[0]) )[kmin:] * df
start, end = f[ifo + '/start_time'][:], f[ifo + '/end_time'][:]

logging.info('Starting the plot PSD over time')

reduction_function = { 'min' : numpy.min , 'max' : numpy.max , 'mean' : numpy.mean ,
                       'median' : numpy.median  }
asd_median = numpy.median( psds, axis = 0 )[kmin:] ** 0.5 * fac
# Select limits for the colorbar
if opts.normalize:
    asdmin = 0.5
    asdmax = 2
else:
    asdmin = min(psds[0][kmin:]) ** 0.5 * fac
    asdmax = 1e-20

for psd_segment in range(len(psds)):
    logging.info('Plotting PSD segment number %d', psd_segment)
    time = numpy.array( [ start[psd_segment] , end[psd_segment] ] )
    if opts.normalize:
        asd = psds[psd_segment][kmin:] ** 0.5 * fac / asd_median
    else:
        asd = psds[psd_segment][kmin:] ** 0.5 * fac

    # Get the frequency bins equally spaced in a logarithmic scale and the
    # indices needed for the ASDs
    freqsegs = numpy.logspace( numpy.log10(flow) , numpy.log10(freqs.max()) ,
                               num = opts.bins )
    indices = numpy.searchsorted( freqs , freqsegs )

    # Take the ASD of each frequency bin using the method given
    asdless = []
    for index in range( 1 , len(indices) ):
        asdless.append( reduction_function[ opts.reduction_method ]
                                          ( asd[indices[index-1] : indices[index]] ))

    Y , X = numpy.meshgrid( freqsegs[:-1] , time )
    im = ax.pcolormesh( X , Y , numpy.array([asdless]) , norm = LogNorm( vmin = asdmin , vmax = asdmax ) )

logging.info('Saving results')

cb = fig.colorbar( im )
if opts.normalize:
    cbLabel = 'Amplitude relative to median'
    cb.set_ticks( [0.5, 1, 2] )
    cb.ax.set_yticklabels( ['0.5', '1', '2'] )
else:
    cbLabel = 'Amplitude Spectral Density (Strain / $\\sqrt{\\rm Hz}$)'
cb.set_label( cbLabel, fontsize=20 )
ax.set_yscale('log')
ax.set_xlim( start.min() - 5000 , end.max() + 5000 )
ax.set_ylim( numpy.min(freqs) , numpy.max(freqs) )

ticks = []
for days in range( (end.max() - start.min()) / 86400 + 1 ):
    ticks.append( start.min() + 86400 * days )
ticklabels = numpy.arange( (end.max() - start.min()) / 86400 + 1)
ax.set_xticks( ticks )
ax.set_xticklabels( ticklabels )

ax.set_xlabel('Days since GPS %d' % start.min(), fontsize=18 )
ax.set_ylabel('Frequency (Hz)', fontsize=18 )
fig.set_size_inches(18.5, 10.5)
pylab.tight_layout()

title = ('Evolution of the noise spectral density over time in %s' % ifo)
caption = ('Variation of the amplitude spectral density over time. The original'
           'frequency bins are reduced to %d logarithmic frequency bins by taking'
           'the %s. Each time segment contains %d seconds.') % \
           (opts.bins , opts.reduction_method , end[0] - start[0] )
pycbc.results.save_fig_with_metadata(fig, opts.output_file,
    title = title,
    caption = caption,
    cmd=' '.join(sys.argv),
    fig_kwds={'dpi': 200})

