/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.trees;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.DataSet;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.boosting.Bagging;
import jsat.classifiers.trees.DecisionStump;
import jsat.classifiers.trees.DecisionTree;
import jsat.classifiers.trees.ImpurityScore;
import jsat.classifiers.trees.RandomDecisionTree;
import jsat.classifiers.trees.TreePruner;
import jsat.math.OnLineStatistics;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.regression.RegressionDataSet;
import jsat.regression.Regressor;
import jsat.utils.FakeExecutor;
import jsat.utils.IntSet;
import jsat.utils.ListUtils;
import jsat.utils.SystemInfo;
import jsat.utils.concurrent.AtomicDoubleArray;

public class RandomForest
implements Classifier,
Regressor,
Parameterized {
    private static final long serialVersionUID = 2725020584282958141L;
    private CategoricalData predicting;
    private int extraSamples;
    private int featureSamples;
    private int maxForestSize;
    private boolean useOutOfBagError = false;
    private double outOfBagError;
    private RandomDecisionTree baseLearner;
    private List<DecisionTree> forest;

    public RandomForest() {
        this(100);
    }

    public RandomForest(int maxForestSize) {
        this.setExtraSamples(0);
        this.setMaxForestSize(maxForestSize);
        this.autoFeatureSample();
        this.baseLearner = new RandomDecisionTree(1, Integer.MAX_VALUE, 3, TreePruner.PruningMethod.NONE, 1.0E-15);
        this.baseLearner.setNumericHandling(DecisionStump.NumericHandlingC.BINARY_BEST_GAIN);
        this.baseLearner.setGainMethod(ImpurityScore.ImpurityMeasure.GINI);
    }

    public void setExtraSamples(int i) {
        this.extraSamples = i;
    }

    public int getExtraSamples() {
        return this.extraSamples;
    }

    public void setFeatureSamples(int featureSamples) {
        if (featureSamples <= 0) {
            throw new ArithmeticException("A positive number of features must be given");
        }
        this.featureSamples = featureSamples;
    }

    public void autoFeatureSample() {
        this.featureSamples = -1;
    }

    public boolean isAutoFeatureSample() {
        return this.featureSamples == -1;
    }

    public void setMaxForestSize(int maxForestSize) {
        if (maxForestSize <= 0) {
            throw new ArithmeticException("Must train a positive number of learners");
        }
        this.maxForestSize = maxForestSize;
    }

    public int getMaxForestSize() {
        return this.maxForestSize;
    }

    public void setUseOutOfBagError(boolean useOutOfBagError) {
        this.useOutOfBagError = useOutOfBagError;
    }

    public boolean isUseOutOfBagError() {
        return this.useOutOfBagError;
    }

    public double getOutOfBagError() {
        return this.outOfBagError;
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        if (this.forest == null || this.forest.isEmpty()) {
            throw new RuntimeException("Classifier has not yet been trained");
        }
        if (this.predicting == null) {
            throw new RuntimeException("Classifier has been trained for regression");
        }
        CategoricalResults totalResult = new CategoricalResults(this.predicting.getNumOfCategories());
        for (DecisionTree tree : this.forest) {
            totalResult.incProb(tree.classify(data).mostLikely(), 1.0);
        }
        totalResult.normalize();
        return totalResult;
    }

    @Override
    public void trainC(ClassificationDataSet dataSet, ExecutorService threadPool) {
        this.predicting = dataSet.getPredicting();
        this.forest = new ArrayList<DecisionTree>(this.maxForestSize);
        this.trainStep(dataSet, threadPool);
    }

    @Override
    public void trainC(ClassificationDataSet dataSet) {
        this.trainC(dataSet, new FakeExecutor());
    }

    @Override
    public boolean supportsWeightedData() {
        return true;
    }

    @Override
    public double regress(DataPoint data) {
        if (this.forest == null || this.forest.isEmpty()) {
            throw new RuntimeException("Classifier has not yet been trained");
        }
        if (this.predicting != null) {
            throw new RuntimeException("Classifier has been trained for classification");
        }
        OnLineStatistics stats = new OnLineStatistics();
        for (DecisionTree tree : this.forest) {
            stats.add(tree.regress(data));
        }
        return stats.getMean();
    }

    @Override
    public void train(RegressionDataSet dataSet, ExecutorService threadPool) {
        this.predicting = null;
        this.forest = new ArrayList<DecisionTree>(this.maxForestSize);
        this.trainStep(dataSet, threadPool);
    }

    @Override
    public void train(RegressionDataSet dataSet) {
        this.train(dataSet, new FakeExecutor());
    }

    private void trainStep(DataSet dataSet, ExecutorService threadPool) {
        boolean autoLearners = this.isAutoFeatureSample();
        if (autoLearners) {
            this.baseLearner.setRandomFeatureCount(Math.max((int)Math.sqrt(dataSet.getNumFeatures()), 1));
        } else {
            this.baseLearner.setRandomFeatureCount(this.featureSamples);
        }
        int roundsToDistribut = this.maxForestSize;
        int roundShare = roundsToDistribut / SystemInfo.LogicalCores;
        int extraRounds = roundsToDistribut % SystemInfo.LogicalCores;
        if (threadPool == null || threadPool instanceof FakeExecutor) {
            roundShare = roundsToDistribut;
        }
        Random rand = new Random();
        ArrayList futures = new ArrayList(SystemInfo.LogicalCores);
        int[][] counts = null;
        AtomicDoubleArray pred = null;
        if (dataSet instanceof RegressionDataSet) {
            pred = new AtomicDoubleArray(dataSet.getSampleSize());
            counts = new int[pred.length()][1];
        } else {
            counts = new int[dataSet.getSampleSize()][((ClassificationDataSet)dataSet).getClassSize()];
        }
        while (roundsToDistribut > 0) {
            int extra = extraRounds-- > 0 ? 1 : 0;
            Future<LearningWorker> future = threadPool.submit(new LearningWorker(dataSet, roundShare + extra, new Random(rand.nextInt()), counts, pred));
            roundsToDistribut -= roundShare + extra;
            futures.add(future);
        }
        this.outOfBagError = 0.0;
        try {
            for (LearningWorker worker : ListUtils.collectFutures(futures)) {
                this.forest.addAll(worker.learned);
            }
        }
        catch (Exception ex) {
            Logger.getLogger(RandomForest.class.getName()).log(Level.SEVERE, null, ex);
        }
        if (this.useOutOfBagError) {
            if (dataSet instanceof ClassificationDataSet) {
                ClassificationDataSet cds = (ClassificationDataSet)dataSet;
                for (int i = 0; i < counts.length; ++i) {
                    int max = 0;
                    for (int j = 1; j < counts[i].length; ++j) {
                        if (counts[i][j] <= counts[i][max]) continue;
                        max = j;
                    }
                    if (max == cds.getDataPointCategory(i)) continue;
                    this.outOfBagError += 1.0;
                }
            } else {
                RegressionDataSet rds = (RegressionDataSet)dataSet;
                for (int i = 0; i < counts.length; ++i) {
                    this.outOfBagError += Math.pow(pred.get(i) / (double)counts[i][0] - rds.getTargetValue(i), 2.0);
                }
            }
            this.outOfBagError /= (double)dataSet.getSampleSize();
        }
    }

    @Override
    public RandomForest clone() {
        RandomForest clone = new RandomForest(this.maxForestSize);
        clone.extraSamples = this.extraSamples;
        clone.featureSamples = this.featureSamples;
        if (this.predicting != null) {
            clone.predicting = this.predicting.clone();
        }
        if (this.forest != null) {
            clone.forest = new ArrayList<DecisionTree>(this.forest.size());
            for (DecisionTree tree : this.forest) {
                clone.forest.add(tree.clone());
            }
        }
        clone.baseLearner = this.baseLearner.clone();
        return clone;
    }

    @Override
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override
    public Parameter getParameter(String paramName) {
        return Parameter.toParameterMap(this.getParameters()).get(paramName);
    }

    private class LearningWorker
    implements Callable<LearningWorker> {
        int toLearn;
        List<DecisionTree> learned;
        DataSet dataSet;
        Random random;
        private AtomicDoubleArray votes;
        private int[][] counts;

        public LearningWorker(DataSet dataSet, int toLearn, Random random, int[][] counts, AtomicDoubleArray pred) {
            this.dataSet = dataSet;
            this.toLearn = toLearn;
            this.random = random;
            this.learned = new ArrayList<DecisionTree>(toLearn);
            if (RandomForest.this.useOutOfBagError) {
                this.votes = pred;
                this.counts = counts;
            }
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        @Override
        public LearningWorker call() throws Exception {
            IntSet features = new IntSet(RandomForest.this.baseLearner.getRandomFeatureCount());
            int[] sampleCounts = new int[this.dataSet.getSampleSize()];
            for (int i = 0; i < this.toLearn; ++i) {
                Bagging.sampleWithReplacement(sampleCounts, sampleCounts.length + RandomForest.this.extraSamples, this.random);
                features.clear();
                while (features.size() < Math.min(RandomForest.this.baseLearner.getRandomFeatureCount(), this.dataSet.getNumFeatures())) {
                    features.add(this.random.nextInt(this.dataSet.getNumFeatures()));
                }
                RandomDecisionTree learner = RandomForest.this.baseLearner.clone();
                if (this.dataSet instanceof ClassificationDataSet) {
                    learner.trainC(Bagging.getWeightSampledDataSet((ClassificationDataSet)this.dataSet, sampleCounts), features);
                } else {
                    learner.train(Bagging.getWeightSampledDataSet((RegressionDataSet)this.dataSet, sampleCounts), features);
                }
                this.learned.add(learner);
                if (!RandomForest.this.useOutOfBagError) continue;
                for (int j = 0; j < sampleCounts.length; ++j) {
                    if (sampleCounts[j] != 0) continue;
                    DataPoint dp = this.dataSet.getDataPoint(j);
                    if (this.dataSet instanceof ClassificationDataSet) {
                        int pred = learner.classify(dp).mostLikely();
                        int[] nArray = this.counts[j];
                        synchronized (nArray) {
                            int[] nArray2 = this.counts[j];
                            int n = pred;
                            nArray2[n] = nArray2[n] + 1;
                            continue;
                        }
                    }
                    this.votes.getAndAdd(j, learner.regress(dp));
                    int[] nArray = this.counts[j];
                    synchronized (nArray) {
                        int[] nArray3 = this.counts[j];
                        nArray3[0] = nArray3[0] + 1;
                        continue;
                    }
                }
            }
            return this;
        }
    }
}

