/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.xgboost;

import com.devsmart.ubjson.GsonUtil;
import com.devsmart.ubjson.UBObject;
import com.devsmart.ubjson.UBReader;
import com.devsmart.ubjson.UBValue;
import com.google.common.collect.Iterables;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.google.gson.JsonParser;
import java.io.DataInput;
import java.io.EOFException;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.Reader;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Model;
import org.dmg.pmml.PMML;
import org.dmg.pmml.Visitable;
import org.dmg.pmml.mining.MiningModel;
import org.jpmml.converter.BinaryFeature;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.Label;
import org.jpmml.converter.MissingValueFeature;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ThresholdFeature;
import org.jpmml.converter.visitors.NaNAsMissingDecorator;
import org.jpmml.converter.visitors.TreeModelPruner;
import org.jpmml.xgboost.BinaryLoadable;
import org.jpmml.xgboost.BinomialLogisticRegression;
import org.jpmml.xgboost.Dart;
import org.jpmml.xgboost.FeatureMap;
import org.jpmml.xgboost.GBTree;
import org.jpmml.xgboost.GeneralizedLinearRegression;
import org.jpmml.xgboost.HingeClassification;
import org.jpmml.xgboost.JSONLoadable;
import org.jpmml.xgboost.LambdaMART;
import org.jpmml.xgboost.LinearRegression;
import org.jpmml.xgboost.LogisticRegression;
import org.jpmml.xgboost.MultinomialLogisticRegression;
import org.jpmml.xgboost.ObjFunction;
import org.jpmml.xgboost.PoissonRegression;
import org.jpmml.xgboost.UBJSONLoadable;
import org.jpmml.xgboost.UBJSONUtil;
import org.jpmml.xgboost.XGBoostDataInput;
import org.jpmml.xgboost.XGBoostEncoder;
import org.jpmml.xgboost.visitors.TreeModelCompactor;

