/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.h2o;

import hex.genmodel.algos.tree.NaSplitDir;
import hex.genmodel.algos.tree.SharedTreeMojoModel;
import hex.genmodel.utils.ByteBufferWrapper;
import hex.genmodel.utils.GenmodelBitSet;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.dmg.pmml.DataType;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.SimplePredicate;
import org.dmg.pmml.True;
import org.dmg.pmml.tree.BranchNode;
import org.dmg.pmml.tree.LeafNode;
import org.dmg.pmml.tree.Node;
import org.dmg.pmml.tree.SimpleNode;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.CategoryManager;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.ContinuousLabel;
import org.jpmml.converter.Feature;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PredicateManager;
import org.jpmml.converter.Schema;
import org.jpmml.h2o.Converter;

public abstract class SharedTreeMojoModelConverter<M extends SharedTreeMojoModel>
extends Converter<M> {
    private static final Field FIELD_COMPRESSEDTREES;
    private static final Field FIELD_NTREEGROUPS;
    private static final Field FIELD_NTREESPERGROUP;

    public SharedTreeMojoModelConverter(M model) {
        super(model);
    }

    public List<TreeModel> encodeTreeModels(Schema schema) {
        SharedTreeMojoModel model = (SharedTreeMojoModel)this.getModel();
        if (model._mojo_version < 1.2) {
            throw new IllegalArgumentException("Version " + model._mojo_version + " is not supported");
        }
        byte[][] compressedTrees = SharedTreeMojoModelConverter.getCompressedTrees(model);
        PredicateManager predicateManager = new PredicateManager();
        List<TreeModel> result = Stream.of(compressedTrees).map(compressedTree -> SharedTreeMojoModelConverter.encodeTreeModel(compressedTree, predicateManager, schema)).collect(Collectors.toList());
        return result;
    }

    public static TreeModel encodeTreeModel(byte[] compressedTree, PredicateManager predicateManager, Schema schema) {
        ContinuousLabel label = new ContinuousLabel(null, DataType.DOUBLE);
        AtomicInteger idSequence = new AtomicInteger(1);
        ByteBufferWrapper buffer = new ByteBufferWrapper(compressedTree);
        Node root = SharedTreeMojoModelConverter.encodeNode((org.dmg.pmml.Predicate)True.INSTANCE, idSequence, compressedTree, buffer, predicateManager, new CategoryManager(), schema);
        TreeModel treeModel = new TreeModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema((Label)label), root).setMissingValueStrategy(TreeModel.MissingValueStrategy.DEFAULT_CHILD);
        return treeModel;
    }

    public static Node encodeNode(org.dmg.pmml.Predicate predicate, AtomicInteger idSequence, byte[] compressedTree, ByteBufferWrapper byteBuffer, PredicateManager predicateManager, CategoryManager categoryManager, Schema schema) {
        Node rightChild;
        Node leftChild;
        org.dmg.pmml.Predicate rightPredicate;
        org.dmg.pmml.Predicate leftPredicate;
        Integer id = SharedTreeMojoModelConverter.nextId(idSequence);
        int nodeType = byteBuffer.get1U();
        int lmask = nodeType & 0x33;
        int lmask2 = (nodeType & 0xC0) >> 2;
        int equal = nodeType & 0xC;
        char colId = byteBuffer.get2();
        if (colId == '\uffff') {
            double score = byteBuffer.get4f();
            SimpleNode result = new LeafNode((Object)score, predicate).setId((Object)id);
            return result;
        }
        int naSplitDir = byteBuffer.get1U();
        boolean naVsRest = naSplitDir == NaSplitDir.NAvsREST.value();
        boolean leftward = naSplitDir == NaSplitDir.NALeft.value() || naSplitDir == NaSplitDir.Left.value();
        Feature feature = schema.getFeature((int)colId);
        CategoryManager leftCategoryManager = categoryManager;
        CategoryManager rightCategoryManager = categoryManager;
        if (naVsRest) {
            leftPredicate = predicateManager.createSimplePredicate(feature, SimplePredicate.Operator.IS_NOT_MISSING, null);
            rightPredicate = predicateManager.createSimplePredicate(feature, SimplePredicate.Operator.IS_MISSING, null);
        } else if (equal != 0) {
            CategoricalFeature categoricalFeature = (CategoricalFeature)feature;
            GenmodelBitSet bitSet = new GenmodelBitSet(0);
            if (equal == 8) {
                bitSet.fill2(compressedTree, byteBuffer);
            } else if (equal == 12) {
                bitSet.fill3(compressedTree, byteBuffer);
            } else {
                throw new IllegalArgumentException("Node type " + equal + " is not supported");
            }
            FieldName name = categoricalFeature.getName();
            List values = categoricalFeature.getValues();
            Predicate valueFilter = categoryManager.getValueFilter(name);
            ArrayList leftValues = new ArrayList();
            ArrayList rightValues = new ArrayList();
            for (int i = 0; i < values.size(); ++i) {
                Object value = values.get(i);
                if (!valueFilter.test(value)) continue;
                if (!bitSet.contains(i)) {
                    leftValues.add(value);
                    continue;
                }
                rightValues.add(value);
            }
            leftCategoryManager = leftCategoryManager.fork(name, leftValues);
            rightCategoryManager = rightCategoryManager.fork(name, rightValues);
            leftPredicate = predicateManager.createSimpleSetPredicate((Feature)categoricalFeature, leftValues);
            rightPredicate = predicateManager.createSimpleSetPredicate((Feature)categoricalFeature, rightValues);
        } else {
            ContinuousFeature continuousFeature = feature.toContinuousFeature();
            Double splitVal = byteBuffer.get4f();
            leftPredicate = predicateManager.createSimplePredicate((Feature)continuousFeature, SimplePredicate.Operator.LESS_THAN, (Object)splitVal);
            rightPredicate = predicateManager.createSimplePredicate((Feature)continuousFeature, SimplePredicate.Operator.GREATER_OR_EQUAL, (Object)splitVal);
        }
        ByteBufferWrapper leftByteBuffer = new ByteBufferWrapper(compressedTree);
        leftByteBuffer.skip(byteBuffer.position());
        if (lmask <= 3) {
            leftByteBuffer.skip(lmask + 1);
        }
        if ((lmask & 0x10) != 0) {
            double score = leftByteBuffer.get4f();
            leftChild = new LeafNode((Object)score, leftPredicate).setId((Object)SharedTreeMojoModelConverter.nextId(idSequence));
        } else {
            leftChild = SharedTreeMojoModelConverter.encodeNode(leftPredicate, idSequence, compressedTree, leftByteBuffer, predicateManager, leftCategoryManager, schema);
        }
        ByteBufferWrapper rightByteBuffer = new ByteBufferWrapper(compressedTree);
        rightByteBuffer.skip(byteBuffer.position());
        switch (lmask) {
            case 0: {
                rightByteBuffer.skip(rightByteBuffer.get1U());
                break;
            }
            case 1: {
                rightByteBuffer.skip((int)rightByteBuffer.get2());
                break;
            }
            case 2: {
                rightByteBuffer.skip(rightByteBuffer.get3());
                break;
            }
            case 3: {
                rightByteBuffer.skip(rightByteBuffer.get4());
                break;
            }
            case 48: {
                rightByteBuffer.skip(4);
                break;
            }
            default: {
                throw new IllegalArgumentException("Node type " + lmask + " is not supported");
            }
        }
        if ((lmask2 & 0x10) != 0) {
            double score = rightByteBuffer.get4f();
            rightChild = new LeafNode((Object)score, rightPredicate).setId((Object)SharedTreeMojoModelConverter.nextId(idSequence));
        } else {
            rightChild = SharedTreeMojoModelConverter.encodeNode(rightPredicate, idSequence, compressedTree, rightByteBuffer, predicateManager, rightCategoryManager, schema);
        }
        Node result = new BranchNode(null, predicate).setId((Object)id).setDefaultChild(leftward ? leftChild.getId() : rightChild.getId()).addNodes(leftChild, rightChild);
        return result;
    }

    public static byte[][] getCompressedTrees(SharedTreeMojoModel model) {
        return (byte[][])SharedTreeMojoModelConverter.getFieldValue(FIELD_COMPRESSEDTREES, model);
    }

    public static int getNTreeGroups(SharedTreeMojoModel model) {
        return (Integer)SharedTreeMojoModelConverter.getFieldValue(FIELD_NTREEGROUPS, model);
    }

    public static int getNTreesPerGroup(SharedTreeMojoModel model) {
        return (Integer)SharedTreeMojoModelConverter.getFieldValue(FIELD_NTREESPERGROUP, model);
    }

    private static Integer nextId(AtomicInteger id) {
        return id.getAndIncrement();
    }

    static {
        try {
            FIELD_COMPRESSEDTREES = SharedTreeMojoModel.class.getDeclaredField("_compressed_trees");
            FIELD_NTREEGROUPS = SharedTreeMojoModel.class.getDeclaredField("_ntree_groups");
            FIELD_NTREESPERGROUP = SharedTreeMojoModel.class.getDeclaredField("_ntrees_per_group");
        }
        catch (ReflectiveOperationException roe) {
            throw new RuntimeException(roe);
        }
    }
}

