import json
import pycrfsuite
import uuid
import os
import re
from collections import Counter
from sklearn.model_selection import train_test_split
from texta_mlp.mlp import MLP
from texta_tools.embedding import W2VEmbedding
from typing import List

from .config import CRFConfig
from .feature_extraction import sent2features, sent2labels, FEATURE_LAYER_WORD, FEATURE_LAYER_LEMMA, FEATURE_LAYER_POS
from .tagging_report import TaggingReport
from . import exceptions


def check_model_loaded(func):
    """
    Wrapper function for checking if Tagger model is loaded.
    """
    def func_wrapper(*args, **kwargs):
        if not args[0].model:
            raise exceptions.ModelNotLoadedError()
        return func(*args, **kwargs)
    return func_wrapper



class CRFExtractor:
    MLP_FIELDS = [FEATURE_LAYER_WORD, FEATURE_LAYER_LEMMA, FEATURE_LAYER_POS]

    def __init__(self,
        description: str = "My CRF Extractor",
        embedding: W2VEmbedding = None,
        config: CRFConfig = CRFConfig()
    ):
        self.description = description
        self.config = config
        self.embedding = embedding
        self.model = None


    def __str__(self):
        return self.description


    @staticmethod
    def _validate_and_extract_layers(mlp_document: dict, field_name: str):
        """
        Validates the MLP document structure to have proper fields.
        """
        field_path = field_name.split(".")
        # parse field path
        for field_path_component in field_path:
            if field_path_component not in mlp_document:
                raise exceptions.InvalidInputError(
                    f"Invalid field_name param for the document. Field component {field_path_component} not found in document!"
                )
            mlp_subdoc = mlp_document[field_path_component]
        # check the resulting subdocument structure
        if not isinstance(mlp_subdoc, dict):
            raise exceptions.InvalidInputError("Document is not a dict!")
        if field_name not in mlp_subdoc:
            raise exceptions.InvalidInputError(f"Field '{field_name}' not present in the document!")
        # empty dict for layers
        layers = {}
        # check if mlp fields are present
        for field in CRFExtractor.MLP_FIELDS:
            if field not in mlp_subdoc:
                raise exceptions.InvalidInputError(f"Field '{field}' not present in the document!")
            # add layer if not present
            if field not in layers:
                layers[field] = []
            # split field into sents
            # use "LBR" as sentence break marker for POS tag layer
            if field == FEATURE_LAYER_POS:
                sentences = mlp_subdoc[field].split(" LBR ")
            else:
                sentences = mlp_subdoc[field].split(" \n ")
            # add sent to layer
            for sentence in sentences:
                sentence_tokens = sentence.split(" ")
                layers[field].append(sentence_tokens)
        return layers


    def _parse_mlp_document(self, mlp_document: dict, add_labels: bool = True, mlp_field: str = "text"):
        """
        Parses MLP output document. Extracts tokens, lemmas, and POS tags.
        Adds labels from texta_facts.
        """
        # TODO: add morph/syntactic info if possible (need to improve MLP for that)
        layers = self._validate_and_extract_layers(mlp_document, mlp_field)
        # add labels from texta_facts
        if add_labels:
            labels = []
            texta_facts = mlp_document.get("texta_facts", [])
            for i, sentence in enumerate(layers[FEATURE_LAYER_WORD]):
                # create initial label list
                labels.append(["0" for x in range(len(sentence))])
                for fact in mlp_document["texta_facts"]:
                    if fact["fact"] in self.config.labels and fact["sent_index"] == i:
                        label = fact["fact"]
                        spans = json.loads(fact["spans"])
                        sent_str = " ".join(layers[FEATURE_LAYER_WORD][i])
                        for span in spans:
                            num_tokens_before_match = len([token for token in sent_str[:span[0]].split(" ") if token])
                            num_tokens_match = len(fact["str_val"].split(" "))
                            # replace initial labels with correct ones
                            for j in range(num_tokens_before_match, num_tokens_before_match+num_tokens_match):
                                labels[i][j] = label

                yield [(token, layers[FEATURE_LAYER_LEMMA][i][j], layers[FEATURE_LAYER_POS][i][j], labels[i][j]) for j, token in enumerate(sentence)]
        else:
            for i, sentence in enumerate(layers[FEATURE_LAYER_WORD]):
                yield i, [(token, layers[FEATURE_LAYER_LEMMA][i][j], layers[FEATURE_LAYER_POS][i][j]) for j, token in enumerate(sentence)]



    def train(self, mlp_documents: List[dict], mlp_field: str = "text", save_path: str = "my_crf_model"):
        """
        Trains & saves the model.
        Model has to be saved by the crfsuite package because it contains C++ bindings.
        """
        # prepare data
        # a list containing tokenized documents for CRF
        token_documents = []
        for mlp_document in mlp_documents:
            for sentence in self._parse_mlp_document(mlp_document, mlp_field=mlp_field):
                token_documents.append(sentence)

        # split dataset
        train_documents, test_documents = train_test_split(token_documents, test_size=self.config.test_size)
        # featurize sequences
        X_train = [sent2features(s, self.config, self.embedding) for s in train_documents]
        y_train = [sent2labels(s) for s in train_documents]
        X_test = [sent2features(s, self.config, self.embedding) for s in test_documents]
        y_test = [sent2labels(s) for s in test_documents]
        # create trainer
        trainer = pycrfsuite.Trainer(verbose=self.config.verbose)
        # feed data to trainer
        for xseq, yseq in zip(X_train, y_train):
            trainer.append(xseq, yseq)
        # set trainer params
        trainer.set_params({
            'c1': self.config.c1, # coefficient for L1 penalty
            'c2': self.config.c2, # coefficient for L2 penalty
            'max_iterations': self.config.num_iter, # stop earlier
            'feature.possible_transitions': True # include transitions that are possible, but not observed
        })
        # train & save the model
        trainer.train(save_path)
        # load tagger model for validation
        tagger = pycrfsuite.Tagger()
        tagger.open(save_path)
        # evaluate model
        y_pred = [tagger.tag(xseq) for xseq in X_test]
        report = TaggingReport(y_test, y_pred, self.config.labels)
        # model & report to class variables
        self.model = tagger
        self.report = report
        return report, save_path


    @check_model_loaded
    def tag(self, mlp_document: dict, field_name: str = "text"):
        """
        Tags input MLP document.
        """
        output = []
        seqs_to_predict = self._parse_mlp_document(mlp_document, add_labels=False)
        sentences = [sent.split(" ") for sent in mlp_document[field_name][FEATURE_LAYER_WORD].split(" \n ")]
        # predict on each sentence
        for i, seq_to_predict in seqs_to_predict:
            features_to_predict = sent2features(seq_to_predict, self.config, self.embedding)
            # predict
            result = self.model.tag(features_to_predict)
            # generate text tokens for final output
            for tag in self._process_tag_output(result, sentences, i):
                output.append(tag)
        return output


    @check_model_loaded
    def get_features(self, n=20):
        info = self.model.info()
        return {
            "positive": Counter(info.state_features).most_common(n),
            "negative": Counter(info.state_features).most_common()[-n:]
        }


    def _process_tag_output(self, tokens: List[str], sentences: List[List[str]], sent_index: int):
        """
        Translates result tokens into entities.
        """
        entities = []
        current_entity = []
        current_entity_type = None
        # iterate over tokens and pick matches
        for i, token in enumerate(tokens):
            if token in self.config.labels:
                entity = sentences[sent_index][i]
                current_entity_type = token
                current_entity.append(entity)
            else:
                if current_entity:
                    entities.append((current_entity_type, " ".join(current_entity)))
                    current_entity = []
        if current_entity:
            entities.append((current_entity_type, " ".join(current_entity)))
        # transform output to facts
        for fact_name, str_val in entities:
            # get spans
            tokenized_sentence = " ".join(sentences[sent_index])
            pattern = re.compile(re.escape(str_val))
            matching_spans = [(match.start(), match.end()) for match in pattern.finditer(tokenized_sentence)]
            fact = {
                "fact": fact_name,
                "str_val": str_val,
                "sent_index": sent_index,
                "spans": json.dumps(matching_spans)
            }
            yield fact


    def load(self, file_path: str):
        """
        Loads CRF model from disk.
        :param str file_path: Path to the model file.
        """
        tagger = pycrfsuite.Tagger()
        tagger.open(file_path)
        self.model = tagger
        return True


    def load_django(self, crf_django_object):
        """
        Loads model file using Django model object. This method is used in Django only!
        :param crf_django_object: Django model object of the Extractor.
        """
        try:
            path = crf_django_object.model.path
            # retrieve tagger info
            self.description = crf_django_object.description
            # load model
            return self.load(path)
        except:
            raise exceptions.ModelLoadFailedError()
