#
#  This file is part of Sequana software
#
#  Copyright (c) 2016-2023 - Sequana Development Team
#
#  File author(s): Sequana team
#
#  Distributed under the terms of the 3-clause BSD license.
#  The full license is in the LICENSE file, distributed with this software.
#
#  website: https://github.com/sequana/sequana
#  documentation: http://sequana.readthedocs.io
#
##############################################################################

from sequana.lazy import pylab
from sequana.utils.pandas import PandasReader

import colorlog


logger = colorlog.getLogger(__name__)


__all__ = ["IDR"]


class IDR:
    """Reader for the output generated by idr package

    Can read the narrow or broad output transparently.

    Note that signalValue = rep1_signal + rep2_signal

    The score columns contains the scaled IDR value, min(int(log2(-125IDR), 1000).
    This means that peaks with an IDR of 0 have a score of 1000. 
    A peak with an IDR of 0.05 has a score of int(-125log2(0.05)) = 540. Finally, a
    peaks with an IDR of 1.0 have a score of 0.

        IDR, score
        0.0039, 1000
        0.01, 830
        0.05, 540
        0.1, 415
        0.5, 125

    This allows to differentiate those that crosses an IDR or 0.05.
    The final IDR is stored in global_idr

    """

    def __init__(self, filename, threshold=0.05):

        self.df = PandasReader(filename, sep="\t", header=None).df

        self.threshold = threshold

        narrow_columns = [
            "chrom",
            "start",
            "end",
            "region_name",
            "score",
            "strand",
            "signalValue",
            "pvalue",
            "qvalue",
            "summit",
            "local_idr",
            "global_idr",
            "rep1_chrom_start",
            "rep1_chrom_end",
            "rep1_signal",
            "rep1_summit",
            "rep2_chrom_start",
            "rep2_chrom_end",
            "rep2_signal",
            "rep2_summit",
        ]
        broad_columns = [
            "chrom",
            "start",
            "end",
            "region_name",
            "score",
            "strand",
            "signalValue",
            "pvalue",
            "qvalue",
            "local_idr",
            "global_idr",
            "rep1_chrom_start",
            "rep1_chrom_end",
            "rep1_signal",
            "rep2_chrom_start",
            "rep2_chrom_end",
            "rep2_signal",
        ]
        try:
            self.df.columns = narrow_columns
            self._mode = 'narrow'
        except Exception:
            try:
                self.df.columns = broad_columns
                self._mode = 'broad'
            except ValueError:  #pragma: no cover (empty dataframe)
                pass

        if len(self.df):
            # add ranks for rep1/rep2
            self.df["rep1_rank"] = self.df["rep1_signal"].rank(ascending=False)
            self.df["rep2_rank"] = self.df["rep2_signal"].rank(ascending=False)
            self.df["idr"] = 10 ** -self.df["local_idr"]

    def __len__(self):
        return len(self.df)

    def _get_mode(self):
        return self._mode

    mode = property(_get_mode)

    def _get_N_significant_peaks(self):
        if len(self.df) == 0: #pragma: no cover
            return 
        else:
            return len(self.df.query("idr<@self.threshold"))

    N_significant_peaks = property(_get_N_significant_peaks)

    def IDR2score(self, IDR):
        if IDR == 0:
            return 1000
        return min(int(-125 * pylab.log2(IDR)), 1000)

    def score2IDR(self, score):
        return 2 ** (score / -125)

    def plot_ranks(self, filename=None, savefig=False):
        # ranks
        # the *score* columns contains the scaled IDR value, min(int(log2(-125IDR), 1000).
        # e.g. peaks with an IDR of 0 have a score of 1000, idr 0.05 have a score of
        # int(-125log2(0.05)) = 540, and idr 1.0 has a score of 0.
        df1 = self.df.query("score>540")
        df2 = self.df.query("score<=540")
        pylab.clf()
        pylab.plot(df1.rep1_rank, df1.rep2_rank, "ko", alpha=0.5, label="<0.05 IDR")
        pylab.plot(df2.rep1_rank, df2.rep2_rank, "ro", alpha=0.5, label=">=0.05 IDR")
        pylab.xlabel("Peak rank - replicate 1")
        pylab.ylabel("Peak rank - replicate 2")
        N = len(self.df)
        pylab.plot([0, N], [0, N], color="blue", alpha=0.5, ls="--")
        pylab.legend(loc="lower right")
        if savefig:
            pylab.savefig(filename)

    def plot_scores(self, filename=None, savefig=False):
        # scores

        pylab.clf()
        pylab.plot(
            pylab.log10(self.df.query("score>540")["rep1_signal"]),
            pylab.log10(self.df.query("score>540")["rep2_signal"]),
            "ko",
            alpha=0.5,
            label="<0.05 IDR",
        )
        pylab.plot(
            pylab.log10(self.df.query("score<540")["rep1_signal"]),
            pylab.log10(self.df.query("score<540")["rep2_signal"]),
            "ro",
            alpha=0.5,
            label=">=0.05 IDR",
        )
        N = pylab.ylim()[1]
        pylab.plot([0, N], [0, N], color="blue", alpha=0.5, ls="--")
        pylab.xlabel("Rep1 log10 score")
        pylab.ylabel("Rep2 log10 score")
        pylab.legend(loc="lower right")
        if savefig:
            pylab.savefig(filename)

    def plot_rank_vs_idr_score(self, filename=None, savefig=False):
        # rank versus IDR scores
        f, axes = pylab.subplots(2, 1)
        df = self.df
        axes[0].plot(range(len(df)), df.sort_values(by="rep1_rank", ascending=False)["local_idr"], "o")
        axes[0].set_ylabel("log10 IDR for replicate 1")
        axes[0].axvline(len(self.df) - self.N_significant_peaks, color="b", ls="--")
        axes[1].plot(range(len(df)), df.sort_values(by="rep2_rank", ascending=False)["local_idr"], "ro")
        axes[1].set_ylabel("log10 IDR for replicate 2")
        axes[1].axvline(len(self.df) - self.N_significant_peaks, color="b", ls="--")
        if savefig:
            pylab.savefig(filename)

    def plot_idr_vs_peaks(self, filename=None, savefig=False):

        pylab.clf()
        X1 = pylab.linspace(0, self.threshold, 100)
        X2 = pylab.linspace(self.threshold, 1, 100)
        # convert local idr to proba

        df1 = self.df.query("idr<@self.threshold")
        df2 = self.df.query("idr>=@self.threshold")

        pylab.plot([sum(df1["idr"] < x) for x in X1], X1, "-", color="r", lw=2)
        shift = len(df1)

        pylab.plot([shift + sum(df2["idr"] < x) for x in X2], X2, "-", color="k", lw=2)
        pylab.xlabel("Number of significant peaks")
        pylab.ylabel("IDR")
        pylab.axhline(0.05, color="b", ls="--")
        pylab.axvline(self.N_significant_peaks, color="b", ls="--")
        if savefig:
            pylab.savefig(filename)
