/*
 * Decompiled with CFR 0.152.
 */
package sklearn.preprocessing;

import com.google.common.base.Function;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import numpy.DType;
import org.dmg.pmml.Apply;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DefineFunction;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Discretize;
import org.dmg.pmml.DiscretizeBin;
import org.dmg.pmml.Expression;
import org.dmg.pmml.Field;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.Interval;
import org.dmg.pmml.OpType;
import org.dmg.pmml.ParameterField;
import org.jpmml.converter.BinaryFeature;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.FieldNameUtil;
import org.jpmml.converter.IndexFeature;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.PMMLUtil;
import org.jpmml.python.ClassDictUtil;
import org.jpmml.python.HasArray;
import org.jpmml.sklearn.SkLearnEncoder;
import sklearn.Transformer;

public class KBinsDiscretizer
extends Transformer {
    public KBinsDiscretizer(String module, String name) {
        super(module, name);
    }

    @Override
    public List<Feature> encodeFeatures(List<Feature> features, SkLearnEncoder encoder) {
        DType dtype = (DType)this.getDType(false);
        String encode = this.getEncode();
        List<Integer> numberOfBins = this.getNumberOfBins();
        List<List<Number>> binEdges = this.getBinEdges();
        ClassDictUtil.checkSize((int)1, (Collection[])new Collection[]{features});
        Feature feature = features.get(0);
        ContinuousFeature continuousFeature = feature.toContinuousFeature();
        continuousFeature = KBinsDiscretizer.addEps(continuousFeature, encoder);
        List<Number> bins = binEdges.get(0);
        if (bins.isEmpty()) {
            throw new IllegalArgumentException();
        }
        ClassDictUtil.checkSize((int)(numberOfBins.get(0) + 1), (Collection[])new Collection[]{bins});
        ArrayList<Integer> labelCategories = new ArrayList<Integer>();
        Discretize discretize = new Discretize(continuousFeature.getName()).setDataType(dtype != null ? dtype.getDataType() : continuousFeature.getDataType());
        for (int i = 0; i < bins.size() - 1; ++i) {
            Number leftMargin = i > 0 ? (Number)bins.get(i) : (Number)null;
            Number rightMargin = i < bins.size() - 1 - 1 ? (Number)bins.get(i + 1) : (Number)null;
            Interval interval = new Interval(Interval.Closure.CLOSED_OPEN).setLeftMargin(leftMargin).setRightMargin(rightMargin);
            Integer label = i;
            labelCategories.add(label);
            DiscretizeBin discretizeBin = new DiscretizeBin((Object)label, interval);
            discretize.addDiscretizeBins(new DiscretizeBin[]{discretizeBin});
        }
        DerivedField derivedField = encoder.createDerivedField(this.createFieldName("discretize", continuousFeature), OpType.CATEGORICAL, discretize.getDataType(), (Expression)discretize);
        switch (encode) {
            case "onehot": 
            case "onehot-dense": {
                ArrayList<Feature> result = new ArrayList<Feature>();
                for (int i = 0; i < labelCategories.size(); ++i) {
                    Integer label = (Integer)labelCategories.get(i);
                    result.add((Feature)new BinaryFeature((PMMLEncoder)encoder, (Field)derivedField, (Object)label));
                }
                return result;
            }
            case "ordinal": {
                return Collections.singletonList(new IndexFeature((PMMLEncoder)encoder, (Field)derivedField, labelCategories));
            }
        }
        throw new IllegalArgumentException();
    }

    @Override
    public Object getDType(boolean extended) {
        Object dtype = this.get("dtype");
        if (dtype == null) {
            return null;
        }
        return super.getDType(extended);
    }

    public String getEncode() {
        return this.getString("encode");
    }

    public List<Integer> getNumberOfBins() {
        return this.getIntegerArray("n_bins_");
    }

    public List<List<Number>> getBinEdges() {
        List arrays = this.getArray("bin_edges_", HasArray.class);
        Function<HasArray, List<Number>> function = new Function<HasArray, List<Number>>(){

            public List<Number> apply(HasArray hasArray) {
                return hasArray.getArrayContent();
            }
        };
        return Lists.transform((List)arrays, (Function)function);
    }

    private static ContinuousFeature addEps(ContinuousFeature continuousFeature, SkLearnEncoder encoder) {
        DefineFunction defineFunction = encoder.getDefineFunction("add_eps");
        if (defineFunction == null) {
            defineFunction = KBinsDiscretizer.encodeDefineFunction("add_eps");
            encoder.addDefineFunction(defineFunction);
        }
        Apply apply = PMMLUtil.createApply((String)defineFunction.getName(), (Expression[])new Expression[]{continuousFeature.ref()});
        DerivedField derivedField = encoder.createDerivedField(FieldNameUtil.create((String)defineFunction.getName(), (Object[])new Object[]{continuousFeature.getName()}), OpType.CONTINUOUS, continuousFeature.getDataType(), (Expression)apply);
        return new ContinuousFeature((PMMLEncoder)encoder, (Field)derivedField);
    }

    private static DefineFunction encodeDefineFunction(String name) {
        ParameterField valueField = new ParameterField(FieldName.create((String)"x"));
        Double atol = 1.0E-8;
        Double rtol = 1.0E-5;
        Apply apply = PMMLUtil.createApply((String)"+", (Expression[])new Expression[]{new FieldRef(valueField.getName()), PMMLUtil.createApply((String)"+", (Expression[])new Expression[]{PMMLUtil.createConstant((Number)atol), PMMLUtil.createApply((String)"*", (Expression[])new Expression[]{PMMLUtil.createConstant((Number)rtol), PMMLUtil.createApply((String)"abs", (Expression[])new Expression[]{new FieldRef(valueField.getName())})})})});
        DefineFunction defineFunction = new DefineFunction(name, OpType.CONTINUOUS, DataType.DOUBLE, null, (Expression)apply).addParameterFields(new ParameterField[]{valueField});
        return defineFunction;
    }
}

