/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.lightgbm;

import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.function.Predicate;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.SimplePredicate;
import org.dmg.pmml.True;
import org.dmg.pmml.tree.CountingBranchNode;
import org.dmg.pmml.tree.CountingLeafNode;
import org.dmg.pmml.tree.Node;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.converter.BinaryFeature;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.CategoryManager;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.FeatureUtil;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PredicateManager;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ValueUtil;
import org.jpmml.lightgbm.BinaryCategoricalFeature;
import org.jpmml.lightgbm.DirectCategoricalFeature;
import org.jpmml.lightgbm.Section;

public class Tree {
    private int num_leaves_;
    private int num_cat_;
    private int[] left_child_;
    private int[] right_child_;
    private int[] split_feature_real_;
    private double[] threshold_;
    private int[] decision_type_;
    private double[] leaf_value_;
    private int[] leaf_count_;
    private double[] internal_value_;
    private int[] internal_count_;
    private int[] cat_boundaries_;
    private long[] cat_threshold_;
    private static final int MASK_CATEGORICAL = 1;
    private static final int MASK_DEFAULT_LEFT = 2;

    public void load(Section section) {
        this.num_leaves_ = section.getInt("num_leaves");
        this.num_cat_ = section.getInt("num_cat");
        this.left_child_ = section.getIntArray("left_child", this.num_leaves_ - 1);
        this.right_child_ = section.getIntArray("right_child", this.num_leaves_ - 1);
        this.split_feature_real_ = section.getIntArray("split_feature", this.num_leaves_ - 1);
        this.threshold_ = section.getDoubleArray("threshold", this.num_leaves_ - 1);
        this.decision_type_ = section.getIntArray("decision_type", this.num_leaves_ - 1);
        this.leaf_value_ = section.getDoubleArray("leaf_value", this.num_leaves_);
        this.leaf_count_ = section.getIntArray("leaf_count", this.num_leaves_);
        this.internal_value_ = section.getDoubleArray("internal_value", this.num_leaves_ - 1);
        this.internal_count_ = section.getIntArray("internal_count", this.num_leaves_ - 1);
        if (this.num_cat_ > 0) {
            this.cat_boundaries_ = section.getIntArray("cat_boundaries", this.num_cat_ + 1);
            this.cat_threshold_ = section.getUnsignedIntArray("cat_threshold", -1);
        }
    }

