/*
 * Decompiled with CFR 0.152.
 */
package sklearn2pmml.neural_network;

import com.google.common.collect.Iterables;
import java.util.ArrayList;
import java.util.List;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Expression;
import org.dmg.pmml.Field;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.MiningField;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.MiningSchema;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.neural_network.NeuralInputs;
import org.dmg.pmml.neural_network.NeuralLayer;
import org.dmg.pmml.neural_network.NeuralNetwork;
import org.dmg.pmml.neural_network.NeuralOutput;
import org.dmg.pmml.neural_network.NeuralOutputs;
import org.dmg.pmml.neural_network.Neuron;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.DerivedOutputField;
import org.jpmml.converter.Feature;
import org.jpmml.converter.FieldNameUtil;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.neural_network.NeuralNetworkUtil;
import org.jpmml.python.HasArray;
import org.jpmml.sklearn.SkLearnEncoder;
import sklearn.Transformer;
import sklearn.neural_network.MLPRegressor;
import sklearn.neural_network.MultilayerPerceptronUtil;

public class MLPTransformer
extends Transformer {
    public MLPTransformer(String module, String name) {
        super(module, name);
    }

    @Override
    public List<Feature> encodeFeatures(List<Feature> features, SkLearnEncoder encoder) {
        MLPRegressor mlp = this.getMLP();
        int transformerOutputLayer = this.getTransformerOutputLayer();
        String activation = mlp.getActivation();
        NeuralNetwork.ActivationFunction activationFunction = MultilayerPerceptronUtil.parseActivationFunction(activation);
        List<? extends HasArray> coefs = mlp.getCoefs();
        List<? extends HasArray> intercepts = mlp.getIntercepts();
        MiningSchema miningSchema = new MiningSchema();
        NeuralInputs neuralInputs = NeuralNetworkUtil.createNeuralInputs(features, (DataType)DataType.DOUBLE);
        List<NeuralLayer> neuralLayers = transformerOutputLayer < 0 ? MultilayerPerceptronUtil.encodeNeuralLayers(neuralInputs, coefs, intercepts) : MultilayerPerceptronUtil.encodeNeuralLayers(neuralInputs, transformerOutputLayer, coefs, intercepts);
        NeuralOutputs neuralOutputs = new NeuralOutputs();
        NeuralLayer neuralLayer = (NeuralLayer)Iterables.getLast(neuralLayers);
        neuralLayer.setActivationFunction(NeuralNetwork.ActivationFunction.IDENTITY);
        List neurons = neuralLayer.getNeurons();
        ArrayList<DataField> dataFields = new ArrayList<DataField>();
        for (int i = 0; i < neurons.size(); ++i) {
            Neuron neuron = (Neuron)neurons.get(i);
            DataField dataField = encoder.createDataField(FieldNameUtil.create((String)"mlp", (Object[])new Object[]{i}), OpType.CONTINUOUS, DataType.DOUBLE);
            MiningField miningField = ModelUtil.createMiningField((String)dataField.requireName(), (MiningField.UsageType)MiningField.UsageType.TARGET);
            miningSchema.addMiningFields(new MiningField[]{miningField});
            DerivedField derivedField = new DerivedField(null, OpType.CONTINUOUS, DataType.DOUBLE, (Expression)new FieldRef((Field)dataField));
            NeuralOutput neuralOutput = new NeuralOutput().setOutputNeuron(neuron.requireId()).setDerivedField(derivedField);
            neuralOutputs.addNeuralOutputs(new NeuralOutput[]{neuralOutput});
            dataFields.add(dataField);
        }
        NeuralNetwork neuralNetwork = new NeuralNetwork(MiningFunction.REGRESSION, activationFunction, miningSchema, neuralInputs, neuralLayers).setNeuralOutputs(neuralOutputs);
        encoder.addTransformer((Model)neuralNetwork);
        ArrayList<Feature> result = new ArrayList<Feature>();
        for (int i = 0; i < dataFields.size(); ++i) {
            DataField dataField = (DataField)dataFields.get(i);
            OutputField outputField = ModelUtil.createPredictedField((String)FieldNameUtil.create((String)"predict", (Object[])new Object[]{dataField.requireName()}), (OpType)dataField.requireOpType(), (DataType)dataField.requireDataType()).setFinalResult(Boolean.valueOf(false)).setTargetField(dataField.requireName());
            DerivedOutputField derivedOutputField = encoder.createDerivedField((Model)neuralNetwork, outputField, false);
            result.add((Feature)new ContinuousFeature((PMMLEncoder)encoder, (Field)derivedOutputField));
        }
        return result;
    }

    public MLPRegressor getMLP() {
        return (MLPRegressor)this.get("mlp_", MLPRegressor.class);
    }

    public int getTransformerOutputLayer() {
        return this.getInteger("transformer_output_layer");
    }
}

