# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/optimizer.adan.ipynb.

# %% ../../nbs/optimizer.adan.ipynb 2
from __future__ import annotations
from typing import Optional, Dict

import numpy as np

from fastai.callback.core import Callback
from fastai.callback.schedule import combined_cos, ParamScheduler
from fastai.learner import Learner
from fastai.optimizer import Optimizer

from .foreach import ForEachOptimizer
from .torchscript import JitOptimizer
from ..imports import *

# %% auto 0
__all__ = ['Adan', 'adan', 'AdanLargeBatchLR']

# %% ../../nbs/optimizer.adan.ipynb 4
def debias(beta:float, step:int):
    "Simple debias calculation"
    return 1-beta**step

# %% ../../nbs/optimizer.adan.ipynb 6
def adan_setup(p:Tensor, step:int=0, grad_avg:Tensor|None=None, diff_avg:Tensor|None=None, 
               sqr_avg:Tensor|None=None, prior_grad:Tensor|None=None, paper_init:bool=False, **kwargs):
    "Handles Adan setup and keeps track of steps"
    if step == 0: 
        grad_avg = torch.zeros_like(p, memory_format=torch.preserve_format)
        diff_avg = torch.zeros_like(p, memory_format=torch.preserve_format)
        sqr_avg  = torch.zeros_like(p, memory_format=torch.preserve_format)
        if paper_init:
            prior_grad = p.grad.clone()
        else:
            prior_grad = torch.zeros_like(p, memory_format=torch.preserve_format)
        step += 1
        return {'grad_avg':grad_avg, 'diff_avg':diff_avg, 'sqr_avg':sqr_avg, 'prior_grad':prior_grad, 'step':step}
    else:
        step += 1
        return {'step':step}

# %% ../../nbs/optimizer.adan.ipynb 7
def adan_step(p:Tensor, lr:float, eps:float, wd:float, beta1:float, beta2:float, beta3:float, 
              step:int, grad_avg:Tensor, diff_avg:Tensor, sqr_avg:Tensor, prior_grad:Tensor, 
              do_wd:bool=True, **kwargs):
    "Updates Adan moving averages and performs the Adan step with `lr` on `p`"

    # difference between current and previous gradients
    grad_diff = torch.sub(p.grad.data, prior_grad)

    # update m_k
    grad_avg.mul_(beta1).add_(p.grad.data, alpha=1-beta1)

    # update v_k
    diff_avg.mul_(beta2).add_(grad_diff, alpha=1-beta2)
    
    # update n_k
    adjusted_grad = torch.add(p.grad.data, grad_diff, alpha=beta2)
    sqr_avg.mul_(beta3).addcmul_(adjusted_grad, adjusted_grad, value=1-beta3)

    # calculate debias terms
    db1, db2, db3 = debias(beta1, step), debias(beta2, step), debias(beta3, step)

    # calculate applied λ 
    wd = (1+lr*wd) if wd!=0 and do_wd else 1 

    # calculate η_k
    lr = lr/torch.sqrt(sqr_avg.div(db3)).add(eps) 

    # perform Adan step and apply to parameter `p`
    p.data.sub_(torch.add(grad_avg.div(db1), diff_avg.div(db2), alpha=beta2).mul_(lr)).div_(wd)

    # set current grad as next step's prior_grad
    prior_grad = p.grad.data.clone()
    return {'grad_avg':grad_avg, 'diff_avg':diff_avg, 'sqr_avg':sqr_avg, 'prior_grad':prior_grad}

adan_step.defaults = dict(beta1=0.98, beta2=0.92, beta3=0.99)

# %% ../../nbs/optimizer.adan.ipynb 9
@torch.jit.script
def adan_jit_step(p:Tensor, g:Tensor, lr:float, wd:float, beta1:float, beta2:float, beta3:float, eps:float,
                  paper_init:bool, grad_avg:Optional[Tensor]=None, diff_avg:Optional[Tensor]=None, 
                  sqr_avg:Optional[Tensor]=None, prior_grad:Optional[Tensor]=None, do_wd:bool=True, step:int=0, 
                  force_train:Optional[bool]=None, mom:Optional[float]=None, decouple_wd:bool=False):
    dp = p
    grad = g
    step += 1

    if grad_avg is None: 
        grad_avg = torch.zeros_like(dp, memory_format=torch.preserve_format)
    if diff_avg is None: 
        diff_avg = torch.zeros_like(dp, memory_format=torch.preserve_format)
    if sqr_avg is None: 
        sqr_avg = torch.zeros_like(dp, memory_format=torch.preserve_format)
    if prior_grad is None:
        if paper_init:
            prior_grad = grad.clone()
        else:
            prior_grad = torch.zeros_like(dp, memory_format=torch.preserve_format)

    # difference between current and previous gradients
    diff_grad = grad.sub(prior_grad)

    # update m_k
    grad_avg = grad_avg.mul(beta1).add(grad, alpha=1-beta1)

    # update v_k
    diff_avg = diff_avg.mul(beta2).add(diff_grad, alpha=1-beta2)
    
    # update n_k
    adjusted_grad = grad.add(diff_grad, alpha=beta2)
    sqr_avg = sqr_avg.mul(beta3).addcmul(adjusted_grad, adjusted_grad, value=1-beta3)

    # calculate debias terms
    db1 = debias(beta1, step)
    db2 = debias(beta2, step)
    db3 = debias(beta3, step)

    # calculate applied λ
    if wd!=0 and do_wd:
        wd = (1+lr*wd) 
    else:
        wd = 1. 

    # calculate η_k
    lr = lr/torch.sqrt(sqr_avg.div(db3)).add(eps)

    # perform Adan step
    dp = dp.sub(torch.add(grad_avg.div(db1), diff_avg.div(db2), alpha=beta2).mul(lr)).div(wd)

    # set current grad as next step's prior_grad
    prior_grad = grad.clone()

    # apply results to parameter p
    p.set_(dp)
    g.set_(grad)

    return torch.jit.annotate(Dict[str, Union[Tensor, int]], {'grad_avg':grad_avg, 'diff_avg':diff_avg, 'sqr_avg':sqr_avg, 'prior_grad':prior_grad, 'step':step})

