#!/usr/bin/env python

"""
Program for plotting the results of the pycbc_faithsim and pycbc_collect_result script
that compare two approximants and compute the match between them.
"""

import argparse
import matplotlib

matplotlib.use("Agg")
import matplotlib.cm
import pylab
import numpy as np
from ligo.lw import utils, table, lsctables
from matplotlib.ticker import MultipleLocator, FormatStrFormatter
from pycbc import pnutils
from pycbc.conversions import (
    mtotal_from_mass1_mass2,
    q_from_mass1_mass2,
    eta_from_mass1_mass2,
    mchirp_from_mass1_mass2,
    chi_eff,
)


def basic_scatter(
    out_name,
    neg_idx,
    xname,
    yname,
    title,
    xval,
    yval,
    cval,
    cname,
    vmin,
    vmax,
    xmin,
    ymin,
    majorL,
    minorL,
    colormap,
):
    if colormap != "jet":
        cmap = colormap
    else:
        cmap = matplotlib.cm.jet
        cmap.set_under(color="gray")

    if vmin is not None:
        vmin = float(vmin)

    if vmax is not None:
        vmax = float(vmax)

    fig = pylab.figure(num=None)
    pylab.scatter(
        xval,
        yval,
        c=cval,
        linewidths=0,
        s=3,
        vmin=vmin,
        vmax=vmax,
        cmap=cmap,
        alpha=0.7,
    )

    if len(neg_idx) > 0:
        pylab.scatter(
            xval[neg_idx],
            yval[neg_idx],
            c="red",
            marker="x",
            label="Unable to generate waveforms",
        )
        matplotlib.pyplot.legend()

    if cval is not None:
        bar = pylab.colorbar()
        bar.set_label(cname)

    pylab.xlabel(xname)
    pylab.ylabel(yname)

    if xmin is None:
        xmin = min(xval)
    else:
        xmin = float(xmin)

    if ymin is None:
        ymin = min(yval)
    else:
        ymin = float(ymin)

    pylab.xlim(xmin, max(xval))
    pylab.ylim(ymin, max(yval))

    ax = fig.gca()
    if majorL:
        majorL = float(majorL)
        ax.xaxis.set_major_locator(MultipleLocator(majorL))
        ax.yaxis.set_major_locator(MultipleLocator(majorL))
    if minorL:
        minorL = float(minorL)
        ax.xaxis.set_minor_locator(MultipleLocator(minorL))
        ax.yaxis.set_minor_locator(MultipleLocator(minorL))

    pylab.grid()
    pylab.title(title)

    pylab.savefig(out_name, dpi=500)


parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
    "--input-file",
    required=True,
    help="resulting .dat file from pycbc_faithsim_collect_full_results",
)
parser.add_argument("--parameter-x", required=True)
parser.add_argument("--parameter-y", required=True)
parser.add_argument("--parameter-z", required=True)
parser.add_argument("--colormap", default="jet")
parser.add_argument("--vmin", default=None, help="min value of the colorbar")
parser.add_argument("--vmax", default=None, help="max value of the colorbar")
parser.add_argument("--xmin", default=None)
parser.add_argument("--ymin", default=None)
parser.add_argument("--majorL", default=None)
parser.add_argument("--minorL", default=None)
parser.add_argument("--output-plot", required=True, help="name of the output plot")

args = parser.parse_args()

derived_func_map = {
    "total_mass": lambda d: mtotal_from_mass1_mass2(d["mass1"], d["mass2"]),
    "mass_ratio": lambda d: q_from_mass1_mass2(d["mass1"], d["mass2"]),
    "mchirp": lambda d: mchirp_from_mass1_mass2(d["mass1"], d["mass2"]),
    "spin1_magnitude": lambda d: (
        d["spin1x"] ** 2 + d["spin1y"] ** 2 + d["spin1z"] ** 2
    )
    ** 0.5,
    "spin2_magnitude": lambda d: (d["spin2x"] ** 2 + d["spin2y"] ** 2 + d["spin2z"] ** 2)
    ** 0.5,
    "eta": lambda d: eta_from_mass1_mass2(d["mass1"], d["mass2"]),
    "chi_eff": lambda d: chi_eff(d["mass1"], d["mass2"], d["spin1z"], d["spin2z"]),
    "horizon_distance_1": lambda d: d["sigma1"] / 8,
    "horizon_distance_2": lambda d: d["sigma2"] / 8,
}

data = np.genfromtxt(args.input_file, names=True)

title = f"{args.parameter_x} VS {args.parameter_y} colorbar {args.parameter_z}"

neg_idx = np.flatnonzero(data["match"] < 0)

if args.parameter_x in data.dtype.names:
    v1d = data[args.parameter_x]
elif args.parameter_x in derived_func_map:
    v1d = derived_func_map[args.parameter_x](data)
else:
    raise NotImplementedError(
        "Quantity "
        + args.parameter_x
        + " not calculated in the plotting script, we should add the calculation"
    )

if args.parameter_y in data.dtype.names:
    v2d = data[args.parameter_y]
elif args.parameter_y in derived_func_map:
    v2d = derived_func_map[args.parameter_y](data)
else:
    raise NotImplementedError(
        "Quantity "
        + args.parameter_y
        + " not calculated in the plotting script, we should add the calculation"
    )

if args.parameter_z in data.dtype.names:
    v3d = data[args.parameter_z]
elif args.parameter_z in derived_func_map:
    v3d = derived_func_map[args.parameter_z](data)
else:
    raise NotImplementedError(
        "Quantity "
        + args.parameter_z
        + " not calculated in the plotting script, we should add the calculation"
    )

basic_scatter(
    args.output_plot,
    neg_idx,
    args.parameter_x,
    args.parameter_y,
    title,
    v1d,
    v2d,
    v3d,
    args.parameter_z,
    vmin=args.vmin,
    vmax=args.vmax,
    xmin=args.xmin,
    ymin=args.ymin,
    majorL=args.majorL,
    minorL=args.minorL,
    colormap=args.colormap,
)