public class Learner
implements BinaryLoadable,
JSONLoadable,
UBJSONLoadable {
    private float base_score;
    private int num_feature;
    private int num_class;
    private int contain_extra_attrs;
    private int contain_eval_metrics;
    private int major_version;
    private int minor_version;
    private int num_target;
    private ObjFunction obj;
    private GBTree gbtree;
    private Map<String, String> attributes = null;
    private String[] feature_names = null;
    private String[] feature_types = null;
    private String[] metrics = null;

    @Override
    public void loadBinary(XGBoostDataInput input) throws IOException {
        this.base_score = input.readFloat();
        this.num_feature = input.readInt();
        this.num_class = input.readInt();
        this.contain_extra_attrs = input.readInt();
        this.contain_eval_metrics = input.readInt();
        this.major_version = input.readInt();
        this.minor_version = input.readInt();
        if (this.major_version < 0 || this.major_version > 1) {
            throw new IllegalArgumentException(this.major_version + "." + this.minor_version);
        }
        this.num_target = Math.max(input.readInt(), 1);
        input.readReserved(26);
        String name_obj = input.readString();
        this.obj = this.parseObjective(name_obj);
        this.base_score = this.major_version >= 1 ? this.obj.probToMargin(this.base_score) + 0.0f : this.base_score;
        String name_gbm = input.readString();
        this.gbtree = this.parseGradientBooster(name_gbm);
        this.gbtree.loadBinary(input);
        if (this.contain_extra_attrs != 0) {
            this.attributes = input.readStringMap();
        }
        if (this.major_version >= 1) {
            return;
        }
        if (this.obj instanceof PoissonRegression) {
            try {
                String max_delta_step = input.readString();
            }
            catch (EOFException eOFException) {
                // empty catch block
            }
        }
        if (this.contain_eval_metrics != 0) {
            this.metrics = input.readStringVector();
        }
    }

    @Override
    public void loadJSON(JsonObject root) {
        UBValue value = GsonUtil.toUBValue((JsonElement)root);
        this.loadUBJSON(value.asObject());
    }

    @Override
    public void loadUBJSON(UBObject root) {
        if (!root.containsKey((Object)"version")) {
            throw new IllegalArgumentException("Property \"version\" not found among " + root.keySet());
        }
        int[] version = UBJSONUtil.toIntArray(root.get((Object)"version"));
        this.major_version = version[0];
        this.minor_version = version[1];
        if (this.major_version < 1 || this.minor_version < 3) {
            throw new IllegalArgumentException(this.major_version + "." + this.minor_version);
        }
        UBObject learner = root.get((Object)"learner").asObject();
        UBObject learnerModelParam = learner.get((Object)"learner_model_param").asObject();
        this.base_score = learnerModelParam.get((Object)"base_score").asFloat32();
        this.num_feature = learnerModelParam.get((Object)"num_feature").asInt();
        this.num_class = learnerModelParam.get((Object)"num_class").asInt();
        this.num_target = learnerModelParam.containsKey((Object)"num_target") ? learnerModelParam.get((Object)"num_target").asInt() : 1;
        UBObject objective = learner.get((Object)"objective").asObject();
        String name_obj = objective.get((Object)"name").asString();
        this.obj = this.parseObjective(name_obj);
        this.base_score = this.obj.probToMargin(this.base_score) + 0.0f;
        UBObject gradientBooster = learner.get((Object)"gradient_booster").asObject();
        String name_gbm = gradientBooster.get((Object)"name").asString();
        this.gbtree = this.parseGradientBooster(name_gbm);
        this.gbtree.loadUBJSON(gradientBooster);
        if (learner.containsKey((Object)"feature_names")) {
            this.feature_names = UBJSONUtil.toStringArray(learner.get((Object)"feature_names"));
        }
        if (learner.containsKey((Object)"feature_types")) {
            this.feature_types = UBJSONUtil.toStringArray(learner.get((Object)"feature_types"));
        }
    }

    public <DIS extends InputStream> void loadBinary(DIS is, String charset) throws IOException {
        long offset;
        boolean hasSerializationHeader = Learner.consumeHeader(is, "CONFIG-offset:");
        if (hasSerializationHeader && (offset = ((DataInput)((Object)is)).readLong()) < 0L) {
            throw new IOException();
        }
        boolean hasBInfHeader = Learner.consumeHeader(is, "binf");
        if (hasBInfHeader) {
            // empty if block
        }
        try (XGBoostDataInput input = new XGBoostDataInput(is, charset);){
            this.loadBinary(input);
            if (hasSerializationHeader) {
            } else {
                int eof = is.read();
                if (eof != -1) {
                    throw new IOException();
                }
            }
        }
    }

    public void loadJSON(InputStream is, String charset, String jsonPath) throws IOException {
        JsonParser parser = new JsonParser();
        if (charset == null) {
            charset = "UTF-8";
        }
        try (InputStreamReader reader = new InputStreamReader(is, charset);){
            JsonElement element = parser.parse((Reader)reader);
            JsonObject object = element.getAsJsonObject();
            String[] names = jsonPath.split("\\.");
            for (int i = 0; i < names.length; ++i) {
                String name = names[i];
                if (i == 0 && "$".equals(name)) continue;
                JsonElement childElement = object.get(name);
                if (childElement == null) {
                    throw new IllegalArgumentException("Property \"" + name + "\" not among " + object.keySet());
                }
                object = childElement.getAsJsonObject();
            }
            this.loadJSON(object);
            int eof = is.read();
            if (eof != -1) {
                throw new IOException();
            }
        }
    }

    public void loadUBJSON(InputStream is, String jsonPath) throws IOException {
        try (UBReader reader = new UBReader(is);){
            UBObject object = reader.read().asObject();
            String[] names = jsonPath.split("\\.");
            for (int i = 0; i < names.length; ++i) {
                String name = names[i];
                if (i == 0 && "$".equals(name)) continue;
                UBValue childValue = object.get((Object)name);
                if (childValue == null) {
                    throw new IllegalArgumentException("Property \"" + name + "\" not among " + object.keySet());
                }
                object = childValue.asObject();
            }
            this.loadUBJSON(object);
            int eof = is.read();
            if (eof != -1) {
                throw new IOException();
            }
        }
    }

    public FeatureMap encodeFeatureMap() {
        if (this.feature_names == null || this.feature_types == null) {
            throw new IllegalArgumentException();
        }
        FeatureMap result = new FeatureMap();
        for (int i = 0; i < this.feature_names.length; ++i) {
            result.addEntry(this.feature_names[i], this.feature_types[i]);
        }
        return result;
    }

    public Schema encodeSchema(String targetName, List<String> targetCategories, FeatureMap featureMap, XGBoostEncoder encoder) {
        if (targetName == null) {
            targetName = "_target";
        }
        if (this.num_target != 1) {
            throw new IllegalArgumentException();
        }
        Label label = this.obj.encodeLabel(targetName, targetCategories, (PMMLEncoder)encoder);
        List<Feature> features = featureMap.encodeFeatures((PMMLEncoder)encoder);
        return new Schema((PMMLEncoder)encoder, label, features);
    }

    public Schema toXGBoostSchema(final boolean numeric, final Schema schema) {
        final GBTree gbtree = this.gbtree;
        Function<Feature, Feature> function = new Function<Feature, Feature>(){
            private List<? extends Feature> features;
            {
                this.features = schema.getFeatures();
            }

            @Override
            public Feature apply(Feature feature) {
                int splitType = this.getSplitType(feature);
                switch (splitType) {
                    case 0: {
                        return this.applyNumerical(feature);
                    }
                    case 1: {
                        return this.applyCategorical(feature);
                    }
                }
                throw new IllegalArgumentException();
            }

            private Feature applyNumerical(Feature feature) {
                if (feature instanceof BinaryFeature) {
                    BinaryFeature binaryFeature = (BinaryFeature)feature;
                    return binaryFeature;
                }
                if (feature instanceof MissingValueFeature) {
                    MissingValueFeature missingValueFeature = (MissingValueFeature)feature;
                    return missingValueFeature;
                }
                if (feature instanceof ThresholdFeature && !numeric) {
                    ThresholdFeature thresholdFeature = (ThresholdFeature)feature;
                    return thresholdFeature;
                }
                ContinuousFeature continuousFeature = feature.toContinuousFeature();
                DataType dataType = continuousFeature.getDataType();
                switch (dataType) {
                    case INTEGER: 
                    case FLOAT: {
                        break;
                    }
                    case DOUBLE: {
                        continuousFeature = continuousFeature.toContinuousFeature(DataType.FLOAT);
                        break;
                    }
                    default: {
                        throw new IllegalArgumentException("Expected integer, float or double data type for continuous feature " + continuousFeature.getName() + ", got " + dataType.value() + " data type");
                    }
                }
                return continuousFeature;
            }

            private Feature applyCategorical(Feature feature) {
                if (feature instanceof CategoricalFeature) {
                    CategoricalFeature categoricalFeature = (CategoricalFeature)feature;
                    return categoricalFeature;
                }
                throw new IllegalArgumentException();
            }

            private int getSplitType(Feature feature) {
                int splitIndex = this.features.indexOf(feature);
                if (splitIndex < 0) {
                    throw new IllegalArgumentException();
                }
                return this.getSplitType(splitIndex);
            }

            private int getSplitType(int splitIndex) {
                Set<Integer> splitTypes = gbtree.getSplitType(splitIndex);
                if (splitTypes.size() == 0) {
                    return 0;
                }
                if (splitTypes.size() == 1) {
                    return (Integer)Iterables.getOnlyElement(splitTypes);
                }
                throw new IllegalArgumentException();
            }
        };
        return schema.toTransformedSchema((Function)function);
    }

    public PMML encodePMML(Map<String, ?> options, String targetName, List<String> targetCategories, FeatureMap featureMap) {
        XGBoostEncoder encoder = new XGBoostEncoder();
        Boolean nanAsMissing = (Boolean)options.get("nan_as_missing");
        Schema schema = this.encodeSchema(targetName, targetCategories, featureMap, encoder);
        MiningModel miningModel = this.encodeMiningModel(options, schema);
        PMML pmml = encoder.encodePMML((Model)miningModel);
        if (Boolean.TRUE.equals(nanAsMissing)) {
            NaNAsMissingDecorator visitor = new NaNAsMissingDecorator();
            visitor.applyTo((Visitable)pmml);
        }
        return pmml;
    }

    public MiningModel encodeMiningModel(Map<String, ?> options, Schema schema) {
        TreeModelCompactor visitor;
        Boolean compact = (Boolean)options.get("compact");
        Boolean numeric = (Boolean)options.get("numeric");
        Boolean prune = (Boolean)options.get("prune");
        Integer ntreeLimit = (Integer)options.get("ntree_limit");
        if (numeric == null) {
            numeric = Boolean.TRUE;
        }
        MiningModel miningModel = this.gbtree.encodeMiningModel(this.obj, this.base_score, ntreeLimit, numeric, schema).setAlgorithmName("XGBoost (" + this.gbtree.getAlgorithmName() + ")");
        if (Boolean.TRUE.equals(compact)) {
            if (Boolean.FALSE.equals(numeric)) {
                throw new IllegalArgumentException("Conflicting XGBoost options");
            }
            visitor = new TreeModelCompactor();
            visitor.applyTo((Visitable)miningModel);
        }
        if (Boolean.TRUE.equals(prune)) {
            visitor = new TreeModelPruner();
            visitor.applyTo((Visitable)miningModel);
        }
        return miningModel;
    }

    public int num_feature() {
        return this.num_feature;
    }

    public int num_class() {
        return this.num_class;
    }

    public ObjFunction obj() {
        return this.obj;
    }

    private GBTree parseGradientBooster(String name_gbm) {
        switch (name_gbm) {
            case "gbtree": {
                return new GBTree();
            }
            case "dart": {
                return new Dart();
            }
        }
        throw new IllegalArgumentException(name_gbm);
    }

    private ObjFunction parseObjective(String name_obj) {
        switch (name_obj) {
            case "reg:linear": 
            case "reg:pseudohubererror": 
            case "reg:squarederror": 
            case "reg:squaredlogerror": {
                return new LinearRegression(name_obj);
            }
            case "reg:logistic": {
                return new LogisticRegression(name_obj);
            }
            case "reg:gamma": 
            case "reg:tweedie": {
                return new GeneralizedLinearRegression(name_obj);
            }
            case "count:poisson": {
                return new PoissonRegression(name_obj);
            }
            case "binary:hinge": {
                return new HingeClassification(name_obj);
            }
            case "binary:logistic": {
                return new BinomialLogisticRegression(name_obj);
            }
            case "rank:map": 
            case "rank:ndcg": 
            case "rank:pairwise": {
                return new LambdaMART(name_obj);
            }
            case "multi:softmax": 
            case "multi:softprob": {
                return new MultinomialLogisticRegression(name_obj, this.num_class);
            }
        }
        throw new IllegalArgumentException(name_obj);
    }

    private static <DIS extends InputStream> boolean consumeHeader(DIS is, String header) throws IOException {
        byte[] headerBytes = header.getBytes(StandardCharsets.UTF_8);
        byte[] buffer = new byte[headerBytes.length];
        is.mark(buffer.length);
        ((DataInput)((Object)is)).readFully(buffer);
        boolean equals = Arrays.equals(headerBytes, buffer);
        if (!equals) {
            is.reset();
        }
        return equals;
    }
}

