/*
 * Decompiled with CFR 0.152.
 */
package jsat.math.optimization;

import java.util.concurrent.ExecutorService;
import jsat.linear.DenseMatrix;
import jsat.linear.IndexValue;
import jsat.linear.Matrix;
import jsat.linear.Vec;
import jsat.math.Function;
import jsat.math.FunctionP;
import jsat.math.FunctionVec;
import jsat.math.optimization.BacktrackingArmijoLineSearch;
import jsat.math.optimization.LineSearch;
import jsat.math.optimization.Optimizer2;

public class BFGS
implements Optimizer2 {
    private LineSearch lineSearch;
    private int maxIterations;
    private boolean inftNormCriterion = true;

    public BFGS() {
        this(250, new BacktrackingArmijoLineSearch());
    }

    public BFGS(int maxIterations, LineSearch lineSearch) {
        this.setMaximumIterations(maxIterations);
        this.setLineSearch(lineSearch);
    }

    @Override
    public void optimize(double tolerance, Vec w, Vec x0, Function f, FunctionVec fp, FunctionVec fpp) {
        this.optimize(tolerance, w, x0, f, fp, fpp, null);
    }

    @Override
    public void optimize(double tolerance, Vec w, Vec x0, Function f, FunctionVec fp, FunctionVec fpp, ExecutorService ex) {
        LineSearch search = this.lineSearch.clone();
        DenseMatrix H = Matrix.eye(x0.length());
        Vec x_prev = x0.clone();
        Vec x_cur = x0.clone();
        double[] f_xVal = new double[1];
        Vec x_grad = x0.clone();
        x_grad.zeroOut();
        Vec x_gradPrev = x_grad.clone();
        Vec p_k = x_grad.clone();
        Vec s_k = x_grad.clone();
        Vec y_k = x_grad.clone();
        f_xVal[0] = ex != null && f instanceof FunctionP ? ((FunctionP)f).f(x_cur, ex) : f.f(x_cur);
        x_grad = ex != null ? fp.f(x_cur, x_grad, ex) : fp.f(x_cur, x_grad);
        int iter = 0;
        while (this.gradConvgHelper(x_grad) > tolerance && iter < this.maxIterations) {
            int i;
            ++iter;
            p_k.zeroOut();
            ((Matrix)H).multiply(x_grad, -1.0, p_k);
            x_cur.copyTo(x_prev);
            x_grad.copyTo(x_gradPrev);
            double alpha_k = search.lineSearch(1.0, x_prev, x_gradPrev, p_k, f, fp, f_xVal[0], x_gradPrev.dot(p_k), x_cur, f_xVal, x_grad, ex);
            if (alpha_k < 1.0E-12 && iter > 5) break;
            if (!search.updatesGrad()) {
                if (ex != null) {
                    fp.f(x_cur, x_grad, ex);
                } else {
                    fp.f(x_cur, x_grad);
                }
            }
            x_cur.copyTo(s_k);
            s_k.mutableSubtract(x_prev);
            x_grad.copyTo(y_k);
            y_k.mutableSubtract(x_gradPrev);
            double skyk = s_k.dot(y_k);
            if (skyk <= 0.0) {
                ((Matrix)H).zeroOut();
                for (i = 0; i < ((Matrix)H).rows(); ++i) {
                    ((Matrix)H).set(i, i, 1.0);
                }
                continue;
            }
            if (iter == 0 && skyk > 1.0E-12) {
                for (i = 0; i < ((Matrix)H).rows(); ++i) {
                    ((Matrix)H).set(i, i, skyk / y_k.dot(y_k));
                }
            }
            Vec Hkyk = H.multiply(y_k);
            Vec ykHk = y_k.multiply(H);
            double b = (1.0 + y_k.dot(Hkyk) / skyk) / skyk;
            Matrix.OuterProductUpdate(H, s_k, ykHk, -1.0 / skyk);
            Matrix.OuterProductUpdate(H, Hkyk, s_k, -1.0 / skyk);
            Matrix.OuterProductUpdate(H, s_k, s_k, b);
        }
        x_cur.copyTo(w);
    }

    public void setInftNormCriterion(boolean inftNormCriterion) {
        this.inftNormCriterion = inftNormCriterion;
    }

    public boolean isInftNormCriterion() {
        return this.inftNormCriterion;
    }

    private double gradConvgHelper(Vec grad) {
        if (!this.inftNormCriterion) {
            return grad.pNorm(2.0);
        }
        double max = 0.0;
        for (IndexValue iv : grad) {
            max = Math.max(max, Math.abs(iv.getValue()));
        }
        return max;
    }

    @Override
    public void setMaximumIterations(int iterations) {
        if (iterations < 1) {
            throw new IllegalArgumentException("Iterations must be a positive value, not " + iterations);
        }
        this.maxIterations = iterations;
    }

    @Override
    public int getMaximumIterations() {
        return this.maxIterations;
    }

    public void setLineSearch(LineSearch lineSearch) {
        this.lineSearch = lineSearch;
    }

    public LineSearch getLineSearch() {
        return this.lineSearch;
    }

    @Override
    public Optimizer2 clone() {
        return new BFGS(this.maxIterations, this.lineSearch.clone());
    }
}

