#!/usr/bin/env python

# Copyright (C) 2021 Thomas Dent
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.


import os
import argparse
import logging
import numpy as np

from matplotlib import use
use('Agg')
from matplotlib import rcParams
from matplotlib import pyplot as plt

from pycbc import init_logging
from pycbc.population import fgmc_functions as utils
from pycbc.population import fgmc_laguerre as fgmcl
from pycbc.population import fgmc_plots


rcParams.update({'axes.labelsize': 12,
                 'font.size': 12,
                 'legend.fontsize': 12,
                 'xtick.labelsize': 12,
                 'ytick.labelsize': 12,
                 'text.usetex': True,
                })


parser = argparse.ArgumentParser()
parser.add_argument('-V', '--verbose', action='store_true', default=False)
parser.add_argument('--diagnostic-plots', action='store_true', default=False,
                    help='Make extra diagnostic scatter plots')
parser.add_argument('--full-data-files', nargs='+',
                    help='hdf5 file(s) (space-separated) containing zerolag events'
                         'and (decimated/undecimated) timeslide background')
parser.add_argument('--bank-files', nargs='+',
                    help='hdf5 file(s) (space-separated) containing bank parameters')
parser.add_argument('--injection-files', nargs='+',
                    help='hdf5 file(s) (space-separated) containing injection event '
                         'data in hdfinjfind format')
parser.add_argument('--rank-injections', nargs='*',
                    help='hdf5 file(s) (space-separated) containing injection events '
                         'in statmap format, to treat as extra zerolag events. '
                         'Optional')
parser.add_argument('--prior-event-info', help='.txt data file with gps, stat, ifar, '
                    'ln bayes factor for prior events. Only ln BF is used in FGMC '
                    'calculation')
parser.add_argument('--network', nargs='*', default=['H1', 'L1', 'V1'],
                    help='All ifos to consider in determining coinc times/types '
                         '(space-separated)')
parser.add_argument('--coinc-times', nargs='+',
                    help='Types of coincident analysis time to read in '
                         '(space-separated). Ex. H1L1 H1L1V1')
parser.add_argument('--stat-threshold', type=float,
                    help='Only analyze events and injections above this '
                         'value')
parser.add_argument('--bg-bin-width', default=None,
                    help='Width of stat bins to estimate background density, eg 0.25')
parser.add_argument('--inj-bin-width', default=None,
                    help='Width of bins to estimate injection density, eg 0.4')
parser.add_argument('--bg-log-bin-width', type=float, default=None,
                    help='Width of log spaced bins to estimate background'
                         ' density, eg 0.02')
parser.add_argument('--inj-log-bin-width', type=float, default=None,
                    help='Width of log spaced bins to estimate injection'
                         ' density, eg 0.04')
parser.add_argument('--min-mchirp', type=float, required=True,
                    help='Minimum template chirp mass to accept')
parser.add_argument('--max-mchirp', type=float, required=True,
                    help='Maximum template chirp mass to accept')
parser.add_argument('--power-law-prior', default=-0.5, type=float,
                    help='Power-law prior on foreground count parameter, '
                         'default -0.5 ("Jeffreys")')
parser.add_argument('--laguerre-degree', type=int, default=20,
                    help='Degree of polynomial used for generalized L-G quadrature')
parser.add_argument('--ntop', type=int, default=15,
                    help='Number of loudest events listed, default 15')
parser.add_argument('--p-astro-txt', help='Save GPS and p_astro in txt file')
parser.add_argument('--p-astro-inj', help='Save injection GPS / p_astro in txt file.'
                    ' Optional')
parser.add_argument('--json-tag', help='If given, save event json files with tag '
                    'in name')
parser.add_argument('--json-min-ifar', type=float, help='IFAR threshold to save '
                    'json files')
parser.add_argument('--vt-mean', type=float,
                    help='Sensitivity estimate conjugate to merger rate: if '
                         'provided, rate estimate will be output in the same units')
parser.add_argument('--vt-err', type=float,
                    help='Standard error on vt-mean value, assumed in same units')
parser.add_argument('--rate-units', default=r'Gpc^-3 yr^-1',
                    help=r'Sensitivity units for rate output, default 1/(Gpc^3 yr)')
parser.add_argument('--plot-max-stat', type=float, default=30.,
                    help='Maximum background stat value to plot')
