/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.linear;

import java.util.List;
import jsat.DataSet;
import jsat.SingleWeightVectorModel;
import jsat.classifiers.BaseUpdateableClassifier;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.DataPoint;
import jsat.classifiers.calibration.BinaryScoreClassifier;
import jsat.distributions.Distribution;
import jsat.distributions.LogUniform;
import jsat.exceptions.FailedToFitException;
import jsat.exceptions.UntrainedModelException;
import jsat.linear.DenseVector;
import jsat.linear.IndexValue;
import jsat.linear.Matrix;
import jsat.linear.Vec;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;

public class NHERD
extends BaseUpdateableClassifier
implements BinaryScoreClassifier,
Parameterized,
SingleWeightVectorModel {
    private static final long serialVersionUID = -1186002893766449917L;
    private Vec w;
    private Matrix sigmaM;
    private Vec sigmaV;
    private CovMode covMode;
    private double C;
    private Vec Sigma_xt;

    public NHERD(double C, CovMode covMode) {
        this.setC(C);
        this.setCovMode(covMode);
    }

    protected NHERD(NHERD other) {
        this.C = other.C;
        this.covMode = other.covMode;
        if (other.w != null) {
            this.w = other.w.clone();
        }
        if (other.sigmaM != null) {
            this.sigmaM = other.sigmaM.clone();
        }
        if (other.sigmaV != null) {
            this.sigmaV = other.sigmaV.clone();
        }
        if (other.Sigma_xt != null) {
            this.Sigma_xt = other.Sigma_xt.clone();
        }
    }

    public void setC(double C) {
        if (Double.isNaN(C) || Double.isInfinite(C) || C <= 0.0) {
            throw new IllegalArgumentException("C must be a postive constant, not " + C);
        }
        this.C = C;
    }

    public double getC() {
        return this.C;
    }

    public void setCovMode(CovMode covMode) {
        this.covMode = covMode;
    }

    public CovMode getCovMode() {
        return this.covMode;
    }

    public Vec getWeightVec() {
        return this.w;
    }

    @Override
    public Vec getRawWeight() {
        return this.w;
    }

    @Override
    public double getBias() {
        return 0.0;
    }

    @Override
    public Vec getRawWeight(int index) {
        if (index < 1) {
            return this.getRawWeight();
        }
        throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override
    public double getBias(int index) {
        if (index < 1) {
            return this.getBias();
        }
        throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override
    public int numWeightsVecs() {
        return 1;
    }

    @Override
    public NHERD clone() {
        return new NHERD(this);
    }

    @Override
    public void setUp(CategoricalData[] categoricalAttributes, int numericAttributes, CategoricalData predicting) {
        if (numericAttributes <= 0) {
            throw new FailedToFitException("AROW requires numeric attributes to perform classification");
        }
        if (predicting.getNumOfCategories() != 2) {
            throw new FailedToFitException("AROW is a binary classifier");
        }
        this.w = new DenseVector(numericAttributes);
        this.Sigma_xt = new DenseVector(numericAttributes);
        if (this.covMode != CovMode.FULL) {
            this.sigmaV = new DenseVector(numericAttributes);
            this.sigmaV.mutableAdd(1.0);
        } else {
            this.sigmaM = Matrix.eye(numericAttributes);
        }
    }

    @Override
    public void update(DataPoint dataPoint, int targetClass) {
        double alpha;
        double y_t = targetClass * 2 - 1;
        Vec x_t = dataPoint.getNumericalValues();
        double pred = x_t.dot(this.w);
        if (y_t * pred > 1.0) {
            return;
        }
        if (this.covMode != CovMode.FULL) {
            alpha = 0.0;
            for (IndexValue iv : x_t) {
                double x_ti = iv.getValue();
                alpha += x_ti * x_ti * this.sigmaV.get(iv.getIndex());
            }
        } else {
            this.sigmaM.multiply(x_t, 1.0, this.Sigma_xt);
            alpha = x_t.dot(this.Sigma_xt);
        }
        double loss = Math.max(0.0, 1.0 - y_t * pred);
        double w_c = y_t * loss / (alpha + 1.0 / this.C);
        if (this.covMode == CovMode.FULL) {
            this.w.mutableAdd(w_c, this.Sigma_xt);
        } else {
            for (IndexValue iv : x_t) {
                this.w.increment(iv.getIndex(), w_c * iv.getValue() * this.sigmaV.get(iv.getIndex()));
            }
        }
        double numer = this.C * (this.C * alpha + 2.0);
        double denom = (1.0 + this.C * alpha) * (1.0 + this.C * alpha);
        switch (this.covMode) {
            case FULL: {
                Matrix.OuterProductUpdate(this.sigmaM, this.Sigma_xt, this.Sigma_xt, -numer / denom);
                break;
            }
            case DROP: {
                double c = -numer / denom;
                for (IndexValue iv : x_t) {
                    int idx = iv.getIndex();
                    double x_ti = iv.getValue() * this.sigmaV.get(idx);
                    this.sigmaV.increment(idx, c * x_ti * x_ti);
                }
                break;
            }
            case PROJECT: {
                for (IndexValue iv : x_t) {
                    int idx = iv.getIndex();
                    double x_r = iv.getValue();
                    double S_rr = this.sigmaV.get(idx);
                    this.sigmaV.set(idx, 1.0 / (1.0 / S_rr + numer * x_r * x_r));
                }
                break;
            }
            case EXACT: {
                for (IndexValue iv : x_t) {
                    int idx = iv.getIndex();
                    double x_r = iv.getValue();
                    double S_rr = this.sigmaV.get(idx);
                    this.sigmaV.set(idx, S_rr / Math.pow(S_rr * x_r * x_r * this.C + 1.0, 2.0));
                }
                break;
            }
        }
        if (this.covMode == CovMode.FULL) {
            this.Sigma_xt.zeroOut();
        }
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        if (this.w == null) {
            throw new UntrainedModelException("Model has not yet ben trained");
        }
        CategoricalResults cr = new CategoricalResults(2);
        double score = this.getScore(data);
        if (score < 0.0) {
            cr.setProb(0, 1.0);
        } else {
            cr.setProb(1, 1.0);
        }
        return cr;
    }

    @Override
    public double getScore(DataPoint dp) {
        return this.w.dot(dp.getNumericalValues());
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    @Override
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override
    public Parameter getParameter(String paramName) {
        return Parameter.toParameterMap(this.getParameters()).get(paramName);
    }

    public static Distribution guessC(DataSet d) {
        return new LogUniform(Math.pow(2.0, -4.0), Math.pow(2.0, 4.0));
    }

    public static enum CovMode {
        FULL,
        DROP,
        PROJECT,
        EXACT;

    }
}

