# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/03-charging.ipynb (unless otherwise specified).

__all__ = ['estimate_daily_solar_quantiles', 'extract_solar_profile', 'charge_profile_greedy', 'topup_charge_naive',
           'optimal_charge_profile', 'construct_charge_profile', 'construct_charge_s', 'charge_is_valid',
           'construct_df_charge_features', 'extract_charging_datetimes', 'prepare_training_input_data',
           'normalise_total_charge', 'clip_charge_rate', 'post_pred_charge_proc_func', 'score_charging',
           'max_available_solar', 'prop_max_solar', 'construct_solar_exploit_calculator', 'fit_and_save_charging_model',
           'prepare_test_feature_data', 'optimise_test_charge_profile']

# Cell
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
import seaborn as sns
import joblib

from moepy.lowess import quantile_model

from sklearn.pipeline import Pipeline
from sklearn.linear_model import LinearRegression
from sklearn.metrics import make_scorer, r2_score, mean_absolute_error, mean_squared_error
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor

from skopt.plots import plot_objective
from skopt.space import Real, Categorical, Integer

from batopt import clean, discharge, utils

import FEAutils as hlp

# Cell
def estimate_daily_solar_quantiles(x, y, x_pred = np.linspace(0, 23.5, 100), **model_kwargs):
    # Fitting the model
    df_quantiles = quantile_model(x, y, x_pred=x_pred, **model_kwargs)

    # Cleaning names and sorting for plotting
    df_quantiles.columns = [f'p{int(col*100)}' for col in df_quantiles.columns]
    df_quantiles = df_quantiles[df_quantiles.columns[::-1]]

    return df_quantiles

# Cell
def extract_solar_profile(s_solar_sample_dt, start_time='00:00', end_time='15:00'):
    dt = str(s_solar_sample_dt.index[0].date())
    solar_profile = s_solar_sample_dt[f'{dt} {start_time}':f'{dt} {end_time}'].values

    return solar_profile

def charge_profile_greedy(solar_profile, capacity=6, initial_charge=0, max_charge_rate=2.5, time_unit=0.5):
    order = np.flip(np.argsort(solar_profile))
    charge = initial_charge
    solution = np.zeros(len(solar_profile))
    for i in order:
        solar_available = np.minimum(solar_profile[i], max_charge_rate)
        solar_available = min(solar_available, (capacity - charge)/time_unit)
        solution[i] = solar_available
        charge = np.sum(solution)*time_unit
        if charge > capacity:
            break
    return solution

def topup_charge_naive(charge_profile, capacity=6, time_unit=0.5, period_start=16, period_end=30):
    charge = np.sum(charge_profile)*time_unit
    spare_cap = capacity - charge
    topup_value = spare_cap/((period_end-period_start)*time_unit)
    new_profile = np.copy(charge_profile)
    new_profile[period_start:period_end] += topup_value # Add topup_value uniformly between start and end periods
    return new_profile

def optimal_charge_profile(solar_profile, capacity=6, time_unit=0.5, max_charge_rate=2.5):
    solution = charge_profile_greedy(solar_profile)
    solution = topup_charge_naive(solution)
    assert np.isclose(np.sum(solution), capacity/time_unit), "Does not meet capacity constraint".format(np.sum(solution))
    assert np.all(solution <= max_charge_rate), "Does not meet max charge rate constraint. Max is {}".format(np.max(solution))
    return solution

# Cell
construct_charge_profile = lambda solar_profile, adj_solar_profile: solar_profile - adj_solar_profile

# Cell
def construct_charge_s(s_pv, start_time='00:00', end_time='15:00'):
    s_charge = pd.Series(index=s_pv.index, dtype=float).fillna(0)

    for dt in s_pv.index.strftime('%Y-%m-%d').unique():
        solar_profile = s_pv[dt].pipe(extract_solar_profile, start_time=start_time, end_time=end_time)
        adj_solar_profile = discharge.flatten_peak(solar_profile)

        charge_profile = construct_charge_profile(solar_profile, adj_solar_profile)

        s_charge[f'{dt} {start_time}':f'{dt} {end_time}'] = charge_profile

    return s_charge

