#!/usr/bin/env python

import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import scanpy as sc
import SCCAF as sf
import logging
import argparse
from re import sub, match
from sys import exit

matplotlib.use('Agg')


def non_empty_string(x):
    if x is not None and len(x) > 0:
        return x
    else:
        msg = "Value provided cannot be empty"
        raise argparse.ArgumentTypeError(msg)


parser = argparse.ArgumentParser()
parser.add_argument("-i", "--input-file",
                    help="Path to input in AnnData or Loom")
parser.add_argument("-o", "--output-file", default='output.h5',
                    help="Path for output file")
parser.add_argument("-e", "--external-clustering-tsv",
                    help="Path to external clustering in TSV")
parser.add_argument("--optimise",
                    help="Not only run assessment, but also optimise the clustering",
                    action="store_true")
parser.add_argument("--skip-assessment", action="store_true",
                    help="If --optimise given, then this allows to optionally skip the original "
                         "assessment of the clustering")
parser.add_argument("-s", "--slot-for-existing-clustering",
                    help="Use clustering pre-computed in the input object, available in this slot of the object.",
                    type=non_empty_string)
parser.add_argument("-r", "--resolution", default=1.5, type=float,
                    help="Resolution for running clustering when no internal or external clustering is given.")
parser.add_argument("-a", "--min-accuracy", type=float, default=0.955,
                    help="Accuracy threshold for convergence of the optimisation procedure.")
parser.add_argument("-p", "--prefix",
                    help="Prefix for clustering labels", default="L1")
parser.add_argument("-c", "--cores",
                    help="Number of processors to use", type=int, default=1)
parser.add_argument("-u", "--undercluster-boundary",
                    help="Label for the underclustering boundary to use in the optimisation. "
                         "It should be present in the annData object")
parser.add_argument("--produce-rounds-summary", action="store_true",
                    help="Set to produce summary files for each round of optimisation"
                    )
parser.add_argument("--optimisation-plots-output", help="PDF file output path for all optimisation plots.")
parser.add_argument("--conf-sampling-iterations", action="store", type=int, default=3,
                    help="How many samples are taken of cells per clusters prior to the confusion matrix calculation."
                         "Higher numbers will produce more stable results in terms of rounds, but will take longer. "
                         "Default: 3."
                    )

args = parser.parse_args()

logging.basicConfig(level=logging.INFO)

if (not args.skip_assessment) and not (args.external_clustering_tsv or args.slot_for_existing_clustering):
    logging.error("Either --external-clustering-tsv or --slot-for-existing-clustering needs to be set "
                  "if the assessment is to be done.")
    exit(1)

ad = sc.read(args.input_file)
logging.info("Read ann data file: DONE")

if args.external_clustering_tsv:
    cluster = pd.read_table(args.external_clustering_tsv, usecols=[0, 1], index_col=0, header=0)
    cluster.columns = ['CLUSTER']
    y = (pd.merge(ad.obs, cluster, how='left', left_index=True, right_index=True))['CLUSTER']
else:
    if args.slot_for_existing_clustering not in ad.obs:
        logging.error("Provided slot name '{}' is not in the provided AnnData object, "
                      "make sure that you provide a valid clustering label inside the "
                      "AnnData object provided.".format(args.slot_for_existing_clustering))
        exit(1)
    y = ad.obs[args.slot_for_existing_clustering]

raw = getattr(ad, 'raw')
if raw:
    X = raw.X
else:
    X = ad.X

if not args.skip_assessment:
    y_prob, y_pred, y_test, clf, cvsm, acc = sf.SCCAF_assessment(X, y, n_jobs=args.cores)
    logging.info("First assesment: DONE")
    aucs = sf.plot_roc(y_prob, y_test, clf, cvsm=cvsm, acc=acc)
    plt.savefig('roc-curve.png')
    plt.close()


def atoi(text):
    return int(text) if text.isdigit() else text


def extract_round_number(text):
    '''
    Obtain round number from the label so that it can be used for sorting rounds (specifying key).
    '''
    # return atoi(text.replace('{}_Round'.format(args.prefix), ""))
    round_num = sub(r'.*_Round(\d+)$', r'\1', text)
    return int(round_num) if round_num.isdigit() else round_num


if args.optimise:
    if args.resolution:
        sc.tl.louvain(ad, resolution=args.resolution, key_added='{}_Round0'.format(args.prefix))
        logging.info("Run louvain for starting point: DONE")
    else:
        # We use the predefined clustering (either internal or external).
        ad.obs['{}_Round0'.format(args.prefix)] = y

    plots = False
    backend=None
    if args.optimisation_plots_output:
        plots = True
        from matplotlib.backends.backend_pdf import PdfPages
        backend = PdfPages(args.optimisation_plots_output)

    if args.undercluster_boundary:
        sf.SCCAF_optimize_all(min_acc=args.min_accuracy, ad=ad, plot=plots, n_jobs=args.cores,
                              low_res=args.undercluster_boundary, mplotlib_backend=backend, c_iter=args.conf_sampling_iterations)
    else:
        sf.SCCAF_optimize_all(min_acc=args.min_accuracy, ad=ad, plot=plots, n_jobs=args.cores,
                              mplotlib_backend=backend, c_iter=args.conf_sampling_iterations)
    logging.info("Run optimise: DONE")
    y_prob, y_pred, y_test, clf, cvsm, acc = sf.SCCAF_assessment(X, ad.obs['{}_result'.format(args.prefix)],
                                                                 n_jobs=args.cores)
    logging.info("Posterior assessment: DONE")
    aucs = sf.plot_roc(y_prob, y_test, clf, cvsm=cvsm, acc=acc)
    plt.savefig('optim.png')
    ad.write(args.output_file)

    if args.produce_rounds_summary:
        rounds = []
        for round_key in ad.obs_keys():
            if round_key.startswith(args.prefix) and match(r'.*Round\d+$', round_key):
                rounds.append(round_key)
        rounds.sort(key=extract_round_number)

        with open("rounds.txt", 'w') as f:
            for item in rounds:
                f.write("%s\n" % item)

    if backend:
        backend.close()

    logging.info("Write output: DONE")

