/*
 * Decompiled with CFR 0.152.
 */
package hex.genmodel.algos.drf;

import hex.ModelCategory;
import hex.genmodel.GenModel;
import hex.genmodel.PredictContributions;
import hex.genmodel.algos.tree.SharedTreeGraphConverter;
import hex.genmodel.algos.tree.SharedTreeMojoModelWithContributions;
import hex.genmodel.algos.tree.TreeSHAPPredictor;

public final class DrfMojoModel
extends SharedTreeMojoModelWithContributions
implements SharedTreeGraphConverter {
    protected boolean _binomial_double_trees;

    public DrfMojoModel(String[] columns, String[][] domains, String responseColumn) {
        super(columns, domains, responseColumn);
    }

    @Override
    protected PredictContributions getContributionsPredictor(TreeSHAPPredictor<double[]> treeSHAPPredictor) {
        return new ContributionsPredictorDRF(this, treeSHAPPredictor);
    }

    @Override
    public final double[] score0(double[] row, double offset, double[] preds) {
        super.scoreAllTrees(row, preds);
        return this.unifyPreds(row, offset, preds);
    }

    @Override
    public final double[] unifyPreds(double[] row, double offset, double[] preds) {
        if (this._nclasses == 1) {
            preds[0] = preds[0] / (double)this._ntree_groups;
        } else {
            if (this._nclasses == 2 && !this._binomial_double_trees) {
                preds[1] = preds[1] / (double)this._ntree_groups;
                preds[2] = 1.0 - preds[1];
            } else {
                int i;
                double sum = 0.0;
                for (i = 1; i <= this._nclasses; ++i) {
                    sum += preds[i];
                }
                if (sum > 0.0) {
                    i = 1;
                    while (i <= this._nclasses) {
                        int n = i++;
                        preds[n] = preds[n] / sum;
                    }
                }
            }
            if (this._balanceClasses) {
                GenModel.correctProbabilities(preds, this._priorClassDistrib, this._modelClassDistrib);
            }
            preds[0] = GenModel.getPrediction(preds, this._priorClassDistrib, row, this._defaultThreshold);
        }
        return preds;
    }

    @Override
    public double[] score0(double[] row, double[] preds) {
        return this.score0(row, 0.0, preds);
    }

    public boolean isBinomialDoubleTrees() {
        return this._binomial_double_trees;
    }

    static class ContributionsPredictorDRF
    extends SharedTreeMojoModelWithContributions.SharedTreeContributionsPredictor {
        private final float _featurePlusBiasRatio;
        private final int _normalizer;

        private ContributionsPredictorDRF(DrfMojoModel model, TreeSHAPPredictor<double[]> treeSHAPPredictor) {
            super(model, treeSHAPPredictor);
            if (model._binomial_double_trees) {
                throw new UnsupportedOperationException("Calculating contributions is currently not supported for model with binomial_double_trees parameter set.");
            }
            if (ModelCategory.Regression.equals((Object)model._category)) {
                this._featurePlusBiasRatio = 0.0f;
                this._normalizer = model._ntree_groups;
            } else if (ModelCategory.Binomial.equals((Object)model._category)) {
                this._featurePlusBiasRatio = 1.0f / (float)(model._nfeatures + 1);
                this._normalizer = -model._ntree_groups;
            } else {
                throw new UnsupportedOperationException("Model category " + (Object)((Object)model._category) + " cannot be used to calculate feature contributions.");
            }
        }

        @Override
        public float[] getContribs(float[] contribs) {
            for (int i = 0; i < contribs.length; ++i) {
                contribs[i] = this._featurePlusBiasRatio + contribs[i] / (float)this._normalizer;
            }
            return contribs;
        }
    }
}

