/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.neuralnetwork;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import jsat.SingleWeightVectorModel;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.DataPoint;
import jsat.classifiers.DataPointPair;
import jsat.classifiers.calibration.BinaryScoreClassifier;
import jsat.exceptions.FailedToFitException;
import jsat.linear.DenseVector;
import jsat.linear.Vec;
import jsat.utils.PairedReturn;

public class Perceptron
implements BinaryScoreClassifier,
SingleWeightVectorModel {
    private static final long serialVersionUID = -3605237847981632021L;
    private double learningRate;
    private double bias;
    private Vec weights;
    private int iteratinLimit;

    public Perceptron() {
        this(0.1, 400);
    }

    public Perceptron(double learningRate, int iteratinLimit) {
        if (learningRate <= 0.0 || learningRate > 1.0) {
            throw new RuntimeException("Preceptron learning rate must be in the range (0,1]");
        }
        this.learningRate = learningRate;
        this.iteratinLimit = iteratinLimit;
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        CategoricalResults cr = new CategoricalResults(2);
        cr.setProb(this.output(data), 1.0);
        return cr;
    }

    @Override
    public double getScore(DataPoint dp) {
        return this.weights.dot(dp.getNumericalValues()) + this.bias;
    }

    /*
     * WARNING - void declaration
     */
    @Override
    public void trainC(ClassificationDataSet dataSet, ExecutorService threadPool) {
        double globalError;
        if (dataSet.getClassSize() != 2) {
            throw new FailedToFitException("Preceptron only supports binary calssification");
        }
        if (dataSet.getNumCategoricalVars() != 0) {
            throw new FailedToFitException("Preceptron only supports vector classification");
        }
        List<DataPointPair<Integer>> dataPoints = dataSet.getAsDPPList();
        Collections.shuffle(dataPoints);
        int partions = Runtime.getRuntime().availableProcessors();
        Random r = new Random();
        int numerVars = dataSet.getNumNumericalVars();
        this.weights = new DenseVector(numerVars);
        for (int i = 0; i < this.weights.length(); ++i) {
            this.weights.set(i, r.nextDouble());
        }
        Vec bestWeightsSoFar = null;
        double lowestErrorSoFar = Double.MAX_VALUE;
        int iterations = 0;
        this.bias = 0.0;
        do {
            globalError = 0.0;
            DenseVector sumedErrors = new DenseVector(this.weights.length());
            double biasChange = 0.0;
            ArrayList<Future<PairedReturn<Vec, Double[]>>> futures = new ArrayList<Future<PairedReturn<Vec, Double[]>>>(partions);
            int blockSize = dataPoints.size() / partions;
            for (int i = 0; i < partions; ++i) {
                void var19_21;
                if (i == partions - 1) {
                    List<DataPointPair<Integer>> list = dataPoints.subList(i * blockSize, dataPoints.size());
                } else {
                    List<DataPointPair<Integer>> list = dataPoints.subList(i * blockSize, (i + 1) * blockSize);
                }
                futures.add(threadPool.submit(new BatchTrainingUnit((List<DataPointPair<Integer>>)var19_21)));
            }
            for (Future future : futures) {
                try {
                    PairedReturn partialResult = (PairedReturn)future.get();
                    sumedErrors.mutableAdd((Vec)partialResult.getFirstItem());
                    biasChange += ((Double[])partialResult.getSecondItem())[0].doubleValue();
                    globalError += ((Double[])partialResult.getSecondItem())[1].doubleValue();
                }
                catch (InterruptedException ex) {
                }
                catch (ExecutionException ex) {}
            }
            if (globalError < lowestErrorSoFar) {
                bestWeightsSoFar = this.weights;
                lowestErrorSoFar = globalError;
            }
            this.bias += biasChange;
            this.weights.mutableAdd(sumedErrors);
        } while (globalError > 0.0 && ++iterations < this.iteratinLimit);
        this.weights = bestWeightsSoFar;
    }

    @Override
    public void trainC(ClassificationDataSet dataSet) {
        this.trainCOnline(dataSet);
    }

    public void trainCOnline(ClassificationDataSet dataSet) {
        double globalError;
        if (dataSet.getClassSize() != 2) {
            throw new FailedToFitException("Preceptron only supports binary calssification");
        }
        if (dataSet.getNumCategoricalVars() != 0) {
            throw new FailedToFitException("Preceptron only supports vector classification");
        }
        List<DataPointPair<Integer>> dataPoints = dataSet.getAsDPPList();
        Collections.shuffle(dataPoints);
        Random r = new Random();
        int numerVars = dataSet.getNumNumericalVars();
        this.weights = new DenseVector(numerVars);
        for (int i = 0; i < this.weights.length(); ++i) {
            this.weights.set(i, r.nextDouble());
        }
        Vec bestWeightsSoFar = null;
        double lowestErrorSoFar = Double.MAX_VALUE;
        int iterations = 0;
        do {
            globalError = 0.0;
            for (DataPointPair<Integer> dpp : dataPoints) {
                int output = this.output(dpp.getDataPoint());
                double localError = dpp.getPair() - output;
                if (localError == 0.0) continue;
                double extraWeight = dpp.getDataPoint().getWeight();
                double magnitude = this.learningRate * localError * extraWeight;
                this.weights.mutableAdd(magnitude, dpp.getVector());
                this.bias += magnitude;
                globalError += Math.abs(localError) * extraWeight;
            }
            if (!(globalError < lowestErrorSoFar)) continue;
            bestWeightsSoFar = this.weights;
            lowestErrorSoFar = globalError;
        } while (globalError > 0.0 && ++iterations < this.iteratinLimit);
        this.weights = bestWeightsSoFar;
    }

    private int output(DataPoint input) {
        double dot = this.getScore(input);
        return dot >= 0.0 ? 1 : 0;
    }

    @Override
    public boolean supportsWeightedData() {
        return true;
    }

    @Override
    public Vec getRawWeight() {
        return this.weights;
    }

    @Override
    public double getBias() {
        return this.bias;
    }

    @Override
    public Vec getRawWeight(int index) {
        if (index < 1) {
            return this.getRawWeight();
        }
        throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override
    public double getBias(int index) {
        if (index < 1) {
            return this.getBias();
        }
        throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override
    public int numWeightsVecs() {
        return 1;
    }

    @Override
    public Perceptron clone() {
        Perceptron copy = new Perceptron(this.learningRate, this.iteratinLimit);
        if (this.weights != null) {
            copy.weights = this.weights.clone();
        }
        copy.bias = this.bias;
        return copy;
    }

    private class BatchTrainingUnit
    implements Callable<PairedReturn<Vec, Double[]>> {
        private Vec tmpSummedErrors;
        private double biasChange;
        private double globalError;
        List<DataPointPair<Integer>> dataPoints;

        public BatchTrainingUnit(List<DataPointPair<Integer>> toOperateOn) {
            this.tmpSummedErrors = new DenseVector(Perceptron.this.weights.length());
            this.dataPoints = toOperateOn;
            this.globalError = 0.0;
            this.biasChange = 0.0;
        }

        @Override
        public PairedReturn<Vec, Double[]> call() throws Exception {
            for (DataPointPair<Integer> dpp : this.dataPoints) {
                int output = Perceptron.this.output(dpp.getDataPoint());
                double localError = dpp.getPair() - output;
                if (localError == 0.0) continue;
                double extraWeight = dpp.getDataPoint().getWeight();
                double magnitude = Perceptron.this.learningRate * localError * extraWeight;
                this.tmpSummedErrors.mutableAdd(magnitude, dpp.getVector());
                this.biasChange += magnitude;
                this.globalError += Math.abs(localError) * extraWeight;
            }
            return new PairedReturn<Vec, Double[]>(this.tmpSummedErrors, new Double[]{this.biasChange, this.globalError});
        }
    }
}

