/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.trees;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.DataPointPair;
import jsat.classifiers.trees.DecisionStump;
import jsat.classifiers.trees.ImpurityScore;
import jsat.classifiers.trees.TreeLearner;
import jsat.classifiers.trees.TreeNodeVisitor;
import jsat.classifiers.trees.TreePruner;
import jsat.exceptions.FailedToFitException;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.regression.RegressionDataSet;
import jsat.regression.Regressor;
import jsat.utils.FakeExecutor;
import jsat.utils.IntSet;
import jsat.utils.ModifiableCountDownLatch;

public class DecisionTree
implements Classifier,
Regressor,
Parameterized,
TreeLearner {
    private static final long serialVersionUID = 9220980056440500214L;
    private int maxDepth;
    private int minSamples;
    private Node root;
    private CategoricalData predicting;
    private TreePruner.PruningMethod pruningMethod;
    private double testProportion;
    private DecisionStump baseStump = new DecisionStump();
    private List<Parameter> params = new ArrayList<Parameter>(Parameter.getParamsFromMethods(this));
    private Map<String, Parameter> paramMap = Parameter.toParameterMap(this.params);

    @Override
    public double regress(DataPoint data) {
        return this.root.regress(data);
    }

    @Override
    public void train(RegressionDataSet dataSet, ExecutorService threadPool) {
        IntSet options = new IntSet(dataSet.getNumFeatures());
        for (int i = 0; i < dataSet.getNumFeatures(); ++i) {
            options.add(i);
        }
        this.train(dataSet, options, threadPool);
    }

    public void train(RegressionDataSet dataSet, Set<Integer> options) {
        this.train(dataSet, options, new FakeExecutor());
    }

    public void train(RegressionDataSet dataSet, Set<Integer> options, ExecutorService threadPool) {
        ModifiableCountDownLatch mcdl = new ModifiableCountDownLatch(1);
        this.root = this.makeNodeR(dataSet.getDPPList(), options, 0, threadPool, mcdl);
        try {
            mcdl.await();
        }
        catch (InterruptedException ex) {
            Logger.getLogger(DecisionTree.class.getName()).log(Level.SEVERE, null, ex);
        }
    }

    @Override
    public void train(RegressionDataSet dataSet) {
        this.train(dataSet, new FakeExecutor());
    }

    public DecisionTree() {
        this(Integer.MAX_VALUE, 10, TreePruner.PruningMethod.REDUCED_ERROR, 0.1);
        this.baseStump.setNumericHandling(DecisionStump.NumericHandlingC.BINARY_BEST_GAIN);
    }

    public DecisionTree(int maxDepth) {
        this(maxDepth, 10, TreePruner.PruningMethod.NONE, 1.0E-5);
        this.baseStump.setNumericHandling(DecisionStump.NumericHandlingC.BINARY_BEST_GAIN);
    }

    public DecisionTree(int maxDepth, int minSamples, TreePruner.PruningMethod pruningMethod, double testProportion) {
        this.setMaxDepth(maxDepth);
        this.setMinSamples(minSamples);
        this.setPruningMethod(pruningMethod);
        this.setTestProportion(testProportion);
    }

    protected DecisionTree(DecisionTree toCopy) {
        this.maxDepth = toCopy.maxDepth;
        this.minSamples = toCopy.minSamples;
        if (toCopy.root != null) {
            this.root = toCopy.root.clone();
        }
        if (toCopy.predicting != null) {
            this.predicting = toCopy.predicting.clone();
        }
        this.pruningMethod = toCopy.pruningMethod;
        this.testProportion = toCopy.testProportion;
        this.baseStump = toCopy.baseStump.clone();
    }

    public static DecisionTree getC45Tree() {
        DecisionTree tree = new DecisionTree();
        tree.setMinResultSplitSize(2);
        tree.setMinSamples(3);
        tree.setMinResultSplitSize(2);
        tree.setTestProportion(1.0);
        tree.setPruningMethod(TreePruner.PruningMethod.ERROR_BASED);
        tree.baseStump.setGainMethod(ImpurityScore.ImpurityMeasure.INFORMATION_GAIN_RATIO);
        tree.baseStump.setNumericHandling(DecisionStump.NumericHandlingC.BINARY_BEST_GAIN);
        return tree;
    }

    public void setNumericHandling(DecisionStump.NumericHandlingC handling) {
        this.baseStump.setNumericHandling(handling);
    }

    public DecisionStump.NumericHandlingC getNumericHandling() {
        return this.baseStump.getNumericHandling();
    }

    public void setGainMethod(ImpurityScore.ImpurityMeasure gainMethod) {
        this.baseStump.setGainMethod(gainMethod);
    }

    public ImpurityScore.ImpurityMeasure getGainMethod() {
        return this.baseStump.getGainMethod();
    }

    public void setMinResultSplitSize(int size) {
        this.baseStump.setMinResultSplitSize(size);
    }

    public int getMinResultSplitSize() {
        return this.baseStump.getMinResultSplitSize();
    }

    public void setMaxDepth(int maxDepth) {
        if (maxDepth < 0) {
            throw new RuntimeException("The maximum depth must be a positive number");
        }
        this.maxDepth = maxDepth;
    }

    public int getMaxDepth() {
        return this.maxDepth;
    }

    public void setMinSamples(int minSamples) {
        this.minSamples = minSamples;
    }

    public int getMinSamples() {
        return this.minSamples;
    }

    public void setPruningMethod(TreePruner.PruningMethod pruningMethod) {
        this.pruningMethod = pruningMethod;
    }

    public TreePruner.PruningMethod getPruningMethod() {
        return this.pruningMethod;
    }

    public double getTestProportion() {
        return this.testProportion;
    }

    public void setTestProportion(double testProportion) {
        if (testProportion < 0.0 || testProportion > 1.0 || Double.isInfinite(testProportion) || Double.isNaN(testProportion)) {
            throw new ArithmeticException("Proportion must be in the range [0, 1], not " + testProportion);
        }
        this.testProportion = testProportion;
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        return this.root.classify(data);
    }

    @Override
    public void trainC(ClassificationDataSet dataSet, ExecutorService threadPool) {
        IntSet options = new IntSet(dataSet.getNumFeatures());
        for (int i = 0; i < dataSet.getNumFeatures(); ++i) {
            options.add(i);
        }
        this.trainC(dataSet, options, threadPool);
    }

    protected void trainC(ClassificationDataSet dataSet, Set<Integer> options, ExecutorService threadPool) {
        if (dataSet.getSampleSize() < this.minSamples) {
            throw new FailedToFitException("There are only " + dataSet.getSampleSize() + " data points in the sample set, at least " + this.minSamples + " are needed to make a tree");
        }
        this.predicting = dataSet.getPredicting();
        ModifiableCountDownLatch mcdl = new ModifiableCountDownLatch(1);
        List<DataPointPair<Integer>> dataPoints = dataSet.getAsDPPList();
        ArrayList<DataPointPair<Integer>> testPoints = new ArrayList<DataPointPair<Integer>>();
        if (this.pruningMethod != TreePruner.PruningMethod.NONE && this.testProportion != 0.0) {
            if (this.testProportion != 1.0) {
                int testSize = (int)((double)dataPoints.size() * this.testProportion);
                Random rand = new Random(testSize);
                for (int i = 0; i < testSize; ++i) {
                    testPoints.add(dataPoints.remove(rand.nextInt(dataPoints.size())));
                }
            } else {
                testPoints.addAll(dataPoints);
            }
        }
        this.root = this.makeNodeC(dataPoints, options, 0, threadPool, mcdl);
        try {
            mcdl.await();
        }
        catch (InterruptedException ex) {
            System.err.println(ex.getMessage());
            Logger.getLogger(DecisionTree.class.getName()).log(Level.SEVERE, null, ex);
        }
        TreePruner.prune((TreeNodeVisitor)this.root, this.pruningMethod, testPoints);
    }

    protected Node makeNodeC(List<DataPointPair<Integer>> dataPoints, final Set<Integer> options, final int depth, final ExecutorService threadPool, final ModifiableCountDownLatch mcdl) {
        if (depth > this.maxDepth || options.isEmpty() || dataPoints.size() < this.minSamples || dataPoints.isEmpty()) {
            mcdl.countDown();
            return null;
        }
        DecisionStump stump = this.baseStump.clone();
        stump.setPredicting(this.predicting);
        List<List<DataPointPair<Integer>>> splits = stump.trainC(dataPoints, options);
        final Node node = new Node(stump);
        if (stump.getNumberOfPaths() > 1) {
            for (int i = 0; i < node.paths.length; ++i) {
                final int ii = i;
                final List<DataPointPair<Integer>> splitI = splits.get(i);
                mcdl.countUp();
                threadPool.submit(new Runnable(){

                    @Override
                    public void run() {
                        node.paths[ii] = DecisionTree.this.makeNodeC(splitI, new IntSet(options), depth + 1, threadPool, mcdl);
                    }
                });
            }
        }
        mcdl.countDown();
        return node;
    }

    protected Node makeNodeR(List<DataPointPair<Double>> dataPoints, final Set<Integer> options, final int depth, final ExecutorService threadPool, final ModifiableCountDownLatch mcdl) {
        if (depth > this.maxDepth || options.isEmpty() || dataPoints.size() < this.minSamples || dataPoints.isEmpty()) {
            mcdl.countDown();
            return null;
        }
        DecisionStump stump = this.baseStump.clone();
        List<List<DataPointPair<Double>>> splits = stump.trainR(dataPoints, options);
        if (splits == null) {
            mcdl.countDown();
            return null;
        }
        final Node node = new Node(stump);
        if (stump.getNumberOfPaths() > 1) {
            for (int i = 0; i < node.paths.length; ++i) {
                final int ii = i;
                final List<DataPointPair<Double>> splitI = splits.get(i);
                mcdl.countUp();
                threadPool.submit(new Runnable(){

                    @Override
                    public void run() {
                        node.paths[ii] = DecisionTree.this.makeNodeR(splitI, new IntSet(options), depth + 1, threadPool, mcdl);
                    }
                });
            }
        }
        mcdl.countDown();
        return node;
    }

    @Override
    public void trainC(ClassificationDataSet dataSet) {
        this.trainC(dataSet, new FakeExecutor());
    }

    public void trainC(ClassificationDataSet dataSet, Set<Integer> options) {
        this.trainC(dataSet, options, new FakeExecutor());
    }

    @Override
    public boolean supportsWeightedData() {
        return true;
    }

    @Override
    public DecisionTree clone() {
        DecisionTree copy = new DecisionTree(this.maxDepth, this.minSamples, this.pruningMethod, this.testProportion);
        if (this.predicting != null) {
            copy.predicting = this.predicting.clone();
        }
        if (this.root != null) {
            copy.root = this.root.clone();
        }
        copy.baseStump = this.baseStump.clone();
        return copy;
    }

    @Override
    public TreeNodeVisitor getTreeNodeVisitor() {
        return this.root;
    }

    @Override
    public List<Parameter> getParameters() {
        ArrayList<Parameter> toRet = new ArrayList<Parameter>(this.params);
        for (Parameter param : this.baseStump.getParameters()) {
            if (param.getName().contains("Gain Method") || param.getName().contains("Numeric Handling")) continue;
            toRet.add(param);
        }
        return Collections.unmodifiableList(toRet);
    }

    @Override
    public Parameter getParameter(String paramName) {
        return this.paramMap.get(paramName);
    }

    protected static class Node
    extends TreeNodeVisitor {
        private static final long serialVersionUID = -7507748424627088734L;
        protected final DecisionStump stump;
        protected Node[] paths;

        public Node(DecisionStump stump) {
            this.stump = stump;
            this.paths = new Node[stump.getNumberOfPaths()];
        }

        @Override
        public boolean isLeaf() {
            if (this.paths == null) {
                return true;
            }
            for (int i = 0; i < this.paths.length; ++i) {
                if (this.paths[i] == null) continue;
                return false;
            }
            return true;
        }

        @Override
        public int childrenCount() {
            return this.paths.length;
        }

        @Override
        public CategoricalResults localClassify(DataPoint dp) {
            return this.stump.classify(dp);
        }

        @Override
        public double localRegress(DataPoint dp) {
            return this.stump.regress(dp);
        }

        @Override
        public Node clone() {
            Node copy = new Node(this.stump.clone());
            for (int i = 0; i < this.paths.length; ++i) {
                copy.paths[i] = this.paths[i] == null ? null : this.paths[i].clone();
            }
            return copy;
        }

        @Override
        public TreeNodeVisitor getChild(int child) {
            if (this.isLeaf()) {
                return null;
            }
            return this.paths[child];
        }

        @Override
        public void setPath(int child, TreeNodeVisitor node) {
            if (node instanceof Node) {
                this.paths[child] = (Node)node;
            } else {
                super.setPath(child, node);
            }
        }

        @Override
        public void disablePath(int child) {
            this.paths[child] = null;
        }

        @Override
        public int getPath(DataPoint dp) {
            return this.stump.whichPath(dp);
        }

        @Override
        public boolean isPathDisabled(int child) {
            if (this.isLeaf()) {
                return true;
            }
            return this.paths[child] == null;
        }
    }
}

