#!/usr/bin/env python3
import argparse
import logging

from sklearn.datasets import load_iris
from sklearn.datasets import load_boston


from interpretability_engine.interpretability_engine import InterpretabilityEngine
from interpretability_engine.routes import Routes
from interpretability_engine.utils import load_samples


def get_cli_args():
    parser = argparse.ArgumentParser(description='Interpretability Engine')

    parser.add_argument(
        '--token',
        action='store',
        dest='token',
        type=str,
        required=True
    )

    group_url = parser.add_mutually_exclusive_group(required=True)

    group_url.add_argument(
        '--deployment-url',
        action='store',
        dest='url_deployment',
        default=None,
        type=str,
        help="URL where your model is deployed."
    )

    group_url.add_argument(
        '--serving-url',
        action='store',
        dest='url_serving',
        default=None,
        type=str,
        help="If no deployment url is provided, you can provide the serving url with the model id."
    )

    parser.add_argument(
        '--disable-ssl-verification',
        action='store_true',
        dest='disable_ssl_verification',
        default=False
    )

    parser.add_argument(
        '--model-id',
        action='store',
        dest='model_id',
        type=str,
        help="If no deployment url is provided, you can provide the serving url with the model id."
    )

    parser.add_argument(
        '--features',
        action='store',
        nargs='+',
        type=int,
        dest='features',
        help="The indexes (or names, depend on how you export the model) for features used for interpretation."
    )

    parser.add_argument(
        '--method',
        action='store',
        dest='method',
        default='pdp',
        type=str,
        help="For now only pdp method is available."
    )

    parser.add_argument(
        '--main-effects',
        action='store_true',
        dest='main_effects',
        default=True
    )

    parser.add_argument(
        '--log-level',
        action='store',
        dest='log_level',
        default='INFO',
        type=str
    )

    parser.add_argument(
        '--output-file',
        action='store',
        dest='output_file',
        default=None,
        type=str,
        help="If specified the interpretation will be saved in a pdf file, else it will be directly displayed."
    )

    parser.add_argument(
        '--feature-names',
        action='store',
        nargs='+',
        dest='feature_names',
        default=None,
        type=str,
        help="If specified it will map the provided names with the indexes features. Useful when the name of the inputs/features are just number, it gives to have a clearer interpretation when displayed."
    )

    parser.add_argument(
        '--label-names',
        action='store',
        nargs='+',
        dest='label_names',
        default=None,
        type=str,
        help="If specified it will map the provided names with the indexes labels. Useful when the name of the outputs/labels are just number, it gives to have a clearer interpretation when displayed."
    )

    group_dataset = parser.add_mutually_exclusive_group(required=True)

    group_dataset.add_argument(
        '--swift-container',
        action='store',
        dest='swift_container',
        default=None,
        type=str
    )

    parser.add_argument(
        '--swift-object',
        action='store',
        dest='swift_object',
        default=None,
        type=str
    )

    group_dataset.add_argument(
        '--samples-path',
        action='store',
        dest='samples_path',
        default=None,
        type=str,
        help="Path to csv file containing samples related to the model."
    )
    # TODO argument to list most important features

    return parser.parse_args()

if __name__ == '__main__':
    args = get_cli_args()
    logging.basicConfig(format='%(asctime)s [%(filename)s:%(lineno)d] %(levelname)s %(message)s', level=getattr(logging, args.log_level.upper()))

    url_deployment = args.url_deployment
    ssl_verification = (not args.disable_ssl_verification)

    if not url_deployment:
        logging.debug("Try to find url_deployment from %s, %s", args.url_serving, args.model_id)
        url_deployment = Routes.get_url_deployment_model(token=args.token,
                                                         url_serving=args.url_serving,
                                                         model_id=args.model_id,
                                                         verify=ssl_verification)

    logging.info("Url deployment found at %s", url_deployment)

    engine = InterpretabilityEngine(token=args.token,
                                    url_deployment=url_deployment,
                                    verify=ssl_verification)

    logging.info("Retrieve samples used for the interpretation")
    X = load_samples(args.samples_path, args.swift_container, args.swift_object)

    engine.interpret(method=args.method, X=X,
                     features=args.features, main_effects=args.main_effects,
                     plot=True, output_file=args.output_file,
                     feature_names=args.feature_names,
                     label_names=args.label_names)
