# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/04-discharging.ipynb (unless otherwise specified).

__all__ = ['construct_df_discharge_features', 'estimate_daily_demand_quantiles', 'sample_random_day',
           'sample_random_days', 'reset_idx_dt', 'extract_evening_demand_profile', 'flatten_peak',
           'construct_discharge_profile', 'construct_discharge_s', 'extract_evening_datetimes',
           'normalise_total_discharge', 'clip_discharge_rate', 'post_pred_discharge_proc_func',
           'construct_peak_reduction_calculator', 'evaluate_discharge_models', 'prepare_training_input_data',
           'fit_and_save_model', 'load_trained_model', 'load_latest_submission_template', 'prepare_test_feature_data',
           'optimise_test_discharge_profile']

# Cell
import numpy as np
import pandas as pd

import seaborn as sns
import matplotlib.pyplot as plt

from sklearn.pipeline import Pipeline
from sklearn.model_selection import KFold
from sklearn.metrics import make_scorer, r2_score, mean_absolute_error, mean_squared_error
from sklearn.linear_model import LinearRegression
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor

from statsmodels.tsa.stattools import acf
from moepy.lowess import Lowess, quantile_model

from batopt import clean, utils

import os
import random
import joblib
from ipypb import track
import FEAutils as hlp

# Cell
def construct_df_discharge_features(df, dt_rng=None):
    if dt_rng is None:
        dt_rng = pd.date_range(df.index.min(), df.index.max(), freq='30T')

    df_features = pd.DataFrame(index=dt_rng)

    # Filtering for the temperature weather data
    temp_loc_cols = df.columns[df.columns.str.contains('temp_location')]
    df_features.loc[df.index, temp_loc_cols] = df[temp_loc_cols].copy()
    df_features = df_features.ffill(limit=1)

    df_features['spatial_avg_temp'] = df_features.mean(axis=1) # Should look into excluding temp_location5 and temp_location6
    df_features['daily_avg_temp'] = pd.Series(df_features.index.date, index=df_features.index).map(df_features['spatial_avg_temp'].groupby(df_features.index.date).mean().to_dict())

    # Adding lagged demand
    df_features['SP_demand_7d_lag'] = df['demand_MW'].shift(48*7)

    s_evening_demand = df['demand_MW'].between_time('15:30', '21:00')
    dt_to_lagged_evening_avg = s_evening_demand.groupby(s_evening_demand.index.date).mean().shift(7).to_dict()
    dt_to_lagged_evening_max = s_evening_demand.groupby(s_evening_demand.index.date).max().shift(7).to_dict()
    df_features['evening_demand_avg_7d_lag'] = pd.Series(df_features.index.date, index=df_features.index).map(dt_to_lagged_evening_avg)
    df_features['evening_demand_max_7d_lag'] = pd.Series(df_features.index.date, index=df_features.index).map(dt_to_lagged_evening_max)

    # Adding datetime features
    dts = df_features.index.tz_convert('Europe/London') # We want to use the 'behavioural' timezone

    df_features['weekend'] = dts.dayofweek.isin([5, 6]).astype(int)
    df_features['hour'] = dts.hour + dts.minute/60
    df_features['doy'] = dts.dayofyear
    df_features['dow'] = dts.dayofweek

    # Removing NaN values
    df_features = df_features.dropna()

    return df_features

# Cell
def estimate_daily_demand_quantiles(x, y, x_pred = np.linspace(0, 23.5, 100), **model_kwargs):
    # Fitting the model
    df_quantiles = quantile_model(x, y, x_pred=x_pred, **model_kwargs)

    # Cleaning names and sorting for plotting
    df_quantiles.columns = [f'p{int(col*100)}' for col in df_quantiles.columns]
    df_quantiles = df_quantiles[df_quantiles.columns[::-1]]

    return df_quantiles

# Cell
reset_idx_dt = lambda s, dt='2020-01-01': s.index - (s.index[0]-pd.to_datetime(dt, utc=True))

def sample_random_day(s):
    random_dt = random.choice(s.index.date)
    s_sample_dt = s.loc[str(random_dt)]

    return s_sample_dt

def sample_random_days(s, num_days=5):
    df_sample_dts = pd.DataFrame()

    for _ in range(num_days):
        s_sample_dt = sample_random_day(s)
        dt = str(s_sample_dt.index[0].date())
        s_sample_dt.index = reset_idx_dt(s_sample_dt)
        df_sample_dts[dt] = s_sample_dt

    df_sample_dts = df_sample_dts.sort_index(axis=1)

    return df_sample_dts

# Cell
def extract_evening_demand_profile(s_demand_sample_dt, start_time='15:30', end_time='20:30'):
    dt = str(s_demand_sample_dt.index[0].date())
    evening_demand_profile = s_demand_sample_dt[f'{dt} {start_time}':f'{dt} {end_time}'].values

    return evening_demand_profile

