/*
 * Decompiled with CFR 0.152.
 */
package sklearn.neighbors;

import com.google.common.base.Function;
import com.google.common.collect.Lists;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.List;
import org.dmg.pmml.CityBlock;
import org.dmg.pmml.CompareFunction;
import org.dmg.pmml.ComparisonMeasure;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Euclidean;
import org.dmg.pmml.Measure;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Minkowski;
import org.dmg.pmml.nearest_neighbor.InstanceField;
import org.dmg.pmml.nearest_neighbor.InstanceFields;
import org.dmg.pmml.nearest_neighbor.KNNInput;
import org.dmg.pmml.nearest_neighbor.KNNInputs;
import org.dmg.pmml.nearest_neighbor.NearestNeighborModel;
import org.dmg.pmml.nearest_neighbor.TrainingInstances;
import org.jpmml.converter.CMatrixUtil;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.ContinuousLabel;
import org.jpmml.converter.Feature;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.MultiLabel;
import org.jpmml.converter.PMMLUtil;
import org.jpmml.converter.ScalarLabel;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ValueUtil;
import org.jpmml.converter.nearest_neighbor.NearestNeighborModelUtil;
import org.jpmml.python.ClassDictUtil;
import sklearn.Estimator;
import sklearn.neighbors.HasNeighbors;
import sklearn.neighbors.HasTrainingData;

public class KNeighborsUtil {
    private KNeighborsUtil() {
    }

    public static <E extends Estimator> int getNumberOfOutputs(E estimator) {
        int[] shape = ((HasTrainingData)((Object)estimator)).getYShape();
        if (shape.length == 1) {
            return 1;
        }
        if (shape.length == 2) {
            return shape[1];
        }
        throw new IllegalArgumentException();
    }

    public static <E extends Estimator & HasTrainingData> NearestNeighborModel encodeNeighbors(E estimator, MiningFunction miningFunction, int numberOfInstances, int numberOfFeatures, Schema schema) {
        int i;
        int numberOfNeighbors = ((HasNeighbors)estimator).getNumberOfNeighbors();
        int numberOfOutputs = estimator.getNumberOfOutputs();
        List<? extends Number> fitX = ((HasTrainingData)estimator).getFitX();
        List<? extends Number> y = ((HasTrainingData)estimator).getY();
        ClassDictUtil.checkSize((int)(numberOfInstances * numberOfOutputs), (Collection[])new Collection[]{y});
        Label label = schema.getLabel();
        List features = schema.getFeatures();
        LinkedHashMap<String, List> data = new LinkedHashMap<String, List>();
        InstanceFields instanceFields = new InstanceFields();
        if (numberOfOutputs == 1) {
            ScalarLabel scalarLabel = (ScalarLabel)label;
            if (scalarLabel != null) {
                InstanceField instanceField = new InstanceField(scalarLabel.getName()).setColumn("data:y");
                instanceFields.addInstanceFields(new InstanceField[]{instanceField});
                data.put(instanceField.getColumn(), KNeighborsUtil.translateValues(scalarLabel, y));
            }
        } else if (numberOfOutputs >= 2) {
            MultiLabel multiLabel = (MultiLabel)label;
            List labels = multiLabel.getLabels();
            for (i = 0; i < labels.size(); ++i) {
                ScalarLabel scalarLabel = (ScalarLabel)labels.get(i);
                if (scalarLabel == null) continue;
                InstanceField instanceField = new InstanceField(scalarLabel.getName()).setColumn("data:y" + String.valueOf(i + 1));
                instanceFields.addInstanceFields(new InstanceField[]{instanceField});
                data.put(instanceField.getColumn(), KNeighborsUtil.translateValues(scalarLabel, CMatrixUtil.getColumn(y, (int)numberOfInstances, (int)numberOfOutputs, (int)i)));
            }
        } else {
            throw new IllegalArgumentException();
        }
        DataType dataType = estimator.getDataType();
        KNNInputs knnInputs = new KNNInputs();
        for (i = 0; i < features.size(); ++i) {
            Feature feature = (Feature)features.get(i);
            ContinuousFeature continuousFeature = feature.toContinuousFeature(dataType);
            String name = continuousFeature.getName();
            InstanceField instanceField = new InstanceField(name).setColumn("data:x" + String.valueOf(i + 1));
            instanceFields.addInstanceFields(new InstanceField[]{instanceField});
            KNNInput knnInput = new KNNInput(name);
            knnInputs.addKNNInputs(new KNNInput[]{knnInput});
            data.put(instanceField.getColumn(), CMatrixUtil.getColumn(fitX, (int)numberOfInstances, (int)numberOfFeatures, (int)i));
        }
        TrainingInstances trainingInstances = new TrainingInstances(instanceFields, PMMLUtil.createInlineTable(data)).setTransformed(Boolean.valueOf(true));
        ComparisonMeasure comparisonMeasure = KNeighborsUtil.encodeComparisonMeasure(estimator);
        NearestNeighborModel nearestNeighborModel = new NearestNeighborModel(miningFunction, Integer.valueOf(numberOfNeighbors), ModelUtil.createMiningSchema((Label)schema.getLabel()), trainingInstances, comparisonMeasure, knnInputs);
        if (numberOfOutputs == 1) {
            nearestNeighborModel.setOutput(NearestNeighborModelUtil.createOutput((int)numberOfNeighbors));
        }
        return nearestNeighborModel;
    }

    private static <E extends Estimator> ComparisonMeasure encodeComparisonMeasure(E estimator) {
        String weights = ((HasNeighbors)((Object)estimator)).getWeights();
        if (!"uniform".equals(weights)) {
            throw new IllegalArgumentException(weights);
        }
        Measure measure = KNeighborsUtil.encodeMeasure(estimator);
        ComparisonMeasure comparisonMeasure = new ComparisonMeasure(ComparisonMeasure.Kind.DISTANCE, measure).setCompareFunction(CompareFunction.ABS_DIFF);
        return comparisonMeasure;
    }

    private static <E extends Estimator> Measure encodeMeasure(E estimator) {
        String metric = ((HasNeighbors)((Object)estimator)).getMetric();
        int p = ((HasNeighbors)((Object)estimator)).getP();
        switch (metric) {
            case "euclidean": {
                return new Euclidean();
            }
            case "manhattan": {
                return new CityBlock();
            }
            case "minkowski": {
                switch (p) {
                    case 1: {
                        return new CityBlock();
                    }
                    case 2: {
                        return new Euclidean();
                    }
                }
                return new Minkowski((Number)p);
            }
        }
        throw new IllegalArgumentException(metric);
    }

    private static List<?> translateValues(ScalarLabel scalarLabel, List<? extends Number> y) {
        if (scalarLabel instanceof ContinuousLabel) {
            ContinuousLabel continuousLabel = (ContinuousLabel)scalarLabel;
            return y;
        }
        if (scalarLabel instanceof CategoricalLabel) {
            final CategoricalLabel categoricalLabel = (CategoricalLabel)scalarLabel;
            Function<Number, Object> function = new Function<Number, Object>(){

                public Object apply(Number number) {
                    int index = ValueUtil.asInt((Number)number);
                    return categoricalLabel.getValue(index);
                }
            };
            return Lists.transform(y, (Function)function);
        }
        throw new IllegalArgumentException();
    }
}

