/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.h2o;

import hex.genmodel.MojoModel;
import hex.genmodel.algos.ensemble.StackedEnsembleMojoModel;
import hex.genmodel.algos.tree.SharedTreeMojoModel;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.List;
import org.dmg.pmml.DataType;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.OutputField;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.ContinuousLabel;
import org.jpmml.converter.DerivedOutputField;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.Schema;
import org.jpmml.converter.SchemaUtil;
import org.jpmml.h2o.Converter;
import org.jpmml.h2o.ConverterFactory;
import org.jpmml.h2o.H2OEncoder;

public class StackedEnsembleMojoModelConverter
extends Converter<StackedEnsembleMojoModel> {
    private static final Field FIELD_BASEMODELS;
    private static final Field FIELD_METALEARNER;
    private static final Class<?> CLASS_STACKEDENSEMBLESUBMODEL;
    private static final Field FIELD_MOJOMODEL;
    private static final Field FIELD_MAPPING;

    public StackedEnsembleMojoModelConverter(StackedEnsembleMojoModel model) {
        super(model);
    }

    @Override
    public Schema encodeSchema(H2OEncoder encoder) {
        StackedEnsembleMojoModel model = (StackedEnsembleMojoModel)this.getModel();
        ConverterFactory converterFactory = ConverterFactory.newConverterFactory();
        Schema schema = super.encodeSchema(encoder);
        Label label = schema.getLabel();
        ArrayList<ContinuousFeature> features = new ArrayList<ContinuousFeature>();
        Schema segmentSchema = schema.toAnonymousSchema();
        Object[] baseModels = StackedEnsembleMojoModelConverter.getBaseModels(model);
        for (int i = 0; i < baseModels.length; ++i) {
            Object baseModel = baseModels[i];
            MojoModel mojoModel = StackedEnsembleMojoModelConverter.getMojoModel(baseModel);
            double[] mapping = StackedEnsembleMojoModelConverter.getMapping(baseModel);
            if (!(mojoModel instanceof SharedTreeMojoModel)) {
                throw new IllegalArgumentException("Stacking of models other than decision tree models is not supported");
            }
            if (mapping != null) {
                throw new IllegalArgumentException("Feature re-indexing is not supported");
            }
            Converter<? extends MojoModel> converter = converterFactory.newConverter(mojoModel);
            Schema baseModelSchema = converter.toMojoModelSchema(segmentSchema);
            Model segmentModel = converter.encodeModel(baseModelSchema);
            if (model._nclasses == 1) {
                ContinuousLabel continuousLabel = (ContinuousLabel)label;
                OutputField predictedOutputField = ModelUtil.createPredictedField((FieldName)FieldName.create((String)("stack(" + i + ")")), (OpType)OpType.CONTINUOUS, (DataType)DataType.DOUBLE).setFinalResult(Boolean.valueOf(false));
                DerivedOutputField predictedField = encoder.createDerivedField(segmentModel, predictedOutputField, false);
                features.add(new ContinuousFeature((PMMLEncoder)encoder, (org.dmg.pmml.Field)predictedField));
            } else {
                CategoricalLabel categoricalLabel = (CategoricalLabel)label;
                SchemaUtil.checkSize((int)model._nclasses, (CategoricalLabel)categoricalLabel);
                List values = categoricalLabel.getValues();
                if (model._nclasses == 2) {
                    values = values.subList(1, 2);
                }
                for (Object value : values) {
                    OutputField probabilityOutputField = ModelUtil.createProbabilityField((FieldName)FieldName.create((String)("stack(" + i + ", " + value + ")")), (DataType)DataType.DOUBLE, value).setFinalResult(Boolean.valueOf(false));
                    DerivedOutputField probabilityField = encoder.createDerivedField(segmentModel, probabilityOutputField, false);
                    features.add(new ContinuousFeature((PMMLEncoder)encoder, (org.dmg.pmml.Field)probabilityField));
                }
            }
            encoder.addTransformer(segmentModel);
        }
        return new Schema(label, features);
    }

    @Override
    public Model encodeModel(Schema schema) {
        StackedEnsembleMojoModel model = (StackedEnsembleMojoModel)this.getModel();
        ConverterFactory converterFactory = ConverterFactory.newConverterFactory();
        MojoModel metaLearner = StackedEnsembleMojoModelConverter.getMetaLearner(model);
        if (metaLearner == null) {
            throw new IllegalArgumentException();
        }
        Converter<? extends MojoModel> converter = converterFactory.newConverter(metaLearner);
        Schema metaLearnerSchema = converter.toMojoModelSchema(schema);
        return converter.encodeModel(metaLearnerSchema);
    }

    public static Object[] getBaseModels(StackedEnsembleMojoModel model) {
        return (Object[])StackedEnsembleMojoModelConverter.getFieldValue(FIELD_BASEMODELS, model);
    }

    public static MojoModel getMetaLearner(StackedEnsembleMojoModel model) {
        return (MojoModel)StackedEnsembleMojoModelConverter.getFieldValue(FIELD_METALEARNER, model);
    }

    public static MojoModel getMojoModel(Object baseModel) {
        return (MojoModel)StackedEnsembleMojoModelConverter.getFieldValue(FIELD_MOJOMODEL, baseModel);
    }

    public static double[] getMapping(Object baseModel) {
        return (double[])StackedEnsembleMojoModelConverter.getFieldValue(FIELD_MAPPING, baseModel);
    }

    static {
        try {
            FIELD_BASEMODELS = StackedEnsembleMojoModel.class.getDeclaredField("_baseModels");
            FIELD_METALEARNER = StackedEnsembleMojoModel.class.getDeclaredField("_metaLearner");
        }
        catch (ReflectiveOperationException roe) {
            throw new RuntimeException(roe);
        }
        try {
            CLASS_STACKEDENSEMBLESUBMODEL = StackedEnsembleMojoModelConverter.getDeclaredClass(StackedEnsembleMojoModel.class, "StackedEnsembleMojoSubModel");
        }
        catch (ReflectiveOperationException roe) {
            throw new RuntimeException(roe);
        }
        try {
            FIELD_MOJOMODEL = CLASS_STACKEDENSEMBLESUBMODEL.getDeclaredField("_mojoModel");
            FIELD_MAPPING = CLASS_STACKEDENSEMBLESUBMODEL.getDeclaredField("_mapping");
        }
        catch (ReflectiveOperationException roe) {
            throw new RuntimeException(roe);
        }
    }
}