# %% ../../nbs/optimizer.adan.ipynb 11
def adan_foreach_step(p:Tensor, grad:Tensor, grad_avg:list[Tensor], diff_avg:list[Tensor], sqr_avg:list[Tensor], 
                      prior_grad:list[Tensor], steps:np.ndarray[Any, int], do_wd:np.ndarray[Any, bool], lr:float, 
                      wd:float, beta1:float, beta2:float, beta3:float, eps:float, **kwargs):

    # difference between current and previous gradients
    grad_diff = torch._foreach_sub(grad, prior_grad)

    # update m_k
    torch._foreach_mul_(grad_avg, beta1)
    torch._foreach_add_(grad_avg, grad, alpha=1-beta1)

    # update v_k
    torch._foreach_mul_(diff_avg, beta2)
    torch._foreach_add_(diff_avg, grad_diff, alpha=1-beta2)
    
    # update n_k
    adjusted_grad = torch._foreach_add(grad, grad_diff, alpha=beta2)
    torch._foreach_mul_(sqr_avg, beta3)
    torch._foreach_addcmul_(sqr_avg, adjusted_grad, adjusted_grad, value=1-beta3)

    # calculate debias terms
    db1 = 1 - beta1**steps
    db2 = 1 - beta2**steps
    db3 = 1 - beta3**steps

    # calculate η_k
    db3 = torch._foreach_div(sqr_avg, scalars=db3.tolist())
    torch._foreach_sqrt_(db3)
    torch._foreach_add_(db3, eps)
    p_lrs = torch._foreach_div(db3, lr)
    torch._foreach_reciprocal_(p_lrs)
    # currently foreach_div doesn't allow a scalar as the first arg

    # perform Adan step
    db1 = torch._foreach_div(grad_avg, scalars=db1.tolist())
    db2 = torch._foreach_div(diff_avg, scalars=db2.tolist())
    torch._foreach_sub_(p, torch._foreach_mul(torch._foreach_add(db1, db2, alpha=beta2), p_lrs))
    
    # calculate and apply λ
    if wd != 0:
        wd = np.where(do_wd, 1+lr*wd, 1.)
        torch._foreach_div_(p, scalars=wd.tolist())

    # set current grad as next step's prior_grad, currently no foreach_set method
    [pg.set_(gd.clone()) for pg, gd in zip(prior_grad, grad)]

