/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.svm;

import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import jsat.DataSet;
import jsat.SingleWeightVectorModel;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.WarmClassifier;
import jsat.classifiers.calibration.BinaryScoreClassifier;
import jsat.classifiers.svm.PlatSMO;
import jsat.distributions.Distribution;
import jsat.exceptions.FailedToFitException;
import jsat.exceptions.UntrainedModelException;
import jsat.linear.DenseVector;
import jsat.linear.Vec;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.regression.RegressionDataSet;
import jsat.regression.Regressor;
import jsat.regression.WarmRegressor;
import jsat.utils.IntList;
import jsat.utils.ListUtils;
import jsat.utils.random.XORWOW;

public class DCDs
implements BinaryScoreClassifier,
Regressor,
Parameterized,
SingleWeightVectorModel,
WarmClassifier,
WarmRegressor {
    private static final long serialVersionUID = -1686294187234524696L;
    private int maxIterations;
    private double tolerance;
    private Vec[] vecs;
    private double[] alpha;
    private double[] y;
    private double bias;
    private Vec w;
    private double C;
    private boolean useL1;
    private double eps = 0.001;
    private boolean useBias = true;
    private final List<Parameter> params = Collections.unmodifiableList(Parameter.getParamsFromMethods(this));
    private final Map<String, Parameter> paramMap = Parameter.toParameterMap(this.params);

    public DCDs() {
        this(10000, false);
    }

    public DCDs(int maxIterations, boolean useL1) {
        this(maxIterations, 0.001, 1.0, useL1);
    }

    public DCDs(int maxIterations, double tolerance, double C, boolean useL1) {
        this.setMaxIterations(maxIterations);
        this.setTolerance(tolerance);
        this.setC(C);
        this.setUseL1(useL1);
    }

    public void setC(double C) {
        if (Double.isNaN(C) || Double.isInfinite(C) || C <= 0.0) {
            throw new ArithmeticException("Penalty parameter must be a positive value, not " + C);
        }
        this.C = C;
    }

    public double getC() {
        return this.C;
    }

    public void setEps(double eps) {
        if (Double.isNaN(eps) || eps < 0.0 || Double.isInfinite(eps)) {
            throw new IllegalArgumentException("eps must be non-negative, not " + eps);
        }
        this.eps = eps;
    }

    public double getEps() {
        return this.eps;
    }

    public void setTolerance(double tolerance) {
        this.tolerance = tolerance;
    }

    public double getTolerance() {
        return this.tolerance;
    }

    public void setUseL1(boolean useL1) {
        this.useL1 = useL1;
    }

    public boolean isUseL1() {
        return this.useL1;
    }

    public void setMaxIterations(int maxIterations) {
        if (maxIterations <= 0) {
            throw new IllegalArgumentException("Number of iterations must be positive, not " + maxIterations);
        }
        this.maxIterations = maxIterations;
    }

    public int getMaxIterations() {
        return this.maxIterations;
    }

    public void setUseBias(boolean useBias) {
        this.useBias = useBias;
    }

    public boolean isUseBias() {
        return this.useBias;
    }

    @Override
    public Vec getRawWeight() {
        return this.w;
    }

    @Override
    public double getBias() {
        return this.bias;
    }

    @Override
    public Vec getRawWeight(int index) {
        if (index < 1) {
            return this.getRawWeight();
        }
        throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override
    public double getBias(int index) {
        if (index < 1) {
            return this.getBias();
        }
        throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override
    public int numWeightsVecs() {
        return 1;
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        if (this.w == null) {
            throw new UntrainedModelException("The model has not been trained");
        }
        CategoricalResults cr = new CategoricalResults(2);
        if (this.getScore(data) < 0.0) {
            cr.setProb(0, 1.0);
        } else {
            cr.setProb(1, 1.0);
        }
        return cr;
    }

    @Override
    public double getScore(DataPoint dp) {
        return this.w.dot(dp.getNumericalValues()) + this.bias;
    }

    @Override
    public void trainC(ClassificationDataSet dataSet, ExecutorService threadPool) {
        this.trainC(dataSet);
    }

    @Override
    public void trainC(ClassificationDataSet dataSet) {
        this.trainC(dataSet, (Classifier)null);
    }

    @Override
    public void trainC(ClassificationDataSet dataSet, Classifier warmSolution, ExecutorService threadPool) {
        this.trainC(dataSet, warmSolution);
    }

    @Override
    public void trainC(ClassificationDataSet dataSet, Classifier warmSolution) {
        if (dataSet.getClassSize() != 2) {
            throw new FailedToFitException("SVM only supports binary classificaiton problems");
        }
        this.vecs = new Vec[dataSet.getSampleSize()];
        this.alpha = new double[this.vecs.length];
        this.y = new double[this.vecs.length];
        this.bias = 0.0;
        double[] Qhs = new double[this.vecs.length];
        double[] U = new double[this.vecs.length];
        double[] D = new double[this.vecs.length];
        for (int i = 0; i < dataSet.getSampleSize(); ++i) {
            DataPoint dp = dataSet.getDataPoint(i);
            this.vecs[i] = dp.getNumericalValues();
            this.y[i] = dataSet.getDataPointCategory(i) * 2 - 1;
            U[i] = this.getU(dp.getWeight());
            D[i] = this.getD(dp.getWeight());
            Qhs[i] = this.vecs[i].dot(this.vecs[i]) + D[i];
            if (!this.useBias) continue;
            int n = i;
            Qhs[n] = Qhs[n] + 1.0;
        }
        this.w = new DenseVector(this.vecs[0].length());
        IntList A = new IntList(this.vecs.length);
        ListUtils.addRange(A, 0, this.vecs.length, 1);
        if (warmSolution != null) {
            if (warmSolution instanceof DCDs) {
                DCDs other = (DCDs)warmSolution;
                if (this.alpha != null && other.alpha.length != this.alpha.length) {
                    throw new FailedToFitException("Warm solution could not have been trained on the same data set");
                }
                double C_mul = this.C / other.C;
                other.w.copyTo(this.w);
                this.w.mutableMultiply(this.C);
                this.bias = other.bias * C_mul;
                System.arraycopy(other.alpha, 0, this.alpha, 0, this.alpha.length);
                int i = 0;
                while (i < this.alpha.length) {
                    int n = i++;
                    this.alpha[n] = this.alpha[n] * C_mul;
                }
            } else {
                throw new FailedToFitException("Warm solution can not be used for warm start");
            }
        }
        double M = Double.NEGATIVE_INFINITY;
        double m = Double.POSITIVE_INFINITY;
        boolean noShrinking = false;
        XORWOW rand = new XORWOW();
        for (int t = 0; t < this.maxIterations; ++t) {
            Collections.shuffle(A, rand);
            M = Double.NEGATIVE_INFINITY;
            m = Double.POSITIVE_INFINITY;
            Iterator iter = A.iterator();
            while (iter.hasNext()) {
                int i = (Integer)iter.next();
                double G = this.y[i] * (this.w.dot(this.vecs[i]) + this.bias) - 1.0 + D[i] * this.alpha[i];
                double PG = 0.0;
                if (this.alpha[i] == 0.0) {
                    if (G > M && !noShrinking) {
                        iter.remove();
                    }
                    if (G < 0.0) {
                        PG = G;
                    }
                } else if (this.alpha[i] == U[i]) {
                    if (G < m && !noShrinking) {
                        iter.remove();
                    }
                    if (G > 0.0) {
                        PG = G;
                    }
                } else {
                    PG = G;
                }
                M = Math.max(M, PG);
                m = Math.min(m, PG);
                if (PG == 0.0) continue;
                double alphaOld = this.alpha[i];
                this.alpha[i] = Math.min(Math.max(this.alpha[i] - G / Qhs[i], 0.0), U[i]);
                double scale = (this.alpha[i] - alphaOld) * this.y[i];
                this.w.mutableAdd(scale, this.vecs[i]);
                if (!this.useBias) continue;
                this.bias += scale;
            }
            if (M - m < this.tolerance) {
                if (A.size() == this.alpha.length) break;
                A.clear();
                ListUtils.addRange(A, 0, this.vecs.length, 1);
                noShrinking = true;
                continue;
            }
            noShrinking = M <= 0.0 || m >= 0.0;
        }
        this.vecs = null;
        this.y = null;
    }

    @Override
    public boolean supportsWeightedData() {
        return true;
    }

    @Override
    public boolean warmFromSameDataOnly() {
        return true;
    }

    @Override
    public DCDs clone() {
        DCDs clone = new DCDs(this.maxIterations, this.tolerance, this.C, this.useL1);
        clone.bias = this.bias;
        clone.useBias = this.useBias;
        if (this.w != null) {
            clone.w = this.w.clone();
        }
        if (this.alpha != null) {
            clone.alpha = Arrays.copyOf(this.alpha, this.alpha.length);
        }
        return clone;
    }

    @Override
    public List<Parameter> getParameters() {
        return this.params;
    }

    @Override
    public Parameter getParameter(String paramName) {
        return this.paramMap.get(paramName);
    }

    @Override
    public double regress(DataPoint data) {
        return this.w.dot(data.getNumericalValues()) + this.bias;
    }

    @Override
    public void train(RegressionDataSet dataSet, Regressor warmSolution, ExecutorService threadPool) {
        this.train(dataSet, warmSolution);
    }

    @Override
    public void train(RegressionDataSet dataSet, ExecutorService threadPool) {
        this.train(dataSet);
    }

    @Override
    public void train(RegressionDataSet dataSet) {
        this.train(dataSet, (Regressor)null);
    }

    @Override
    public void train(RegressionDataSet dataSet, Regressor warmSolution) {
        this.vecs = new Vec[dataSet.getSampleSize()];
        this.alpha = new double[this.vecs.length];
        this.y = new double[this.vecs.length];
        this.bias = 0.0;
        double[] Qhs = new double[this.vecs.length];
        double[] U = new double[this.vecs.length];
        double[] lambda = new double[this.vecs.length];
        double v_0 = 0.0;
        for (int i = 0; i < dataSet.getSampleSize(); ++i) {
            DataPoint dp = dataSet.getDataPoint(i);
            this.vecs[i] = dp.getNumericalValues();
            this.y[i] = dataSet.getTargetValue(i);
            U[i] = this.getU(dp.getWeight());
            lambda[i] = this.getD(dp.getWeight());
            Qhs[i] = this.vecs[i].dot(this.vecs[i]) + lambda[i];
            if (this.useBias) {
                int n = i;
                Qhs[n] = Qhs[n] + 1.0;
            }
            v_0 += Math.abs(DCDs.eq24(0.0, -this.y[i] - this.eps, -this.y[i] + this.eps, U[i]));
        }
        this.w = new DenseVector(this.vecs[0].length());
        IntList activeSet = new IntList(2 * this.vecs.length);
        ListUtils.addRange(activeSet, 0, this.vecs.length, 1);
        if (warmSolution != null) {
            if (warmSolution instanceof DCDs) {
                DCDs other = (DCDs)warmSolution;
                if (this.alpha != null && other.alpha.length != this.alpha.length) {
                    throw new FailedToFitException("Warm solution could not have been trained on the same data set");
                }
                double C_mul = this.C / other.C;
                other.w.copyTo(this.w);
                this.w.mutableMultiply(this.C);
                this.bias = other.bias * C_mul;
                System.arraycopy(other.alpha, 0, this.alpha, 0, this.alpha.length);
                int i = 0;
                while (i < this.alpha.length) {
                    int n = i++;
                    this.alpha[n] = this.alpha[n] * C_mul;
                }
            } else {
                throw new FailedToFitException("Warm solution can not be used for warm start");
            }
        }
        XORWOW rand = new XORWOW();
        double M = Double.POSITIVE_INFINITY;
        for (int iteration = 0; iteration < this.maxIterations; ++iteration) {
            double maxVk = Double.NEGATIVE_INFINITY;
            double vKSum = 0.0;
            Collections.shuffle(activeSet, rand);
            Iterator iter = activeSet.iterator();
            while (iter.hasNext()) {
                double Q_ii;
                double d;
                int i = (Integer)iter.next();
                double y_i = this.y[i];
                Vec x_i = this.vecs[i];
                double wDotX = this.w.dot(x_i) + this.bias;
                double g = -y_i + wDotX + lambda[i] * this.alpha[i];
                double gP = g + this.eps;
                double gN = g - this.eps;
                double v_i = DCDs.eq24(this.alpha[i], gN, gP, U[i]);
                maxVk = Math.max(maxVk, v_i);
                vKSum += Math.abs(v_i);
                boolean shrink = false;
                if (this.alpha[i] == 0.0 && gN < -M && -M < 0.0 && M < gP) {
                    shrink = true;
                }
                if (this.alpha[i] == U[i] && gP < -M || this.alpha[i] == -U[i] && gN > M) {
                    shrink = true;
                }
                if (shrink) {
                    iter.remove();
                }
                if (Math.abs(d = gP < (Q_ii = Qhs[i]) * this.alpha[i] ? -gP / Q_ii : (gN > Q_ii * this.alpha[i] ? -gN / Q_ii : -this.alpha[i])) < 1.0E-14) continue;
                double s = Math.max(-U[i], Math.min(U[i], this.alpha[i] + d));
                this.w.mutableAdd(s - this.alpha[i], x_i);
                if (this.useBias) {
                    this.bias += s - this.alpha[i];
                }
                this.alpha[i] = s;
            }
            if (vKSum / v_0 < this.tolerance) {
                if (activeSet.size() == this.vecs.length) break;
                activeSet.clear();
                ListUtils.addRange(activeSet, 0, this.vecs.length, 1);
                M = Double.POSITIVE_INFINITY;
                continue;
            }
            M = maxVk;
        }
        this.y = null;
        this.vecs = null;
    }

    private double getU(double w) {
        if (this.useL1) {
            return this.C * w;
        }
        return Double.POSITIVE_INFINITY;
    }

    private double getD(double w) {
        if (this.useL1) {
            return 0.0;
        }
        return 1.0 / (2.0 * this.C * w);
    }

    protected static double eq24(double beta_i, double gN, double gP, double U) {
        double vi = 0.0;
        if (beta_i == 0.0) {
            if (gN >= 0.0) {
                vi = gN;
            } else if (gP <= 0.0) {
                vi = -gP;
            }
        } else if (beta_i < 0.0) {
            if (beta_i > -U || beta_i == -U && gN <= 0.0) {
                vi = Math.abs(gN);
            }
        } else if (beta_i < U || beta_i == U && gP >= 0.0) {
            vi = Math.abs(gP);
        }
        return vi;
    }

    public static Distribution guessC(DataSet d) {
        return PlatSMO.guessC(d);
    }
}

