#
#  This file is part of Sequana software
#
#  Copyright (c) 2020 - Sequana Development Team
#
#  Distributed under the terms of the 3-clause BSD license.
#  The full license is in the LICENSE file, distributed with this software.
#
#  website: https://github.com/sequana/sequana
#  documentation: http://sequana.readthedocs.io
#
##############################################################################
"""Module to write KEGG enrichment report"""
import ast
import os
import sys

from sequana.lazy import pandas as pd
from sequana.lazy import pylab

from sequana.modules_report.base_module import SequanaBaseModule
from sequana.utils.datatables_js import DataTable
from sequana.enrichment.kegg import KEGGPathwayEnrichment

from easydev import Progress

import colorlog

logger = colorlog.getLogger(__name__)

from sequana.utils import config


class ModuleKEGGEnrichment(SequanaBaseModule):
    """Write HTML report of variant calling. This class takes a csv file
    generated by sequana_variant_filter.
    """

    def __init__(
        self,
        gene_lists,
        kegg_name,
        dataframe,
        enrichment_params={
            "padj": 0.05,
            "log2_fc": 3,
            "nmax": 15,
            "max_entries": 3000,
            "kegg_background": None,
            "mapper": None,
            "preload_directory": None,
            "plot_logx": True,
        },
        command="",
    ):
        """.. rubric:: constructor"""
        super().__init__()
        self.title = "Enrichment"

        self.command = command
        self.gene_lists = gene_lists
        self.enrichment_params = enrichment_params
        self.data = dataframe
        self.organism = kegg_name

        if self.enrichment_params["preload_directory"]:
            pathname = self.enrichment_params["preload_directory"]
            if not os.path.exists(pathname):
                logger.error(f"{pathname} does not exist")
                sys.exit(1)
        self.nmax = enrichment_params.get("nmax", 15)

        self.ke = KEGGPathwayEnrichment(
            self.gene_lists,
            self.organism,
            mapper=self.enrichment_params["mapper"],
            background=self.enrichment_params["kegg_background"],
            preload_directory=self.enrichment_params["preload_directory"],
        )

        self.create_report_content()
        self.create_html("enrichment.html")

    def create_report_content(self):
        self.sections = list()
        self.summary()
        self.add_kegg()
        self.sections.append({"name": "3 - Info", "anchor": "command", "content": self.command})

    def summary(self):
        """Add information of filter."""
        total_up = len(self.gene_lists["up"])
        total_down = len(self.gene_lists["down"])
        total = total_up + total_down
        log2fc = self.enrichment_params["log2_fc"]

        self.sections.append(
            {
                "name": "1 - Summary",
                "anchor": "filters_option",
                "content": f"""

<p>In the following sections, you will find the KEGG Pathway enrichment.
The input data for those analyis is the output of the RNADiff
analysis where adjusted p-values above 0.05 are excluded. Moreover, we removed 
candidates with log2 fold change below {log2fc}. Using these filters, the list of
differentially expressed genes is made of {total_up} up and {total_down} down genes (total {total})</p>
<p> In the following plots you can find the first KEGG Pathways that are enriched, keeping a 
maximum of {self.nmax} pathways. </p>

<p>The KEGG name used is {self.organism}.<br>




""",
            }
        )

    def add_kegg(self):
        logger.info("Enrichment module: kegg term")
        style = "width:45%"

        logger.info(f"Saving all pathways in kegg_pathways/{self.organism}")
        self.ke.export_pathways_to_json()

        html = f""

        for category in ["down", "up", "all"]:
            df = self.ke.barplot(category, nmax=self.nmax)
            n_enriched = len(df)

            if len(df):
                img_barplot = self.create_embedded_png(self.plot_barplot, "filename", style=style, category=category)
                img_scatter = self.create_embedded_png(self.plot_scatter, "filename", style=style, category=category)
                js_table, html_table, fotorama = self.get_table(category)
            else:
                img_barplot = img_scatter = js_table = html_table = fotorama = ""

            html += f"""
<h3>2.1 - KEGG pathways enriched in {category} regulated genes</h3>
<p>{n_enriched} KEGG pathways are found enriched in {category} regulated genes</p>
<br>
{img_barplot}
{img_scatter}
<hr>
{js_table} {html_table}
<hr>
<p>Here below are the pathways with gene colored according to their fold change.
Blue colors are for down-regulated genes and Orange are for up-regulated genes. 
(Note that absolute log2 fold change above 4 are clipped to 4; So a gene with a
log2 fold change of 4 of 40 will have the same darkest color.). </p>
{fotorama}

"""
        self.sections.append({"name": "2 - KEGG", "anchor": "kegg", "content": html})

    def plot_barplot(self, filename, category=None):
        self.ke.barplot(category, nmax=self.nmax)
        pylab.savefig(filename)

    def plot_scatter(self, filename, category=None):
        self.ke.scatterplot(category, nmax=self.nmax)
        pylab.savefig(filename)

    def get_table(self, category):
        # Results down (pathway info)
        # html_before_table = """<p>Enrichment pathways summary</p>"""

        df = self.ke.barplot(category, nmax=self.nmax)

        if len(df):
            links = ["https://www.genome.jp/dbget-bin/www_bget?path:{}".format(x) for x in df["pathway_id"]]
            df["links"] = links
            df = df[
                [
                    "pathway_id",
                    "name",
                    "size",
                    "Overlap",
                    "P-value",
                    "Adjusted P-value",
                    "Genes",
                    "links",
                ]
            ]

            # save pathways and add fotorama
            logger.setLevel("WARNING")
            pb = Progress(len(df))
            files = []
            for i, ID in enumerate(df["pathway_id"]):
                df_pathways = self.ke.save_pathway(ID, self.data, filename=f"{config.output_dir}/{ID}.png")
                files.append(f"{ID}.png")
                pb.animate(i + 1)
            fotorama = self.add_fotorama(files, width=800)

            datatable = DataTable(df, f"kegg_{category}")
            datatable.datatable.set_links_to_column("links", "pathway_id")
            datatable.datatable.datatable_options = {
                "scrollX": "true",
                "pageLength": 20,
                "scrollCollapse": "true",
                "dom": "Bfrtip",
                "buttons": ["copy", "csv"],
            }
            js_table = datatable.create_javascript_function()
            html_table = datatable.create_datatable(float_format="%E")

            return (js_table, html_table, fotorama)
