/*
 * Decompiled with CFR 0.152.
 */
package jsat.math.optimization.stochastic;

import jsat.linear.DenseVector;
import jsat.linear.IndexValue;
import jsat.linear.ScaledVector;
import jsat.linear.Vec;
import jsat.math.optimization.stochastic.GradientUpdater;

public class AdaDelta
implements GradientUpdater {
    private static final long serialVersionUID = 5855631993426837618L;
    private double rho;
    private Vec gSqrd;
    private Vec deltaXSqrt;
    private double biasGSqrd;
    private double deltaBiasSqrt;
    private double eps = 1.0E-4;

    public AdaDelta() {
        this(0.95);
    }

    public AdaDelta(double rho) {
        this.setRho(rho);
    }

    public AdaDelta(AdaDelta toCopy) {
        this.rho = toCopy.rho;
        if (toCopy.gSqrd != null) {
            this.gSqrd = toCopy.gSqrd.clone();
            this.deltaXSqrt = toCopy.deltaXSqrt.clone();
        }
        this.biasGSqrd = toCopy.biasGSqrd;
        this.deltaBiasSqrt = toCopy.deltaBiasSqrt;
    }

    public void setRho(double rho) {
        if (rho <= 0.0 || rho >= 1.0 || Double.isNaN(rho)) {
            throw new IllegalArgumentException("Rho must be in (0, 1)");
        }
        this.rho = rho;
    }

    public double getRho() {
        return this.rho;
    }

    @Override
    public void update(Vec x, Vec grad, double eta) {
        this.update(x, grad, eta, 0.0, 0.0);
    }

    @Override
    public double update(Vec x, Vec grad, double eta, double bias, double biasGrad) {
        this.gSqrd.mutableMultiply(this.rho);
        this.biasGSqrd *= this.rho;
        for (IndexValue iv : grad) {
            int indx = iv.getIndex();
            double grad_i = iv.getValue();
            this.gSqrd.increment(indx, grad_i * grad_i * (1.0 - this.rho));
            double gSqrd_i = this.gSqrd.get(indx);
            double deltaX_i = this.deltaXSqrt.get(indx);
            double newDeltaX_i = -Math.sqrt((deltaX_i + this.eps) / (gSqrd_i + this.eps)) * grad_i;
            x.increment(indx, eta * newDeltaX_i);
            this.deltaXSqrt.increment(indx, (1.0 - this.rho) / this.rho * newDeltaX_i * newDeltaX_i);
        }
        this.deltaXSqrt.mutableMultiply(this.rho);
        this.biasGSqrd += biasGrad * biasGrad * (1.0 - this.rho);
        double newDeltaBias = Math.sqrt((this.deltaBiasSqrt + this.eps) / (this.biasGSqrd + this.eps)) * biasGrad;
        double biasUpdate = eta * newDeltaBias;
        this.deltaBiasSqrt += (1.0 - this.rho) / this.rho * newDeltaBias * newDeltaBias;
        this.deltaBiasSqrt *= this.rho;
        return biasUpdate;
    }

    @Override
    public AdaDelta clone() {
        return new AdaDelta(this);
    }

    @Override
    public void setup(int d) {
        this.gSqrd = new ScaledVector(new DenseVector(d));
        this.deltaXSqrt = new ScaledVector(new DenseVector(d));
        this.biasGSqrd = 0.0;
        this.deltaBiasSqrt = 0.0;
    }
}

