#!/usr/bin/env python

# Copyright (C) 2015 Tito Dal Canton
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""This program reads the noise PSDs estimated by pycbc_calculate_psd and
calculates the average PSD over time for each detector, as well as the average
PSD across time and detectors. The currently implemented averaging method is
the harmonic mean."""

import logging
import argparse
import numpy as np
import h5py
import pycbc
from pycbc.version import git_verbose_msg as version
from pycbc.types import MultiDetOptionAction, FrequencySeries


parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('--version', action='version', version=version)
parser.add_argument('--verbose', action='store_true')
parser.add_argument('--input-files', nargs='+', required=True, metavar='PATH',
                    help='HDF5 files from pycbc_calculate_psd (one per '
                         'detector) containing the input PSDs to average.')
parser.add_argument('--time-avg-file', nargs='+', action=MultiDetOptionAction,
                    metavar='DETECTOR:PATH',
                    help='Output file names for single-detector PSDs averaged '
                         'over time.')
parser.add_argument('--detector-avg-file', metavar='PATH',
                    help='Output file name for the average PSD over time and '
                         'detectors.')
args = parser.parse_args()
pycbc.init_logging(args.verbose)

dynamic_range_factor = pycbc.DYN_RANGE_FAC ** (-2.)

time_avg_psds = {}

delta_f = None

for input_file in args.input_files:
    logging.info('Reading %s', input_file)
    f = h5py.File(input_file, 'r')
    ifo = tuple(f.keys())[0]
    df = f[ifo + '/psds/0'].attrs['delta_f']
    if delta_f is None:
        delta_f = df
    elif delta_f != df:
        raise ValueError('Inconsistent frequency resolution in input PSDs '
                         '(%f vs %f)' % (df, delta_f))
    keys = f[ifo + '/psds'].keys()

    logging.info('Averaging %s over time', ifo)
    sum_inv_psds = None
    count = 0
    for key in f[ifo + '/psds'].keys():
        psd = f[ifo + '/psds/' + key][:]
        if sum_inv_psds is None:
            sum_inv_psds = 1. / psd
        else:
            sum_inv_psds += 1. / psd
        count += 1
    avg_psd = count / sum_inv_psds
    time_avg_psds[ifo] = avg_psd

    if ifo in args.time_avg_file and args.time_avg_file[ifo]:
        logging.info('Writing %s average over time', ifo)
        fs = FrequencySeries(
                avg_psd.astype(np.float64) * dynamic_range_factor,
                delta_f=delta_f)
        fs.save(args.time_avg_file[ifo], ifo=ifo)

if args.detector_avg_file:
    logging.info('Averaging over detectors')
    sum_inv_psds = None
    for ifo, psd in time_avg_psds.items():
        if sum_inv_psds is None:
            sum_inv_psds = 1. / psd
        else:
            sum_inv_psds += 1. / psd
    network_psd = len(time_avg_psds) / sum_inv_psds

    logging.info('Writing average over detectors')
    fs = FrequencySeries(
            network_psd.astype(np.float64) * dynamic_range_factor,
            delta_f=delta_f)
    ifo_str = ''.join(sorted(time_avg_psds.keys()))
    fs.save(args.detector_avg_file, ifo=ifo_str)

logging.info('Done')

