# coding: utf-8
#
#  This file is part of Sequana software
#
#  Copyright (c) 2020 - Sequana Development Team
#
#  Distributed under the terms of the 3-clause BSD license.
#  The full license is in the LICENSE file, distributed with this software.
#
#  website: https://github.com/sequana/sequana
#  documentation: http://sequana.readthedocs.io
#
##############################################################################
"""Module to write variant calling report"""
import ast

import pandas as pd

from sequana.modules_report.base_module import SequanaBaseModule
from sequana.utils.datatables_js import DataTable


class RNAdiffModule(SequanaBaseModule):
    """ Write HTML report of variant calling. This class takes a csv file
    generated by sequana_variant_filter.
    """
    def __init__(self, data):
        """.. rubric:: constructor

        """
        super().__init__()
        self.title = "RNAdiff"


        from sequana.rnadiff import RNADiffResults
        self.rnadiff = RNADiffResults(data)



        self.df = data.df.copy()
        self.df.columns = [x.replace(".", "") for x in self.df.columns]

        # nice layout for the report
        import seaborn
        seaborn.set()
        self.create_report_content()
        self.create_html("rnadiff.html")
        import matplotlib
        matplotlib.rc_file_defaults()

    def create_report_content(self):
        self.sections = list()

        self.summary()
        self.add_plot_count_per_sample()
        self.add_cluster()
        self.add_dge()
        self.add_rnadiff_table()

    def summary(self):
        """ Add information of filter.
        """
        S = self.rnadiff.summary()


        A = len(self.df.query("padj>=0.05 and log2FoldChange>1"))
        B = len(self.df.query("padj>=0.05 and log2FoldChange<-1"))

        
        self.sections.append({
            'name': "Summary",
            'anchor': 'filters_option',
            'content':
                """
<p>The final Differententially Gene Expression (DGE) analysis
led to {} up and {} down genes (total {}). Filtering out the log2 fold change
below 1 (or -1) gives {} up and {} down (total of {})</p>""".format(S.loc['up'][0],
S.loc['down'][0],
S.loc['all'][0], A, B, A+B)
        })


    def add_cluster(self):
        style = "width:65%"
        def dendogram(filename):
            import pylab
            pylab.ioff()
            pylab.clf()
            self.rnadiff.plot_dendogram()
            pylab.savefig(filename)
            pylab.close()
        html_dendogram = """<p>The following image shows a hierarchical
clustering of the whole sample set. The data was log-transformed first.
</p>{}<hr>""".format(
        self.create_embedded_png(dendogram, "filename", style=style))


        def pca(filename):
            import pylab
            pylab.ioff()
            pylab.clf()
            self.rnadiff.plot_pca(2)
            pylab.savefig(filename)
            pylab.close()
        html_pca = """<p>The expriment variability is also represented by a
principal component analysis as shown here below. The two main components are
represented </p>{}<hr>""".format(
            self.create_embedded_png(pca, "filename", style=style))

        self.sections.append({
           "name": "Clusterisation",
           "anchor": "table",
           "content": html_dendogram + html_pca 
         })

    def add_plot_count_per_sample(self):
        style = "width:65%"
        import pylab
        def plotter(filename):
            pylab.ioff()
            pylab.clf()
            self.rnadiff.plot_count_per_sample()
            pylab.savefig(filename)
            pylab.close()
        html1 = """<p>The following image show the toral number of counted reads
for each sample. We expect counts to be similar within conditions. They may be
different across conditions. Variation may happen: different rRNA contamination
levels, library concentrations, etc)<p>{}<hr>""".format(
         self.create_embedded_png(plotter, "filename", style=style))

        

        def null_counts(filename):
            pylab.ioff()
            pylab.clf()
            self.rnadiff.plot_percentage_null_read_counts()
            pylab.savefig(filename)
            pylab.close()
        
        html_null = """<p>The next image shows the percentage of features with no
read count in each sample. Features with null read counts in all samples are not
taken into account in the analysis (black dashed line). fold-change and p-values
will be set to NA in the final results</p> {}<hr>""".format(
            self.create_embedded_png(null_counts, "filename", style=style))

        def count_density(filename):
            pylab.ioff()
            pylab.clf()
            self.rnadiff.plot_density()
            pylab.savefig(filename)
            pylab.close()
        html_density = """<p>In the following figure, we show the distribution
of read counts for each sample (log10 scale). We expect replicates to behave in
a similar fashion. The mode depends on the biological conditions and organism
considered.</p> {}<hr>""".format(
            self.create_embedded_png(count_density, "filename", style=style))

        def best_count(filename):
            pylab.ioff()
            pylab.clf()
            self.rnadiff.plot_feature_most_present()
            pylab.savefig(filename)
            pylab.close()
        html_feature = """<p>In the following figure, we show for each sample the feature that
capture the highest proportion of the reads considered. This should not impact
the DESEq2 normalization. We expect consitence across samples within a single
conditions</p> {}<hr>""".format(
            self.create_embedded_png(best_count, "filename", style=style))


        self.sections.append({
           "name": "Diagnostic plots",
           "anchor": "table",
           "content": html1 +  html_null + html_density + html_feature
         })


    def add_dge(self):
        style = "width:45%"
        def rawcount(filename):
            import pylab
            pylab.ioff()
            pylab.clf()
            self.rnadiff.boxplot_rawdata()
            pylab.savefig(filename)
            pylab.close()
        def normedcount(filename):
            import pylab
            pylab.ioff()
            pylab.clf()
            self.rnadiff.boxplot_normeddata()
            pylab.savefig(filename)
            pylab.close()
        html_boxplot = """<p>The following image shows a hierarchical
clustering of the whole sample set. The data was log-transformed first.
</p>"""
        img1 = self.create_embedded_png(rawcount, "filename", style=style)
        img2 = self.create_embedded_png(normedcount, "filename", style=style)


        self.sections.append({
           "name": "Normalisation",
           "anchor": "table",
           "content": html_boxplot + img1 + img2 + "</hr>" 
         })

        def plot_pvalue_hist(filename):
            import pylab; pylab.ioff(); pylab.clf()
            self.rnadiff.plot_pvalue_hist()
            pylab.savefig(filename); pylab.close()
        def plot_padj_hist(filename):
            import pylab; pylab.ioff(); pylab.clf()
            self.rnadiff.plot_padj_hist()
            pylab.savefig(filename); pylab.close()
        img1 = self.create_embedded_png(plot_pvalue_hist, "filename", style=style)
        img2 = self.create_embedded_png(plot_padj_hist, "filename", style=style)
        def plot_volcano(filename):
            import pylab; pylab.ioff(); pylab.clf()
            self.rnadiff.plot_volcano()
            pylab.savefig(filename); pylab.close()
        html_volcano = """<p>The volcano plot here below shows the diﬀerentially
expressed features in red. A volcano plot represents the log of the adjusted P
value as a function of the log ratio of diﬀerential expression. </p>"""
        img3 = self.create_embedded_png(plot_volcano, "filename", style=style)
        def plot_volcano2(filename):
            import pylab; pylab.ioff(); pylab.clf()
            self.rnadiff.plot_volcano(add_broken_axes=True)
            pylab.savefig(filename); pylab.close()
        from pylab import log10
        M = max(-log10(self.rnadiff.df.padj.dropna()))
        if M>20:
            img4 = self.create_embedded_png(plot_volcano2, "filename", style=style)
        else:
            img4 = ""

        description = """<p>The distribution of raw p-values computed by the statistical test 
is expected to be a mixture of a uniform distribution on [0, 1] and a peak
around 0 corresponding to the diﬀerentially expressed features. </p>"""
        fig = self.rnadiff.plot_volcano(plotly=True)
        plotly = fig.to_html(include_plotlyjs='cdn')

        self.sections.append({
           "name": "Diagnostic plots",
           "anchor": "table",
           "content": description + img1 + img2 + "</hr>" + html_volcano + img3
+ img4 + "<hr>" + plotly
         })


    def add_rnadiff_table(self):
        """ RNADiff.        """
        datatable = DataTable(self.df, 'rnadiff')
        # set options
        datatable.datatable.datatable_options = {
            'scrollX': 'true',
            'pageLength': 20,
            'scrollCollapse': 'true',
            'dom': 'Bfrtip',
            'buttons': ['copy', 'csv']
        }
        js = datatable.create_javascript_function()
        html_tab = datatable.create_datatable(float_format='%.3f')
        self.sections.append({
            'name': "Tables",
            'anchor': 'stats',
            'content':
                "<p>This table gives all DGE results</p>{}{}"
                .format(js, html_tab)
        })
