/*
 * Decompiled with CFR 0.152.
 */
package sklego.meta;

import com.google.common.collect.Iterables;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Field;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.ResultFeature;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.ContinuousLabel;
import org.jpmml.converter.DerivedOutputField;
import org.jpmml.converter.Feature;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.Schema;
import org.jpmml.converter.TypeUtil;
import org.jpmml.sklearn.SkLearnEncoder;
import sklearn.ClassifierUtil;
import sklearn.Estimator;
import sklearn.HasApplyField;
import sklearn.HasDecisionFunctionField;
import sklearn.HasEstimator;
import sklearn.HasMultiApplyField;
import sklearn.HasPredictField;
import sklearn.Transformer;
import sklearn.tree.HasTreeOptions;

public class EstimatorTransformer
extends Transformer
implements HasEstimator<Estimator> {
    public EstimatorTransformer(String module, String name) {
        super(module, name);
    }

    @Override
    public List<Feature> encodeFeatures(List<Feature> features, SkLearnEncoder encoder) {
        List<Object> inputNames;
        String predictFunc;
        Estimator estimator = this.getEstimator();
        switch (predictFunc = this.getPredictFunc()) {
            case "apply": {
                if (estimator instanceof HasTreeOptions) {
                    HasTreeOptions hasTreeOptions = (HasTreeOptions)((Object)estimator);
                    estimator.putOption("winner_id", Boolean.TRUE);
                }
                if (estimator instanceof HasApplyField) {
                    HasApplyField hasApplyField = (HasApplyField)((Object)estimator);
                    inputNames = Collections.singletonList(hasApplyField.getApplyField());
                    break;
                }
                if (estimator instanceof HasMultiApplyField) {
                    HasMultiApplyField hasMultiApplyField = (HasMultiApplyField)((Object)estimator);
                    inputNames = hasMultiApplyField.getApplyFields();
                    break;
                }
                throw new IllegalArgumentException();
            }
            case "decision_function": {
                if (estimator instanceof HasDecisionFunctionField) {
                    HasDecisionFunctionField hasDecisionFunctionField = (HasDecisionFunctionField)((Object)estimator);
                    inputNames = Collections.singletonList(hasDecisionFunctionField.getDecisionFunctionField());
                    break;
                }
                throw new IllegalArgumentException();
            }
            case "predict": {
                if (estimator instanceof HasPredictField) {
                    HasPredictField hasPredictField = (HasPredictField)((Object)estimator);
                    inputNames = Collections.singletonList(hasPredictField.getPredictField());
                    break;
                }
                inputNames = null;
                break;
            }
            default: {
                throw new IllegalArgumentException(predictFunc);
            }
        }
        Schema schema = EstimatorTransformer.createSchema(estimator, features, encoder);
        Model model = estimator.encode(schema);
        LinkedHashMap<FieldName, DerivedOutputField> derivedOutputFields = new LinkedHashMap<FieldName, DerivedOutputField>();
        Output output = model.getOutput();
        if (output != null && output.hasOutputFields()) {
            List outputFields = output.getOutputFields();
            Iterator it = outputFields.iterator();
            while (it.hasNext()) {
                OutputField outputField = (OutputField)it.next();
                ResultFeature resultFeature = outputField.getResultFeature();
                switch (resultFeature) {
                    case PREDICTED_VALUE: 
                    case TRANSFORMED_VALUE: 
                    case DECISION: 
                    case ENTITY_ID: {
                        DerivedOutputField derivedOutputField = encoder.createDerivedField(model, outputField, true);
                        derivedOutputFields.put(derivedOutputField.getName(), derivedOutputField);
                        break;
                    }
                }
                it.remove();
            }
        }
        encoder.addTransformer(model);
        if (inputNames == null) {
            if (estimator.isSupervised()) {
                if (!derivedOutputFields.isEmpty()) {
                    throw new IllegalArgumentException();
                }
                Label label = schema.getLabel();
                FieldName name = this.createFieldName("predict", new Object[0]);
                MiningFunction miningFunction = estimator.getMiningFunction();
                switch (miningFunction) {
                    case CLASSIFICATION: {
                        CategoricalLabel categoricalLabel = (CategoricalLabel)label;
                        List categories = categoricalLabel.getValues();
                        OutputField predictedOutputField = ModelUtil.createPredictedField((FieldName)name, (OpType)OpType.CATEGORICAL, (DataType)categoricalLabel.getDataType());
                        DerivedOutputField predictedField = encoder.createDerivedField(model, predictedOutputField, false);
                        return Collections.singletonList(new CategoricalFeature((PMMLEncoder)encoder, (Field)predictedField, categories));
                    }
                    case REGRESSION: {
                        ContinuousLabel continuousLabel = (ContinuousLabel)label;
                        OutputField predictedOutputField = ModelUtil.createPredictedField((FieldName)name, (OpType)OpType.CONTINUOUS, (DataType)continuousLabel.getDataType());
                        DerivedOutputField predictedField = encoder.createDerivedField(model, predictedOutputField, false);
                        return Collections.singletonList(new ContinuousFeature((PMMLEncoder)encoder, (Field)predictedField));
                    }
                }
                throw new IllegalArgumentException();
            }
            if (derivedOutputFields.isEmpty()) {
                throw new IllegalArgumentException();
            }
            inputNames = Collections.singletonList(Iterables.getLast(derivedOutputFields.keySet()));
        }
        ArrayList<Feature> result = new ArrayList<Feature>();
        for (FieldName inputName : inputNames) {
            ContinuousFeature feature;
            DerivedOutputField inputField = (DerivedOutputField)derivedOutputFields.get(inputName);
            if (inputField == null) {
                throw new IllegalArgumentException();
            }
            OpType opType = inputField.getOpType();
            switch (opType) {
                case CATEGORICAL: {
                    feature = new CategoricalFeature((PMMLEncoder)encoder, (Field)inputField.getOutputField());
                    break;
                }
                case CONTINUOUS: {
                    feature = new ContinuousFeature((PMMLEncoder)encoder, (Field)inputField);
                    break;
                }
                default: {
                    throw new IllegalArgumentException();
                }
            }
            result.add((Feature)feature);
        }
        return result;
    }

    @Override
    public Estimator getEstimator() {
        return (Estimator)this.get("estimator_", Estimator.class);
    }

    public String getPredictFunc() {
        return this.getString("predict_func");
    }

    private static Schema createSchema(Estimator estimator, List<Feature> features, SkLearnEncoder encoder) {
        ContinuousLabel label = null;
        if (estimator.isSupervised()) {
            MiningFunction miningFunction = estimator.getMiningFunction();
            switch (miningFunction) {
                case CLASSIFICATION: {
                    List<?> categories = ClassifierUtil.getClasses(estimator);
                    DataType dataType = TypeUtil.getDataType(categories, (DataType)DataType.STRING);
                    label = new CategoricalLabel(null, dataType, categories);
                    break;
                }
                case REGRESSION: {
                    label = new ContinuousLabel(null, DataType.DOUBLE);
                    break;
                }
                default: {
                    throw new IllegalArgumentException();
                }
            }
        }
        return new Schema((PMMLEncoder)encoder, label, features);
    }
}

