/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.trees;

import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import jsat.DataSet;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.DataPoint;
import jsat.classifiers.trees.ExtraTree;
import jsat.classifiers.trees.TreeNodeVisitor;
import jsat.exceptions.FailedToFitException;
import jsat.regression.RegressionDataSet;
import jsat.utils.FakeExecutor;
import jsat.utils.SystemInfo;

public class ERTrees
extends ExtraTree {
    private static final long serialVersionUID = 7139392253403373132L;
    private ExtraTree baseTree = new ExtraTree();
    private boolean useDefaultSelectionCount = true;
    private boolean useDefaultStopSize = true;
    private CategoricalData predicting;
    private ExtraTree[] forrest;
    private int forrestSize;

    public ERTrees() {
        this(100);
    }

    public ERTrees(int forrestSize) {
        this.forrestSize = forrestSize;
    }

    public void setUseDefaultSelectionCount(boolean useDefaultSelectionCount) {
        this.useDefaultSelectionCount = useDefaultSelectionCount;
    }

    public boolean getUseDefaultSelectionCount() {
        return this.useDefaultSelectionCount;
    }

    public void setUseDefaultStopSize(boolean useDefaultStopSize) {
        this.useDefaultStopSize = useDefaultStopSize;
    }

    public boolean getUseDefaultStopSize() {
        return this.useDefaultStopSize;
    }

    public void setForrestSize(int forrestSize) {
        this.forrestSize = forrestSize;
    }

    public int getForrestSize() {
        return this.forrestSize;
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        CategoricalResults cr = new CategoricalResults(this.predicting.getNumOfCategories());
        for (ExtraTree tree : this.forrest) {
            cr.incProb(tree.classify(data).mostLikely(), 1.0);
        }
        cr.normalize();
        return cr;
    }

    private void doTraining(ExecutorService threadPool, DataSet dataSet) throws FailedToFitException {
        this.forrest = new ExtraTree[this.forrestSize];
        int chunkSize = this.forrestSize / SystemInfo.LogicalCores;
        int extra = this.forrestSize % SystemInfo.LogicalCores;
        int planted = 0;
        CountDownLatch latch = new CountDownLatch(SystemInfo.LogicalCores);
        while (planted < this.forrestSize) {
            int start = planted;
            int end = start + chunkSize;
            if (extra-- > 0) {
                // empty if block
            }
            planted = ++end;
            threadPool.submit(new ForrestPlanter(start, end, dataSet, latch));
        }
        try {
            latch.await();
        }
        catch (InterruptedException ex) {
            throw new FailedToFitException(ex);
        }
    }

    @Override
    public void trainC(ClassificationDataSet dataSet, ExecutorService threadPool) {
        if (this.useDefaultSelectionCount) {
            this.baseTree.setSelectionCount((int)Math.max(Math.round(Math.sqrt(dataSet.getNumFeatures())), 1L));
        }
        if (this.useDefaultStopSize) {
            this.baseTree.setStopSize(2);
        }
        this.predicting = dataSet.getPredicting();
        this.doTraining(threadPool, dataSet);
    }

    @Override
    public void trainC(ClassificationDataSet dataSet) {
        this.trainC(dataSet, new FakeExecutor());
    }

    @Override
    public boolean supportsWeightedData() {
        return true;
    }

    @Override
    public double regress(DataPoint data) {
        double mean = 0.0;
        for (ExtraTree tree : this.forrest) {
            mean += tree.regress(data);
        }
        return mean / (double)this.forrest.length;
    }

    @Override
    public void train(RegressionDataSet dataSet, ExecutorService threadPool) {
        if (this.useDefaultSelectionCount) {
            this.baseTree.setSelectionCount(dataSet.getNumFeatures());
        }
        if (this.useDefaultStopSize) {
            this.baseTree.setStopSize(5);
        }
        this.doTraining(threadPool, dataSet);
    }

    @Override
    public void train(RegressionDataSet dataSet) {
        this.train(dataSet, new FakeExecutor());
    }

    @Override
    public ERTrees clone() {
        ERTrees clone = new ERTrees();
        clone.forrestSize = this.forrestSize;
        clone.useDefaultSelectionCount = this.useDefaultSelectionCount;
        clone.useDefaultStopSize = this.useDefaultStopSize;
        clone.baseTree = this.baseTree.clone();
        if (this.predicting != null) {
            clone.predicting = this.predicting.clone();
        }
        if (this.forrest != null) {
            clone.forrest = new ExtraTree[this.forrest.length];
            for (int i = 0; i < this.forrest.length; ++i) {
                clone.forrest[i] = this.forrest[i].clone();
            }
        }
        return clone;
    }

    @Override
    public TreeNodeVisitor getTreeNodeVisitor() {
        throw new UnsupportedOperationException("Can not get the tree node vistor becase ERTrees is really a ensemble");
    }

    private class ForrestPlanter
    implements Runnable {
        int start;
        int end;
        DataSet dataSet;
        CountDownLatch latch;

        public ForrestPlanter(int start, int end, DataSet dataSet, CountDownLatch latch) {
            this.start = start;
            this.end = end;
            this.dataSet = dataSet;
            this.latch = latch;
        }

        @Override
        public void run() {
            if (this.dataSet instanceof ClassificationDataSet) {
                ClassificationDataSet cds = (ClassificationDataSet)this.dataSet;
                for (int i = this.start; i < this.end; ++i) {
                    ((ERTrees)ERTrees.this).forrest[i] = ERTrees.this.baseTree.clone();
                    ERTrees.this.forrest[i].trainC(cds);
                }
            } else if (this.dataSet instanceof RegressionDataSet) {
                RegressionDataSet rds = (RegressionDataSet)this.dataSet;
                for (int i = this.start; i < this.end; ++i) {
                    ((ERTrees)ERTrees.this).forrest[i] = ERTrees.this.baseTree.clone();
                    ERTrees.this.forrest[i].train(rds);
                }
            } else {
                throw new RuntimeException("BUG: Please report");
            }
            this.latch.countDown();
        }
    }
}