parser.add_argument('--plot-bg-hist', action='store_true',
                    help='Make diagnostic background PDF plots')
parser.add_argument('--plot-inj-hist', action='store_true',
                    help='Make diagnostic injection PDF plots')
parser.add_argument('--plot-dir', required=True, help='Destination for plots, '
                    'must include final "/"')
parser.add_argument('--plot-tag', help='String for plot filenames')

args = parser.parse_args()

if not args.plot_dir.endswith('/'):
    args.plot_dir += '/'
if not os.path.exists(args.plot_dir):
    raise RuntimeError("Output dir %s doesn't exist!" % args.plot_dir)
if args.bg_bin_width is not None and args.bg_log_bin_width is not None:
    raise RuntimeError("Can't have both linear and log background bins!")
if args.inj_bin_width is not None and args.inj_log_bin_width is not None:
    raise RuntimeError("Can't have both linear and log injection bins!")

init_logging(args.verbose)

# prepare to read in zerolag and background data
nchunks = len(args.full_data_files)
MANY_BANKS = True
if len(args.bank_files) != nchunks:
    MANY_BANKS = False
    if len(args.bank_files) != 1:
        raise RuntimeError('Either need the same number of banks as chunks, '\
                           + str(nchunks) + ', or exactly 1 bank!')
    else:
        args.bank_files = args.bank_files * nchunks

# Read zerolag/time slide data chunk by chunk
tot_exp_bg = 0
zl_inputs = zip(args.full_data_files, args.bank_files)

logging.info('Setting up zerolag')
fg = utils.ForegroundEvents(args, args.coinc_times,
                            bin_lo=args.min_mchirp, bin_hi=args.max_mchirp)

logging.info('Setting up background')
bg = utils.BackgroundEventRate(args, args.coinc_times,
                               bin_lo=args.min_mchirp, bin_hi=args.max_mchirp)

for fd, b in zl_inputs:
    logging.info('Adding bank info from ' + b)
    fg.add_bank(b)
    fg.filter_templates()
    bg.add_bank(b)
    bg.filter_templates()

    logging.info('Adding event info from ' + fd)
    fg.add_zerolag(fd)
    bg.add_background(fd)

if args.diagnostic_plots:
    plt.plot(fg.stat, fg.masspars, '+b', ms=3)
    plt.grid(True)
    plt.ylim(ymax=1.1 * fg.masspars.max())
    plt.xlabel('Rank statistic')
    plt.ylabel('Template param (mchirp)')
    plt.savefig(args.plot_dir + 'fg_stat_vs_mchirp.png')
    plt.semilogy()
    plt.ylim(0.9 * fg.masspars.min(), 1.1 * fg.masspars.max())
    plt.savefig(args.plot_dir + 'fg_stat_vs_logmchirp.png')
    plt.close()

if args.plot_bg_hist:
    bg.plot_bg()

# normalize and count expected bg events
bg.get_norms()
logging.info('Total expected bg count ' + str(bg.norm))
logging.info('Actual zerolag event count ' + str(len(fg.stat)))

# Read signal (injection) data
inj_inputs = zip(args.injection_files, args.bank_files, args.full_data_files)
sg = utils.SignalEventRate(args, args.coinc_times,
                           bin_lo=args.min_mchirp, bin_hi=args.max_mchirp)

for jf, b, fd in inj_inputs:
    logging.info('Adding inj info from' + jf)
    sg.add_bank(b)
    sg.filter_templates()
    sg.add_injections(jf, fd)
sg.make_all_bins()
if args.plot_inj_hist:
    sg.plot_inj()

sg.get_norms()
logging.info('Total number of inj used ' + str(sg.norm))

fg.get_bg_pdf(bg)
fg.get_sg_pdf(sg)
# Bayes factor for each zerolag event
allfgbg = fg.sg_pdf - fg.bg_pdf

