from typing import Union, Optional, List, Tuple, Dict

import pandas as pd
import numpy as np
import scipy.stats
from matplotlib import pyplot as plt

from .macro import Inflation
from .helpers import Float, Frame, Rebalance, Date, Index
from .settings import default_ticker, PeriodLength, _MONTHS_PER_YEAR
from .data import QueryData, get_assets_namespaces


class Asset:
    """
    An asset, that could be used in a list of assets or in portfolio.
    Works with monthly end of day historical rate of return data.
    """

    def __init__(self, symbol: str = default_ticker):
        self.symbol: str = symbol
        self._check_namespace()
        self._get_symbol_data(symbol)
        self.ror: pd.Series = QueryData.get_ror(symbol)
        self.first_date: pd.Timestamp = self.ror.index[0].to_timestamp()
        self.last_date: pd.Timestamp = self.ror.index[-1].to_timestamp()
        self.period_length: float = round((self.last_date - self.first_date) / np.timedelta64(365, 'D'), ndigits=1)

    def __repr__(self):
        dic = {
            'symbol': self.symbol,
            'name': self.name,
            'country': self.country,
            'exchange': self.exchange,
            'currency': self.currency,
            'type': self.type,
            'first date': self.first_date.strftime("%Y-%m"),
            'last date': self.last_date.strftime("%Y-%m"),
            'period length': "{:.2f}".format(self.period_length)
        }
        return repr(pd.Series(dic))

    def _check_namespace(self):
        namespace = self.symbol.split('.', 1)[-1]
        allowed_namespaces = get_assets_namespaces()
        if namespace not in allowed_namespaces:
            raise Exception(f'{namespace} is not in allowed assets namespaces: {allowed_namespaces}')

    def _get_symbol_data(self, symbol) -> None:
        x = QueryData.get_symbol_info(symbol)
        self.ticker: str = x['code']
        self.name: str = x['name']
        self.country: str = x['country']
        self.exchange: str = x['exchange']
        self.currency: str = x['currency']
        self.type: str = x['type']
        self.inflation: str = f'{self.currency}.INFL'

    @property
    def price(self) -> float:
        """
        Live price of an asset.
        """
        return QueryData.get_live_price(self.symbol)

    @property
    def dividends(self) -> pd.Series:
        """
        Dividends time series daily data.
        Not defined for namespaces: 'PIF', 'INFL', 'INDX', 'FX', 'COMM'
        """
        div = QueryData.get_dividends(self.symbol)
        if div.empty:
            # Zero time series for assets where dividend yield is not defined.
            index = pd.date_range(start=self.first_date, end=self.last_date, freq='MS', closed=None)
            period = index.to_period('D')
            div = pd.Series(data=0, index=period)
            div.rename(self.symbol, inplace=True)
        return div

    @property
    def nav_ts(self) -> pd.Series:
        """
        NAV time series (monthly) for mutual funds when available in data.
        """
        if self.exchange == 'PIF':
            return QueryData.get_nav(self.symbol)
        return np.nan


