/*
 * Decompiled with CFR 0.152.
 */
package jsat.math.optimization;

import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import jsat.linear.DenseMatrix;
import jsat.linear.DenseVector;
import jsat.linear.LUPDecomposition;
import jsat.linear.Vec;
import jsat.math.Function;
import jsat.math.optimization.Optimizer;
import jsat.utils.FakeExecutor;
import jsat.utils.SystemInfo;

public class IterativelyReweightedLeastSquares
implements Optimizer {
    private static final long serialVersionUID = -6872953184371630318L;
    private DenseMatrix hessian;
    private DenseMatrix coefficentMatrix;
    private DenseVector derivatives;
    private DenseVector errors;
    private DenseVector gradiant;

    @Override
    public Vec optimize(double eps, int iterationLimit, Function f, Function fd, Vec vars, List<Vec> inputs, Vec outputs) {
        return this.optimize(eps, iterationLimit, f, fd, vars, inputs, outputs, null);
    }

    @Override
    public Vec optimize(double eps, int iterationLimit, Function f, Function fd, Vec vars, List<Vec> inputs, Vec outputs, ExecutorService threadpool) {
        this.hessian = new DenseMatrix(vars.length(), vars.length());
        this.coefficentMatrix = new DenseMatrix(inputs.size(), vars.length());
        for (int i = 0; i < inputs.size(); ++i) {
            Vec x_i = inputs.get(i);
            this.coefficentMatrix.set(i, 0, 1.0);
            for (int j = 1; j < vars.length(); ++j) {
                this.coefficentMatrix.set(i, j, x_i.get(j - 1));
            }
        }
        this.derivatives = new DenseVector(inputs.size());
        this.errors = new DenseVector(outputs.length());
        this.gradiant = new DenseVector(vars.length());
        double maxChange = Double.MAX_VALUE;
        if (threadpool != null && !(threadpool instanceof FakeExecutor)) {
            while (!Double.isNaN(maxChange = this.iterationStep(f, fd, vars, inputs, outputs, threadpool)) && maxChange > eps && iterationLimit-- > 0) {
            }
        } else {
            while (!Double.isNaN(maxChange = this.iterationStep(f, fd, vars, inputs, outputs)) && maxChange > eps && iterationLimit-- > 0) {
            }
        }
        return vars;
    }

    private double iterationStep(Function f, Function fd, Vec vars, List<Vec> inputs, Vec outputs) {
        Vec delta = null;
        for (int i = 0; i < inputs.size(); ++i) {
            Vec x_i = inputs.get(i);
            double y = f.f(x_i);
            double error = y - outputs.get(i);
            this.errors.set(i, error);
            this.derivatives.set(i, fd.f(x_i));
        }
        for (int j = 0; j < this.hessian.rows(); ++j) {
            double gradTmp = 0.0;
            for (int k = 0; k < this.coefficentMatrix.rows(); ++k) {
                double coefficient_kj = this.coefficentMatrix.get(k, j);
                gradTmp += coefficient_kj * this.errors.get(k);
                double multFactor = this.derivatives.get(k) * coefficient_kj;
                for (int i = 0; i < this.hessian.rows(); ++i) {
                    this.hessian.increment(j, i, this.coefficentMatrix.get(k, i) * multFactor);
                }
            }
            this.gradiant.set(j, gradTmp);
        }
        LUPDecomposition lupDecomp = new LUPDecomposition(this.hessian.clone());
        if (Math.abs(lupDecomp.det()) < 1.0E-14) {
            return Double.NaN;
        }
        delta = lupDecomp.solve(this.gradiant);
        vars.mutableSubtract(delta);
        return Math.max(delta.max(), Math.abs(delta.min()));
    }

    private double iterationStep(Function f, Function fd, Vec vars, List<Vec> inputs, Vec outputs, ExecutorService threadpool) {
        Vec delta = null;
        for (int i = 0; i < inputs.size(); ++i) {
            Vec x_i = inputs.get(i);
            double y = f.f(x_i);
            double error = y - outputs.get(i);
            this.errors.set(i, error);
            this.derivatives.set(i, fd.f(x_i));
        }
        int overFlow = this.hessian.rows() % SystemInfo.LogicalCores;
        int size = this.hessian.rows() / SystemInfo.LogicalCores;
        int start = 0;
        final CountDownLatch latch = new CountDownLatch(SystemInfo.LogicalCores);
        for (int t = 0; t < SystemInfo.LogicalCores; ++t) {
            int TO;
            final int START = start;
            start = TO = (overFlow-- > 0 ? 1 : 0) + START + size;
            threadpool.submit(new Runnable(){

                @Override
                public void run() {
                    for (int j = START; j < TO; ++j) {
                        double gradTmp = 0.0;
                        for (int k = 0; k < IterativelyReweightedLeastSquares.this.coefficentMatrix.rows(); ++k) {
                            double coefficient_kj = IterativelyReweightedLeastSquares.this.coefficentMatrix.get(k, j);
                            gradTmp += coefficient_kj * IterativelyReweightedLeastSquares.this.errors.get(k);
                            double multFactor = IterativelyReweightedLeastSquares.this.derivatives.get(k) * coefficient_kj;
                            for (int i = 0; i < IterativelyReweightedLeastSquares.this.hessian.rows(); ++i) {
                                IterativelyReweightedLeastSquares.this.hessian.increment(j, i, IterativelyReweightedLeastSquares.this.coefficentMatrix.get(k, i) * multFactor);
                            }
                        }
                        IterativelyReweightedLeastSquares.this.gradiant.set(j, gradTmp);
                    }
                    latch.countDown();
                }
            });
        }
        try {
            latch.await();
        }
        catch (InterruptedException ex) {
            ex.printStackTrace();
        }
        LUPDecomposition lupDecomp = new LUPDecomposition(this.hessian.clone(), threadpool);
        if (Math.abs(lupDecomp.det()) < 1.0E-14) {
            return Double.NaN;
        }
        delta = lupDecomp.solve(this.gradiant);
        vars.mutableSubtract(delta);
        return Math.max(delta.max(), Math.abs(delta.min()));
    }
}

