#!/usr/bin/env python
""" Make interactive visualization of segments
"""
import argparse, pycbc.version
from itertools import cycle
import matplotlib
matplotlib.use('Agg')
import numpy, pylab, pycbc.events, mpld3, mpld3.plugins
from matplotlib.patches import Rectangle
from pycbc.results.mpld3_utils import MPLSlide, Tooltip


parser = argparse.ArgumentParser()
parser.add_argument('--version', action='version', version=pycbc.version.git_verbose_msg)
parser.add_argument('--segment-files', nargs='+', help="List of segment files to plot")
parser.add_argument('--output-file', help="output html file")
args = parser.parse_args()


def timestr(s):
    t = ""
    s = int(s)
    d = s / 86400
    t += "%sd " % d if d else ""
    s -= d * 86400
    h = s / 3600
    t += "%sh " % h if h else ""
    s -= h * 3600
    m = s / 60
    t += "%sm " % m if m else ""
    s -= m * 60
    t += "%ss " % s
    return t

def get_name(segment_file):
    from ligo.lw import table, utils as ligolw_utils
    from pycbc.io.ligolw import LIGOLWContentHandler

    indoc = ligolw_utils.load_filename(
            segment_file, False, contenthandler=LIGOLWContentHandler)
    n = table.Table.get_table(indoc, 'segment_definer')[0]
    return "%s:%s:%s" % (n.ifos, n.name, n.version)

def plot_segs(start, end, color=None, y=0, h=1):
    patches = []
    if not hasattr(plot_segs, 'colors'):
        plot_segs.colors = cycle(['red', 'blue', 'green', 'yellow', 'cyan', 'violet'])

    if color is None:
        color = next(plot_segs.colors)

    for s, e in zip(start, end):
        ax = pylab.gca()
        patch = Rectangle((s, y), (e-s), h, facecolor=color)
        ax.add_patch(patch)
        patches.append(patch)

    return patches

# Define some CSS to control our custom labels
css = """
    table
    {
      border-collapse: collapse;
    }
    th
    {
      background-color: #cccccc;
    }
    td
    {
      background-color: #ffffff;
    }
    table, th, td
    {
      font-family:Arial, Helvetica, sans-serif;
      border: 1px solid black;
      text-align: right;
    }
"""

mpld3.plugins.DEFAULT_PLUGINS = []
fig = pylab.figure(figsize=[10, 5])
ax = fig.gca()

names = []
smin, smax = -1, 1
for i, seg_file in enumerate(sorted(args.segment_files)[::-1]):
    y = i + .05
    h = .7
    name = get_name(seg_file)
    start, end = pycbc.events.start_end_from_segments(seg_file)

    if len(start) > 0:
        # remove duplicate segments, e.g. inspiral jobs with split banks
        unique_startend = numpy.array(list(set(zip(start, end))))
        start, end = unique_startend[:,0], unique_startend[:,1]

    dur = end - start
    total = timestr(abs(pycbc.events.start_end_to_segments(start, end).coalesce()))

    label = """<table>
             <tr><th>Start</th><td>%.0f</td></tr>
             <tr><th>End</th><td>%.0f</td></tr>
             <tr><th>Duration</th><td>%s</td></tr>
      </table>
    """
    smin = start.min() if len(start) and ( start.min() < smin or smin == -1) else smin
    smax = end.max() if len(end) and end.max() > smax else smax

    names += [(name, total, y + h + .1)]

    patches = plot_segs(start, end, y=y, h=h)
    for i, p in enumerate(patches):
        l = label % (start[i], end[i], timestr(dur[i]))
        if i == 0:
            mpld3.plugins.connect(fig, mpld3.plugins.PointHTMLTooltip(p, [l], css=css))
        else:
            mpld3.plugins.connect(fig, Tooltip(p, [l], css=css))

for name, total, h in names:
    ax.text(smin, h, "%s: %s" % (name, total))

ax.set_ylim(0, h + 0.2)
ax.set_xlim(smin, smax)
ax.set_yticks([])
ax.set_xlabel('GPS Time (s)')

mpld3.plugins.connect(fig, mpld3.plugins.MousePosition(fontsize=14, fmt='10f'))
mpld3.plugins.connect(fig, mpld3.plugins.BoxZoom())
mpld3.plugins.connect(fig, MPLSlide())
mpld3.plugins.connect(fig, mpld3.plugins.Reset())
mpld3.save_html(fig, open(args.output_file, 'w'))