# %% ../../nbs/optimizer.adan.ipynb 12
class AdanForEachOptimizer(ForEachOptimizer):
    "An `Optimizer` with a modified step for Adan ForEach"
    def __init__(self,
        params:listified[Tensor], # Model parameters
        opt_step:Callable, # `ForEachOptimizer` optimizer step
        paper_init:bool=False, # Initialize first prior_grad to grad following paper or zeros
        **defaults # Optimizer specific hyper parameters
    ):
        super().__init__(params, opt_step, **defaults)
        self.paper_init = paper_init

    @torch.no_grad()
    def step(self, closure=None):
        if closure is not None: raise NotImplementedError("fastai optimizers currently do not support closure")
        for pg, hyper in zip(self.param_lists, self.hypers):
            pl, gl, grad_avg, diff_avg, sqr_avg, prior_grad, steps, do_wd = [], [], [], [], [], [], [], []

            for p in pg:
                if hasattr(p, 'grad') and p.grad is not None:
                    state = self.state[p]

                    if 'step' not in state:
                        state['grad_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                        state['diff_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                        state['sqr_avg']  = torch.zeros_like(p, memory_format=torch.preserve_format)
                        if self.paper_init:
                            state['prior_grad'] = p.grad.clone()
                        else:
                            state['prior_grad'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                        state['step'] = 0

                    state['step'] += 1
                    pl.append(p)
                    gl.append(p.grad)
                    grad_avg.append(state['grad_avg'])
                    diff_avg.append(state['diff_avg'])
                    sqr_avg.append(state['sqr_avg'])
                    prior_grad.append(state['prior_grad'])
                    do_wd.append(state.get('do_wd', True))
                    steps.append(state['step'])

            self.opt_step(p=pl, grad=gl, grad_avg=grad_avg, diff_avg=diff_avg, sqr_avg=sqr_avg, 
                          prior_grad=prior_grad, steps=np.array(steps, dtype=np.int32), do_wd=np.array(do_wd, dtype=bool), **hyper)

# %% ../../nbs/optimizer.adan.ipynb 15
def Adan(
    params:listified[Tensor], # Model parameters or parameter groups
    lr:float, # Default learning rate
    beta1:float=0.98, # Gradient moving average (β1) coefficient
    beta2:float=0.92, # Gradient difference moving average (β2) coefficient
    beta3:float=0.99, # Gradient squared moving average (β3) coefficient
    eps:float=1e-8, # Added for numerical stability
    wd:float=0.02, # True weight decay
    paper_init:bool=False, # Initialize prior gradient with current gradient per paper, or zeroes
    foreach:bool=False, # Use fused ForEach implementation
    jit:bool=False # Use fused TorchScript implementation
) -> Optimizer|AdanForEachOptimizer|JitOptimizer:
    "A fastai Adan optimizer with optional ForEach and TorchScript implementations"
    if foreach:
        return AdanForEachOptimizer(params, adan_foreach_step, lr=lr, beta1=beta1, beta2=beta2, 
                                    beta3=beta3, eps=eps, wd=wd, paper_init=paper_init)
    elif jit:
        cb = partial(adan_jit_step, paper_init=paper_init)
        return JitOptimizer(params, cb, lr=lr, beta1=beta1, beta2=beta2, beta3=beta3, eps=eps, wd=wd)
    else:
        cbs = [partial(adan_setup, paper_init=paper_init), adan_step]
        return Optimizer(params, cbs, lr=lr, beta1=beta1, beta2=beta2, beta3=beta3, eps=eps, wd=wd)

# %% ../../nbs/optimizer.adan.ipynb 16
def adan(
    beta1:float=0.98, # Gradient moving average (β1) coefficient
    beta2:float=0.92, # Gradient difference moving average (β2) coefficient
    beta3:float=0.99, # Gradient squared moving average (β3) coefficient
    eps:float=1e-8, # Added for numerical stability
    wd:float=0.02, # True weight decay
    paper_init:bool=False, # Initialize prior gradient with current gradient per paper, or zeroes
    foreach:bool=False, # Use fused ForEach implementation
    jit:bool=False # Use fused TorchScript implementation
) -> Optimizer|AdanForEachOptimizer|JitOptimizer:
    "Partial function for the Adan optimizer with fused ForEach and TorchScript implementations"
    return partialler(Adan, beta1=beta1, beta2=beta2, beta3=beta3, eps=eps, wd=wd, 
                      paper_init=paper_init, foreach=foreach, jit=jit)

# %% ../../nbs/optimizer.adan.ipynb 19
def AdanLargeBatchLR(bs:int) -> float:
    "Square root rule for scaling `Adan` learning rate for large-batch training"
    return math.sqrt(bs/256)*6.25e-3

# %% ../../nbs/optimizer.adan.ipynb 21
@patch
def fit_adan_cycle(self:Learner, 
    n_epoch:int, # Number of epochs
    lr_max:float|None=None, # Maximum learning rate
    div:Number=25., # Initial learning rate: `lr_max/div`
    div_final:Number=1e5, # Final learning rate: `lr_max/div_final`
    pct_start=0.25, # Finish learning rate warmup and start cosine annealing
    wd:float|None=None, # Weight decay, defaults to `Optimizer` weight decay
    beta1s:tuple[float,float,float]|None=None,
    beta3s:tuple[float,float,float]|None=None, 
    cbs:listified[Callback]|None=None, # Temporary Callbacks to apply during fit
    reset_opt:bool=False # Reset `Optimizer` before fit
):
    "Fit `self.model` for `n_epoch` using the 1cycle policy with `Adan` hyperparams."
    if self.opt is None: self.create_opt()
    self.opt.set_hyper('lr', self.lr if lr_max is None else lr_max)
    lr_max = np.array([h['lr'] for h in self.opt.hypers])
    scheds = {'lr': combined_cos(pct_start, lr_max/div, lr_max, lr_max/div_final)}
    if beta1s is not None:
        scheds['beta1'] = combined_cos(pct_start, *beta1s)
    if beta3s is not None:
        scheds['beta3'] = combined_cos(pct_start, *beta3s)
    self.fit(n_epoch, cbs=ParamScheduler(scheds)+L(cbs), reset_opt=reset_opt, wd=wd, start_epoch=0)
