/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.linear;

import java.util.Arrays;
import java.util.List;
import jsat.DataSet;
import jsat.SimpleWeightVectorModel;
import jsat.classifiers.BaseUpdateableClassifier;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.DataPoint;
import jsat.classifiers.linear.PassiveAggressive;
import jsat.distributions.Distribution;
import jsat.linear.DenseVector;
import jsat.linear.Vec;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.utils.IndexTable;

public class SPA
extends BaseUpdateableClassifier
implements Parameterized,
SimpleWeightVectorModel {
    private static final long serialVersionUID = 3613279663279244169L;
    private Vec[] w;
    private double[] bias;
    private double C = 1.0;
    private boolean useBias = false;
    private PassiveAggressive.Mode mode;
    private double[] loss;
    private IndexTable it;

    public SPA() {
        this(10, PassiveAggressive.Mode.PA2);
    }

    public SPA(int epochs, PassiveAggressive.Mode mode) {
        this.setEpochs(epochs);
        this.setMode(mode);
    }

    public void setUseBias(boolean useBias) {
        this.useBias = useBias;
    }

    public boolean isUseBias() {
        return this.useBias;
    }

    public void setC(double C) {
        if (Double.isNaN(C) || Double.isInfinite(C) || C <= 0.0) {
            throw new ArithmeticException("Aggressiveness must be a positive constant");
        }
        this.C = C;
    }

    public double getC() {
        return this.C;
    }

    public void setMode(PassiveAggressive.Mode mode) {
        this.mode = mode;
    }

    public PassiveAggressive.Mode getMode() {
        return this.mode;
    }

    @Override
    public Vec getRawWeight(int index) {
        return this.w[index];
    }

    @Override
    public double getBias(int index) {
        return this.bias[index];
    }

    @Override
    public int numWeightsVecs() {
        return this.w.length;
    }

    @Override
    public SPA clone() {
        SPA clone = new SPA();
        if (this.w != null) {
            clone.w = new Vec[this.w.length];
            for (int i = 0; i < this.w.length; ++i) {
                clone.w[i] = this.w[i].clone();
            }
        }
        if (this.it != null) {
            clone.it = new IndexTable(this.it.length());
        }
        if (this.loss != null) {
            clone.loss = Arrays.copyOf(this.loss, this.loss.length);
        }
        clone.C = this.C;
        clone.mode = this.mode;
        if (this.bias != null) {
            clone.bias = Arrays.copyOf(this.bias, this.bias.length);
        }
        clone.useBias = this.useBias;
        return clone;
    }

    @Override
    public void setUp(CategoricalData[] categoricalAttributes, int numericAttributes, CategoricalData predicting) {
        this.w = new Vec[predicting.getNumOfCategories()];
        for (int i = 0; i < this.w.length; ++i) {
            this.w[i] = new DenseVector(numericAttributes);
        }
        this.bias = new double[this.w.length];
        this.loss = new double[this.w.length];
        this.it = new IndexTable(this.w.length);
    }

    private double getSupportClassGoal(double xNorm, int k, double loss_k) {
        if (this.mode == PassiveAggressive.Mode.PA1) {
            return Math.min((double)(k - 1) * loss_k + this.C * xNorm, (double)k * loss_k);
        }
        if (this.mode == PassiveAggressive.Mode.PA2) {
            return ((double)k * xNorm + (double)(k - 1) / (2.0 * this.C)) / (xNorm + 1.0 / (2.0 * this.C)) * loss_k;
        }
        return (double)k * loss_k;
    }

    private double getStepSize(double loss_cur, double xNorm, int k, double supLossSum) {
        if (this.mode == PassiveAggressive.Mode.PA1) {
            return Math.max(0.0, loss_cur - Math.max(supLossSum / (double)(k - 1) - this.C / (double)(k - 1) * xNorm, supLossSum / (double)k)) / xNorm;
        }
        if (this.mode == PassiveAggressive.Mode.PA2) {
            return Math.max(0.0, loss_cur - (xNorm + 1.0 / (2.0 * this.C)) / ((double)k * xNorm + (double)(k - 1) / (2.0 * this.C)) * supLossSum) / xNorm;
        }
        return Math.max(0.0, loss_cur - supLossSum / (double)k) / xNorm;
    }

    @Override
    public void update(DataPoint dataPoint, int targetClass) {
        int j;
        Vec x = dataPoint.getNumericalValues();
        double w_y_dot_x = this.w[targetClass].dot(x) + this.bias[targetClass];
        for (int v = 0; v < this.w.length; ++v) {
            this.loss[v] = v != targetClass ? Math.max(0.0, 1.0 - (w_y_dot_x - this.w[v].dot(x) - this.bias[v])) : Double.POSITIVE_INFINITY;
        }
        double xNorm = Math.pow(x.pNorm(2.0) + (double)(this.useBias ? 1 : 0), 2.0);
        this.it.sortR(this.loss);
        int k = 1;
        for (double T31 = 0.0; k < this.loss.length && T31 < this.getSupportClassGoal(xNorm, k, this.loss[this.it.index(k)]); T31 += this.loss[this.it.index(k++)]) {
        }
        double supportLossSum = 0.0;
        for (j = 1; j < k; ++j) {
            supportLossSum += this.loss[this.it.index(j)];
        }
        for (j = 1; j < k; ++j) {
            int v = this.it.index(j);
            double tau = this.getStepSize(this.loss[v], xNorm, k, supportLossSum);
            this.w[targetClass].mutableAdd(tau, x);
            this.w[v].mutableSubtract(tau, x);
            if (!this.useBias) continue;
            int n = targetClass;
            this.bias[n] = this.bias[n] + tau;
            int n2 = v;
            this.bias[n2] = this.bias[n2] - tau;
        }
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        Vec x = data.getNumericalValues();
        CategoricalResults cr = new CategoricalResults(this.w.length);
        int maxIdx = 0;
        double maxVAl = this.w[0].dot(x) + this.bias[0];
        for (int i = 1; i < this.w.length; ++i) {
            double val = this.w[i].dot(x) + this.bias[i];
            if (!(val > maxVAl)) continue;
            maxVAl = val;
            maxIdx = i;
        }
        cr.setProb(maxIdx, 1.0);
        return cr;
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    @Override
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override
    public Parameter getParameter(String paramName) {
        return Parameter.toParameterMap(this.getParameters()).get(paramName);
    }

    public static Distribution guessC(DataSet d) {
        return PassiveAggressive.guessC(d);
    }
}

