/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.trees;

import java.util.ArrayList;
import java.util.List;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.DataPoint;
import jsat.classifiers.DataPointPair;
import jsat.classifiers.trees.TreeNodeVisitor;
import jsat.math.SpecialMath;

public class TreePruner {
    private TreePruner() {
    }

    public static void prune(TreeNodeVisitor root, PruningMethod method, ClassificationDataSet testSet) {
        TreePruner.prune(root, method, testSet.getAsDPPList());
    }

    public static void prune(TreeNodeVisitor root, PruningMethod method, List<DataPointPair<Integer>> testSet) {
        if (method == PruningMethod.NONE) {
            return;
        }
        if (method == PruningMethod.REDUCED_ERROR) {
            TreePruner.pruneReduceError(null, -1, root, testSet);
        } else if (method == PruningMethod.ERROR_BASED) {
            TreePruner.pruneErrorBased(null, -1, root, testSet, 0.25);
        } else {
            throw new RuntimeException("BUG: please report");
        }
    }

    private static int pruneReduceError(TreeNodeVisitor parent, int pathFollowed, TreeNodeVisitor current, List<DataPointPair<Integer>> testSet) {
        if (current == null) {
            return 0;
        }
        int nodesPruned = 0;
        if (!current.isLeaf()) {
            int numSplits = current.childrenCount();
            ArrayList splits = new ArrayList(numSplits);
            for (int i = 0; i < numSplits; ++i) {
                splits.add(new ArrayList());
            }
            for (DataPointPair<Integer> dpp : testSet) {
                ((List)splits.get(current.getPath(dpp.getDataPoint()))).add(dpp);
            }
            for (int i = numSplits - 1; i >= 0; --i) {
                nodesPruned += TreePruner.pruneReduceError(current, i, current.getChild(i), (List)splits.get(i));
            }
        }
        if (current.isLeaf() && parent != null) {
            double childCorrect = 0.0;
            double parrentCorrect = 0.0;
            for (DataPointPair<Integer> dpp : testSet) {
                DataPoint dp = dpp.getDataPoint();
                int truth = dpp.getPair();
                if (current.localClassify(dp).mostLikely() == truth) {
                    childCorrect += dp.getWeight();
                }
                if (parent.localClassify(dp).mostLikely() != truth) continue;
                parrentCorrect += dp.getWeight();
            }
            if (parrentCorrect >= childCorrect) {
                parent.disablePath(pathFollowed);
                return nodesPruned + 1;
            }
            return nodesPruned;
        }
        return nodesPruned;
    }

    private static double pruneErrorBased(TreeNodeVisitor parent, int pathFollowed, TreeNodeVisitor current, List<DataPointPair<Integer>> testSet, double alpha) {
        double maxChildTreeScore;
        if (current == null || testSet.isEmpty()) {
            return 0.0;
        }
        if (current.isLeaf()) {
            int errors = 0;
            for (DataPointPair<Integer> dpp : testSet) {
                if (current.localClassify(dpp.getDataPoint()).mostLikely() == dpp.getPair().intValue()) continue;
                ++errors;
            }
            return TreePruner.computeBinomialUpperBound(testSet.size(), alpha, errors);
        }
        ArrayList splitSet = new ArrayList(current.childrenCount());
        for (int i = 0; i < current.childrenCount(); ++i) {
            splitSet.add(new ArrayList());
        }
        int localErrors = 0;
        double subTreeScore = 0.0;
        for (DataPointPair<Integer> dpp : testSet) {
            int path;
            DataPoint dp = dpp.getDataPoint();
            if (current.localClassify(dp).mostLikely() != dpp.getPair().intValue()) {
                ++localErrors;
            }
            if ((path = current.getPath(dp)) < 0) continue;
            ((List)splitSet.get(path)).add(dpp);
        }
        int maxChildCount = 0;
        int maxChild = -1;
        for (int path = 0; path < splitSet.size(); ++path) {
            if (current.isPathDisabled(path)) continue;
            subTreeScore += TreePruner.pruneErrorBased(current, path, current.getChild(path), (List)splitSet.get(path), alpha);
            if (maxChildCount >= ((List)splitSet.get(path)).size()) continue;
            maxChildCount = ((List)splitSet.get(path)).size();
            maxChild = path;
        }
        int N = testSet.size();
        double prunedTreeScore = TreePruner.computeBinomialUpperBound(N, alpha, localErrors);
        if (maxChild == -1) {
            maxChildTreeScore = Double.POSITIVE_INFINITY;
        } else {
            TreeNodeVisitor maxChildNode = current.getChild(maxChild);
            int otherE = 0;
            for (int path = 0; path < splitSet.size(); ++path) {
                for (DataPointPair dpp : (List)splitSet.get(path)) {
                    if (maxChildNode.classify(dpp.getDataPoint()).mostLikely() == ((Integer)dpp.getPair()).intValue()) continue;
                    ++otherE;
                }
            }
            int otherN = testSet.size();
            maxChildTreeScore = TreePruner.computeBinomialUpperBound(otherN, alpha, otherE);
        }
        if (maxChildTreeScore < prunedTreeScore && maxChildTreeScore < subTreeScore && parent != null) {
            try {
                parent.setPath(pathFollowed, current.getChild(maxChild));
                return maxChildTreeScore;
            }
            catch (UnsupportedOperationException ex) {
                // empty catch block
            }
        }
        if (prunedTreeScore < subTreeScore) {
            for (int i = 0; i < current.childrenCount(); ++i) {
                current.disablePath(i);
            }
            return prunedTreeScore;
        }
        return subTreeScore;
    }

    private static double computeBinomialUpperBound(int N, double alpha, int errors) {
        return (double)N * (1.0 - SpecialMath.invBetaIncReg(alpha, (double)(N - errors) + 1.0E-9, (double)errors + 1.0));
    }

    public static enum PruningMethod {
        NONE,
        REDUCED_ERROR,
        ERROR_BASED;

    }
}

