# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/00_baseClasses.ipynb.

# %% ../nbs/00_baseClasses.ipynb 5
from __future__ import annotations
from fastcore.docments import *
from fastcore.test import *
from fastcore.utils import *

from sklearn.base import BaseEstimator

from abc import ABC, abstractmethod
import pandas as pd
import numpy as np

# %% auto 0
__all__ = ['BaseWeightsBasedEstimator']

# %% ../nbs/00_baseClasses.ipynb 7
class BaseWeightsBasedEstimator(BaseEstimator):
    """ 
    Base class that implements the 'prediction'-method for approaches based 
    on a reweighting of the empirical distribution.
    """
    
    # @abstractmethod
    # def getWeights(self, X):
    #     """Compute weights for every sample specified by feature matrix 'X'"""
    #     pass

    # def predict(self: BaseWeightsBasedEstimator, 
    #                      X: np.ndarray, # Feature matrix of samples for which conditional quantiles are computed.
    #                      probs: list | np.ndarray = [0.1, 0.5, 0.9], # Probabilities for which the estimated conditional p-quantiles are computed.
    #                      outputAsDf: bool = False, # Output is either a dataframe with 'probs' as cols or a dict with 'probs' as keys.
    #                      scalingList: list | np.ndarray | None = None, # List or array with same size as self.Y containing floats being multiplied with self.Y.
    #                      ):
    
    def predict(self, 
                X,
                probs = [0.1, 0.5, 0.9], 
                outputAsDf = True, 
                scalingList = None, 
                ):
        
        # Checks
        if isinstance(probs, float) or probs == 0 or probs == 1:
            probs = [probs]
            
        if any([prob > 1 or prob < 0 for prob in probs]):
            raise ValueError("The values specified via 'probs' must lie between 0 and 1!")
        
        #---
                             
        distributionDataList = self.getWeights(X = X,
                                               outputType = 'cumulativeDistribution',
                                               scalingList = scalingList)

        quantilesDict = {prob: [] for prob in probs}

        for probsDistributionFunction, yDistributionFunction in distributionDataList:

            for prob in probs:
                
                # A tolerance term of 10^-8 is substracted from prob to account for rounding errors due to numerical precision.
                quantileIndex = np.where(probsDistributionFunction >= prob - 10**-8)[0][0]
                    
                quantile = yDistributionFunction[quantileIndex]
                quantilesDict[prob].append(quantile)

        quantilesDf = pd.DataFrame(quantilesDict)

        # Just done to make the dictionary contain arrays rather than lists of the quantiles.
        quantilesDict = {prob: np.array(quantiles) for prob, quantiles in quantilesDict.items()}

        #---

        if outputAsDf:
            return quantilesDf

        else:
            return quantilesDict
    