class AssetList:
    """
    The list of assets implementation.
    Works with monthly end of day historical rate of return data.
    """
    def __init__(self,
                 symbols: Optional[List[str]] = None, *,
                 first_date: Optional[str] = None,
                 last_date: Optional[str] = None,
                 ccy: str = 'USD',
                 inflation: bool = True):
        self.__symbols = symbols
        self.__tickers: List[str] = [x.split(".", 1)[0] for x in self.symbols]
        self.__currency: Asset = Asset(symbol=f'{ccy}.FX')
        self.__make_asset_list(self.symbols)
        if inflation:
            self.inflation: str = f'{ccy}.INFL'
            self._inflation_instance: Inflation = Inflation(self.inflation, self.first_date, self.last_date)
            self.inflation_ts: pd.Series = self._inflation_instance.values_ts
            self.inflation_first_date: pd.Timestamp = self._inflation_instance.first_date
            self.inflation_last_date: pd.Timestamp = self._inflation_instance.last_date
            self.first_date: pd.Timestamp = max(self.first_date, self.inflation_first_date)
            self.last_date: pd.Timestamp = min(self.last_date, self.inflation_last_date)
            # Add inflation to the date range dict
            self.assets_first_dates.update({self.inflation: self.inflation_first_date})
            self.assets_last_dates.update({self.inflation: self.inflation_last_date})
        if first_date:
            self.first_date: pd.Timestamp = max(self.first_date, pd.to_datetime(first_date))
        self.ror = self.ror[self.first_date:]
        if last_date:
            self.last_date: pd.Timestamp = min(self.last_date, pd.to_datetime(last_date))
        self.ror: pd.DataFrame = self.ror[self.first_date: self.last_date]
        self.period_length: float = round((self.last_date - self.first_date) / np.timedelta64(365, 'D'), ndigits=1)
        self.pl = PeriodLength(self.ror.shape[0] // _MONTHS_PER_YEAR, self.ror.shape[0] % _MONTHS_PER_YEAR)
        self._pl_txt = f'{self.pl.years} years, {self.pl.months} months'
        self._dividend_yield: pd.DataFrame = pd.DataFrame(dtype=float)
        self._dividends_ts: pd.DataFrame = pd.DataFrame(dtype=float)

    def __repr__(self):
        dic = {
            'symbols': self.symbols,
            'currency': self.currency.ticker,
            'first date': self.first_date.strftime("%Y-%m"),
            'last_date': self.last_date.strftime("%Y-%m"),
            'period length': self._pl_txt,
            'inflation': self.inflation if hasattr(self, 'inflation') else 'None',
        }
        return repr(pd.Series(dic))

    def __len__(self):
        return len(self.symbols)

    def __make_asset_list(self, ls: list) -> None:
        """
        Makes an asset list from a list of symbols. Returns DataFrame of returns (monthly) as an attribute.
        """
        first_dates: Dict[str, pd.Timestamp] = {}
        last_dates: Dict[str, pd.Timestamp] = {}
        names: Dict[str, str] = {}
        currencies: Dict[str, str] = {}
        for i, x in enumerate(ls):
            asset = Asset(x)
            if i == 0:
                if asset.currency == self.currency.name:
                    df = asset.ror
                else:
                    df = self._set_currency(returns=asset.ror, asset_currency=asset.currency)
            else:
                if asset.currency == self.currency.name:
                    new = asset.ror
                else:
                    new = self._set_currency(returns=asset.ror, asset_currency=asset.currency)
                df = pd.concat([df, new], axis=1, join='inner', copy='false')
            currencies.update({asset.symbol: asset.currency})
            names.update({asset.symbol: asset.name})
            first_dates.update({asset.symbol: asset.first_date})
            last_dates.update({asset.symbol: asset.last_date})
        # Add currency to the date range dict
        first_dates.update({self.currency.name: self.currency.first_date})
        last_dates.update({self.currency.name: self.currency.last_date})

        first_dates_sorted = sorted(first_dates.items(), key=lambda x: x[1])
        last_dates_sorted = sorted(last_dates.items(), key=lambda x: x[1])
        self.first_date: pd.Timestamp = first_dates_sorted[-1][1]
        self.last_date: pd.Timestamp = last_dates_sorted[0][1]
        self.newest_asset: str = first_dates_sorted[-1][0]
        self.eldest_asset: str = first_dates_sorted[0][0]
        self.names: Dict[str, str] = names
        currencies.update({'asset list': self.currency.currency})
        self.currencies: Dict[str, str] = currencies
        self.assets_first_dates: Dict[str, pd.Timestamp] = dict(first_dates_sorted)
        self.assets_last_dates: Dict[str, pd.Timestamp] = dict(last_dates_sorted)
        if isinstance(df, pd.Series):  # required to convert Series to DataFrame for single asset list
            df = df.to_frame()
        self.ror: pd.DataFrame = df

    def _set_currency(self, returns: pd.Series, asset_currency: str) -> pd.Series:
        """
        Set return to a certain currency. Input is a pd.Series of mean returns and a currency symbol.
        """
        currency = Asset(symbol=f'{asset_currency}{self.currency.name}.FX')
        asset_mult = returns + 1.
        currency_mult = currency.ror + 1.
        # join dataframes to have the same Time Series Index
        df = pd.concat([asset_mult, currency_mult], axis=1, join='inner', copy='false')
        currency_mult = df.iloc[:, -1]
        asset_mult = df.iloc[:, 0]
        x = asset_mult * currency_mult - 1.
        x.rename(returns.name, inplace=True)
        return x

    @property
    def symbols(self):
        symbols = [default_ticker] if not self.__symbols else self.__symbols
        if not isinstance(symbols, list):
            raise ValueError('Symbols should be a list of string values.')
        return symbols

    @property
    def tickers(self):
        return self.__tickers

    @property
    def currency(self):
        return self.__currency

    @property
    def wealth_indexes(self) -> pd.DataFrame:
        """
        Wealth index time series for the assets and accumulated inflation.
        Wealth index is obtained from the accumulated return multiplicated by the initial investments (1000).
        """
        if hasattr(self, 'inflation'):
            df = pd.concat([self.ror, self.inflation_ts], axis=1, join='inner', copy='false')
        else:
            df = self.ror
        return Frame.get_wealth_indexes(df)

    @property
    def risk_monthly(self) -> pd.Series:
        """
        Takes assets returns DataFrame and calculates monthly risks (std) for each asset.
        """
        return self.ror.std()

    @property
    def risk_annual(self) -> pd.Series:
        """
        Takes assets returns DataFrame and calculates annulized risks (std) for each asset.
        """
        risk = self.ror.std()
        mean_return = self.ror.mean()
        return Float.annualize_risk(risk, mean_return)

    @property
    def semideviation_monthly(self) -> pd.Series:
        """
        Returns semideviation monthly values for each asset (full period).
        """
        return Frame.get_semideviation(self.ror)

    @property
    def semideviation_annual(self) -> float:
        """
        Returns semideviation annual values for each asset (full period).
        """
        return Frame.get_semideviation(self.returns_ts) * 12 ** 0.5

    def get_var_historic(self, level: int = 5) -> pd.Series:
        """
        Calculates historic VAR for the assets (full period).
        VAR levels could be set by level attribute (integer).
        """
        return Frame.get_var_historic(self.ror, level)

    def get_cvar_historic(self, level: int = 5) -> pd.Series:
        """
        Calculates historic CVAR for the assets (full period).
        CVAR levels could be set by level attribute (integer).
        """
        return Frame.get_cvar_historic(self.ror, level)

    @property
    def drawdowns(self) -> pd.DataFrame:
        """
        Calculates drawdowns time series for the assets.
        """
        return Frame.get_drawdowns(self.ror)

    def get_cagr(self, period: Union[str, int, None] = None) -> pd.Series:
        """
        Calculates Compound Annual Growth Rate (CAGR) for a given period:
        None: full time
        'YTD': Year To Date compound rate of return (formally not a CAGR)
        Integer: several years
        """
        if hasattr(self, 'inflation'):
            df: pd.DataFrame = pd.concat([self.ror, self.inflation_ts], axis=1, join='inner', copy='false')
        else:
            df = self.ror
        dt0 = self.last_date

        if not period:
            cagr = Frame.get_cagr(df)
        elif period == 'YTD':
            year = dt0.year
            cagr = (df[str(year):] + 1.).prod() - 1.
        elif isinstance(period, int):
            dt = Date.subtract_years(dt0, period)
            if dt >= self.first_date:
                cagr = Frame.get_cagr(df[dt:])
            else:
                row = {x: None for x in df.columns}
                cagr = pd.Series(row)
        else:
            raise ValueError(f'{period} is not a valid value for period')
        return cagr

    @property
    def annual_return_ts(self) -> pd.DataFrame:
        """
        Calculates annual rate of return time series for the assets.
        """
        return Frame.get_annual_return_ts_from_monthly(self.ror)

    def describe(self, years: tuple = (1, 5, 10), tickers: bool = True) -> pd.DataFrame:
        """
        Generate descriptive statistics for a given list of tickers.
        Statistics includes:
        - YTD compound return
        - CAGR for a given list of periods
        - Dividend yield - yield for last 12 months (LTM)
        - risk (std) for a full period
        - CVAR for a full period
        - max drawdowns (and dates) for a full period
        - inception date - first date available for each asset
        - last asset date - available for each asset date
        - last data data - common for all assets data (may be set by last_date manually)
        """
        description = pd.DataFrame()
        dt0 = self.last_date
        if hasattr(self, 'inflation'):
            df = pd.concat([self.ror, self.inflation_ts], axis=1, join='inner', copy='false')
        else:
            df = self.ror
        # YTD return
        ytd_return = self.get_cagr(period='YTD')
        row = ytd_return.to_dict()
        row.update({'period': 'YTD'})
        row.update({'property': 'Compound return'})
        description = description.append(row, ignore_index=True)
        # CAGR for a list of periods
        for i in years:
            dt = Date.subtract_years(dt0, i)
            if dt >= self.first_date:
                row = self.get_cagr(period=i).to_dict()
            else:
                row = {x: None for x in df.columns}
            row.update({'period': f'{i} years'})
            row.update({'property': 'CAGR'})
            description = description.append(row, ignore_index=True)
        # CAGR for full period
        row = self.get_cagr(period=None).to_dict()
        row.update({'period': self._pl_txt})
        row.update({'property': 'CAGR'})
        description = description.append(row, ignore_index=True)
        # Dividend Yield
        row = self.dividend_yield.iloc[-1].to_dict()
        row.update({'period': 'LTM'})
        row.update({'property': 'Dividend yield'})
        description = description.append(row, ignore_index=True)
        # risk for full period
        row = self.risk_annual.to_dict()
        row.update({'period': self._pl_txt})
        row.update({'property': 'Risk'})
        description = description.append(row, ignore_index=True)
        # CVAR
        row = self.get_cvar_historic().to_dict()
        row.update({'period': self._pl_txt})
        row.update({'property': 'CVAR'})
        description = description.append(row, ignore_index=True)
        # max drawdowns
        row = self.drawdowns.min().to_dict()
        row.update({'period': self._pl_txt})
        row.update({'property': 'Max drawdowns'})
        description = description.append(row, ignore_index=True)
        # max drawdowns dates
        row = self.drawdowns.idxmin().to_dict()
        row.update({'period': self._pl_txt})
        row.update({'property': 'Max drawdowns dates'})
        description = description.append(row, ignore_index=True)
        # inception dates
        row = {}
        for ti in self.symbols:
            # short_ticker = ti.split(".", 1)[0]
            value = self.assets_first_dates[ti].strftime("%Y-%m")
            row.update({ti: value})
        row.update({'period': None})
        row.update({'property': 'Inception date'})
        if hasattr(self, 'inflation'):
            row.update({self.inflation: self.inflation_first_date.strftime("%Y-%m")})
        description = description.append(row, ignore_index=True)
        # last asset date
        row = {}
        for ti in self.symbols:
            # short_ticker = ti.split(".", 1)[0]
            value = self.assets_last_dates[ti].strftime("%Y-%m")
            row.update({ti: value})
        row.update({'period': None})
        row.update({'property': 'Last asset date'})
        if hasattr(self, 'inflation'):
            row.update({self.inflation: self.inflation_last_date.strftime("%Y-%m")})
        description = description.append(row, ignore_index=True)
        # last data date
        row = {x: self.last_date.strftime("%Y-%m") for x in df.columns}
        row.update({'period': None})
        row.update({'property': 'Common last data date'})
        description = description.append(row, ignore_index=True)
        # rename columns
        if hasattr(self, 'inflation'):
            description.rename(columns={self.inflation: 'inflation'}, inplace=True)
            description = Frame.change_columns_order(description, ['inflation'], position='last')
        description = Frame.change_columns_order(description, ['property', 'period'], position='first')
        if not tickers:
            for ti in self.symbols:
                # short_ticker = ti.split(".", 1)[0]
                description.rename(columns={ti: self.names[ti]}, inplace=True)
        return description

    @property
    def mean_return(self) -> pd.Series:
        """
        Calculates mean return (arithmetic mean) for the assets.
        """
        if hasattr(self, 'inflation'):
            df = pd.concat([self.ror, self.inflation_ts], axis=1, join='inner', copy='false')
        else:
            df = self.ror
        mean: pd.Series = df.mean()
        return Float.annualize_return(mean)

    @property
    def real_mean_return(self) -> pd.Series:
        """
        Calculates real mean return (arithmetic mean) for the assets.
        """
        if hasattr(self, 'inflation'):
            df = pd.concat([self.ror, self.inflation_ts], axis=1, join='inner', copy='false')
        else:
            raise Exception('Real Return is not defined. Set inflation=True to calculate.')
        infl_mean = Float.annualize_return(self.inflation_ts.values.mean())
        ror_mean = Float.annualize_return(df.loc[:, self.symbols].mean())
        return (1. + ror_mean) / (1. + infl_mean) - 1.

    def _get_asset_dividends(self, tick, remove_forecast=True) -> pd.Series:
        first_period = pd.Period(self.first_date, freq='M')
        first_day = first_period.to_timestamp(how='Start')
        last_period = pd.Period(self.last_date, freq='M')
        last_day = last_period.to_timestamp(how='End')
        s = Asset(tick).dividends[first_day: last_day]  # limit divs by first_day and last_day
        if remove_forecast:
            s = s[:pd.Period.now(freq='D')]
        # Create time series with zeros to pad the empty spaces in dividends time series
        index = pd.date_range(start=first_day, end=last_day, freq='D')
        period = index.to_period('D')
        pad_s = pd.Series(data=0, index=period)
        return s.add(pad_s, fill_value=0)

    def _get_dividends(self, remove_forecast=True) -> pd.DataFrame:
        if self._dividends_ts.empty:
            dic = {}
            for tick in self.symbols:
                s = self._get_asset_dividends(tick, remove_forecast=remove_forecast)
                dic.update({tick: s})
            self._dividends_ts = pd.DataFrame(dic)
        return self._dividends_ts

    @property
    def dividend_yield(self) -> pd.DataFrame:
        """
        Dividend yield (LTM) time series monthly.
        Calculates yield assuming original asset currency (not adjusting to AssetList currency).
        Forecast dividends are removed.
        """
        if self._dividend_yield.empty:
            frame = {}
            df = self._get_dividends(remove_forecast=True)
            for tick in self.symbols:
                # Get dividends time series
                div = df[tick]
                # Get close (not adjusted) values time series.
                # If the last_date month is current month live price of assets is used.
                if div.sum() != 0:
                    div_monthly = div.resample('M').sum()
                    price = QueryData.get_close(tick, period='M').loc[self.first_date: self.last_date]
                else:
                    # skipping prices if no dividends
                    div_yield = div.asfreq(freq='M')
                    frame.update({tick: div_yield})
                    continue
                if price.index[-1] == pd.Period(pd.Timestamp.today(), freq='M'):
                    price.loc[f'{pd.Timestamp.today().year}-{pd.Timestamp.today().month}'] = Asset(tick).price
                # Get dividend yield time series
                div_yield = pd.Series(dtype=float)
                div_monthly.index = div_monthly.index.to_timestamp()
                for date in price.index.to_timestamp(how='End'):
                    ltm_div = div_monthly[:date].last('12M').sum()
                    last_price = price.loc[:date].iloc[-1]
                    value = ltm_div / last_price
                    div_yield.at[date] = value
                div_yield.index = div_yield.index.to_period('M')
                # Currency adjusted yield
                # if self.currencies[tick] != self.currency.name:
                #     div_yield = self._set_currency(returns=div_yield, asset_currency=self.currencies[tick])
                frame.update({tick: div_yield})
            self._dividend_yield = pd.DataFrame(frame)
        return self._dividend_yield

    @property
    def dividends_annual(self) -> pd.DataFrame:
        """
        Time series of dividends for a calendar year.
        """
        return self._get_dividends().resample('Y').sum()

    @property
    def dividend_growing_years(self) -> pd.DataFrame:
        """
        Returns the number of growing dividend years for each asset.
        """
        div_growth = self.dividends_annual.pct_change()[1:]
        df = pd.DataFrame()
        for name in div_growth:
            s = div_growth[name]
            s1 = s.where(s > 0).notnull().astype(int)
            s1_1 = s.where(s > 0).isnull().astype(int).cumsum()
            s2 = s1.groupby(s1_1).cumsum()
            df = pd.concat([df, s2], axis=1, copy='false')
        return df

    @property
    def dividend_paying_years(self) -> pd.DataFrame:
        """
        Returns the number of years of consecutive dividend payments.
        """
        div_annual = self.dividends_annual
        frame = pd.DataFrame()
        df = frame
        for name in div_annual:
            s = div_annual[name]
            s1 = s.where(s != 0).notnull().astype(int)
            s1_1 = s.where(s != 0).isnull().astype(int).cumsum()
            s2 = s1.groupby(s1_1).cumsum()
            df = pd.concat([df, s2], axis=1, copy='false')
        return df

    def get_dividend_mean_growth_rate(self, period=5) -> pd.Series:
        """
        Calculates geometric mean of dividends growth rate time series for a certain period.
        Period should be integer and not exceed the available data period_length.
        """
        if period > self.pl.years or not isinstance(period, int):
            raise TypeError(f'{period} is not a valid value for period')
        growth_ts = self.dividends_annual.pct_change().iloc[1:-1]  # Slice the last year for full dividends
        dt0 = self.last_date
        dt = Date.subtract_years(dt0, period)
        return ((growth_ts[dt:] + 1.).prod()) ** (1 / period) - 1.

    # index methods
    @property
    def tracking_difference(self):
        """
        Returns tracking difference for the rate of return of assets.
        Assets are compared with the index or another benchmark.
        Index should be in the first position (first column).
        """
        accumulated_return = Frame.get_wealth_indexes(self.ror)  # we don't need inflation here
        return Index.tracking_difference(accumulated_return)

    @property
    def tracking_difference_annualized(self):
        """
        Annualizes the values of tracking difference time series.
        Annual values are available for periods of more than 12 months.
        Returns for less than 12 months can't be annualized.
        """
        return Index.tracking_difference_annualized(self.tracking_difference)

    @property
    def tracking_error(self):
        """
        Returns tracking error for the rate of return time series of assets.
        Assets are compared with the index or another benchmark.
        Index should be in the first position (first column).
        """
        return Index.tracking_error(self.ror)

    @property
    def index_corr(self):
        """
        Compute expanding correlation with the index (or benchmark) time series for the assets.
        Index should be in the first position (first column).
        The period should be at least 12 months.
        """
        return Index.cov_cor(self.ror, fn='corr')

    def index_rolling_corr(self, window: int = 60):
        """
        Compute rolling correlation with the index (or benchmark) time series for the assets.
        Index should be in the first position (first column).
        The period should be at least 12 months.
        window - the rolling window size in months (default is 5 years).
        """
        return Index.rolling_cov_cor(self.ror, window=window, fn='corr')

    @property
    def index_beta(self):
        """
        Compute beta coefficient time series for the assets.
        Index (or benchmark) should be in the first position (first column).
        Rolling window size should be at least 12 months.
        """
        return Index.beta(self.ror)

    # distributions
    @property
    def skewness(self):
        """
        Compute expanding skewness of the return time series for each asset returns.
        For normally distributed data, the skewness should be about zero.
        A skewness value greater than zero means that there is more weight in the right tail of the distribution.
        """
        return Frame.skewness(self.ror)

    def skewness_rolling(self, window: int = 60):
        """
        Compute rolling skewness of the return time series for each asset returns.
        For normally distributed data, the skewness should be about zero.
        A skewness value greater than zero means that there is more weight in the right tail of the distribution.

        window - the rolling window size in months (default is 5 years).
        The window size should be at least 12 months.
        """
        return Frame.skewness_rolling(self.ror, window=window)

    @property
    def kurtosis(self):
        """
        Calculate expanding Fisher (normalized) kurtosis time series for each asset returns.
        Kurtosis is the fourth central moment divided by the square of the variance.
        Kurtosis should be close to zero for normal distribution.
        """
        return Frame.kurtosis(self.ror)

    def kurtosis_rolling(self, window: int = 60):
        """
        Calculate rolling Fisher (normalized) kurtosis time series for each asset returns.
        Kurtosis is the fourth central moment divided by the square of the variance.
        Kurtosis should be close to zero for normal distribution.

        window - the rolling window size in months (default is 5 years).
        The window size should be at least 12 months.
        """
        return Frame.kurtosis_rolling(self.ror, window=window)

    @property
    def jarque_bera(self):
        """
        Perform Jarque-Bera test for normality of assets returns historical data.
        It shows whether the returns have the skewness and kurtosis matching a normal distribution.

        Returns:
            (The test statistic, The p-value for the hypothesis test)
            Low statistic numbers correspond to normal distribution.
        """
        return Frame.jarque_bera_dataframe(self.ror)

    def kstest(self, distr: str = 'norm') -> dict:
        """
        Perform Kolmogorov-Smirnov test for goodness of fit the asset returns to a given distribution.

        Returns:
            (The test statistic, The p-value for the hypothesis test)
            Low statistic numbers correspond to normal distribution.
        """
        return Frame.kstest_dataframe(self.ror, distr=distr)


class Portfolio:
    """
    Implementation of investment portfolio.
    Arguments are similar to AssetList (weights are added), but different behavior.
    Works with monthly end of day historical rate of return data.
    """
    def __init__(self,
                 symbols: Optional[List[str]] = None, *,
                 first_date: Optional[str] = None,
                 last_date: Optional[str] = None,
                 ccy: str = 'USD',
                 inflation: bool = True,
                 weights: Optional[List[float]] = None):
        self._list: AssetList = AssetList(symbols=symbols, first_date=first_date, last_date=last_date,
                                          ccy=ccy, inflation=inflation)
        self.currency: str = self._list.currency.name
        self._ror: pd.DataFrame = self._list.ror
        self.symbols: List[str] = self._list.symbols
        self.tickers: List[str] = [x.split(".", 1)[0] for x in self.symbols]
        self.names: Dict[str, str] = self._list.names
        self._weights = None
        self.weights = weights
        self.assets_weights = dict(zip(self.symbols, self.weights))
        self.assets_first_dates: Dict[str, pd.Timestamp] = self._list.assets_first_dates
        self.assets_last_dates: Dict[str, pd.Timestamp] = self._list.assets_last_dates
        self.first_date = self._list.first_date
        self.last_date = self._list.last_date
        self.period_length = self._list.period_length
        self.pl = PeriodLength(self.returns_ts.shape[0] // _MONTHS_PER_YEAR, self.returns_ts.shape[0] % _MONTHS_PER_YEAR)
        self._pl_txt = f'{self.pl.years} years, {self.pl.months} months'
        if inflation:
            self.inflation = self._list.inflation
            self.inflation_ts: pd.Series = self._list.inflation_ts

    def __repr__(self):
        dic = {
            'symbols': self.symbols,
            'weights': self.weights,
            'currency': self.currency,
            'first date': self.first_date.strftime("%Y-%m"),
            'last_date': self.last_date.strftime("%Y-%m"),
            'period length': self._pl_txt
        }
        return repr(pd.Series(dic))

    def __len__(self):
        return len(self.symbols)

    @property
    def weights(self):
        return self._weights

    @weights.setter
    def weights(self, weights: list):
        if weights is None:
            # Equally weighted portfolio
            n = len(self.symbols)  # number of assets
            weights = list(np.repeat(1/n, n))
        else:
            Frame.weights_sum_is_one(weights)
            if len(weights) != len(self.symbols):
                raise Exception(f'Number of tickers ({len(self.symbols)}) should be equal '
                                f'to the weights number ({len(weights)})')
        self._weights = weights

    @property
    def returns_ts(self) -> pd.Series:
        """
        Rate of return time series for portfolio.
        Returns:
            pd.Series
        """
        s = Frame.get_portfolio_return_ts(self.weights, self._ror)
        s.rename('portfolio', inplace=True)
        return s

    @property
    def wealth_index(self) -> pd.DataFrame:
        if hasattr(self, 'inflation'):
            df = pd.concat([self.returns_ts, self.inflation_ts], axis=1, join='inner', copy='false')
        else:
            df = self.returns_ts
        df = Frame.get_wealth_indexes(df)
        if isinstance(df, pd.Series):  # return should always be DataFrame
            df = df.to_frame()
            df.rename({1: 'portfolio'}, axis='columns', inplace=True)
        return df

    @property
    def wealth_index_with_assets(self) -> pd.Series:
        if hasattr(self, 'inflation'):
            df = pd.concat([self.returns_ts, self._ror, self.inflation_ts], axis=1, join='inner', copy='false')
        else:
            df = pd.concat([self.returns_ts, self._ror], axis=1, join='inner', copy='false')
        return Frame.get_wealth_indexes(df)

    def get_rebalanced_portfolio_return_ts(self, period='year') -> pd.Series:
        return Rebalance.rebalanced_portfolio_return_ts(self.weights, self._ror, period=period)

    @property
    def mean_return_monthly(self) -> float:
        return Frame.get_portfolio_mean_return(self.weights, self._ror)

    @property
    def mean_return_annual(self) -> float:
        return Float.annualize_return(self.mean_return_monthly)

    @property
    def cagr(self) -> Union[pd.Series, float]:
        if hasattr(self, 'inflation'):
            df = pd.concat([self.returns_ts, self.inflation_ts], axis=1, join='inner', copy='false')
        else:
            df = self.returns_ts
        return Frame.get_cagr(df)

    @property
    def annual_return_ts(self) -> pd.DataFrame:
        return Frame.get_annual_return_ts_from_monthly(self.returns_ts)

    @property
    def dividend_yield(self) -> pd.DataFrame:
        """
        Calculates dividend yield time series in all base currencies of portfolio assets.
        For every currency dividend yield is a weighted sum of the assets dividend yields.
        """
        div_yield_assets = self._list.dividend_yield
        currencies_dict = self._list.currencies
        if 'asset list' in currencies_dict:
            del currencies_dict['asset list']
        currencies_list = list(set(currencies_dict.values()))
        div_yield_df = pd.DataFrame(dtype=float)
        for currency in currencies_list:
            assets_with_the_same_currency = [x for x in currencies_dict if currencies_dict[x] == currency]
            df = div_yield_assets[assets_with_the_same_currency]
            weights = [self.assets_weights[k] for k in self.assets_weights if k in assets_with_the_same_currency]
            weighted_weights = np.asarray(weights) / np.asarray(weights).sum()
            div_yield_series = Frame.get_portfolio_return_ts(weighted_weights, df)
            div_yield_series.rename(currency, inplace=True)
            div_yield_df = pd.concat([div_yield_df, div_yield_series], axis=1)
        return div_yield_df

    @property
    def real_mean_return(self) -> float:
        if not hasattr(self, 'inflation'):
            raise Exception('Real Return is not defined. Set inflation=True to calculate.')
        infl_mean = Float.annualize_return(self.inflation_ts.mean())
        ror_mean = Float.annualize_return(self.returns_ts.mean())
        return (1. + ror_mean) / (1. + infl_mean) - 1.

    @property
    def real_cagr(self) -> float:
        if not hasattr(self, 'inflation'):
            raise Exception('Real Return is not defined. Set inflation=True to calculate.')
        infl_cagr = Frame.get_cagr(self.inflation_ts)
        ror_cagr = Frame.get_cagr(self.returns_ts)
        return (1. + ror_cagr) / (1. + infl_cagr) - 1.

    @property
    def risk_monthly(self) -> float:
        return Frame.get_portfolio_risk(self.weights, self._ror)

    @property
    def risk_annual(self) -> float:
        return Float.annualize_risk(self.risk_monthly, self.mean_return_monthly)

    @property
    def semideviation_monthly(self) -> float:
        return Frame.get_semideviation(self.returns_ts)

    @property
    def semideviation_annual(self) -> float:
        return Frame.get_semideviation(self.returns_ts) * 12 ** 0.5

    def get_var_historic(self, level=5) -> float:
        rolling = self.returns_ts.rolling(12).apply(Frame.get_cagr)
        return Frame.get_var_historic(rolling, level)

    def get_cvar_historic(self, level=5) -> float:
        rolling = self.returns_ts.rolling(12).apply(Frame.get_cagr)
        return Frame.get_cvar_historic(rolling, level)

    @property
    def drawdowns(self) -> pd.Series:
        return Frame.get_drawdowns(self.returns_ts)

    def describe(self, years: Tuple[int] = (1, 5, 10)) -> pd.DataFrame:
        """
        Generate descriptive statistics for a given list of tickers.
        Statistics includes:
        - YTD compound return
        - CAGR for a given list of periods
        - risk (std) for a full period
        - CVAR for a full period
        - max drawdowns (and dates) for a full period
        TODO: add dividend yield
        """
        description = pd.DataFrame()
        dt0 = self.last_date
        if hasattr(self, 'inflation'):
            df = pd.concat([self.returns_ts, self.inflation_ts], axis=1, join='inner', copy='false')
        else:
            df = self.returns_ts
        # YTD return
        year = dt0.year
        ts = Rebalance.rebalanced_portfolio_return_ts(self.weights, self._ror[str(year):], period='none')
        value = Frame.get_compound_return(ts)
        if hasattr(self, 'inflation'):
            ts = df[str(year):].loc[:, self.inflation]
            inflation = Frame.get_compound_return(ts)
            row = {'portfolio': value, self.inflation: inflation}
        else:
            row = {'portfolio': value}
        row.update({'period': 'YTD'})
        row.update({'rebalancing': 'Not rebalanced'})
        row.update({'property': 'compound return'})
        description = description.append(row, ignore_index=True)
        # CAGR for a list of periods (rebalanced 1 month)
        for i in years:
            dt = Date.subtract_years(dt0, i)
            if dt >= self.first_date:
                ts = Rebalance.rebalanced_portfolio_return_ts(self.weights, self._ror[dt:], period='year')
                value = Frame.get_cagr(ts)
                if hasattr(self, 'inflation'):
                    ts = df[dt:].loc[:, self.inflation]
                    inflation = Frame.get_cagr(ts)
                    row = {'portfolio': value, self.inflation: inflation}
                else:
                    row = {'portfolio': value}
            else:
                row = {x: None for x in df.columns}
            row.update({'period': f'{i} years'})
            row.update({'rebalancing': '1 year'})
            row.update({'property': 'CAGR'})
            description = description.append(row, ignore_index=True)
        # CAGR for full period (rebalanced 1 year)
        ts = Rebalance.rebalanced_portfolio_return_ts(self.weights, self._ror, period='year')
        value = Frame.get_cagr(ts)
        if hasattr(self, 'inflation'):
            ts = df.loc[:, self.inflation]
            full_inflation = Frame.get_cagr(ts)  # full period inflation is required for following calc
            row = {'portfolio': value, self.inflation: full_inflation}
        else:
            row = {'portfolio': value}
        row.update({'period': f'{self.period_length} years'})
        row.update({'rebalancing': '1 year'})
        row.update({'property': 'CAGR'})
        description = description.append(row, ignore_index=True)
        # CAGR rebalanced 1 month
        value = self.cagr
        if hasattr(self, 'inflation'):
            row = value.to_dict()
            full_inflation = value.loc[self.inflation]  # full period inflation is required for following calc
        else:
            row = {'portfolio': value}
        row.update({'period': f'{self.period_length} years'})
        row.update({'rebalancing': '1 month'})
        row.update({'property': 'CAGR'})
        description = description.append(row, ignore_index=True)
        # CAGR not rebalanced
        value = Frame.get_cagr(self.get_rebalanced_portfolio_return_ts(period='none'))
        if hasattr(self, 'inflation'):
            row = {'portfolio': value, self.inflation: full_inflation}
        else:
            row = {'portfolio': value}
        row.update({'period': f'{self.period_length} years'})
        row.update({'rebalancing': 'Not rebalanced'})
        row.update({'property': 'CAGR'})
        description = description.append(row, ignore_index=True)
        # risk (rebalanced 1 month)
        row = {'portfolio': self.risk_annual}
        row.update({'period': f'{self.period_length} years'})
        row.update({'rebalancing': '1 month'})
        row.update({'property': 'Risk'})
        description = description.append(row, ignore_index=True)
        # CVAR (rebalanced 1 month)
        row = {'portfolio': self.get_cvar_historic()}
        row.update({'period': f'{self.period_length} years'})
        row.update({'rebalancing': '1 month'})
        row.update({'property': 'CVAR'})
        description = description.append(row, ignore_index=True)
        # max drawdowns (rebalanced 1 month)
        row = {'portfolio': self.drawdowns.min()}
        row.update({'period': f'{self.period_length} years'})
        row.update({'rebalancing': '1 month'})
        row.update({'property': 'Max drawdown'})
        description = description.append(row, ignore_index=True)
        # max drawdowns dates
        row = {'portfolio': self.drawdowns.idxmin()}
        row.update({'period': f'{self.period_length} years'})
        row.update({'rebalancing': '1 month'})
        row.update({'property': 'Max drawdown date'})
        description = description.append(row, ignore_index=True)
        if hasattr(self, 'inflation'):
            description.rename(columns={self.inflation: 'inflation'}, inplace=True)
        description = Frame.change_columns_order(description, ['property', 'rebalancing', 'period', 'portfolio'])
        return description

    @property
    def table(self) -> pd.DataFrame:
        """
        Returns security name - ticker - weight DataFrame table.
        """
        x = pd.DataFrame(data={'asset name': list(self.names.values()), 'ticker': list(self.names.keys())})
        x['weights'] = self.weights
        return x

    def get_rolling_cagr(self, years: int = 1) -> pd.Series:
        """
        Rolling portfolio CAGR (annualized rate of return) time series.
        TODO: check if self.period_length is below 1 year
        """
        rolling_return = (self.returns_ts + 1.).rolling(_MONTHS_PER_YEAR * years).apply(np.prod, raw=True) ** (1 / years) - 1.
        rolling_return.dropna(inplace=True)
        return rolling_return

    # Forecasting

    def _test_forecast_period(self, years):
        max_period_years = round(self.period_length / 2)
        if max_period_years < 1:
            raise ValueError(f'Time series does not have enough history to forecast. '
                             f'Period length is {self.period_length:.2f} years. At least 2 years are required.')
        if not isinstance(years, int) or years == 0:
            raise ValueError('years must be an integer number (not equal to zero).')
        if years > max_period_years:
            raise ValueError(f'Forecast period {years} years is not credible. '
                             f'It should not exceed 1/2 of portfolio history period length {self.period_length / 2} years')

    def percentile_inverse(self, distr: str = 'norm', years: int = 1, score: float = 0, n: Optional[int] = None) -> float:
        """
        Compute the percentile rank of a score relative to an array of CAGR values.
        A percentile_inverse of, for example, 80% means that 80% of the scores in distr are below the given score.

        Args:
            distr: norm, lognorm, hist - distribution type (normal or lognormal) or hist for CAGR array from history
            years: period length when CAGR is calculated
            score: score that is compared to the elements in CAGR array.
            n: number of random time series (for 'norm' or 'lognorm' only)

        Returns:
            Percentile-position of score (0-100) relative to distr.
        """
        if distr == 'hist':
            cagr_distr = self.get_rolling_cagr(years)
        elif distr in ['norm', 'lognorm']:
            if not n:
                n = 1000
            cagr_distr = self._get_monte_carlo_cagr_distribution(distr=distr, years=years, n=n)
        else:
            raise ValueError('distr should be one of "norm", "lognorm", "hist".')
        return scipy.stats.percentileofscore(cagr_distr, score, kind='rank')

    def percentile_from_history(self, years: int, percentiles: List[int] = [10, 50, 90]) -> pd.DataFrame:
        """
        Calculate given percentiles for portfolio CAGR (annualized rolling returns) distribution from the historical data.
        Each percentile is calculated for a period range from 1 year to 'years'.

        years - max window size for rolling CAGR (limited with half history of period length).
        percentiles - list of percentiles to be calculated
        """
        self._test_forecast_period(years)
        period_range = range(1, years + 1)
        returns_dict = {}
        for percentile in percentiles:
            percentile_returns_list = [self.get_rolling_cagr(years).quantile(percentile / 100) for years in period_range]
            returns_dict.update({percentile: percentile_returns_list})
        df = pd.DataFrame(returns_dict, index=list(period_range))
        df.index.rename('years', inplace=True)
        return df

    def forecast_wealth_history(self, years: int = 1, percentiles: List[int] = [10, 50, 90]) -> pd.DataFrame:
        """
        Compute accumulated wealth for each CAGR derived by 'percentile_from_history' method.
        CAGRs are taken from the historical data.

        Initial portfolio wealth is adjusted to the last known historical value (from wealth_index). It is useful
        for a chart with historical wealth index and forecasted values.

        Args:
            years:
            percentiles:

        Returns:
            Dataframe of percentiles for period range from 1 to 'years'
        """
        first_value = self.wealth_index['portfolio'].values[-1]
        percentile_returns = self.percentile_from_history(years=years, percentiles=percentiles)
        return first_value * (percentile_returns + 1.).pow(percentile_returns.index.values, axis=0)

    def _forecast_preparation(self, years: int):
        self._test_forecast_period(years)
        period_months = years * _MONTHS_PER_YEAR
        # make periods index where the shape is max_period
        start_period = self.last_date.to_period('M')
        end_period = self.last_date.to_period('M') + period_months - 1
        ts_index = pd.period_range(start_period, end_period, freq='M')
        return period_months, ts_index

    def forecast_monte_carlo_returns(self, distr: str = 'norm', years: int = 1, n: int = 100) -> pd.DataFrame:
        """
        Generates N random monthly returns time series with normal or lognormal distributions.
        Forecast period should not exceed 1/2 of portfolio history period length.
        """
        period_months, ts_index = self._forecast_preparation(years)
        # random returns
        if distr == 'norm':
            random_returns = np.random.normal(self.mean_return_monthly, self.risk_monthly, (period_months, n))
        elif distr == 'lognorm':
            std, loc, scale = scipy.stats.lognorm.fit(self.returns_ts)
            random_returns = scipy.stats.lognorm(std, loc=loc, scale=scale).rvs(size=[period_months, n])
        else:
            raise ValueError('distr should be "norm" (default) or "lognorm".')
        return pd.DataFrame(data=random_returns, index=ts_index)

    def forecast_monte_carlo_wealth_indexes(self, distr: str = 'norm', years: int = 1, n: int = 100) -> pd.DataFrame:
        """
        Generates N future random wealth indexes.
        Random distribution could be normal or lognormal.

        First value for the forecasted wealth indexes is the last historical portfolio index value. It is useful
        for a chart with historical wealth index and forecasted values.
        """
        if distr not in ['norm', 'lognorm']:
            raise ValueError('distr should be "norm" (default) or "lognorm".')
        return_ts = self.forecast_monte_carlo_returns(distr=distr, years=years, n=n)
        first_value = self.wealth_index['portfolio'].values[-1]
        return Frame.get_wealth_indexes(return_ts, first_value)

    def _get_monte_carlo_cagr_distribution(self,
                                           distr: str = 'norm',
                                           years: int = 1,
                                           n: int = 100,
                                           ) -> pd.Series:
        """
        Generate random CAGR distribution.
        CAGR is calculated for each of N future random returns time series.
        Random distribution could be normal or lognormal.
        """
        if distr not in ['norm', 'lognorm']:
            raise ValueError('distr should be "norm" (default) or "lognorm".')
        return_ts = self.forecast_monte_carlo_returns(distr=distr, years=years, n=n)
        return Frame.get_cagr(return_ts)

    def forecast_monte_carlo_cagr(self,
                                  distr: str = 'norm',
                                  years: int = 1,
                                  percentiles: List[int] = [10, 50, 90],
                                  n: int = 10000,
                                  ) -> pd.Series:
        """
        Calculate percentiles for forecasted CAGR distribution.
        CAGR is calculated for each of N future random returns time series.
        Random distribution could be normal or lognormal.
        """
        if distr not in ['norm', 'lognorm']:
            raise ValueError('distr should be "norm" (default) or "lognorm".')
        cagr_distr = self._get_monte_carlo_cagr_distribution(distr=distr, years=years, n=n)
        results = {}
        for percentile in percentiles:
            value = cagr_distr.quantile(percentile / 100)
            results.update({percentile: value})
        return results

    def forecast_wealth(self,
                        distr: str = 'norm',
                        years: int = 1,
                        percentiles: List[int] = [10, 50, 90],
                        today_value: Optional[int] = None,
                        n: int = 1000,
                        ) -> Dict[int, float]:
        """
        Calculate percentiles of forecasted random accumulated wealth distribution.
        Random distribution could be normal or lognormal.

        today_value - the value of portfolio today (before forecast period). If today_value is None
        the last value of the historical wealth indexes is taken.
        """
        if distr == 'hist':
            results = self.forecast_wealth_history(years=years, percentiles=percentiles).iloc[-1].to_dict()
        elif distr in ['norm', 'lognorm']:
            results = {}
            wealth_indexes = self.forecast_monte_carlo_wealth_indexes(distr=distr, years=years, n=n)
            for percentile in percentiles:
                value = wealth_indexes.iloc[-1, :].quantile(percentile / 100)
                results.update({percentile: value})
        else:
            raise ValueError('distr should be "norm", "lognorm" or "hist".')
        if today_value:
            modifier = today_value / self.wealth_index['portfolio'].values[-1]
            results.update((x, y * modifier)for x, y in results.items())
        return results

    def plot_forecast(self,
                      distr: str = 'norm',
                      years: int = 5,
                      percentiles: List[int] = [10, 50, 90],
                      today_value: Optional[int] = None,
                      n: int = 1000,
                      figsize: Optional[tuple] = None,
                      ):
        """
        Plots forecasted ranges of wealth indexes (lines) for a given set of percentiles.

        distr - the distribution model type:
            norm - normal distribution
            lognorm - lognormal distribution
            hist - percentiles are taken from historical data
        today_value - the value of portfolio today (before forecast period)
        n - number of random wealth time series used to calculate percentiles (not needed if distr='hist')
        """
        wealth = self.wealth_index
        x1 = self.last_date
        x2 = x1.replace(year=x1.year + years)
        y_start_value = wealth['portfolio'].iloc[-1]
        y_end_values = self.forecast_wealth(distr=distr,
                                            years=years,
                                            percentiles=percentiles,
                                            n=n)
        if today_value:
            modifier = today_value / y_start_value
            wealth *= modifier
            y_start_value = y_start_value * modifier
            y_end_values.update((x, y * modifier)for x, y in y_end_values.items())
        fig, ax = plt.subplots(figsize=figsize)
        ax.plot(wealth.index.to_timestamp(), wealth['portfolio'], linewidth=1, label='Historical data')
        for percentile in percentiles:
            x, y = [x1, x2], [y_start_value, y_end_values[percentile]]
            if percentile == 50:
                ax.plot(x, y, color='blue', linestyle='-', linewidth=2, label='Median')
            else:
                ax.plot(x, y, linestyle='dashed', linewidth=1, label=f'Percentile {percentile}')
        ax.legend(loc='upper left')
        return ax

    def plot_forecast_monte_carlo(self,
                                  distr: str = 'norm',
                                  years: int = 1,
                                  n: int = 20,
                                  figsize: Optional[tuple] = None,
                                  ):
        """
        Plots N random wealth indexes and historical wealth index.
        Forecasted indexes are generated accorded to a given distribution (Monte Carlo simulation).
        Normal and lognormal distributions could be used for Monte Carlo simulation.
        """
        s1 = self.wealth_index
        s2 = self.forecast_monte_carlo_wealth_indexes(distr=distr, years=years, n=n)
        s1['portfolio'].plot(legend=None, figsize=figsize)
        for n in s2:
            s2[n].plot(legend=None)

    # distributions
    @property
    def skewness(self):
        """
        Compute expanding skewness of the return time series.
        For normally distributed data, the skewness should be about zero.
        A skewness value greater than zero means that there is more weight in the right tail of the distribution.
        """
        return Frame.skewness(self.returns_ts)

    def skewness_rolling(self, window: int = 60):
        """
        Compute rolling skewness of the return time series.
        For normally distributed data, the skewness should be about zero.
        A skewness value greater than zero means that there is more weight in the right tail of the distribution.

        window - the rolling window size in months (default is 5 years).
        The window size should be at least 12 months.
        """
        return Frame.skewness_rolling(self.returns_ts, window=window)

    @property
    def kurtosis(self):
        """
        Calculate expanding Fisher (normalized) kurtosis time series for portfolio returns.
        Kurtosis is the fourth central moment divided by the square of the variance.
        Kurtosis should be close to zero for normal distribution.
        """
        return Frame.kurtosis(self.returns_ts)

    def kurtosis_rolling(self, window: int = 60):
        """
        Calculate rolling Fisher (normalized) kurtosis time series for portfolio returns.
        Kurtosis is the fourth central moment divided by the square of the variance.
        Kurtosis should be close to zero for normal distribution.

        window - the rolling window size in months (default is 5 years).
        The window size should be at least 12 months.
        """
        return Frame.kurtosis_rolling(self.returns_ts, window=window)

    @property
    def jarque_bera(self):
        """
        Performs Jarque-Bera test for normality.
        It shows whether the returns have the skewness and kurtosis matching a normal distribution.

        Returns:
            (The test statistic, The p-value for the hypothesis test)
            Low statistic numbers correspond to normal distribution.
        """
        return Frame.jarque_bera_series(self.returns_ts)

    def kstest(self, distr: str = 'norm') -> dict:
        """
        Performs Kolmogorov-Smirnov test on portfolio returns and evaluate goodness of fit.
        Test works with normal and lognormal distributions.

        Returns:
            (The test statistic, The p-value for the hypothesis test)
        """
        return Frame.kstest_series(self.returns_ts, distr=distr)

    def plot_percentiles_fit(self, distr: str = 'norm', figsize: Optional[tuple] = None):
        """
        Generates a probability plot of portfolio returns against percentiles of a specified
        theoretical distribution (the normal distribution by default).
        Works with normal and lognormal distributions.
        """
        plt.figure(figsize=figsize)
        if distr == 'norm':
            scipy.stats.probplot(self.returns_ts, dist=distr, plot=plt)
        elif distr == 'lognorm':
            scipy.stats.probplot(self.returns_ts, sparams=(scipy.stats.lognorm.fit(self.returns_ts)), dist=distr, plot=plt)
        else:
            raise ValueError('distr should be "norm" (default) or "lognorm".')
        plt.show()

    def plot_hist_fit(self, distr: str = 'norm', bins: int = None):
        """
        Plots historical distribution histogram and theoretical PDF (Probability Distribution Function).
        Lognormal and normal distributions could be used.
        """
        data = self.returns_ts
        # Plot the histogram
        plt.hist(data, bins=bins, density=True, alpha=0.6, color='g')
        # Plot the PDF.Probability Density Function
        xmin, xmax = plt.xlim()
        x = np.linspace(xmin, xmax, 100)
        if distr == 'norm':  # Generate PDF
            mu, std = scipy.stats.norm.fit(data)
            p = scipy.stats.norm.pdf(x, mu, std)
        elif distr == 'lognorm':
            std, loc, scale = scipy.stats.lognorm.fit(data)
            mu = np.log(scale)
            p = scipy.stats.lognorm.pdf(x, std, loc, scale)
        else:
            raise ValueError('distr should be "norm" (default) or "lognorm".')
        plt.plot(x, p, 'k', linewidth=2)
        title = "Fit results: mu = %.3f,  std = %.3f" % (mu, std)
        plt.title(title)
        plt.show()

