/*
 * Decompiled with CFR 0.152.
 */
package sklearn2pmml.ensemble;

import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.PMMLObject;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.ResultFeature;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segment;
import org.dmg.pmml.mining.Segmentation;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.mining.MiningModelUtil;
import org.jpmml.model.ReflectionUtil;
import org.jpmml.python.DataFrameScope;
import org.jpmml.python.PredicateTranslator;
import org.jpmml.python.Scope;
import org.jpmml.python.TupleUtil;
import sklearn.Estimator;

public class SelectFirstUtil {
    private SelectFirstUtil() {
    }

    public static MiningModel encodeRegressor(List<Object[]> steps, Schema schema) {
        return SelectFirstUtil.encodeModel(MiningFunction.REGRESSION, steps, schema);
    }

    public static MiningModel encodeClassifier(List<Object[]> steps, Schema schema) {
        return SelectFirstUtil.encodeModel(MiningFunction.CLASSIFICATION, steps, schema);
    }

    private static MiningModel encodeModel(MiningFunction miningFunction, List<Object[]> steps, Schema schema) {
        if (steps.isEmpty()) {
            throw new IllegalArgumentException();
        }
        Label label = schema.getLabel();
        List features = schema.getFeatures();
        Segmentation segmentation = new Segmentation(Segmentation.MultipleModelMethod.SELECT_FIRST, null);
        DataFrameScope scope = new DataFrameScope("X", features);
        for (Object[] step : steps) {
            String name = (String)TupleUtil.extractElement((Object[])step, (int)0, String.class);
            Estimator estimator = (Estimator)TupleUtil.extractElement((Object[])step, (int)1, Estimator.class);
            String predicate = (String)TupleUtil.extractElement((Object[])step, (int)2, String.class);
            if (estimator.getMiningFunction() != miningFunction) {
                throw new IllegalArgumentException();
            }
            Predicate pmmlPredicate = PredicateTranslator.translate((String)predicate, (Scope)scope);
            Model model = estimator.encode(schema);
            Segment segment = new Segment(pmmlPredicate, model).setId(name);
            segmentation.addSegments(new Segment[]{segment});
        }
        MiningModel miningModel = new MiningModel(miningFunction, ModelUtil.createMiningSchema((Label)label)).setSegmentation(segmentation);
        SelectFirstUtil.optimizeOutputFields(miningModel);
        return miningModel;
    }

    public static void optimizeOutputFields(MiningModel miningModel) {
        Segmentation segmentation = miningModel.requireSegmentation();
        Map<String, OutputField> commonOutputFields = SelectFirstUtil.collectCommonOutputFields(segmentation);
        if (!commonOutputFields.isEmpty()) {
            Output output = ModelUtil.ensureOutput((Model)miningModel);
            SelectFirstUtil.removeCommonOutputFields(segmentation, commonOutputFields.keySet());
            List outputFields = output.getOutputFields();
            outputFields.addAll(commonOutputFields.values());
        }
    }

    private static Map<String, OutputField> collectCommonOutputFields(Segmentation segmentation) {
        List segments = segmentation.requireSegments();
        Map<Object, Object> result = null;
        for (Segment segment : segments) {
            Model model = segment.requireModel();
            Model finalModel = MiningModelUtil.getFinalModel((Model)model);
            Output output = finalModel.getOutput();
            if (output != null && output.hasOutputFields()) {
                List outputFields = output.getOutputFields();
                if (result == null) {
                    result = outputFields.stream().filter(outputField -> {
                        ResultFeature resultFeature = outputField.getResultFeature();
                        switch (resultFeature) {
                            case PROBABILITY: 
                            case AFFINITY: {
                                return true;
                            }
                        }
                        return false;
                    }).collect(Collectors.toMap(outputField -> outputField.requireName(), outputField -> outputField));
                } else {
                    LinkedHashSet<String> names = new LinkedHashSet<String>();
                    for (OutputField outputField2 : outputFields) {
                        String name = outputField2.requireName();
                        names.add(name);
                        OutputField commonOutputField = (OutputField)result.get(name);
                        if (commonOutputField == null || ReflectionUtil.equals((PMMLObject)outputField2, (PMMLObject)commonOutputField)) continue;
                        result.remove(name);
                    }
                    result.keySet().retainAll(names);
                }
            } else {
                result = Collections.emptyMap();
            }
            if (!result.isEmpty()) continue;
            break;
        }
        return result;
    }

    private static void removeCommonOutputFields(Segmentation segmentation, Set<String> names) {
        List segments = segmentation.requireSegments();
        for (Segment segment : segments) {
            Model model = segment.requireModel();
            Model finalModel = MiningModelUtil.getFinalModel((Model)model);
            Output output = finalModel.getOutput();
            if (output == null || !output.hasOutputFields()) continue;
            List outputFields = output.getOutputFields();
            outputFields.removeIf(outputField -> names.contains(outputField.requireName()));
            if (!outputFields.isEmpty()) continue;
            finalModel.setOutput(null);
        }
    }
}

