/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.linear;

import java.util.Iterator;
import java.util.List;
import jsat.SingleWeightVectorModel;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.Classifier;
import jsat.linear.IndexValue;
import jsat.linear.Vec;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.regression.Regressor;

public abstract class StochasticSTLinearL1
implements Classifier,
Regressor,
Parameterized,
SingleWeightVectorModel {
    private static final long serialVersionUID = -6761456665014802608L;
    protected int epochs;
    protected double lambda;
    protected Loss loss;
    protected Vec w;
    protected double bias;
    protected double[] obvMin;
    protected double[] obvMax;
    protected boolean reScale;
    protected double minScaled = 0.0;
    protected double maxScaled = 1.0;
    public static final int DEFAULT_EPOCHS = 1000;
    public static final double DEFAULT_REG = 1.0E-14;
    public static final Loss DEFAULT_LOSS = Loss.SQUARED;

    @Override
    public abstract StochasticSTLinearL1 clone();

    public void setEpochs(int epochs) {
        if (epochs < 1) {
            throw new ArithmeticException("A positive amount of iterations must be performed");
        }
        this.epochs = epochs;
    }

    public double getEpochs() {
        return this.epochs;
    }

    public void setMaxScaled(double maxFeature) {
        if (Double.isNaN(maxFeature)) {
            throw new ArithmeticException("NaN is not a valid feature value");
        }
        if (maxFeature > 1.0) {
            throw new ArithmeticException("Maximum possible feature value is 1, can not use " + maxFeature);
        }
        if (maxFeature <= this.minScaled) {
            throw new ArithmeticException("Maximum feature value must be learger than the minimum");
        }
        this.maxScaled = maxFeature;
    }

    public double getMaxScaled() {
        return this.maxScaled;
    }

    public void setMinScaled(double minFeature) {
        if (Double.isNaN(minFeature)) {
            throw new ArithmeticException("NaN is not a valid feature value");
        }
        if (minFeature < -1.0) {
            throw new ArithmeticException("Minimum possible feature value is -1, can not use " + minFeature);
        }
        if (minFeature >= this.maxScaled) {
            throw new ArithmeticException("Minimum feature value must be smaller than the maximum");
        }
        this.minScaled = minFeature;
    }

    public double getMinScaled() {
        return this.minScaled;
    }

    public void setLambda(double lambda) {
        if (Double.isInfinite(lambda) || Double.isNaN(lambda) || lambda <= 0.0) {
            throw new ArithmeticException("A positive amount of regularization must be performed");
        }
        this.lambda = lambda;
    }

    public double getLambda() {
        return this.lambda;
    }

    public void setLoss(Loss loss) {
        this.loss = loss;
    }

    public Loss getLoss() {
        return this.loss;
    }

    public void setReScale(boolean reScale) {
        this.reScale = reScale;
    }

    public boolean isReScale() {
        return this.reScale;
    }

    protected double wDot(Vec x) {
        double a;
        block9: {
            block10: {
                if (!this.reScale) break block10;
                a = this.bias;
                if (!this.w.isSparse()) {
                    for (IndexValue iv : x) {
                        int j = iv.getIndex();
                        double xV = iv.getValue() - this.obvMin[j];
                        xV *= (this.maxScaled - this.minScaled) / (this.obvMax[j] - this.obvMin[j]);
                        a += this.w.get(j) * (xV += this.minScaled);
                    }
                    return a;
                }
                Iterator<IndexValue> wIter = this.w.getNonZeroIterator();
                Iterator<IndexValue> xIter = x.getNonZeroIterator();
                if (!wIter.hasNext() || !xIter.hasNext()) {
                    return a;
                }
                IndexValue wIV = wIter.next();
                IndexValue xIV = xIter.next();
                do {
                    if (wIV.getIndex() == xIV.getIndex()) {
                        int j = xIV.getIndex();
                        double xV = xIV.getValue() - this.obvMin[j];
                        xV *= (this.maxScaled - this.minScaled) / (this.obvMax[j] - this.obvMin[j]);
                        a += wIV.getValue() * (xV += this.minScaled);
                        if (wIter.hasNext() && xIter.hasNext()) {
                            wIV = wIter.next();
                            xIV = xIter.next();
                            continue;
                        }
                        break block9;
                    }
                    if (wIV.getIndex() < xIV.getIndex()) {
                        if (wIter.hasNext()) {
                            wIV = wIter.next();
                            continue;
                        }
                        break block9;
                    }
                    if (wIV.getIndex() <= xIV.getIndex()) continue;
                    if (!xIter.hasNext()) break block9;
                    xIV = xIter.next();
                } while (wIV != null && xIV != null);
                break block9;
                {
                }
            }
            a = this.w.dot(x) + this.bias;
        }
        return a;
    }

    public Vec getWRaw() {
        return this.w;
    }

    public Vec getW() {
        if (this.w == null) {
            return this.w;
        }
        return this.w.clone();
    }

    @Override
    public Vec getRawWeight() {
        return this.w;
    }

    @Override
    public double getBias() {
        return this.bias;
    }

    @Override
    public Vec getRawWeight(int index) {
        if (index < 1) {
            return this.getRawWeight();
        }
        throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override
    public double getBias(int index) {
        if (index < 1) {
            return this.getBias();
        }
        throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override
    public int numWeightsVecs() {
        return 1;
    }

    @Override
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override
    public Parameter getParameter(String paramName) {
        return Parameter.toParameterMap(this.getParameters()).get(paramName);
    }

    public static enum Loss {
        SQUARED{

            @Override
            public double loss(double a, double y) {
                return 0.5 * Math.pow(a - y, 2.0);
            }

            @Override
            public double deriv(double a, double y) {
                return a - y;
            }

            @Override
            public double beta() {
                return 1.0;
            }

            @Override
            public CategoricalResults classify(double a) {
                CategoricalResults cr = new CategoricalResults(2);
                if ((a = (a + 1.0) / 2.0) > 1.0) {
                    a = 1.0;
                } else if (a < 0.0) {
                    a = 0.0;
                }
                cr.setProb(1, a);
                cr.setProb(0, 1.0 - a);
                return cr;
            }

            @Override
            public double regress(double a) {
                return a;
            }
        }
        ,
        LOG{

            @Override
            public double loss(double a, double y) {
                return 1.0 + Math.exp(-y * a);
            }

            @Override
            public double deriv(double a, double y) {
                return -y / (1.0 + Math.exp(a * y));
            }

            @Override
            public double beta() {
                return 0.25;
            }

            @Override
            public CategoricalResults classify(double a) {
                CategoricalResults cr = new CategoricalResults(2);
                cr.setProb(1, this.regress(a));
                cr.setProb(0, 1.0 - cr.getProb(1));
                return cr;
            }

            @Override
            public double regress(double a) {
                return 1.0 / (1.0 + Math.exp(-a));
            }
        };


        public abstract double loss(double var1, double var3);

        public abstract double deriv(double var1, double var3);

        public abstract double beta();

        public abstract CategoricalResults classify(double var1);

        public abstract double regress(double var1);
    }
}