# Cell
def flatten_peak(evening_demand_profile, charge=6, time_unit=0.5):
    peak = max(evening_demand_profile)
    adj_evening_demand_profile = evening_demand_profile.copy()

    while charge > 0:
        num_periods_plateaued = (evening_demand_profile>=peak).sum()

        # If the evening demand profile has been fully flattened
        # then split up the remaining charge equally across all SPs
        fully_flattened = len(set(adj_evening_demand_profile)) == 1

        if fully_flattened == True:
            remaining_discharge_rate_for_each_SP = (1/time_unit)*charge/len(adj_evening_demand_profile)
            adj_evening_demand_profile -= remaining_discharge_rate_for_each_SP
            charge = 0
            break

        # If there is still a peak then determine the next highest value
        else:
            peak = max(adj_evening_demand_profile)
            highest_non_peak = max(adj_evening_demand_profile[peak>adj_evening_demand_profile])

            proposed_additional_discharge = time_unit*(adj_evening_demand_profile.sum() - np.minimum(adj_evening_demand_profile, highest_non_peak).sum())

        # if its possible to reduce the peak to the next highest value do so
        if charge >= proposed_additional_discharge:
            adj_evening_demand_profile = np.minimum(adj_evening_demand_profile, highest_non_peak)
            charge -= proposed_additional_discharge

        # If the capacity constraints are broken when reducing to the next
        # highest value then just lower the current peak as far as possible
        else:
            new_peak = peak - ((1/time_unit)*charge/(num_periods_plateaued+1))
            adj_evening_demand_profile = np.minimum(adj_evening_demand_profile, new_peak)
            charge = 0

    return adj_evening_demand_profile

# Cell
construct_discharge_profile = lambda evening_demand_profile, adj_evening_demand_profile: -(evening_demand_profile - adj_evening_demand_profile)

# Cell
def construct_discharge_s(s_demand, start_time='15:30', end_time='20:30'):
    s_discharge = pd.Series(index=s_demand.index, dtype=float).fillna(0)

    for dt in s_demand.index.strftime('%Y-%m-%d').unique():
        evening_demand_profile = s_demand[dt].pipe(extract_evening_demand_profile)
        adj_evening_demand_profile = flatten_peak(evening_demand_profile)

        discharge_profile = construct_discharge_profile(evening_demand_profile, adj_evening_demand_profile)
        s_discharge[f'{dt} {start_time}':f'{dt} {end_time}'] = discharge_profile

    return s_discharge

# Cell
def extract_evening_datetimes(df):
    hour = df.index.hour + df.index.minute/60
    evening_datetimes = df.index[(20.5>=hour) & (15.5<=hour)]

    return evening_datetimes

# Cell
def normalise_total_discharge(s_pred, charge=6, time_unit=0.5):
    s_daily_discharge = s_pred.groupby(s_pred.index.date).sum()

    for date, total_discharge in s_daily_discharge.items():
        s_pred.loc[str(date)] *= -charge/(time_unit*total_discharge)

    return s_pred

# Cell
clip_discharge_rate = lambda s_pred, max_rate=-2.5, min_rate=0: s_pred.clip(lower=max_rate, upper=min_rate)

# Cell
post_pred_discharge_proc_func = lambda s_pred: (s_pred
                                                .pipe(clip_discharge_rate)
                                                .pipe(normalise_total_discharge)
                                               )

# Cell
def construct_peak_reduction_calculator(s_demand, evening_datetimes=None, scorer=False):
    if evening_datetimes is None:
        evening_datetimes = extract_evening_datetimes(s_demand)

    def calc_peak_reduction(y, y_pred):
        # Checking evening datetimes
        if hasattr(y_pred, 'index') == True:
            evening_datetimes = extract_evening_datetimes(y_pred)

        assert y_pred.shape[0] == s_demand.loc[evening_datetimes].shape[0], f'The prediction series must be the same length as the number of evening datetimes in the main dataframe, {y_pred.shape[0]} {s_demand.loc[evening_datetimes].shape[0]}'

        # Post-processing the discharge profile to handle constraints
        y_pred = post_pred_discharge_proc_func(y_pred)

        # Identifying daily peaks
        s_old_peaks = s_demand.loc[evening_datetimes].groupby(evening_datetimes.date).max()
        s_new_peaks = (s_demand.loc[evening_datetimes]+y_pred).groupby(evening_datetimes.date).max()
        s_optimal_peaks = (s_demand.loc[evening_datetimes]+y).groupby(evening_datetimes.date).max()

        # Calculating the peak reduction
        s_new_pct_peak_reduction = 100*(s_old_peaks-s_new_peaks)/s_old_peaks
        s_optimal_pct_peak_reduction = 100*(s_old_peaks-s_optimal_peaks)/s_old_peaks

        # after cleaning anomalous demand data should add an assert to check for non finite values

        pct_of_max_possible_reduction = 100*(s_new_pct_peak_reduction.replace(np.inf, np.nan).dropna().mean()/
                                             s_optimal_pct_peak_reduction.replace(np.inf, np.nan).dropna().mean())

        return pct_of_max_possible_reduction

    if scorer == True:
        return make_scorer(calc_peak_reduction)
    else:
        return calc_peak_reduction

