/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.svm;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import jsat.DataSet;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.WarmClassifier;
import jsat.classifiers.calibration.BinaryScoreClassifier;
import jsat.classifiers.svm.PlatSMO;
import jsat.classifiers.svm.SupportVectorLearner;
import jsat.distributions.Distribution;
import jsat.distributions.kernels.KernelTrick;
import jsat.distributions.kernels.LinearKernel;
import jsat.exceptions.FailedToFitException;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.regression.RegressionDataSet;
import jsat.regression.Regressor;
import jsat.regression.WarmRegressor;
import jsat.utils.FakeExecutor;
import jsat.utils.PairedReturn;
import jsat.utils.SystemInfo;
import jsat.utils.concurrent.ParallelUtils;

public class LSSVM
extends SupportVectorLearner
implements BinaryScoreClassifier,
Regressor,
Parameterized,
WarmRegressor,
WarmClassifier {
    private static final long serialVersionUID = -7569924400631719451L;
    protected double b = 0.0;
    protected double b_low;
    protected double b_up;
    private double C = 1.0;
    private int i_up;
    private int i_low;
    private double[] fcache;
    private double dualObjective;
    private static double epsilon = 1.0E-12;
    private static double tol = 0.001;

    public LSSVM() {
        this(new LinearKernel());
    }

    public LSSVM(KernelTrick kernel) {
        this(kernel, SupportVectorLearner.CacheMode.NONE);
    }

    public LSSVM(KernelTrick kernel, SupportVectorLearner.CacheMode cacheMode) {
        super(kernel, cacheMode);
    }

    public LSSVM(LSSVM toCopy) {
        super(toCopy.getKernel().clone(), toCopy.getCacheMode());
        this.b_low = toCopy.b_low;
        this.b_up = toCopy.b_up;
        this.i_up = toCopy.i_up;
        this.i_low = toCopy.i_low;
        this.C = toCopy.C;
        if (toCopy.alphas != null) {
            this.alphas = Arrays.copyOf(toCopy.alphas, toCopy.alphas.length);
        }
        if (toCopy.fcache != null) {
            this.fcache = Arrays.copyOf(toCopy.fcache, toCopy.fcache.length);
        }
    }

    @Parameter.WarmParameter(prefLowToHigh=true)
    public void setC(double C) {
        if (C <= 0.0 || Double.isNaN(C) || Double.isInfinite(C)) {
            throw new IllegalArgumentException("C must be in (0, Infty), not " + C);
        }
        this.C = C;
    }

    public double getC() {
        return this.C;
    }

    private boolean takeStep(int i1, int i2, ExecutorService ex, int P) throws InterruptedException, ExecutionException {
        double a1;
        double k22;
        double alph1 = this.alphas[i1];
        double alph2 = this.alphas[i2];
        double F1 = this.fcache[i1];
        double F2 = this.fcache[i2];
        double gamma = alph1 + alph2;
        double k11 = this.kEval(i1, i1);
        double k12 = this.kEval(i2, i1);
        double eta = 2.0 * k12 - k11 - (k22 = this.kEval(i2, i2));
        double a2 = alph2 - (F1 - F2) / eta;
        if (Math.abs(a2 - alph2) < epsilon * (a2 + alph2 + epsilon)) {
            return false;
        }
        this.alphas[i1] = a1 = gamma - a2;
        this.alphas[i2] = a2;
        double t = (F1 - F2) / eta;
        this.dualObjective -= eta / 2.0 * t * t;
        this.b_up = Double.NEGATIVE_INFINITY;
        this.b_low = Double.POSITIVE_INFINITY;
        ArrayList<Future<PairedReturn<Integer, Integer>>> futures = new ArrayList<Future<PairedReturn<Integer, Integer>>>(P);
        for (int id = 0; id < P; ++id) {
            int n = ParallelUtils.getStartBlock(this.fcache.length, id, P);
            int to = ParallelUtils.getEndBlock(this.fcache.length, id, P);
            futures.add(ex.submit(new TakeStepLoop(n, to, i1, i2, alph1, alph2)));
        }
        for (Future future : futures) {
            PairedReturn pr = (PairedReturn)future.get();
            int i_up_cand = (Integer)pr.getFirstItem();
            int i_low_cand = (Integer)pr.getSecondItem();
            if (this.fcache[i_up_cand] > this.b_up) {
                this.b_up = this.fcache[i_up_cand];
                this.i_up = i_up_cand;
            }
            if (!(this.fcache[i_low_cand] < this.b_low)) continue;
            this.b_low = this.fcache[i_low_cand];
            this.i_low = i_low_cand;
        }
        return true;
    }

    @Override
    public boolean warmFromSameDataOnly() {
        return true;
    }

    private double computeDualityGap(boolean fast, ExecutorService ex, int P) throws InterruptedException, ExecutionException {
        double gap = 0.0;
        if (fast) {
            this.b = (this.b_up + this.b_low) / 2.0;
        } else {
            this.b = 0.0;
            ArrayList<Future<Double>> bParts = new ArrayList<Future<Double>>(P);
            for (int id = 0; id < P; ++id) {
                bParts.add(ex.submit(new BiasGapCallable(ParallelUtils.getStartBlock(this.alphas.length, id, P), ParallelUtils.getEndBlock(this.alphas.length, id, P))));
            }
            for (Future future : bParts) {
                this.b += ((Double)future.get()).doubleValue();
            }
            this.b /= (double)this.alphas.length;
        }
        ArrayList<Future<Double>> gapParts = new ArrayList<Future<Double>>(P);
        for (int id = 0; id < P; ++id) {
            gapParts.add(ex.submit(new DualityGapCallable(ParallelUtils.getStartBlock(this.alphas.length, id, P), ParallelUtils.getEndBlock(this.alphas.length, id, P))));
        }
        for (Future future : gapParts) {
            gap += ((Double)future.get()).doubleValue();
        }
        return gap;
    }

    private void initializeVariables(double[] targets, LSSVM warmSolution, DataSet data) {
        this.alphas = new double[targets.length];
        this.fcache = new double[targets.length];
        this.dualObjective = 0.0;
        if (warmSolution != null) {
            if (warmSolution.alphas.length != this.alphas.length) {
                throw new FailedToFitException("Warm LS-SVM solution could not have been trained on the sama data, different number of alpha values present");
            }
            double C_ratio = this.C / warmSolution.C;
            for (int i = 0; i < targets.length; ++i) {
                this.alphas[i] = warmSolution.alphas[i];
                this.fcache[i] = warmSolution.fcache[i] - (C_ratio - 1.0) * warmSolution.alphas[i] / this.C;
                this.dualObjective += this.alphas[i] * (targets[i] - this.fcache[i]);
            }
            this.dualObjective /= 2.0;
        } else {
            for (int i = 0; i < targets.length; ++i) {
                this.fcache[i] = -targets[i];
            }
        }
        this.b_up = Double.NEGATIVE_INFINITY;
        this.b_low = Double.POSITIVE_INFINITY;
        for (int i = 0; i < this.fcache.length; ++i) {
            double Fi = this.fcache[i];
            if (Fi > this.b_up) {
                this.b_up = Fi;
                this.i_up = i;
            }
            if (!(Fi < this.b_low)) continue;
            this.b_low = Fi;
            this.i_low = i;
        }
        this.setCacheMode(this.getCacheMode());
    }

    @Override
    public double getScore(DataPoint dp) {
        return this.regress(dp);
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        CategoricalResults cr = new CategoricalResults(2);
        if (this.regress(data) > 0.0) {
            cr.setProb(1, 1.0);
        } else {
            cr.setProb(0, 1.0);
        }
        return cr;
    }

    @Override
    public void trainC(ClassificationDataSet dataSet, ExecutorService threadPool) {
        this.trainC(dataSet, null, threadPool);
    }

    @Override
    public void trainC(ClassificationDataSet dataSet) {
        this.trainC(dataSet, (ExecutorService)null);
    }

    @Override
    public void train(RegressionDataSet dataSet, Regressor warmSolution, ExecutorService threadPool) {
        if (warmSolution != null && !(warmSolution instanceof LSSVM)) {
            throw new FailedToFitException("Warm solution must be an implementation of LS-SVM, not " + warmSolution.getClass());
        }
        double[] targets = dataSet.getTargetValues().arrayCopy();
        this.mainLoop(dataSet, (LSSVM)warmSolution, targets, threadPool);
    }

    @Override
    public void train(RegressionDataSet dataSet, Regressor warmSolution) {
        this.train(dataSet, warmSolution, null);
    }

    @Override
    public void trainC(ClassificationDataSet dataSet, Classifier warmSolution, ExecutorService threadPool) {
        if (dataSet.getClassSize() != 2) {
            throw new FailedToFitException("LS-SVM only supports binary classification problems");
        }
        if (warmSolution != null && !(warmSolution instanceof LSSVM)) {
            throw new FailedToFitException("Warm solution must be an implementation of LS-SVM, not " + warmSolution.getClass());
        }
        double[] targets = new double[dataSet.getSampleSize()];
        for (int i = 0; i < dataSet.getSampleSize(); ++i) {
            targets[i] = dataSet.getDataPointCategory(i) * 2 - 1;
        }
        this.mainLoop(dataSet, (LSSVM)warmSolution, targets, threadPool);
    }

    @Override
    public void trainC(ClassificationDataSet dataSet, Classifier warmSolution) {
        this.trainC(dataSet, warmSolution, null);
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    @Override
    public double regress(DataPoint data) {
        return this.kEvalSum(data.getNumericalValues()) - this.b;
    }

    @Override
    public void train(RegressionDataSet dataSet, ExecutorService threadPool) {
        this.train(dataSet, null, threadPool);
    }

    @Override
    public void train(RegressionDataSet dataSet) {
        this.train(dataSet, (ExecutorService)null);
    }

    @Override
    public LSSVM clone() {
        return new LSSVM(this);
    }

    @Override
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override
    public Parameter getParameter(String paramName) {
        return Parameter.toParameterMap(this.getParameters()).get(paramName);
    }

    private void mainLoop(DataSet dataSet, LSSVM warmSolution, double[] targets, ExecutorService ex) {
        int P;
        if (ex == null || ex instanceof FakeExecutor) {
            ex = new FakeExecutor();
            P = 1;
        } else {
            P = SystemInfo.LogicalCores;
        }
        try {
            this.vecs = dataSet.getDataVectors();
            this.initializeVariables(targets, warmSolution, dataSet);
            boolean change = true;
            double dualityGap = this.computeDualityGap(true, ex, P);
            int iter = 0;
            while (dualityGap > tol * this.dualObjective && change) {
                change = this.takeStep(this.i_up, this.i_low, ex, P);
                dualityGap = this.computeDualityGap(true, ex, P);
                ++iter;
            }
            this.setCacheMode(null);
            this.setAlphas(this.alphas);
        }
        catch (InterruptedException interruptedException) {
            throw new FailedToFitException(interruptedException);
        }
        catch (ExecutionException executionException) {
            throw new FailedToFitException(executionException);
        }
    }

    public static Distribution guessC(DataSet d) {
        return PlatSMO.guessC(d);
    }

    private class DualityGapCallable
    implements Callable<Double> {
        int from;
        int to;

        public DualityGapCallable(int from, int to) {
            this.from = from;
            this.to = to;
        }

        @Override
        public Double call() throws Exception {
            double gap = 0.0;
            for (int i = this.from; i < this.to; ++i) {
                double x_i = LSSVM.this.b + LSSVM.this.alphas[i] / LSSVM.this.C - LSSVM.this.fcache[i];
                gap += LSSVM.this.alphas[i] * (LSSVM.this.fcache[i] - 0.5 * LSSVM.this.alphas[i] / LSSVM.this.C) + LSSVM.this.C * x_i * x_i / 2.0;
            }
            return gap;
        }
    }

    private class BiasGapCallable
    implements Callable<Double> {
        int from;
        int to;

        public BiasGapCallable(int from, int to) {
            this.from = from;
            this.to = to;
        }

        @Override
        public Double call() throws Exception {
            double B = 0.0;
            for (int i = this.from; i < this.to; ++i) {
                B += LSSVM.this.fcache[i] - LSSVM.this.alphas[i] / LSSVM.this.C;
            }
            return B;
        }
    }

    private class TakeStepLoop
    implements Callable<PairedReturn<Integer, Integer>> {
        int from;
        int to;
        int i1;
        int i2;
        double alph1;
        double alph2;
        int i_low_p;
        int i_up_p;

        public TakeStepLoop(int from, int to, int i1, int i2, double alph1, double alph2) {
            this.from = from;
            this.to = to;
            this.i1 = i1;
            this.i2 = i2;
            this.alph1 = alph1;
            this.alph2 = alph2;
        }

        @Override
        public PairedReturn<Integer, Integer> call() throws Exception {
            double a1 = LSSVM.this.alphas[this.i1];
            double a2 = LSSVM.this.alphas[this.i2];
            double b_up_p = Double.NEGATIVE_INFINITY;
            double b_low_p = Double.POSITIVE_INFINITY;
            for (int i = this.from; i < this.to; ++i) {
                double k_i1 = LSSVM.this.kEval(this.i1, i);
                double k_i2 = LSSVM.this.kEval(this.i2, i);
                double[] dArray = LSSVM.this.fcache;
                int n = i;
                double d = dArray[n] = dArray[n] + ((a1 - this.alph1) * k_i1 + (a2 - this.alph2) * k_i2);
                double Fi = d;
                if (Fi > b_up_p) {
                    b_up_p = Fi;
                    this.i_up_p = i;
                }
                if (!(Fi < b_low_p)) continue;
                b_low_p = Fi;
                this.i_low_p = i;
            }
            return new PairedReturn<Integer, Integer>(this.i_up_p, this.i_low_p);
        }
    }
}