# make diagnostic plots
if args.diagnostic_plots:
    print('ln Bayes factors from', min(allfgbg), max(allfgbg))
    col = {'H1L1': 'r', 'H1L1V1': 'b', 'H1V1': 'm', 'L1V1': 'g'}
    for cty in col:
        in_type = fg.ctype == cty
        plt.semilogx(fg.stat[in_type], fg.sg_pdf[in_type], col[cty]+'+',
                     ms=3, alpha=0.6, label=cty)
    plt.grid(True)
    plt.legend()
    plt.xlabel('Rank statistic')
    plt.ylabel('Signal ln PDF')
    plt.savefig(args.plot_dir + 'fg_stat_vs_sgpdf.png')
    plt.close()

    for cty in col:
        in_type = fg.ctype == cty
        plt.semilogx(fg.stat[in_type], fg.bg_pdf[in_type], col[cty]+'+',
                     ms=3, alpha=0.6, label=cty)
    plt.grid(True)
    plt.legend()
    plt.xlabel('Rank statistic')
    plt.ylabel('Noise ln PDF')
    plt.savefig(args.plot_dir + 'fg_stat_vs_bgpdf.png')
    plt.close()

    for cty in col:
        in_type = fg.ctype == cty
        plt.plot(fg.stat[in_type], allfgbg[in_type], col[cty]+'+',
                 ms=3, alpha=0.6, label=cty)
    plt.grid(True)
    plt.legend()
    plt.xlabel('Rank statistic')
    plt.ylabel('fg/bg Bayes factor')
    plt.xlim(xmax=args.plot_max_stat)  # zoom to see not-super-loud events
    plt.savefig(args.plot_dir + 'fg_stat_vs_fgbg_bf.png')
    plt.semilogx()
    plt.xlim(args.stat_threshold, 1.1 * fg.stat.max())
    plt.savefig(args.plot_dir + 'fg_logstat_vs_fgbg_bf.png')
    plt.close()

    for cty in col:
        in_type = fg.ctype == cty
        plt.loglog(fg.stat[in_type], fg.ifar[in_type], col[cty]+'+',
                   ms=3, alpha=0.6, label=cty)
    plt.grid(True)
    plt.legend()
    plt.xlabel('Rank statistic')
    plt.ylabel('IFAR (yr)')
    plt.savefig(args.plot_dir + 'fg_stat_vs_ifar.png')
    plt.close()

# add in prior event info
if args.prior_event_info:
    logging.warn('Adding events from', args.prior_event_info)
    prior_info = np.genfromtxt(args.prior_event_info, names=True)
    allfgbg = np.append(allfgbg, prior_info['ln_bayes_factor'])
else:  # no prior events!
    prior_info = {'stat': np.array([]),
                  'ifar': np.array([]),
                  'gps': np.array([]),
                  'bayes_factor': np.array([])}

# do the calculation
foreground_count_dist = \
      fgmcl.count_posterior(allfgbg, laguerre_n=args.laguerre_degree,
                            Lambda0=bg.norm, name='foreground count posterior',
                            prior=args.power_law_prior)

# Roulet et al. argue that information on rate scales with sum (p_astro^2)
logging.info('sum of p_astro^2',
             ((1 - foreground_count_dist.p_bg(allfgbg)) ** 2.).sum())

fgmc_plots.odds_summary(
    args,
    np.append(fg.stat, prior_info['stat']),
    np.append(fg.ifar, prior_info['ifar']),
    foreground_count_dist.p_bg(allfgbg),
    args.ntop,
    times=np.append(fg.gpst, prior_info['gps']),
    name='foreground events',
    plot_extensions=['.png'],
)

args.plot_limits = None
cmed, cminus, cplus = fgmc_plots.dist_summary(
    args,
    foreground_count_dist,
    plot_extensions=['.png'],
    middle='median',
    credible_intervals=[0.90],
)

# if VT and error are provided, convert signal PDF to rate PDF
if args.vt_mean:
    if args.vt_err is None:
        logging.warn('NB : no VT error provided, will use 0!')
        args.vt_err = 0.
    logging.info('Rate in units of ' + args.rate_units + ' ...')
    print('Rate without VT uncertainty :')
    print('%.1f_{-%.1f}^{+%.1f}' % (
        cmed / args.vt_mean, abs(cminus) / args.vt_mean, cplus / args.vt_mean))
    import rate_functions
    fracerr = args.vt_err / args.vt_mean
    # attr to arbitrarily rescale counts in rate calculation : here, don't rescale
    foreground_count_dist.scale = 1.
    # we want a Jeffreys type prior: this is already supplied by the default value
    # of the --power-law-prior option (see above) so leave a uniform prior on rate
    rate_distribution = \
        rate_functions.rate_posterior(foreground_count_dist, args.vt_mean,
                                      fracerr, unit=args.rate_units, texunit='')
    median, minus, plus = fgmc_plots.dist_summary(args, rate_distribution,
          plot_extensions=None, credible_intervals=[0.9])
    print('Rate with VT uncertainty via rate_posterior :')
    print('%.1f_{-%.1f}^{+%.1f}' % (median, abs(minus), plus))
