/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.svm;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ExecutorService;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.DataPoint;
import jsat.classifiers.calibration.BinaryScoreClassifier;
import jsat.classifiers.svm.SupportVectorLearner;
import jsat.distributions.kernels.KernelTrick;
import jsat.exceptions.FailedToFitException;
import jsat.exceptions.UntrainedModelException;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.utils.IndexTable;
import jsat.utils.ListUtils;
import jsat.utils.random.XORWOW;

public class SBP
extends SupportVectorLearner
implements BinaryScoreClassifier,
Parameterized {
    private static final long serialVersionUID = 6112916782260792833L;
    private double nu = 0.1;
    private int iterations;
    private double burnIn = 0.2;
    private IndexTable it;

    public SBP(KernelTrick kernel, SupportVectorLearner.CacheMode cacheMode, int iterations, double v) {
        super(kernel, cacheMode);
        this.setIterations(iterations);
        this.setNu(v);
    }

    protected SBP(SBP other) {
        this(other.getKernel().clone(), other.getCacheMode(), other.iterations, other.nu);
        if (other.alphas != null) {
            this.alphas = Arrays.copyOf(other.alphas, other.alphas.length);
        }
    }

    @Override
    public SBP clone() {
        return new SBP(this);
    }

    public void setIterations(int iterations) {
        this.iterations = iterations;
    }

    public int getIterations() {
        return this.iterations;
    }

    public void setNu(double nu) {
        if (Double.isNaN(nu) || nu <= 0.0 || nu >= 1.0) {
            throw new IllegalArgumentException("nu must be in the range (0, 1)");
        }
        this.nu = nu;
    }

    public double getNu() {
        return this.nu;
    }

    public void setBurnIn(double burnIn) {
        if (Double.isNaN(burnIn) || burnIn < 0.0 || burnIn >= 1.0) {
            throw new IllegalArgumentException("BurnInFraction must be in [0, 1), not " + burnIn);
        }
        this.burnIn = burnIn;
    }

    public double getBurnIn() {
        return this.burnIn;
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        if (this.vecs == null) {
            throw new UntrainedModelException("Classifier has yet to be trained");
        }
        CategoricalResults cr = new CategoricalResults(2);
        double sum = this.getScore(data);
        if (sum < 0.0) {
            cr.setProb(0, 1.0);
        } else {
            cr.setProb(1, 1.0);
        }
        return cr;
    }

    @Override
    public double getScore(DataPoint dp) {
        return this.kEvalSum(dp.getNumericalValues());
    }

    @Override
    public void trainC(ClassificationDataSet dataSet, ExecutorService threadPool) {
        this.trainC(dataSet);
    }

    @Override
    public void trainC(ClassificationDataSet dataSet) {
        if (dataSet.getClassSize() != 2) {
            throw new FailedToFitException("SBP supports only binary classification");
        }
        int n = dataSet.getSampleSize();
        int T_0 = (int)Math.min(this.burnIn * (double)this.iterations, (double)(this.iterations - 1));
        double[] C = new double[n];
        double[] CSum = new double[n];
        this.alphas = new double[n];
        double[] alphasSum = new double[n];
        double[] y = new double[n];
        this.vecs = new ArrayList(n);
        for (int i = 0; i < n; ++i) {
            y[i] = dataSet.getDataPointCategory(i) * 2 - 1;
            this.vecs.add(dataSet.getDataPoint(i).getNumericalValues());
        }
        this.setCacheMode(this.getCacheMode());
        XORWOW rand = new XORWOW();
        double maxKii = 0.0;
        for (int i = 0; i < n; ++i) {
            maxKii = Math.max(maxKii, this.kEval(i, i));
        }
        double eta_0 = 1.0 / Math.sqrt(maxKii);
        double rSqrd = 0.0;
        for (int t = 1; t <= this.iterations; ++t) {
            int i;
            double eta = eta_0 / Math.sqrt(t);
            double gamma = this.findGamma(C, (double)n * this.nu);
            int n2 = i = this.sampleC(rand, n, C, gamma);
            this.alphas[n2] = this.alphas[n2] + eta;
            rSqrd = this.updateLoop(rSqrd, eta, C, i, y, n);
            rSqrd = this.projectionStep(rSqrd, n, C);
            if (t < T_0) continue;
            for (int j = 0; j < n; ++j) {
                int n3 = j;
                alphasSum[n3] = alphasSum[n3] + this.alphas[j];
                int n4 = j;
                CSum[n4] = CSum[n4] + C[j];
            }
        }
        for (int j = 0; j < n; ++j) {
            this.alphas[j] = alphasSum[j] / (double)(this.iterations - T_0);
            C[j] = CSum[j] / (double)(this.iterations - T_0);
        }
        double gamma = this.findGamma(C, (double)n * this.nu);
        int j = 0;
        while (j < n) {
            int n5 = j++;
            this.alphas[n5] = this.alphas[n5] / gamma;
        }
        int supportVectorCount = 0;
        for (int i = 0; i < this.vecs.size(); ++i) {
            if (this.alphas[i] == 0.0) continue;
            ListUtils.swap(this.vecs, supportVectorCount, i);
            this.alphas[supportVectorCount++] = this.alphas[i] * y[i];
        }
        this.vecs = new ArrayList(this.vecs.subList(0, supportVectorCount));
        this.alphas = Arrays.copyOfRange(this.alphas, 0, supportVectorCount);
        this.it = null;
        this.setCacheMode(null);
        this.setAlphas(this.alphas);
    }

    private double projectionStep(double rSqrd, int n, double[] C) {
        if (rSqrd > 1.0) {
            double rInv = 1.0 / Math.sqrt(rSqrd);
            int j = 0;
            while (j < n) {
                int n2 = j;
                C[n2] = C[n2] * rInv;
                int n3 = j++;
                this.alphas[n3] = this.alphas[n3] * rInv;
            }
            rSqrd = 1.0;
        }
        return rSqrd;
    }

    private int sampleC(Random rand, int n, double[] C, double gamma) throws FailedToFitException {
        int i = 0;
        int attempts = 0;
        while (C[i = rand.nextInt(n)] > gamma && ++attempts < 5) {
        }
        if (C[i] > gamma) {
            int candidates = 0;
            for (int j = 0; j < C.length; ++j) {
                if (!(C[j] < gamma)) continue;
                ++candidates;
            }
            if (candidates == 0) {
                throw new FailedToFitException("BUG: please report");
            }
            int randCand = rand.nextInt(candidates);
            i = 0;
            for (int j = 0; j < C.length && i < randCand; ++j) {
                if (!(C[i] < gamma)) continue;
                ++i;
            }
        }
        return i;
    }

    private double updateLoop(double rSqrd, double eta, double[] C, int i, double[] y, int n) {
        rSqrd += 2.0 * eta * C[i] + eta * eta * this.kEval(i, i);
        double y_i = y[i];
        for (int j = 0; j < n; ++j) {
            int n2 = j;
            C[n2] = C[n2] + eta * y_i * y[j] * this.kEval(i, j);
        }
        return rSqrd;
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    private double findGamma(double[] C, double d) {
        if (this.it == null) {
            this.it = new IndexTable(C);
        } else {
            this.it.sort(C);
        }
        double sum = 0.0;
        double finalScore = 0.0;
        double prevScore = 0.0;
        for (int i = 0; i < this.it.length(); ++i) {
            double max = C[this.it.index(i)];
            double score = max * (double)i - (sum += max);
            prevScore = finalScore;
            finalScore = (d - max * (double)i + sum) / (double)i + max;
            if (score >= d) break;
        }
        return prevScore;
    }

    @Override
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override
    public Parameter getParameter(String paramName) {
        return Parameter.toParameterMap(this.getParameters()).get(paramName);
    }
}

