#! /usr/bin/env python

# Copyright (C) 2019 Collin Capano
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""Wrapper around ligo-skymap-from-samples that creates fits files from a
inference or posterior hdf file.

Requires ligo.skymap to be installed, which requires python 3.

To create a fits file, a right ascension (ra), declination (dec), luminosity
distance (distance), and coalescence time (tc) must be provided. By default,
this will try to load those parameters from the file. However, math operations
on multiple parameters in the file, or constants may be provided instead using
the corresponding option (see below). For math operations, the syntax is the
same as what is used for the --parameters option in
pycbc_inference_extract_samples (see that program for details).
"""

import os
import numpy
import argparse
import subprocess

from pycbc.inference import io

parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('--input-file', required=True,
                    help='The inference or posterior hdf file to load.')
parser.add_argument('--output-file', required=True,
                    help='The output file name; it must end in ".fits".')
parser.add_argument('--maxpts', type=int,
                    help='The number of posterior samples to extract for '
                         'making the fits file. Default is to load all. Using '
                         'fewer points will result in faster runtime, but '
                         'more error in the result sky map.')
parser.add_argument('--ra', default='ra',
                    help='The name of the right ascension parameter in the '
                         'input file. May also provide a constant, or math '
                         'operations on multiple parameters. Default is "ra".')
parser.add_argument('--dec', default='dec',
                    help='The name of the declination parameter in the '
                         'input file. May also provide a constant, or math '
                         'operations on multiple parameters. Default '
                         'is "dec".')
parser.add_argument('--distance', default='distance',
                    help='The name of the distance parameter in the input '
                         'file. May also provide a constant, or math '
                         'operations on multiple parameters. Default is '
                         '"distance".')
parser.add_argument('--tc', default='tc',
                    help='The name of the coalesence time parameter in the '
                         'input file. May also provide a constant, or math '
                         'operations on multiple parameters. Default is '
                         '"tc".')
opts = parser.parse_args()

fp = io.loadfile(opts.input_file, 'r')
samples = fp.read_samples([opts.ra, opts.dec, opts.distance, opts.tc])
fp.close()

out = numpy.zeros((samples.size, 4))
out[:, 0] = samples[opts.ra]
out[:, 1] = samples[opts.dec]
out[:, 2] = samples[opts.distance]
out[:, 3] = samples[opts.tc]
basename = opts.output_file.replace('.fits', '')
txtfile = '{}.dat'.format(basename)
numpy.savetxt(txtfile, out, header='ra dec distance time')

if opts.maxpts is not None:
    maxptsarg = '--maxpts {} '.format(opts.maxpts)
else:
    maxptsarg = ' '
cmd = 'ligo-skymap-from-samples {}{} --fitsoutname {}'.format(
    maxptsarg, txtfile, opts.output_file)
ret = subprocess.run(cmd.split())
os.remove(txtfile)
ret.check_returncode()
try:
    os.remove('skypost.obj')
except:
    pass
