/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.h2o;

import hex.genmodel.algos.drf.DrfMojoModel;
import hex.genmodel.algos.tree.SharedTreeMojoModel;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.List;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segmentation;
import org.dmg.pmml.regression.RegressionModel;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.converter.CMatrixUtil;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousLabel;
import org.jpmml.converter.FieldNameUtil;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.Transformation;
import org.jpmml.converter.mining.MiningModelUtil;
import org.jpmml.h2o.SharedTreeMojoModelConverter;

public class DrfMojoModelConverter
extends SharedTreeMojoModelConverter<DrfMojoModel> {
    private static final Field FIELD_BOOLEANDOUBLETREES;

    public DrfMojoModelConverter(DrfMojoModel model) {
        super(model);
    }

    @Override
    public Model encodeModel(Schema schema) {
        DrfMojoModel model = (DrfMojoModel)this.getModel();
        boolean binomialDoubleTrees = DrfMojoModelConverter.getBinomialDoubleTrees(model);
        int ntreeGroups = DrfMojoModelConverter.getNTreeGroups((SharedTreeMojoModel)model);
        int ntreesPerGroup = DrfMojoModelConverter.getNTreesPerGroup((SharedTreeMojoModel)model);
        Label label = schema.getLabel();
        List<TreeModel> treeModels = this.encodeTreeModels(schema);
        if (model._nclasses == 1) {
            ContinuousLabel continuousLabel = (ContinuousLabel)label;
            return DrfMojoModelConverter.encodeTreeEnsemble(treeModels, ensembleTreeModels -> {
                MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema((Label)continuousLabel)).setSegmentation(MiningModelUtil.createSegmentation((Segmentation.MultipleModelMethod)Segmentation.MultipleModelMethod.AVERAGE, (List)ensembleTreeModels));
                return miningModel;
            });
        }
        if (model._nclasses == 2 && !binomialDoubleTrees) {
            ContinuousLabel continuousLabel = new ContinuousLabel(DataType.DOUBLE);
            Model pmmlModel = DrfMojoModelConverter.encodeTreeEnsemble(treeModels, ensembleTreeModels -> {
                MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema((Label)continuousLabel)).setSegmentation(MiningModelUtil.createSegmentation((Segmentation.MultipleModelMethod)Segmentation.MultipleModelMethod.AVERAGE, (List)ensembleTreeModels));
                return miningModel;
            });
            pmmlModel.setOutput(ModelUtil.createPredictedOutput((String)"drfValue", (OpType)OpType.CONTINUOUS, (DataType)DataType.DOUBLE, (Transformation[])new Transformation[0]));
            return MiningModelUtil.createBinaryLogisticClassification((Model)pmmlModel, (double)-1.0, (double)1.0, (RegressionModel.NormalizationMethod)RegressionModel.NormalizationMethod.NONE, (boolean)true, (Schema)schema);
        }
        CategoricalLabel categoricalLabel = (CategoricalLabel)label;
        ArrayList<Model> models = new ArrayList<Model>();
        for (int i = 0; i < categoricalLabel.size(); ++i) {
            Model pmmlModel = DrfMojoModelConverter.encodeTreeEnsemble(CMatrixUtil.getRow(treeModels, (int)ntreesPerGroup, (int)ntreeGroups, (int)i), ensembleTreeModels -> {
                MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(null)).setSegmentation(MiningModelUtil.createSegmentation((Segmentation.MultipleModelMethod)Segmentation.MultipleModelMethod.SUM, (List)ensembleTreeModels));
                return miningModel;
            });
            pmmlModel.setOutput(ModelUtil.createPredictedOutput((String)FieldNameUtil.create((String)"drfValue", (Object[])new Object[]{categoricalLabel.getValue(i)}), (OpType)OpType.CONTINUOUS, (DataType)DataType.DOUBLE, (Transformation[])new Transformation[0]));
            models.add(pmmlModel);
        }
        return MiningModelUtil.createClassification(models, (RegressionModel.NormalizationMethod)RegressionModel.NormalizationMethod.SIMPLEMAX, (boolean)true, (Schema)schema);
    }

    public static boolean getBinomialDoubleTrees(DrfMojoModel model) {
        return (Boolean)DrfMojoModelConverter.getFieldValue(FIELD_BOOLEANDOUBLETREES, model);
    }

    static {
        try {
            FIELD_BOOLEANDOUBLETREES = DrfMojoModel.class.getDeclaredField("_binomial_double_trees");
        }
        catch (ReflectiveOperationException roe) {
            throw new RuntimeException(roe);
        }
    }
}