def charge_is_valid(charge_profile, capacity=6, max_charge_rate=2.5, time_unit=0.5):
    """
    Function determining if a charge profile is valid (and fully charges the battery)
    """
    if np.all(np.isclose(capacity/time_unit, charge_profile.groupby(charge_profile.index.date).sum())) is False:
        return False
    elif np.all(charge_profile.groupby(charge_profile.index.date).max() <= max_charge_rate) is False:
        return False
    else:
        return True


# Cell
def construct_df_charge_features(df, dt_rng=None):
    if dt_rng is None:
        dt_rng = pd.date_range(df.index.min(), df.index.max(), freq='30T')

    df_features = pd.DataFrame(index=dt_rng)

    # Filtering for the temperature weather data
    temp_loc_cols = df.columns[df.columns.str.contains('temp_location')]
    df_features.loc[df.index, temp_loc_cols] = df[temp_loc_cols].copy()
    df_features = df_features.ffill(limit=1)

    # Adding lagged solar
    df_features['pv_7d_lag'] = df['pv_power_mw'].shift(48*7)

    # Adding solar irradiance data
    solar_loc_cols = df.columns[df.columns.str.contains('solar_location')]
    df_features.loc[df.index, solar_loc_cols] = df[solar_loc_cols].copy()
    df_features = df_features.ffill(limit=1)

    # Adding datetime features
    dts = df_features.index.tz_convert('Europe/London') # We want to use the 'behavioural' timezone

    df_features['weekend'] = dts.dayofweek.isin([5, 6]).astype(int)
    df_features['dow'] = dts.dayofweek

    hour = dts.hour + dts.minute/60
    df_features['sin_hour'] = np.sin(2*np.pi*hour/24)
    df_features['cos_hour'] = np.cos(2*np.pi*hour/24)

    df_features['sin_doy'] = np.sin(2*np.pi*dts.dayofyear/365)
    df_features['cos_doy'] = np.cos(2*np.pi*dts.dayofyear/365)

    # Removing some extraneous features
    cols = [c for c in df_features.columns if 'solar_location4' not in c and 'solar_location1' not in c]
    df_features = df_features.filter(cols)

    # Add rolling solar
    solar_cols = [c for c in df_features.columns if 'solar_location' in c]
    df_features[[col+'_rolling' for col in solar_cols]] = df_features.rolling(3).mean()[solar_cols]

    # Add rolling temp
    temp_cols = [c for c in df_features.columns if 'temp_location' in c]
    df_features[[col+'_rolling' for col in temp_cols]] = df_features.rolling(3).mean()[temp_cols]

    # Removing NaN values
    df_features = df_features.dropna()

    return df_features

#exports
def extract_charging_datetimes(df, start_hour=4, end_hour=15):
    hour = df.index.hour + df.index.minute/60
    charging_datetimes = df.index[(hour>=start_hour) & (hour<=end_hour)]

    return charging_datetimes

# Cell
def prepare_training_input_data(intermediate_data_dir, start_hour=4):
    # Loading input data
    df = clean.combine_training_datasets(intermediate_data_dir).interpolate(limit=1)
    df_features = construct_df_charge_features(df)

    # Filtering for overlapping feature and target data
    dt_idx = pd.date_range(df_features.index.min(), df['pv_power_mw'].dropna().index.max()-pd.Timedelta(minutes=30), freq='30T')

    s_pv = df.loc[dt_idx, 'pv_power_mw']
    print(s_pv)
    df_features = df_features.loc[dt_idx]

    # Constructing the charge series
    s_charge = construct_charge_s(s_pv, start_time=f'0{start_hour}:00', end_time='15:00')

    # Filtering for evening datetimes
    charging_datetimes = extract_charging_datetimes(df_features, start_hour=start_hour)

    X = df_features.loc[charging_datetimes]
    y = s_charge.loc[charging_datetimes]

    return X, y

# Cell
def normalise_total_charge(s_pred, charge=6., time_unit=0.5):
    s_daily_charge = s_pred.groupby(s_pred.index.date).sum()

    for date, total_charge in s_daily_charge.items():
        with np.errstate(divide='ignore', invalid='ignore'):
            s_pred.loc[str(date)] *= charge/(time_unit*total_charge)

    return s_pred

clip_charge_rate = lambda s_pred, max_rate=2.5, min_rate=0: s_pred.clip(lower=min_rate, upper=max_rate)