    public TreeModel encodeTreeModel(PredicateManager predicateManager, Schema schema) {
        Node root = this.encodeNode((org.dmg.pmml.Predicate)True.INSTANCE, predicateManager, new CategoryManager(), 0, schema);
        TreeModel treeModel = new TreeModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema((Label)schema.getLabel()), root).setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT).setMissingValueStrategy(TreeModel.MissingValueStrategy.DEFAULT_CHILD);
        return treeModel;
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    public Node encodeNode(org.dmg.pmml.Predicate predicate, PredicateManager predicateManager, CategoryManager categoryManager, int index, Schema schema) {
        FieldName name;
        org.dmg.pmml.Predicate rightPredicate;
        org.dmg.pmml.Predicate leftPredicate;
        Object value;
        Integer id = ~index;
        if (index < 0) return new CountingLeafNode((Object)this.leaf_value_[index ^= 0xFFFFFFFF], predicate).setId((Object)id).setRecordCount((Number)this.leaf_count_[index]);
        Feature feature = schema.getFeature(this.split_feature_real_[index]);
        double threshold_ = this.threshold_[index];
        int decision_type_ = this.decision_type_[index];
        CategoryManager leftCategoryManager = categoryManager;
        CategoryManager rightCategoryManager = categoryManager;
        boolean defaultLeft = Tree.hasDefaultLeftMask(decision_type_);
        if (feature instanceof BinaryFeature) {
            BinaryFeature binaryFeature = (BinaryFeature)feature;
            if (Tree.hasCategoricalMask(decision_type_)) {
                throw new IllegalArgumentException("Expected a false (off) categorical split mask for binary feature " + FeatureUtil.getName((Feature)binaryFeature) + ", got true (on)");
            }
            if (threshold_ != 0.5) {
                throw new IllegalArgumentException("Expected 0.5 as a threshold value for binary feature " + FeatureUtil.getName((Feature)binaryFeature) + ", got " + threshold_);
            }
            value = binaryFeature.getValue();
            leftPredicate = predicateManager.createSimplePredicate((Feature)binaryFeature, SimplePredicate.Operator.NOT_EQUAL, value);
            rightPredicate = predicateManager.createSimplePredicate((Feature)binaryFeature, SimplePredicate.Operator.EQUAL, value);
        } else if (feature instanceof BinaryCategoricalFeature) {
            BinaryCategoricalFeature binaryCategoricalFeature = (BinaryCategoricalFeature)feature;
            if (!Tree.hasCategoricalMask(decision_type_)) {
                throw new IllegalArgumentException("Expected a true (on) categorical split mask for binary categorical feature " + FeatureUtil.getName((Feature)binaryCategoricalFeature) + ", got false (off)");
            }
            name = binaryCategoricalFeature.getName();
            List values = binaryCategoricalFeature.getValues();
            int cat_idx = ValueUtil.asInt((Number)threshold_);
            List<Object> leftValues = this.selectValues(false, values, Objects::nonNull, cat_idx, true);
            List<Object> rightValues = this.selectValues(false, values, Objects::nonNull, cat_idx, false);
            Object value2 = values.get(1);
            if (leftValues.size() == 0 && rightValues.size() == 1) {
                rightCategoryManager = rightCategoryManager.fork(name, rightValues);
                leftPredicate = predicateManager.createSimplePredicate((Feature)binaryCategoricalFeature, SimplePredicate.Operator.NOT_EQUAL, value2);
                rightPredicate = predicateManager.createSimplePredicate((Feature)binaryCategoricalFeature, SimplePredicate.Operator.EQUAL, value2);
                defaultLeft = true;
            } else {
                if (leftValues.size() != 1) throw new IllegalArgumentException("Neither left nor right branch is selectable");
                if (rightValues.size() != 0) throw new IllegalArgumentException("Neither left nor right branch is selectable");
                leftCategoryManager = leftCategoryManager.fork(name, leftValues);
                leftPredicate = predicateManager.createSimplePredicate((Feature)binaryCategoricalFeature, SimplePredicate.Operator.EQUAL, value2);
                rightPredicate = predicateManager.createSimplePredicate((Feature)binaryCategoricalFeature, SimplePredicate.Operator.NOT_EQUAL, value2);
                defaultLeft = false;
            }
        } else if (feature instanceof CategoricalFeature) {
            CategoricalFeature categoricalFeature = (CategoricalFeature)feature;
            if (!Tree.hasCategoricalMask(decision_type_)) {
                throw new IllegalArgumentException("Expected a true (on) categorical split mask for categorical feature " + FeatureUtil.getName((Feature)categoricalFeature) + ", got false (off)");
            }
            name = categoricalFeature.getName();
            boolean indexAsValue = categoricalFeature instanceof DirectCategoricalFeature;
            List values = categoricalFeature.getValues();
            Predicate valueFilter = categoryManager.getValueFilter(name);
            int cat_idx = ValueUtil.asInt((Number)threshold_);
            List<Object> leftValues = this.selectValues(indexAsValue, values, valueFilter, cat_idx, true);
            List<Object> rightValues = this.selectValues(indexAsValue, values, valueFilter, cat_idx, false);
            Set parentValues = (Set)categoryManager.getValue(name);
            if (leftValues.size() == 0) {
                throw new IllegalArgumentException("Left branch is not selectable");
            }
            if (parentValues != null && rightValues.size() == parentValues.size()) {
                throw new IllegalArgumentException("Right branch is not selectable");
            }
            leftCategoryManager = categoryManager.fork(name, leftValues);
            rightCategoryManager = categoryManager.fork(name, rightValues);
            leftPredicate = predicateManager.createSimpleSetPredicate((Feature)categoricalFeature, leftValues);
            rightPredicate = predicateManager.createSimpleSetPredicate((Feature)categoricalFeature, rightValues);
            defaultLeft = false;
        } else {
            ContinuousFeature continuousFeature = feature.toContinuousFeature();
            if (Tree.hasCategoricalMask(decision_type_)) {
                throw new IllegalArgumentException("Expected a false (off) categorical split mask for continuous feature " + FeatureUtil.getName((Feature)continuousFeature) + ", got true (on)");
            }
            value = threshold_;
            leftPredicate = predicateManager.createSimplePredicate((Feature)continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, value);
            rightPredicate = predicateManager.createSimplePredicate((Feature)continuousFeature, SimplePredicate.Operator.GREATER_THAN, value);
        }
        Node leftChild = this.encodeNode(leftPredicate, predicateManager, leftCategoryManager, this.left_child_[index], schema);
        Node rightChild = this.encodeNode(rightPredicate, predicateManager, rightCategoryManager, this.right_child_[index], schema);
        return new CountingBranchNode(null, predicate).setId((Object)id).setDefaultChild(defaultLeft ? leftChild.getId() : rightChild.getId()).setRecordCount((Number)this.internal_count_[index]).addNodes(leftChild, rightChild);
    }

    private List<Object> selectValues(boolean indexAsValue, List<?> values, Predicate<Object> valueFilter, int cat_idx, boolean left) {
        ArrayList<Object> result = left ? new ArrayList() : new ArrayList(values);
        int n = this.cat_boundaries_[cat_idx + 1] - this.cat_boundaries_[cat_idx];
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < 32; ++j) {
                int cat = i * 32 + j;
                if (!Tree.findInBitset(this.cat_threshold_, this.cat_boundaries_[cat_idx], n, cat)) continue;
                Integer value2 = indexAsValue ? Integer.valueOf(cat) : values.get(cat);
                if (left) {
                    result.add(value2);
                    continue;
                }
                result.remove(value2);
            }
        }
        result.removeIf(value -> !valueFilter.test(value));
        return result;
    }

    Boolean isBinary(int feature) {
        Boolean result = null;
        for (int i = 0; i < this.split_feature_real_.length; ++i) {
            if (this.split_feature_real_[i] != feature) continue;
            if (Tree.hasCategoricalMask(this.decision_type_[i])) {
                return Boolean.FALSE;
            }
            if (this.threshold_[i] != 0.5) {
                return Boolean.FALSE;
            }
            result = Boolean.TRUE;
        }
        return result;
    }

    Boolean isCategorical(int feature) {
        Boolean result = null;
        for (int i = 0; i < this.split_feature_real_.length; ++i) {
            if (this.split_feature_real_[i] != feature) continue;
            if (!Tree.hasCategoricalMask(this.decision_type_[i])) {
                return Boolean.FALSE;
            }
            result = Boolean.TRUE;
        }
        return result;
    }

    private static boolean hasCategoricalMask(int decision_type) {
        return Tree.getDecisionType(decision_type, 1) == 1;
    }

    private static boolean hasDefaultLeftMask(int decision_type) {
        return Tree.getDecisionType(decision_type, 2) == 2;
    }

    static int getDecisionType(int decision_type, int mask) {
        return decision_type & mask;
    }

    static int getMissingType(int decision_type) {
        return Tree.getDecisionType(decision_type >> 2, 3);
    }

    private static boolean findInBitset(long[] bits, int bitOffset, int n, int pos) {
        int i1 = pos / 32;
        if (i1 >= n) {
            return false;
        }
        int i2 = pos % 32;
        return (bits[bitOffset + i1] >> i2 & 1L) == 1L;
    }
}

