/*
 * Decompiled with CFR 0.152.
 */
package sklearn.ensemble.stacking;

import com.google.common.collect.Iterables;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import org.dmg.pmml.Model;
import org.dmg.pmml.mining.MiningModel;
import org.jpmml.converter.Feature;
import org.jpmml.converter.Label;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.Schema;
import org.jpmml.converter.mining.MiningModelUtil;
import org.jpmml.python.ClassDictUtil;
import org.jpmml.sklearn.SkLearnEncoder;
import sklearn.Estimator;

public class StackingUtil {
    private StackingUtil() {
    }

    public static <E extends Estimator> MiningModel encodeStacking(List<? extends E> estimators, List<String> stackMethods, PredictFunction predictFunction, E finalEstimator, boolean passthrough, Schema schema) {
        ClassDictUtil.checkSize((Collection[])new Collection[]{estimators, stackMethods});
        Label label = schema.getLabel();
        List features = schema.getFeatures();
        SkLearnEncoder encoder = (SkLearnEncoder)StackingUtil.getEncoder(features);
        ArrayList<Feature> stackFeatures = new ArrayList<Feature>();
        ArrayList<Model> models = new ArrayList<Model>();
        for (int i = 0; i < estimators.size(); ++i) {
            Estimator estimator = (Estimator)estimators.get(i);
            String stackMethod = stackMethods.get(i);
            Model model = estimator.encodeModel(schema);
            List<Feature> predictFeatures = predictFunction.apply(i, model, stackMethod, encoder);
            if (predictFeatures != null && predictFeatures.size() > 0) {
                stackFeatures.addAll(predictFeatures);
            }
            models.add(model);
        }
        if (passthrough) {
            stackFeatures.addAll(features);
        }
        Schema stackSchema = new Schema(label, stackFeatures);
        Model finalModel = finalEstimator.encodeModel(stackSchema);
        models.add(finalModel);
        return MiningModelUtil.createModelChain(models);
    }

    public static PMMLEncoder getEncoder(List<? extends Feature> features) {
        Set encoders = features.stream().map(feature -> feature.getEncoder()).collect(Collectors.toSet());
        return (PMMLEncoder)Iterables.getOnlyElement(encoders);
    }

    public static interface PredictFunction {
        public List<Feature> apply(int var1, Model var2, String var3, SkLearnEncoder var4);
    }
}