post_pred_charge_proc_func = lambda s_pred: (s_pred
                                      .pipe(clip_charge_rate)
                                      .pipe(normalise_total_charge)
                                     )


# Cell
def score_charging(schedule, solar_profile):
    # The actual pv charge is the minimum of the scheduled charge and the actual solar availability
    actual_pv_charge = np.minimum(schedule, solar_profile)
    score = np.sum(actual_pv_charge)/np.sum(schedule)
    return score

# Cell
def max_available_solar(solar_profile, max_charge_rate=2.5, capacity_mwh=6, time_unit=0.5):
    """
    Return the solar PV potential available to the battery.

    That is, the total PV potential with a daily cap of 6 MWh.
    """
    available = solar_profile.clip(0,2.5).groupby(solar_profile.index.date).sum() * time_unit
    clipped = np.clip(available.values, 0, capacity_mwh)
    total = np.sum(clipped)
    return total

# Cell
def prop_max_solar(schedule, solar_profile, time_unit=0.5):
    """
    Get the proportion of maximum solar exploitation for charging schedule, given a solar PV profile
    """
    actual_pv_charge = np.sum(np.minimum(schedule, solar_profile)*time_unit)
    max_pv_charge = max_available_solar(solar_profile)
    return actual_pv_charge/max_pv_charge

def construct_solar_exploit_calculator(solar_profile, charging_datetimes=None, scorer=False):
    if charging_datetimes is None:
        charging_datetimes = extract_charging_datetimes(solar_profile)

    def calc_solar_exploitation(y, y_pred):
        # Checking evening datetimes
        if hasattr(y_pred, 'index') == True:
            charging_datetimes = extract_charging_datetimes(y_pred)

        assert y_pred.shape[0] == solar_profile.loc[charging_datetimes].shape[0], f'The prediction series must be the same length as the number of evening datetimes in the main dataframe, {y_pred.shape[0]} {s_demand.loc[evening_datetimes].shape[0]}'

        exploitation_pct = 100 * prop_max_solar(y_pred, solar_profile.loc[charging_datetimes])

        return exploitation_pct

    if scorer == True:
        return make_scorer(calc_solar_exploitation)
    else:
        return calc_solar_exploitation

# Cell
def fit_and_save_charging_model(X, y, charge_opt_model_fp, model_class=RandomForestRegressor, **model_params):
    model = model_class(**model_params)
    model.fit(X, y)

    with open(charge_opt_model_fp, 'wb') as fp:
        joblib.dump(model, fp)

    return

# Cell
def prepare_test_feature_data(raw_data_dir, intermediate_data_dir, test_start_date=None, test_end_date=None, start_time='08:00', end_time='23:59'):
    # Loading input data
    df_features = (clean
                   .combine_training_datasets(intermediate_data_dir)
                   .interpolate(limit=1)
                   .pipe(construct_df_charge_features)
                  )

    # Loading default index (latest submission)
    if test_end_date is None or test_start_date is None:
        index = discharge.load_latest_submission_template(raw_data_dir).index
    else:
        index = df_features[test_start_date:test_end_date].index

    # Filtering feature data on submission datetimes
    df_features = df_features.loc[index].between_time(start_time, end_time)

    return df_features

# Cell
def optimise_test_charge_profile(raw_data_dir, intermediate_data_dir, charge_opt_model_fp, test_start_date=None, test_end_date=None, start_time='08:00', end_time='23:59'):
    df_features = prepare_test_feature_data(raw_data_dir, intermediate_data_dir, test_start_date=test_start_date, test_end_date=test_end_date, start_time=start_time, end_time=end_time)
    charging_datetimes = extract_charging_datetimes(df_features)
    X_test = df_features.loc[charging_datetimes].values

    model = discharge.load_trained_model(charge_opt_model_fp)
    charge_profile = model.predict(X_test)

    s_charge_profile = pd.Series(charge_profile, index=charging_datetimes)
    s_charge_profile = s_charge_profile.reindex(df_features.index).fillna(0)
    s_charge_profile = post_pred_charge_proc_func(s_charge_profile)

    assert charge_is_valid(s_charge_profile), "Charging profile is invalid"

    return s_charge_profile