/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.h2o;

import hex.genmodel.algos.gbm.GbmMojoModel;
import hex.genmodel.algos.tree.SharedTreeMojoModel;
import hex.genmodel.utils.DistributionFamily;
import java.util.ArrayList;
import java.util.List;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segmentation;
import org.dmg.pmml.regression.RegressionModel;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.converter.CMatrixUtil;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousLabel;
import org.jpmml.converter.FieldNameUtil;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.Transformation;
import org.jpmml.converter.mining.MiningModelUtil;
import org.jpmml.h2o.SharedTreeMojoModelConverter;

public class GbmMojoModelConverter
extends SharedTreeMojoModelConverter<GbmMojoModel> {
    public GbmMojoModelConverter(GbmMojoModel model) {
        super(model);
    }

    public MiningModel encodeModel(Schema schema) {
        GbmMojoModel model = (GbmMojoModel)this.getModel();
        int ntreeGroups = GbmMojoModelConverter.getNTreeGroups((SharedTreeMojoModel)model);
        int ntreesPerGroup = GbmMojoModelConverter.getNTreesPerGroup((SharedTreeMojoModel)model);
        Label label = schema.getLabel();
        List<TreeModel> treeModels = this.encodeTreeModels(schema);
        if (model._family == DistributionFamily.gaussian) {
            ContinuousLabel continuousLabel = (ContinuousLabel)label;
            MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema((Label)continuousLabel)).setSegmentation(MiningModelUtil.createSegmentation((Segmentation.MultipleModelMethod)Segmentation.MultipleModelMethod.SUM, treeModels)).setTargets(ModelUtil.createRescaleTargets(null, (Number)model._init_f, (ContinuousLabel)continuousLabel));
            return miningModel;
        }
        if (model._family == DistributionFamily.poisson || model._family == DistributionFamily.gamma || model._family == DistributionFamily.tweedie) {
            ContinuousLabel continuousLabel = new ContinuousLabel(DataType.DOUBLE);
            MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema((Label)continuousLabel)).setSegmentation(MiningModelUtil.createSegmentation((Segmentation.MultipleModelMethod)Segmentation.MultipleModelMethod.SUM, treeModels)).setTargets(ModelUtil.createRescaleTargets(null, (Number)model._init_f, (ContinuousLabel)continuousLabel)).setOutput(ModelUtil.createPredictedOutput((String)"gbmValue", (OpType)OpType.CONTINUOUS, (DataType)DataType.DOUBLE, (Transformation[])new Transformation[0]));
            return MiningModelUtil.createRegression((Model)miningModel, (RegressionModel.NormalizationMethod)RegressionModel.NormalizationMethod.EXP, (Schema)schema);
        }
        if (model._family == DistributionFamily.bernoulli) {
            ContinuousLabel continuousLabel = new ContinuousLabel(DataType.DOUBLE);
            MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema((Label)continuousLabel)).setSegmentation(MiningModelUtil.createSegmentation((Segmentation.MultipleModelMethod)Segmentation.MultipleModelMethod.SUM, treeModels)).setTargets(ModelUtil.createRescaleTargets(null, (Number)model._init_f, (ContinuousLabel)continuousLabel)).setOutput(ModelUtil.createPredictedOutput((String)"gbmValue", (OpType)OpType.CONTINUOUS, (DataType)DataType.DOUBLE, (Transformation[])new Transformation[0]));
            return MiningModelUtil.createBinaryLogisticClassification((Model)miningModel, (double)1.0, (double)0.0, (RegressionModel.NormalizationMethod)RegressionModel.NormalizationMethod.LOGIT, (boolean)true, (Schema)schema);
        }
        if (model._family == DistributionFamily.multinomial) {
            CategoricalLabel categoricalLabel = (CategoricalLabel)label;
            ArrayList<MiningModel> models = new ArrayList<MiningModel>();
            for (int i = 0; i < categoricalLabel.size(); ++i) {
                MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(null)).setSegmentation(MiningModelUtil.createSegmentation((Segmentation.MultipleModelMethod)Segmentation.MultipleModelMethod.SUM, (List)CMatrixUtil.getRow(treeModels, (int)ntreesPerGroup, (int)ntreeGroups, (int)i))).setOutput(ModelUtil.createPredictedOutput((String)FieldNameUtil.create((String)"gbmValue", (Object[])new Object[]{categoricalLabel.getValue(i)}), (OpType)OpType.CONTINUOUS, (DataType)DataType.DOUBLE, (Transformation[])new Transformation[0]));
                models.add(miningModel);
            }
            return MiningModelUtil.createClassification(models, (RegressionModel.NormalizationMethod)RegressionModel.NormalizationMethod.SOFTMAX, (boolean)true, (Schema)schema);
        }
        throw new IllegalArgumentException("Distribution family " + model._family + " is not supported");
    }
}

