/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.linear;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.DataSet;
import jsat.SimpleWeightVectorModel;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.WarmClassifier;
import jsat.distributions.Distribution;
import jsat.distributions.LogUniform;
import jsat.exceptions.FailedToFitException;
import jsat.linear.ConcatenatedVec;
import jsat.linear.DenseVector;
import jsat.linear.IndexValue;
import jsat.linear.SubVector;
import jsat.linear.Vec;
import jsat.lossfunctions.LossC;
import jsat.lossfunctions.LossFunc;
import jsat.lossfunctions.LossMC;
import jsat.lossfunctions.LossR;
import jsat.lossfunctions.SoftmaxLoss;
import jsat.math.FunctionP;
import jsat.math.FunctionVec;
import jsat.math.optimization.LBFGS;
import jsat.math.optimization.Optimizer2;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.regression.RegressionDataSet;
import jsat.regression.Regressor;
import jsat.regression.WarmRegressor;
import jsat.utils.ListUtils;
import jsat.utils.SystemInfo;
import jsat.utils.concurrent.ParallelUtils;

public class LinearBatch
implements Classifier,
Regressor,
Parameterized,
SimpleWeightVectorModel,
WarmClassifier,
WarmRegressor {
    private static final long serialVersionUID = -446156124954287580L;
    private Vec[] ws;
    private double[] bs;
    private LossFunc loss;
    private double lambda0;
    private Optimizer2 optimizer;
    private double tolerance;
    private boolean useBiasTerm = true;

    public LinearBatch() {
        this(new SoftmaxLoss(), 1.0E-6);
    }

    public LinearBatch(LossFunc loss, double lambda0) {
        this(loss, lambda0, 0.001);
    }

    public LinearBatch(LossFunc loss, double lambda0, double tolerance) {
        this(loss, lambda0, tolerance, null);
    }

    public LinearBatch(LossFunc loss, double lambda0, double tolerance, Optimizer2 optimizer) {
        this.setLoss(loss);
        this.setLambda0(lambda0);
        this.setOptimizer(optimizer);
        this.setTolerance(tolerance);
    }

    public LinearBatch(LinearBatch toCopy) {
        this(toCopy.loss.clone(), toCopy.lambda0, toCopy.tolerance, toCopy.optimizer == null ? null : toCopy.optimizer.clone());
        if (toCopy.ws != null) {
            this.ws = new Vec[toCopy.ws.length];
            for (int i = 0; i < toCopy.ws.length; ++i) {
                this.ws[i] = toCopy.ws[i].clone();
            }
        }
        if (toCopy.bs != null) {
            this.bs = Arrays.copyOf(toCopy.bs, toCopy.bs.length);
        }
    }

    public void setUseBiasTerm(boolean useBiasTerm) {
        this.useBiasTerm = useBiasTerm;
    }

    public boolean isUseBiasTerm() {
        return this.useBiasTerm;
    }

    public void setLambda0(double lambda0) {
        if (lambda0 < 0.0 || Double.isNaN(lambda0) || Double.isInfinite(lambda0)) {
            throw new IllegalArgumentException("Lambda0 must be non-negative, not " + lambda0);
        }
        this.lambda0 = lambda0;
    }

    public double getLambda0() {
        return this.lambda0;
    }

    public void setLoss(LossFunc loss) {
        this.loss = loss;
    }

    public LossFunc getLoss() {
        return this.loss;
    }

    public void setOptimizer(Optimizer2 optimizer) {
        this.optimizer = optimizer;
    }

    public Optimizer2 getOptimizer() {
        return this.optimizer;
    }

    public void setTolerance(double tolerance) {
        if (tolerance < 0.0 || Double.isNaN(tolerance) || Double.isInfinite(tolerance)) {
            throw new IllegalArgumentException("Tolerance must be a non-negative constant, not " + tolerance);
        }
        this.tolerance = tolerance;
    }

    public double getTolerance() {
        return this.tolerance;
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        Vec x = data.getNumericalValues();
        if (this.ws.length == 1) {
            return ((LossC)this.loss).getClassification(this.ws[0].dot(x) + this.bs[0]);
        }
        DenseVector pred = new DenseVector(this.ws.length);
        for (int i = 0; i < this.ws.length; ++i) {
            ((Vec)pred).set(i, this.ws[i].dot(x) + this.bs[i]);
        }
        ((LossMC)this.loss).process(pred, pred);
        return ((LossMC)this.loss).getClassification(pred);
    }

    @Override
    public double regress(DataPoint data) {
        Vec x = data.getNumericalValues();
        return ((LossR)this.loss).getRegression(this.ws[0].dot(x) + this.bs[0]);
    }

    @Override
    public void trainC(ClassificationDataSet dataSet, Classifier warmSolution) {
        this.trainC(dataSet, warmSolution, null);
    }

    @Override
    public void trainC(ClassificationDataSet D, ExecutorService threadPool) {
        this.trainC(D, null, threadPool);
    }

    @Override
    public void trainC(ClassificationDataSet D, Classifier warmSolution, ExecutorService threadPool) {
        if (D.getNumNumericalVars() <= 0) {
            throw new FailedToFitException("LinearBath requires numeric features to work");
        }
        if (!(this.loss instanceof LossC)) {
            throw new FailedToFitException("Loss function " + this.loss.getClass().getSimpleName() + " does not support classification");
        }
        if (D.getClassSize() > 2) {
            if (!(this.loss instanceof LossMC)) {
                throw new FailedToFitException("Loss function " + this.loss.getClass().getSimpleName() + " does not support multi-class classification");
            }
            this.ws = new Vec[D.getClassSize()];
            this.bs = new double[this.ws.length];
        } else {
            this.ws = new Vec[1];
            this.bs = new double[1];
        }
        for (int i = 0; i < this.ws.length; ++i) {
            this.ws[i] = new DenseVector(D.getNumNumericalVars());
        }
        Optimizer2 optimizerToUse = this.optimizer == null ? new LBFGS(10) : this.optimizer.clone();
        this.doWarmStartIfNotNull(warmSolution);
        if (this.ws.length == 1) {
            if (this.useBiasTerm) {
                VecWithBias w_tmp = new VecWithBias(this.ws[0], this.bs);
                optimizerToUse.optimize(this.tolerance, w_tmp, w_tmp, new LossFunction(D, this.loss), new GradFunction(D, this.loss), null, threadPool);
            } else {
                optimizerToUse.optimize(this.tolerance, this.ws[0], this.ws[0], new LossFunction(D, this.loss), new GradFunction(D, this.loss), null, threadPool);
            }
        } else {
            ConcatenatedVec wAll;
            LossMC lossMC = (LossMC)this.loss;
            if (this.useBiasTerm) {
                ArrayList<Vec> vecs = new ArrayList<Vec>(Arrays.asList(this.ws));
                vecs.add(DenseVector.toDenseVec(this.bs));
                wAll = new ConcatenatedVec(vecs);
            } else {
                wAll = new ConcatenatedVec(Arrays.asList(this.ws));
            }
            optimizerToUse.optimize(this.tolerance, wAll, new DenseVector(wAll), new LossMCFunction(D, lossMC), new GradMCFunction(D, lossMC), null, threadPool);
        }
    }

    private void doWarmStartIfNotNull(Object warmSolution) throws FailedToFitException {
        if (warmSolution != null) {
            if (warmSolution instanceof SimpleWeightVectorModel) {
                SimpleWeightVectorModel warm = (SimpleWeightVectorModel)warmSolution;
                if (warm.numWeightsVecs() != this.ws.length) {
                    throw new FailedToFitException("Warm solution has " + warm.numWeightsVecs() + " weight vectors instead of " + this.ws.length);
                }
                for (int i = 0; i < this.ws.length; ++i) {
                    warm.getRawWeight(i).copyTo(this.ws[i]);
                    if (!this.useBiasTerm) continue;
                    this.bs[i] = warm.getBias(i);
                }
            } else {
                throw new FailedToFitException("Can not warm warm from " + warmSolution.getClass().getCanonicalName());
            }
        }
    }

    @Override
    public void trainC(ClassificationDataSet dataSet) {
        this.trainC(dataSet, (ExecutorService)null);
    }

    @Override
    public void train(RegressionDataSet D, ExecutorService threadPool) {
        this.train(D, null, threadPool);
    }

    @Override
    public void train(RegressionDataSet dataSet, Regressor warmSolution) {
        this.train(dataSet, warmSolution, null);
    }

    @Override
    public void train(RegressionDataSet D, Regressor warmSolution, ExecutorService threadPool) {
        if (D.getNumNumericalVars() <= 0) {
            throw new FailedToFitException("LinearBath requires numeric features to work");
        }
        if (!(this.loss instanceof LossR)) {
            throw new FailedToFitException("Loss function " + this.loss.getClass().getSimpleName() + " does not regression");
        }
        this.ws = new Vec[]{new DenseVector(D.getNumNumericalVars())};
        this.bs = new double[1];
        Optimizer2 optimizerToUse = this.optimizer == null ? new LBFGS(10) : this.optimizer.clone();
        this.doWarmStartIfNotNull(warmSolution);
        if (this.useBiasTerm) {
            VecWithBias w_tmp = new VecWithBias(this.ws[0], this.bs);
            optimizerToUse.optimize(this.tolerance, w_tmp, w_tmp, new LossFunction(D, this.loss), new GradFunction(D, this.loss), null, threadPool);
        } else {
            optimizerToUse.optimize(this.tolerance, this.ws[0], this.ws[0], new LossFunction(D, this.loss), new GradFunction(D, this.loss), null, threadPool);
        }
    }

    @Override
    public void train(RegressionDataSet dataSet) {
        this.train(dataSet, (ExecutorService)null);
    }

    private static double getTargetY(DataSet D, int i) {
        double y = D instanceof ClassificationDataSet ? (double)(((ClassificationDataSet)D).getDataPointCategory(i) * 2 - 1) : ((RegressionDataSet)D).getTargetValue(i);
        return y;
    }

    @Override
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override
    public Parameter getParameter(String paramName) {
        return Parameter.toParameterMap(this.getParameters()).get(paramName);
    }

    @Override
    public boolean warmFromSameDataOnly() {
        return false;
    }

    @Override
    public Vec getRawWeight(int index) {
        return this.ws[index];
    }

    @Override
    public double getBias(int index) {
        return this.bs[index];
    }

    @Override
    public int numWeightsVecs() {
        return this.ws.length;
    }

    @Override
    public boolean supportsWeightedData() {
        return true;
    }

    @Override
    public LinearBatch clone() {
        return new LinearBatch(this);
    }

    public static Distribution guessLambda0(DataSet d) {
        return new LogUniform(1.0E-7, 0.01);
    }

    private class GradMCFunction
    implements FunctionVec {
        private final ClassificationDataSet D;
        private final LossMC loss;
        private ThreadLocal<Vec> tempVecs;

        public GradMCFunction(ClassificationDataSet D, LossMC loss) {
            this.D = D;
            this.loss = loss;
        }

        @Override
        public Vec f(double ... x) {
            return this.f(DenseVector.toDenseVec(x));
        }

        @Override
        public Vec f(Vec w) {
            Vec s = w.clone();
            this.f(w, s);
            return s;
        }

        @Override
        public Vec f(Vec w, Vec s) {
            if (s == null) {
                s = w.clone();
            }
            s.zeroOut();
            DenseVector pred = new DenseVector(this.D.getClassSize());
            int subWSize = (w.length() - (LinearBatch.this.useBiasTerm ? LinearBatch.this.bs.length : 0)) / this.D.getClassSize();
            double weightSum = 0.0;
            for (int i = 0; i < this.D.getSampleSize(); ++i) {
                DataPoint dp = this.D.getDataPoint(i);
                Vec x = dp.getNumericalValues();
                for (int k = 0; k < ((Vec)pred).length(); ++k) {
                    ((Vec)pred).set(k, new SubVector(k * subWSize, subWSize, w).dot(x));
                }
                if (LinearBatch.this.useBiasTerm) {
                    pred.mutableAdd(new SubVector(w.length() - LinearBatch.this.bs.length, LinearBatch.this.bs.length, w));
                }
                this.loss.process(pred, pred);
                int y = this.D.getDataPointCategory(i);
                this.loss.deriv(pred, pred, y);
                for (int k = 0; k < ((Vec)pred).length(); ++k) {
                    new SubVector(k * subWSize, subWSize, s).mutableAdd(((Vec)pred).get(k) * dp.getWeight(), x);
                }
                weightSum += dp.getWeight();
            }
            s.mutableDivide(weightSum);
            if (LinearBatch.this.lambda0 > 0.0) {
                s.mutableSubtract(LinearBatch.this.lambda0, w);
            }
            return s;
        }

        @Override
        public Vec f(final Vec w, Vec s, ExecutorService ex) {
            if (s == null) {
                s = w.clone();
            }
            s.zeroOut();
            if (this.tempVecs == null) {
                this.tempVecs = new ThreadLocal<Vec>(){

                    @Override
                    protected Vec initialValue() {
                        return w.clone();
                    }
                };
            }
            final Vec store = s;
            final int N = this.D.getSampleSize();
            final int P = SystemInfo.LogicalCores;
            final int subWSize = (w.length() - (LinearBatch.this.useBiasTerm ? LinearBatch.this.bs.length : 0)) / this.D.getClassSize();
            final CountDownLatch latch = new CountDownLatch(P);
            final double[] weightSums = new double[P];
            int p = 0;
            while (p < SystemInfo.LogicalCores) {
                final int ID = p++;
                ex.submit(new Runnable(){

                    /*
                     * WARNING - Removed try catching itself - possible behaviour change.
                     */
                    @Override
                    public void run() {
                        Vec temp = (Vec)GradMCFunction.this.tempVecs.get();
                        temp.zeroOut();
                        DenseVector pred = new DenseVector(GradMCFunction.this.D.getClassSize());
                        double weightSum = 0.0;
                        for (int i = ParallelUtils.getStartBlock(N, ID, P); i < ParallelUtils.getEndBlock(N, ID, P); ++i) {
                            DataPoint dp = GradMCFunction.this.D.getDataPoint(i);
                            Vec x = dp.getNumericalValues();
                            for (int k = 0; k < ((Vec)pred).length(); ++k) {
                                ((Vec)pred).set(k, new SubVector(k * subWSize, subWSize, w).dot(x));
                            }
                            if (LinearBatch.this.useBiasTerm) {
                                pred.mutableAdd(new SubVector(w.length() - LinearBatch.this.bs.length, LinearBatch.this.bs.length, w));
                            }
                            GradMCFunction.this.loss.process(pred, pred);
                            int y = GradMCFunction.this.D.getDataPointCategory(i);
                            GradMCFunction.this.loss.deriv(pred, pred, y);
                            for (IndexValue iv : pred) {
                                new SubVector(iv.getIndex() * subWSize, subWSize, temp).mutableAdd(iv.getValue() * dp.getWeight(), x);
                            }
                            weightSum += dp.getWeight();
                        }
                        Vec vec = store;
                        synchronized (vec) {
                            store.mutableAdd(temp);
                        }
                        weightSums[ID] = weightSum;
                        latch.countDown();
                    }
                });
            }
            try {
                latch.await();
            }
            catch (InterruptedException ex1) {
                Logger.getLogger(LinearBatch.class.getName()).log(Level.SEVERE, null, ex1);
            }
            double weightSum = 0.0;
            for (double ws : weightSums) {
                weightSum += ws;
            }
            s.mutableDivide(weightSum);
            if (LinearBatch.this.lambda0 > 0.0) {
                s.mutableSubtract(LinearBatch.this.lambda0, w);
            }
            return s;
        }
    }

    public class LossMCFunction
    implements FunctionP {
        private static final long serialVersionUID = -861700500356609563L;
        private final ClassificationDataSet D;
        private final LossMC loss;

        public LossMCFunction(ClassificationDataSet D, LossMC loss) {
            this.D = D;
            this.loss = loss;
        }

        @Override
        public double f(Vec w) {
            double sum = 0.0;
            DenseVector pred = new DenseVector(this.D.getClassSize());
            int subWSize = (w.length() - (LinearBatch.this.useBiasTerm ? LinearBatch.this.bs.length : 0)) / this.D.getClassSize();
            double weightSum = 0.0;
            for (int i = 0; i < this.D.getSampleSize(); ++i) {
                DataPoint dp = this.D.getDataPoint(i);
                Vec x = dp.getNumericalValues();
                for (int k = 0; k < ((Vec)pred).length(); ++k) {
                    ((Vec)pred).set(k, new SubVector(k * subWSize, subWSize, w).dot(x));
                }
                if (LinearBatch.this.useBiasTerm) {
                    pred.mutableAdd(new SubVector(w.length() - LinearBatch.this.bs.length, LinearBatch.this.bs.length, w));
                }
                this.loss.process(pred, pred);
                int y = this.D.getDataPointCategory(i);
                sum += this.loss.getLoss(pred, y) * dp.getWeight();
                weightSum += dp.getWeight();
            }
            if (LinearBatch.this.lambda0 > 0.0) {
                return sum / weightSum + LinearBatch.this.lambda0 * w.dot(w);
            }
            return sum;
        }

        @Override
        public double f(final Vec w, ExecutorService ex) {
            final int N = this.D.getSampleSize();
            final int P = SystemInfo.LogicalCores;
            final int subWSize = (w.length() - (LinearBatch.this.useBiasTerm ? LinearBatch.this.bs.length : 0)) / this.D.getClassSize();
            ArrayList partialSums = new ArrayList(P);
            final double[] weightSums = new double[P];
            int p = 0;
            while (p < SystemInfo.LogicalCores) {
                final int ID = p++;
                partialSums.add(ex.submit(new Callable<Double>(){

                    @Override
                    public Double call() throws Exception {
                        double sum = 0.0;
                        DenseVector pred = new DenseVector(LossMCFunction.this.D.getClassSize());
                        double weightSum = 0.0;
                        for (int i = ParallelUtils.getStartBlock(N, ID, P); i < ParallelUtils.getEndBlock(N, ID, P); ++i) {
                            DataPoint dp = LossMCFunction.this.D.getDataPoint(i);
                            Vec x = dp.getNumericalValues();
                            for (int k = 0; k < ((Vec)pred).length(); ++k) {
                                ((Vec)pred).set(k, new SubVector(k * subWSize, subWSize, w).dot(x));
                            }
                            if (LinearBatch.this.useBiasTerm) {
                                pred.mutableAdd(new SubVector(w.length() - LinearBatch.this.bs.length, LinearBatch.this.bs.length, w));
                            }
                            LossMCFunction.this.loss.process(pred, pred);
                            int y = LossMCFunction.this.D.getDataPointCategory(i);
                            sum += LossMCFunction.this.loss.getLoss(pred, y) * dp.getWeight();
                            weightSum += dp.getWeight();
                        }
                        weightSums[ID] = weightSum;
                        return sum;
                    }
                }));
            }
            double sum = 0.0;
            try {
                for (Double partial : ListUtils.collectFutures(partialSums)) {
                    sum += partial.doubleValue();
                }
            }
            catch (ExecutionException ex1) {
                Logger.getLogger(LinearBatch.class.getName()).log(Level.SEVERE, null, ex1);
            }
            catch (InterruptedException ex1) {
                Logger.getLogger(LinearBatch.class.getName()).log(Level.SEVERE, null, ex1);
            }
            double weightSum = 0.0;
            for (double ws : weightSums) {
                weightSum += ws;
            }
            return sum / weightSum + LinearBatch.this.lambda0 * w.dot(w);
        }

        @Override
        public double f(double ... x) {
            return this.f(DenseVector.toDenseVec(x));
        }
    }

    public class GradFunction
    implements FunctionVec {
        private final DataSet D;
        private final LossFunc loss;
        private ThreadLocal<Vec> tempVecs;

        public GradFunction(DataSet D, LossFunc loss) {
            this.D = D;
            this.loss = loss;
        }

        @Override
        public Vec f(double ... x) {
            return this.f(DenseVector.toDenseVec(x));
        }

        @Override
        public Vec f(Vec w) {
            Vec s = w.clone();
            this.f(w, s);
            return s;
        }

        @Override
        public Vec f(Vec w, Vec s) {
            if (s == null) {
                s = w.clone();
            }
            s.zeroOut();
            double weightSum = 0.0;
            for (int i = 0; i < this.D.getSampleSize(); ++i) {
                DataPoint dp = this.D.getDataPoint(i);
                Vec x = dp.getNumericalValues();
                double y = LinearBatch.getTargetY(this.D, i);
                s.mutableAdd(this.loss.getDeriv(w.dot(x), y) * dp.getWeight(), x);
                weightSum += dp.getWeight();
            }
            s.mutableDivide(weightSum);
            if (LinearBatch.this.lambda0 > 0.0) {
                s.mutableSubtract(LinearBatch.this.lambda0, w);
            }
            return s;
        }

        @Override
        public Vec f(final Vec w, Vec s, ExecutorService ex) {
            if (s == null) {
                s = w.clone();
            }
            s.zeroOut();
            if (this.tempVecs == null) {
                this.tempVecs = new ThreadLocal<Vec>(){

                    @Override
                    protected Vec initialValue() {
                        return w.clone();
                    }
                };
            }
            final Vec store = s;
            final int N = this.D.getSampleSize();
            final int P = SystemInfo.LogicalCores;
            final CountDownLatch latch = new CountDownLatch(P);
            final double[] weightSums = new double[P];
            int p = 0;
            while (p < SystemInfo.LogicalCores) {
                final int ID = p++;
                ex.submit(new Runnable(){

                    /*
                     * WARNING - Removed try catching itself - possible behaviour change.
                     */
                    @Override
                    public void run() {
                        Vec temp = (Vec)GradFunction.this.tempVecs.get();
                        temp.zeroOut();
                        double weightSum = 0.0;
                        for (int i = ParallelUtils.getStartBlock(N, ID, P); i < ParallelUtils.getEndBlock(N, ID, P); ++i) {
                            DataPoint dp = GradFunction.this.D.getDataPoint(i);
                            Vec x = dp.getNumericalValues();
                            double y = LinearBatch.getTargetY(GradFunction.this.D, i);
                            temp.mutableAdd(GradFunction.this.loss.getDeriv(w.dot(x), y) * dp.getWeight(), x);
                            weightSum += dp.getWeight();
                        }
                        Vec vec = store;
                        synchronized (vec) {
                            store.mutableAdd(temp);
                        }
                        weightSums[ID] = weightSum;
                        latch.countDown();
                    }
                });
            }
            try {
                latch.await();
            }
            catch (InterruptedException ex1) {
                Logger.getLogger(LinearBatch.class.getName()).log(Level.SEVERE, null, ex1);
            }
            double weightSum = 0.0;
            for (double ws : weightSums) {
                weightSum += ws;
            }
            s.mutableDivide(weightSum);
            if (LinearBatch.this.lambda0 > 0.0) {
                s.mutableSubtract(LinearBatch.this.lambda0, w);
            }
            return s;
        }
    }

    public class LossFunction
    implements FunctionP {
        private static final long serialVersionUID = -576682206943283356L;
        private final DataSet D;
        private final LossFunc loss;

        public LossFunction(DataSet D, LossFunc loss) {
            this.D = D;
            this.loss = loss;
        }

        @Override
        public double f(Vec w) {
            double sum = 0.0;
            double weightSum = 0.0;
            for (int i = 0; i < this.D.getSampleSize(); ++i) {
                DataPoint dp = this.D.getDataPoint(i);
                Vec x = dp.getNumericalValues();
                double y = LinearBatch.getTargetY(this.D, i);
                sum += this.loss.getLoss(w.dot(x), y) * dp.getWeight();
                weightSum += dp.getWeight();
            }
            if (LinearBatch.this.lambda0 > 0.0) {
                return sum / weightSum + LinearBatch.this.lambda0 * w.dot(w);
            }
            return sum / weightSum;
        }

        @Override
        public double f(final Vec w, ExecutorService ex) {
            final int N = this.D.getSampleSize();
            final int P = SystemInfo.LogicalCores;
            final double[] weightSums = new double[P];
            ArrayList partialSums = new ArrayList(P);
            int p = 0;
            while (p < SystemInfo.LogicalCores) {
                final int ID = p++;
                partialSums.add(ex.submit(new Callable<Double>(){

                    @Override
                    public Double call() throws Exception {
                        double sum = 0.0;
                        double weightSum = 0.0;
                        for (int i = ParallelUtils.getStartBlock(N, ID, P); i < ParallelUtils.getEndBlock(N, ID, P); ++i) {
                            DataPoint dp = LossFunction.this.D.getDataPoint(i);
                            Vec x = dp.getNumericalValues();
                            double y = LinearBatch.getTargetY(LossFunction.this.D, i);
                            sum += LossFunction.this.loss.getLoss(w.dot(x), y) * dp.getWeight();
                            weightSum += dp.getWeight();
                        }
                        weightSums[ID] = weightSum;
                        return sum;
                    }
                }));
            }
            double sum = 0.0;
            try {
                for (Double partial : ListUtils.collectFutures(partialSums)) {
                    sum += partial.doubleValue();
                }
            }
            catch (ExecutionException ex1) {
                Logger.getLogger(LinearBatch.class.getName()).log(Level.SEVERE, null, ex1);
            }
            catch (InterruptedException ex1) {
                Logger.getLogger(LinearBatch.class.getName()).log(Level.SEVERE, null, ex1);
            }
            double weightSum = 0.0;
            for (double ws : weightSums) {
                weightSum += ws;
            }
            if (LinearBatch.this.lambda0 > 0.0) {
                return sum / weightSum + LinearBatch.this.lambda0 * w.dot(w);
            }
            return sum / weightSum;
        }

        @Override
        public double f(double ... x) {
            return this.f(DenseVector.toDenseVec(x));
        }
    }

    private class VecWithBias
    extends Vec {
        public Vec w;
        public double[] b;

        public VecWithBias(Vec w, double[] b) {
            this.w = w;
            this.b = b;
        }

        @Override
        public double dot(Vec v) {
            if (v.length() == this.w.length()) {
                return this.w.dot(v) + this.b[0];
            }
            return super.dot(v);
        }

        @Override
        public void mutableAdd(double c, Vec b) {
            if (b.length() == this.w.length()) {
                this.w.mutableAdd(c, b);
                this.b[0] = this.b[0] + c;
            } else {
                super.mutableAdd(c, b);
            }
        }

        @Override
        public int length() {
            return this.w.length() + 1;
        }

        @Override
        public double get(int index) {
            if (index < this.w.length()) {
                return this.w.get(index);
            }
            if (index == this.w.length()) {
                return this.b[0];
            }
            throw new IndexOutOfBoundsException();
        }

        @Override
        public void set(int index, double val) {
            if (index < this.w.length()) {
                this.w.set(index, val);
            } else if (index == this.w.length()) {
                this.b[0] = val;
            } else {
                throw new IndexOutOfBoundsException();
            }
        }

        @Override
        public boolean isSparse() {
            return this.w.isSparse();
        }

        @Override
        public Vec clone() {
            return new VecWithBias(this.w.clone(), Arrays.copyOf(this.b, this.b.length));
        }
    }
}