elif args.vt_err:
    raise RuntimeError("Can't supply VT error without a VT mean!")


# if injections are to be evaluated 'like zerolag', read them in
if args.rank_injections:
    assert args.p_astro_inj is not None
    inj = utils.ForegroundEvents(args, args.coinc_times,
                                 bin_lo=args.min_mchirp, bin_hi=args.max_mchirp)

    # if using the same single input file for all results
    if not MANY_BANKS:
        banks = [args.bank_files[0]] * len(args.rank_injections)
    # otherwise need banks to match up with injs
    else:
        assert len(args.rank_injections) == len(args.bank_files)
        banks = args.bank_files
    for i, b in zip(args.rank_injections, banks):
        logging.info('Adding bank info from ' + b)
        inj.add_bank(b)
        inj.filter_templates()

        logging.info('Adding event info from ' + i)
        inj.add_zerolag(i)

    # get densities for injections as if zerolag
    inj.get_bg_pdf(bg)
    inj.get_sg_pdf(sg)
    # Bayes factor for each injection event
    injfgbg = inj.sg_pdf - inj.bg_pdf

    if args.diagnostic_plots:
        col = {'H1L1': 'r', 'H1L1V1': 'b', 'H1V1': 'm', 'L1V1': 'g'}
        for cty in col:
            in_type = inj.ctype == cty
            plt.semilogx(inj.stat[in_type], inj.sg_pdf[in_type], col[cty]+'+',
                         ms=3, alpha=0.6, label=cty)
        plt.grid(True)
        plt.legend()
        plt.xlabel('Rank statistic')
        plt.ylabel('Signal ln PDF')
        plt.savefig(args.plot_dir + 'inj_stat_vs_sgpdf.png')
        plt.close()

        for cty in col:
            in_type = inj.ctype == cty
            plt.semilogx(inj.stat[in_type], inj.bg_pdf[in_type], col[cty]+'+',
                         ms=3, alpha=0.6, label=cty)
        plt.grid(True)
        plt.legend()
        plt.xlabel('Rank statistic')
        plt.ylabel('Noise ln PDF')
        plt.savefig(args.plot_dir + 'inj_stat_vs_bgpdf.png')
        plt.close()

        for cty in col:
            in_type = inj.ctype == cty
            plt.plot(inj.stat[in_type], injfgbg[in_type], col[cty]+'+',
                     ms=3, alpha=0.6, label=cty)
        plt.grid(True)
        plt.legend()
        plt.xlabel('Rank statistic')
        plt.ylabel('fg/bg Bayes factor')
        # zoom to see not-super-loud events
        plt.xlim(args.stat_threshold, args.plot_max_stat)
        plt.savefig(args.plot_dir + 'inj_stat_vs_fgbg_bf.png')
        plt.semilogx()
        plt.xlim(args.stat_threshold, 1.1 * inj.stat.max())
        plt.savefig(args.plot_dir + 'inj_logstat_vs_fgbg_bf.png')
        plt.close()

    # evaluate p_astro using zerolag signal rate
    inj_pbg = foreground_count_dist.p_bg(injfgbg)

    # make a special txt format output file
    np.savetxt(args.p_astro_inj,
               np.column_stack((inj.gpst,
                                1. / inj.ifar,
                                inj.ifar,
                                inj.stat,
                                1. - inj_pbg)),
               fmt=['%.3f', '%e', '%e', '%6f', '%.6f'],
               delimiter=',',
               header='geocent_end_time,FAR,IFAR,detection_statistic,p_astro')

    if args.diagnostic_plots:
        plt.loglog(inj.stat, (1. - inj_pbg) / inj_pbg, 'k.', ms=3)
        plt.grid(True)
        plt.xlabel('Ranking statistic')
        plt.ylabel(r'$P_1/P_0$')
        plt.xlim(0.99 * inj.stat.min(), 2. * args.plot_max_stat)
        plt.savefig(args.plot_dir + 'injection_events_odds.png')
        plt.close()

logging.info('Done!')