def evaluate_discharge_models(df, models, features_kwargs={}):
    df_features = construct_df_discharge_features(df, **features_kwargs)
    s_discharge = construct_discharge_s(df['demand_MW'], start_time='15:30', end_time='20:30')

    evening_datetimes = extract_evening_datetimes(df_features)

    X = df_features.loc[evening_datetimes].values
    y = s_discharge.loc[evening_datetimes].values

    model_scores = dict()
    peak_reduction_calc = construct_peak_reduction_calculator(s_demand=df['demand_MW'], evening_datetimes=evening_datetimes)

    for model_name, model in track(models.items()):
        df_pred = clean.generate_kfold_preds(X, y, model, index=evening_datetimes)
        df_pred['pred'] = post_pred_discharge_proc_func(df_pred['pred'])

        model_scores[model_name] = {
            'pct_optimal_reduction': peak_reduction_calc(df_pred['true'], df_pred['pred']),
            'optimal_discharge_mae': mean_absolute_error(df_pred['true'], df_pred['pred']),
            'optimal_discharge_rmse': np.sqrt(mean_squared_error(df_pred['true'], df_pred['pred']))
        }

    df_model_scores = pd.DataFrame(model_scores)

    df_model_scores.index.name = 'metric'
    df_model_scores.columns.name = 'model'

    return df_model_scores

# Cell
def prepare_training_input_data(intermediate_data_dir):
    # Loading input data
    df = clean.combine_training_datasets(intermediate_data_dir).interpolate(limit=1)
    df_features = construct_df_discharge_features(df)

    # Filtering for overlapping feature and target data
    dt_idx = pd.date_range(df_features.index.min(), df['demand_MW'].dropna().index.max()-pd.Timedelta(minutes=30), freq='30T')

    s_demand = df.loc[dt_idx, 'demand_MW']
    df_features = df_features.loc[dt_idx]

    # Constructing the discharge series
    s_discharge = construct_discharge_s(s_demand, start_time='15:30', end_time='20:30')

    # Filtering for evening datetimes
    evening_datetimes = extract_evening_datetimes(df_features)

    X = df_features.loc[evening_datetimes]
    y = s_discharge.loc[evening_datetimes]

    return X, y

# Cell
def fit_and_save_model(X, y, discharge_opt_model_fp, model_class=RandomForestRegressor, **model_params):
    model = model_class(**model_params)
    model.fit(X, y)

    with open(discharge_opt_model_fp, 'wb') as fp:
        joblib.dump(model, fp)

    return

# Cell
def load_trained_model(discharge_opt_model_fp):
    with open(discharge_opt_model_fp, 'rb') as fp:
        model = joblib.load(fp)

    return model

# Cell
def load_latest_submission_template(raw_data_dir, latest_submission_template_name=None):
    if latest_submission_template_name is None:
        latest_submission_template_name = max([filename for filename in os.listdir(raw_data_dir) if 'teamname_set' in filename])

    df_submission_template = pd.read_csv(f'{raw_data_dir}/{latest_submission_template_name}')

    df_submission_template['datetime'] = pd.to_datetime(df_submission_template['datetime'], utc=True)
    df_submission_template = df_submission_template.set_index('datetime')

    return df_submission_template

def prepare_test_feature_data(raw_data_dir, intermediate_data_dir, test_start_date=None, test_end_date=None):
    # Loading input data
    df_features = (clean
                   .combine_training_datasets(intermediate_data_dir)
                   .interpolate(limit=1)
                   .pipe(construct_df_discharge_features)
                  )

    # Loading default index (latest submission)
    if test_end_date is None or test_start_date is None:
        index = load_latest_submission_template(raw_data_dir).index
    else:
        index = df_features[test_start_date:test_end_date].index

    # Filtering feature data on submission datetimes
    df_features = df_features.loc[index]

    return df_features

# Cell
def optimise_test_discharge_profile(raw_data_dir, intermediate_data_dir, discharge_opt_model_fp, test_start_date=None, test_end_date=None):
    df_features = prepare_test_feature_data(raw_data_dir, intermediate_data_dir, test_start_date=test_start_date, test_end_date=test_end_date)
    evening_datetimes = extract_evening_datetimes(df_features)
    X_test = df_features.loc[evening_datetimes].values

    model = load_trained_model(discharge_opt_model_fp)
    discharge_profile = model.predict(X_test)

    s_discharge_profile = pd.Series(discharge_profile, index=evening_datetimes)
    s_discharge_profile = s_discharge_profile.reindex(df_features.index).fillna(0)
    s_discharge_profile = post_pred_discharge_proc_func(s_discharge_profile)

    return s_discharge_profile