/*
 * Decompiled with CFR 0.152.
 */
package sklearn.dummy;

import com.google.common.primitives.Doubles;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.PMMLObject;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.True;
import org.dmg.pmml.tree.ClassifierNode;
import org.dmg.pmml.tree.Node;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ScoreDistributionManager;
import org.jpmml.python.ClassDictUtil;
import sklearn.Classifier;
import sklearn.HasPriorProbability;

public class DummyClassifier
extends Classifier
implements HasPriorProbability {
    public DummyClassifier(String module, String name) {
        super(module, name);
    }

    @Override
    public Number getPriorProbability(int index) {
        List<?> classes = this.getClasses();
        List<? extends Number> classPrior = this.getClassPrior();
        String strategy = this.getStrategy();
        ClassDictUtil.checkSize((Collection[])new Collection[]{classes, classPrior});
        switch (strategy) {
            case "prior": {
                return classPrior.get(index);
            }
        }
        throw new IllegalArgumentException(strategy);
    }

    public TreeModel encodeModel(Schema schema) {
        double[] probabilities;
        int index;
        List<?> classes = this.getClasses();
        List<? extends Number> classPrior = this.getClassPrior();
        Object constant = this.getConstant();
        String strategy = this.getStrategy();
        ClassDictUtil.checkSize((Collection[])new Collection[]{classes, classPrior});
        CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();
        switch (strategy) {
            case "constant": {
                index = classes.indexOf(constant);
                if (index < 0) {
                    throw new IllegalArgumentException();
                }
                probabilities = new double[classes.size()];
                probabilities[index] = 1.0;
                break;
            }
            case "most_frequent": {
                index = DummyClassifier.indexOfMax(classPrior);
                probabilities = new double[classes.size()];
                probabilities[index] = 1.0;
                break;
            }
            case "prior": {
                index = DummyClassifier.indexOfMax(classPrior);
                probabilities = Doubles.toArray(classPrior);
                break;
            }
            default: {
                throw new IllegalArgumentException(strategy);
            }
        }
        ClassifierNode root = new ClassifierNode(categoricalLabel.getValue(index), (Predicate)True.INSTANCE);
        ScoreDistributionManager scoreDistributionManager = new ScoreDistributionManager();
        scoreDistributionManager.addScoreDistributions((PMMLObject)root, categoricalLabel.getValues(), probabilities);
        TreeModel treeModel = new TreeModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema((Label)categoricalLabel), (Node)root).setOutput(ModelUtil.createProbabilityOutput((DataType)DataType.DOUBLE, (CategoricalLabel)categoricalLabel));
        return treeModel;
    }

    public List<? extends Number> getClassPrior() {
        return this.getNumberArray("class_prior_");
    }

    public Object getConstant() {
        return this.getOptionalScalar("constant");
    }

    public String getStrategy() {
        return this.getString("strategy");
    }

    private static int indexOfMax(List<? extends Number> values) {
        Number maxValue = Collections.max(values);
        return values.indexOf(maxValue);
    }
}

