#!/usr/bin/env python

import argparse
import glob
import pandas as pd
import seaborn as sb
from re import split
from matplotlib import pyplot

parser = argparse.ArgumentParser()
parser.add_argument("-i", "--input-directory", required=True,
                    help="Path to input directory where asses results are")
parser.add_argument("-r", "--rounds-file", required=True,
                    help="File listing rounds ordered")
parser.add_argument("-o", "--output-plot", default='sccaf_assesment_accuracies.pdf',
                    help="Path for merged assessment plot. Extension determines format (pdf/png).")

args = parser.parse_args()


def atoi(text):
    return int(text) if text.isdigit() else text


def natural_keys(text):
    """
    alist.sort(key=natural_keys) sorts in human order
    http://nedbatchelder.com/blog/200712/human_sorting.html
    (See Toothy's implementation in the comments)
    """
    return [atoi(c) for c in split(r'(\d+)', text)]

# We need to read in vectors of assess results, merge them based on the
# vectors title and then produce the violin plots per round. This means
# that each asses result can contain iterations for the same round, so they need
# to be merged first.


acc_arrays = {}

# We only need to store rounds that are in the rounds file
rounds_df = pd.read_csv(args.rounds_file, header=None, names=['rounds'])
sorted_rounds = rounds_df['rounds'].tolist()
sorted_rounds.sort(key=natural_keys)

df_merged = pd.DataFrame(columns=["Accuracy", "Type", "Round"])
for asses_file in glob.glob("{}/sccaf_assess_*.txt".format(args.input_directory)):
    df = pd.read_csv(asses_file)
    if rounds_df['rounds'].str.contains(df['Round'][0]).any():
        df_merged = df_merged.append(df, ignore_index=True)


fig, ax = pyplot.subplots(figsize=(9, 7))
pyplot.xticks(rotation=45)
sb.violinplot(x="Round", y="Accuracy", hue="Type", data=df_merged, order=sorted_rounds, cut=0, palette={'Test': '#3cb44b', 'CV': '#e6194b'}).set_title("Test and cross validation accuracies per SCCAF round.")
pyplot.legend(loc='lower right')
pyplot.savefig(args.output_plot)

