/*
 * Decompiled with CFR 0.152.
 */
package jsat.linear.vectorcollection;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Stack;
import java.util.concurrent.ExecutorService;
import jsat.linear.Vec;
import jsat.linear.VecPaired;
import jsat.linear.VecPairedComparable;
import jsat.linear.distancemetrics.ChebyshevDistance;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.EuclideanDistance;
import jsat.linear.distancemetrics.ManhattanDistance;
import jsat.linear.distancemetrics.MinkowskiDistance;
import jsat.linear.vectorcollection.VectorCollection;
import jsat.linear.vectorcollection.VectorCollectionFactory;
import jsat.math.OnLineStatistics;
import jsat.utils.BoundedSortedList;
import jsat.utils.DoubleList;
import jsat.utils.FakeExecutor;
import jsat.utils.IntList;
import jsat.utils.ListUtils;
import jsat.utils.ModifiableCountDownLatch;
import jsat.utils.ProbailityMatch;

public class KDTree<V extends Vec>
implements VectorCollection<V> {
    private static final long serialVersionUID = -7401342201406776463L;
    private DistanceMetric distanceMetric;
    private KDNode root;
    private PivotSelection pvSelection;
    private int size;
    private List<V> allVecs;
    private List<Double> distCache;

    public KDTree(List<V> vecs, DistanceMetric distanceMetric, PivotSelection pvSelection, ExecutorService threadpool) {
        if (!(distanceMetric instanceof EuclideanDistance || distanceMetric instanceof ChebyshevDistance || distanceMetric instanceof ManhattanDistance || distanceMetric instanceof MinkowskiDistance)) {
            throw new ArithmeticException("KD Trees are not compatible with the given distance metric.");
        }
        this.distanceMetric = distanceMetric;
        this.pvSelection = pvSelection;
        this.size = vecs.size();
        vecs = new ArrayList<V>(vecs);
        this.allVecs = vecs;
        this.distCache = threadpool == null || threadpool instanceof FakeExecutor ? distanceMetric.getAccelerationCache(this.allVecs) : distanceMetric.getAccelerationCache(vecs, threadpool);
        IntList vecIndices = new IntList(this.size);
        ListUtils.addRange(vecIndices, 0, this.size, 1);
        if (threadpool == null) {
            this.root = this.buildTree(vecIndices, 0, null, null);
        } else {
            ModifiableCountDownLatch mcdl = new ModifiableCountDownLatch(1);
            this.root = this.buildTree(vecIndices, 0, threadpool, mcdl);
            try {
                mcdl.await();
            }
            catch (InterruptedException ex) {
                this.root = this.buildTree(vecIndices, 0, null, null);
            }
        }
    }

    public KDTree(List<V> vecs, DistanceMetric distanceMetric, PivotSelection pvSelection) {
        this(vecs, distanceMetric, pvSelection, null);
    }

    public KDTree(List<V> vecs, DistanceMetric distanceMetric) {
        this(vecs, distanceMetric, PivotSelection.Variance);
    }

    private KDTree(DistanceMetric distanceMetric, PivotSelection pvSelection) {
        this.distanceMetric = distanceMetric;
        this.pvSelection = pvSelection;
    }

    public KDTree() {
        this(new EuclideanDistance(), PivotSelection.Variance);
    }

    private KDNode buildTree(final List<Integer> data, final int depth, final ExecutorService threadpool, final ModifiableCountDownLatch mcdl) {
        if (data == null || data.isEmpty()) {
            if (threadpool != null) {
                mcdl.countDown();
            }
            return null;
        }
        int mod = ((Vec)this.allVecs.get(0)).length();
        if (data.size() == 1) {
            if (threadpool != null) {
                mcdl.countDown();
            }
            return new KDNode(data.get(0), depth % mod);
        }
        int pivot = -1;
        if (this.pvSelection == PivotSelection.Incremental) {
            pivot = depth % mod;
        } else {
            int j;
            OnLineStatistics[] allStats = new OnLineStatistics[mod];
            for (int j2 = 0; j2 < allStats.length; ++j2) {
                allStats[j2] = new OnLineStatistics();
            }
            for (int i = 0; i < data.size(); ++i) {
                Vec vec = (Vec)this.allVecs.get(data.get(i));
                for (j = 0; j < allStats.length; ++j) {
                    allStats[j].add(vec.get(j));
                }
            }
            double maxVariance = -1.0;
            for (j = 0; j < allStats.length; ++j) {
                if (!(allStats[j].getVarance() > maxVariance)) continue;
                maxVariance = allStats[j].getVarance();
                pivot = j;
            }
            if (pivot < 0) {
                pivot = depth % mod;
            }
        }
        Collections.sort(data, new VecIndexComparator(pivot));
        final int medianIndex = data.size() / 2;
        final KDNode node = new KDNode(data.get(medianIndex), pivot);
        if (threadpool == null) {
            node.setLeft(this.buildTree(data.subList(0, medianIndex), depth + 1, threadpool, mcdl));
            node.setRight(this.buildTree(data.subList(medianIndex + 1, data.size()), depth + 1, threadpool, mcdl));
        } else {
            mcdl.countUp();
            threadpool.submit(new Runnable(){

                @Override
                public void run() {
                    node.setRight(KDTree.this.buildTree(data.subList(medianIndex + 1, data.size()), depth + 1, threadpool, mcdl));
                }
            });
            node.setLeft(this.buildTree(data.subList(0, medianIndex), depth + 1, threadpool, mcdl));
        }
        return node;
    }

    private void knnKDSearch(Vec query, BoundedSortedList<ProbailityMatch<V>> knns) {
        List<Double> qi;
        Stack<KDNode> stack = new Stack<KDNode>();
        stack.push(this.root);
        List<Double> list = qi = this.distanceMetric.supportsAcceleration() ? this.distanceMetric.getQueryInfo(query) : null;
        while (!stack.isEmpty()) {
            KDNode node = (KDNode)stack.pop();
            if (node == null) continue;
            Vec curData = (Vec)this.allVecs.get(node.locatin);
            double distance = this.distanceMetric.dist(node.locatin, query, qi, this.allVecs, this.distCache);
            knns.add(new ProbailityMatch<Vec>(distance, curData));
            double qVal = query.get(node.axis);
            double cVal = curData.get(node.axis);
            double diff = qVal - cVal;
            if (diff <= 0.0) {
                if (qVal - knns.last().getProbability() <= cVal || knns.size() < knns.maxSize()) {
                    stack.push(node.left);
                }
                if (!(qVal + knns.last().getProbability() > cVal) && knns.size() >= knns.maxSize()) continue;
                stack.push(node.right);
                continue;
            }
            if (qVal + knns.last().getProbability() > cVal || knns.size() < knns.maxSize()) {
                stack.push(node.right);
            }
            if (!(qVal - knns.last().getProbability() <= cVal) && knns.size() >= knns.maxSize()) continue;
            stack.push(node.left);
        }
    }

    @Override
    public List<? extends VecPaired<V, Double>> search(Vec query, int neighbors) {
        if (neighbors < 1) {
            throw new RuntimeException("Invalid number of neighbors to search for");
        }
        BoundedSortedList<ProbailityMatch<V>> knns = new BoundedSortedList<ProbailityMatch<V>>(neighbors);
        this.knnKDSearch(query, knns);
        ArrayList<VecPaired<Vec, Double>> knnsList = new ArrayList<VecPaired<Vec, Double>>(knns.size());
        for (int i = 0; i < knns.size(); ++i) {
            ProbailityMatch pm = (ProbailityMatch)knns.get(i);
            knnsList.add(new VecPaired<Vec, Double>((Vec)pm.getMatch(), pm.getProbability()));
        }
        return knnsList;
    }

    private void distanceSearch(Vec query, List<Double> qi, KDNode node, List<VecPairedComparable<V, Double>> knns, double range) {
        if (node == null) {
            return;
        }
        Vec curData = (Vec)this.allVecs.get(node.locatin);
        double distance = this.distanceMetric.dist(node.locatin, query, qi, this.allVecs, this.distCache);
        if (distance <= range) {
            knns.add(new VecPairedComparable<Vec, Double>(curData, distance));
        }
        double diff = query.get(node.axis) - curData.get(node.axis);
        KDNode close = node.left;
        KDNode far = node.right;
        if (diff > 0.0) {
            close = node.right;
            far = node.left;
        }
        this.distanceSearch(query, qi, close, knns, range);
        if (diff * diff <= range) {
            this.distanceSearch(query, qi, far, knns, range);
        }
    }

    @Override
    public int size() {
        return this.size;
    }

    @Override
    public List<? extends VecPaired<V, Double>> search(Vec query, double range) {
        if (range <= 0.0) {
            throw new RuntimeException("Range must be a positive number");
        }
        ArrayList<VecPairedComparable<V, Double>> vecs = new ArrayList<VecPairedComparable<V, Double>>();
        List<Double> qi = this.distanceMetric.supportsAcceleration() ? this.distanceMetric.getQueryInfo(query) : null;
        this.distanceSearch(query, qi, this.root, vecs, range);
        Collections.sort(vecs);
        return vecs;
    }

    @Override
    public KDTree<V> clone() {
        KDTree<V> clone = new KDTree<V>(this.distanceMetric, this.pvSelection);
        if (this.distCache != null) {
            clone.distCache = new DoubleList(this.distCache);
        }
        if (this.allVecs != null) {
            clone.allVecs = new ArrayList<V>(this.allVecs);
        }
        clone.size = this.size;
        if (this.root != null) {
            clone.root = this.root.clone();
        }
        return clone;
    }

    public static class KDTreeFactory<V extends Vec>
    implements VectorCollectionFactory<V> {
        private static final long serialVersionUID = 3508731608962277804L;
        private PivotSelection pivotSelectionMethod;

        public KDTreeFactory(PivotSelection pvSelectionMethod) {
            this.pivotSelectionMethod = pvSelectionMethod;
        }

        public KDTreeFactory() {
            this(PivotSelection.Variance);
        }

        public PivotSelection getPivotSelectionMethod() {
            return this.pivotSelectionMethod;
        }

        public void setPivotSelectionMethod(PivotSelection pivotSelectionMethod) {
            this.pivotSelectionMethod = pivotSelectionMethod;
        }

        @Override
        public VectorCollection<V> getVectorCollection(List<V> source, DistanceMetric distanceMetric) {
            return this.getVectorCollection(source, distanceMetric, null);
        }

        @Override
        public VectorCollection<V> getVectorCollection(List<V> source, DistanceMetric distanceMetric, ExecutorService threadpool) {
            return new KDTree<V>(source, distanceMetric, this.pivotSelectionMethod, threadpool);
        }

        @Override
        public KDTreeFactory<V> clone() {
            return new KDTreeFactory<V>(this.pivotSelectionMethod);
        }
    }

    private class VecIndexComparator
    implements Comparator<Integer> {
        private final int index;

        public VecIndexComparator(int index) {
            this.index = index;
        }

        @Override
        public int compare(Integer o1, Integer o2) {
            return Double.compare(((Vec)KDTree.this.allVecs.get(o1)).get(this.index), ((Vec)KDTree.this.allVecs.get(o2)).get(this.index));
        }
    }

    private class KDNode
    implements Cloneable,
    Serializable {
        int locatin;
        int axis;
        KDNode left;
        KDNode right;

        public KDNode(int locatin, int axis) {
            this.locatin = locatin;
            this.axis = axis;
        }

        public void setAxis(int axis) {
            this.axis = axis;
        }

        public void setLeft(KDNode left) {
            this.left = left;
        }

        public void setLocatin(int locatin) {
            this.locatin = locatin;
        }

        public void setRight(KDNode right) {
            this.right = right;
        }

        public int getAxis() {
            return this.axis;
        }

        public KDNode getLeft() {
            return this.left;
        }

        public int getLocatin() {
            return this.locatin;
        }

        public KDNode getRight() {
            return this.right;
        }

        protected KDNode clone() {
            KDNode clone = new KDNode(this.locatin, this.axis);
            if (this.left != null) {
                clone.left = this.left.clone();
            }
            if (this.right != null) {
                clone.right = this.right.clone();
            }
            return clone;
        }
    }

    public static enum PivotSelection {
        Incremental,
        Variance;

    }
}

