/*
 * Decompiled with CFR 0.152.
 */
package jsat.clustering;

import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.SynchronousQueue;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.DataSet;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.DataPoint;
import jsat.clustering.ClustererBase;
import jsat.linear.DenseVector;
import jsat.linear.Vec;
import jsat.linear.VecPaired;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.EuclideanDistance;
import jsat.linear.distancemetrics.TrainableDistanceMetric;
import jsat.linear.vectorcollection.KDTree;
import jsat.linear.vectorcollection.VectorCollection;
import jsat.linear.vectorcollection.VectorCollectionFactory;
import jsat.math.OnLineStatistics;
import jsat.utils.SystemInfo;

public class DBSCAN
extends ClustererBase {
    private static final long serialVersionUID = 1627963360642560455L;
    private static final int UNCLASSIFIED = -1;
    private static final int NOISE = -2;
    private VectorCollectionFactory<VecPaired<Vec, Integer>> vecFactory;
    private DistanceMetric dm;
    private double stndDevs = 2.0;

    public DBSCAN(DistanceMetric dm, VectorCollectionFactory<VecPaired<Vec, Integer>> vecFactory) {
        this.dm = dm;
        this.vecFactory = vecFactory;
    }

    public DBSCAN() {
        this(new EuclideanDistance());
    }

    public DBSCAN(DistanceMetric dm) {
        this(dm, new KDTree.KDTreeFactory<VecPaired<Vec, Integer>>());
    }

    public DBSCAN(DBSCAN toCopy) {
        this.vecFactory = toCopy.vecFactory.clone();
        this.dm = toCopy.dm.clone();
        this.stndDevs = toCopy.stndDevs;
    }

    public List<List<DataPoint>> cluster(DataSet dataSet, int minPts) {
        return DBSCAN.createClusterListFromAssignmentArray(this.cluster(dataSet, minPts, (int[])null), dataSet);
    }

    public int[] cluster(DataSet dataSet, int minPts, int[] designations) {
        OnLineStatistics stats = new OnLineStatistics();
        TrainableDistanceMetric.trainIfNeeded(this.dm, dataSet);
        VectorCollection<VecPaired<Vec, Integer>> vc = this.vecFactory.getVectorCollection(this.getVecIndexPairs(dataSet), this.dm);
        List<DataPoint> dps = dataSet.getDataPoints();
        for (DataPoint dp : dps) {
            stats.add(vc.search(dp.getNumericalValues(), minPts + 1).get(minPts).getPair());
        }
        double eps = stats.getMean() + stats.getStandardDeviation() * this.stndDevs;
        return this.cluster(dataSet, eps, minPts, vc, designations);
    }

    @Override
    public int[] cluster(DataSet dataSet, int[] designations) {
        return this.cluster(dataSet, 3, designations);
    }

    @Override
    public int[] cluster(DataSet dataSet, ExecutorService threadpool, int[] designations) {
        return this.cluster(dataSet, 3, threadpool, designations);
    }

    @Override
    public DBSCAN clone() {
        return new DBSCAN(this);
    }

    public List<List<DataPoint>> cluster(DataSet dataSet, int minPts, ExecutorService threadpool) {
        return DBSCAN.createClusterListFromAssignmentArray(this.cluster(dataSet, minPts, threadpool, null), dataSet);
    }

    public int[] cluster(DataSet dataSet, int minPts, ExecutorService threadpool, int[] designations) {
        int i;
        OnLineStatistics stats = null;
        TrainableDistanceMetric.trainIfNeeded(this.dm, dataSet, threadpool);
        VectorCollection<VecPaired<Vec, Integer>> vc = this.vecFactory.getVectorCollection(this.getVecIndexPairs(dataSet), this.dm);
        ArrayBlockingQueue<DataPoint> queue = new ArrayBlockingQueue<DataPoint>(SystemInfo.L2CacheSize * 2);
        ArrayList<Future<OnLineStatistics>> futures = new ArrayList<Future<OnLineStatistics>>(SystemInfo.LogicalCores);
        for (i = 0; i < SystemInfo.LogicalCores; ++i) {
            futures.add(threadpool.submit(new StatsWorker(queue, vc, minPts)));
        }
        for (i = 0; i < dataSet.getSampleSize(); ++i) {
            queue.add(dataSet.getDataPoint(i));
        }
        for (i = 0; i < SystemInfo.LogicalCores; ++i) {
            queue.add(new DataPoint(new DenseVector(0), new int[0], new CategoricalData[0]));
        }
        for (Future future : futures) {
            try {
                if (stats == null) {
                    stats = (OnLineStatistics)future.get();
                    continue;
                }
                stats = OnLineStatistics.add(stats, (OnLineStatistics)future.get());
            }
            catch (InterruptedException ex) {
                Logger.getLogger(DBSCAN.class.getName()).log(Level.SEVERE, null, ex);
            }
            catch (ExecutionException ex) {
                Logger.getLogger(DBSCAN.class.getName()).log(Level.SEVERE, null, ex);
            }
        }
        double eps = stats.getMean() + stats.getStandardDeviation() * this.stndDevs;
        return this.cluster(dataSet, eps, minPts, vc, threadpool, designations);
    }

    private List<VecPaired<Vec, Integer>> getVecIndexPairs(DataSet dataSet) {
        ArrayList<VecPaired<Vec, Integer>> vecs = new ArrayList<VecPaired<Vec, Integer>>(dataSet.getSampleSize());
        for (int i = 0; i < dataSet.getSampleSize(); ++i) {
            vecs.add(new VecPaired<Vec, Integer>(dataSet.getDataPoint(i).getNumericalValues(), i));
        }
        return vecs;
    }

    public List<List<DataPoint>> cluster(DataSet dataSet, double eps, int minPts) {
        return DBSCAN.createClusterListFromAssignmentArray(this.cluster(dataSet, eps, minPts, (int[])null), dataSet);
    }

    public int[] cluster(DataSet dataSet, double eps, int minPts, int[] designations) {
        TrainableDistanceMetric.trainIfNeeded(this.dm, dataSet);
        return this.cluster(dataSet, eps, minPts, this.vecFactory.getVectorCollection(this.getVecIndexPairs(dataSet), this.dm), designations);
    }

    public List<List<DataPoint>> cluster(DataSet dataSet, double eps, int minPts, ExecutorService threadpool) {
        return DBSCAN.createClusterListFromAssignmentArray(this.cluster(dataSet, eps, minPts, threadpool, null), dataSet);
    }

    public int[] cluster(DataSet dataSet, double eps, int minPts, ExecutorService threadpool, int[] designations) {
        TrainableDistanceMetric.trainIfNeeded(this.dm, dataSet, threadpool);
        return this.cluster(dataSet, eps, minPts, this.vecFactory.getVectorCollection(this.getVecIndexPairs(dataSet), this.dm), threadpool, designations);
    }

    private int[] cluster(DataSet dataSet, double eps, int minPts, VectorCollection<VecPaired<Vec, Integer>> vc, int[] pointCats) {
        if (pointCats == null) {
            pointCats = new int[dataSet.getSampleSize()];
        }
        Arrays.fill(pointCats, -1);
        int curClusterID = 0;
        for (int i = 0; i < pointCats.length; ++i) {
            if (pointCats[i] != -1 || !this.expandCluster(pointCats, dataSet, i, curClusterID, eps, minPts, vc)) continue;
            ++curClusterID;
        }
        return pointCats;
    }

    private int[] cluster(DataSet dataSet, double eps, int minPts, VectorCollection<VecPaired<Vec, Integer>> vc, ExecutorService threadpool, int[] pointCats) {
        int i;
        if (pointCats == null) {
            pointCats = new int[dataSet.getSampleSize()];
        }
        Arrays.fill(pointCats, -1);
        SynchronousQueue<List<? extends VecPaired<VecPaired<Vec, Integer>, Double>>> resultQ = new SynchronousQueue<List<? extends VecPaired<VecPaired<Vec, Integer>, Double>>>();
        LinkedBlockingQueue<Vec> sourceQ = new LinkedBlockingQueue<Vec>();
        for (int i2 = 0; i2 < SystemInfo.LogicalCores; ++i2) {
            threadpool.submit(new ClusterWorker(vc, eps, resultQ, sourceQ));
        }
        int curClusterID = 0;
        for (i = 0; i < pointCats.length; ++i) {
            if (pointCats[i] != -1 || !this.expandCluster(pointCats, dataSet, i, curClusterID, eps, minPts, vc, threadpool, resultQ, sourceQ)) continue;
            ++curClusterID;
        }
        try {
            for (i = 0; i < SystemInfo.LogicalCores; ++i) {
                sourceQ.put(new DenseVector(0));
            }
        }
        catch (InterruptedException interruptedException) {
            // empty catch block
        }
        return pointCats;
    }

    private boolean expandCluster(int[] pointCats, DataSet dataSet, int point, int clId, double eps, int minPts, VectorCollection<VecPaired<Vec, Integer>> vc) {
        Vec queryPoint = dataSet.getDataPoint(point).getNumericalValues();
        List<VecPaired<VecPaired<Vec, Integer>, Double>> seeds = vc.search(queryPoint, eps);
        if (seeds.size() < minPts) {
            pointCats[point] = -2;
            return false;
        }
        pointCats[point] = clId;
        ArrayDeque<VecPaired<VecPaired<Vec, Integer>, Double>> workQue = new ArrayDeque<VecPaired<VecPaired<Vec, Integer>, Double>>(seeds);
        while (!workQue.isEmpty()) {
            VecPaired currentP = (VecPaired)workQue.poll();
            List<VecPaired<VecPaired<Vec, Integer>, Double>> results = vc.search((Vec)currentP, eps);
            if (results.size() < minPts) continue;
            for (VecPaired<VecPaired<Vec, Integer>, Double> resultP : results) {
                int resultPIndx = resultP.getVector().getPair();
                if (pointCats[resultPIndx] >= 0) continue;
                if (pointCats[resultPIndx] == -1) {
                    workQue.add(resultP);
                }
                pointCats[resultPIndx] = clId;
            }
        }
        return true;
    }

    private boolean expandCluster(int[] pointCats, DataSet dataSet, int point, int clId, double eps, int minPts, VectorCollection<VecPaired<Vec, Integer>> vc, ExecutorService threadpool, BlockingQueue<List<? extends VecPaired<VecPaired<Vec, Integer>, Double>>> resultQ, BlockingQueue<Vec> sourceQ) {
        Vec queryPoint = dataSet.getDataPoint(point).getNumericalValues();
        List<VecPaired<VecPaired<Vec, Integer>, Double>> seeds = vc.search(queryPoint, eps);
        if (seeds.size() < minPts) {
            pointCats[point] = -2;
            return false;
        }
        try {
            pointCats[point] = clId;
            int out = seeds.size();
            for (VecPaired<VecPaired<Vec, Integer>, Double> vecPaired : seeds) {
                sourceQ.put(vecPaired.getVector().getVector());
            }
            while (out > 0) {
                List<? extends VecPaired<VecPaired<Vec, Integer>, Double>> results = resultQ.take();
                --out;
                if (results.size() < minPts) continue;
                for (VecPaired<VecPaired<Vec, Integer>, Double> vecPaired : results) {
                    int resultPIndx = vecPaired.getVector().getPair();
                    if (pointCats[resultPIndx] >= 0) continue;
                    if (pointCats[resultPIndx] == -1) {
                        sourceQ.put(vecPaired.getVector().getVector());
                        ++out;
                    }
                    pointCats[resultPIndx] = clId;
                }
            }
        }
        catch (InterruptedException interruptedException) {
            // empty catch block
        }
        return true;
    }

    private class ClusterWorker
    implements Runnable {
        private VectorCollection<VecPaired<Vec, Integer>> vc;
        private volatile List<? extends VecPaired<VecPaired<Vec, Integer>, Double>> results;
        private final double range;
        private final BlockingQueue<List<? extends VecPaired<VecPaired<Vec, Integer>, Double>>> resultQ;
        private final BlockingQueue<Vec> sourceQ;

        public ClusterWorker(VectorCollection<VecPaired<Vec, Integer>> vc, double range, BlockingQueue<List<? extends VecPaired<VecPaired<Vec, Integer>, Double>>> resultQ, BlockingQueue<Vec> sourceQ) {
            this.vc = vc;
            this.range = range;
            this.resultQ = resultQ;
            this.sourceQ = sourceQ;
        }

        public List<? extends VecPaired<VecPaired<Vec, Integer>, Double>> getResults() {
            return this.results;
        }

        @Override
        public void run() {
            try {
                Vec searchPoint;
                while ((searchPoint = this.sourceQ.take()).length() != 0) {
                    this.results = this.vc.search(searchPoint, this.range);
                    this.resultQ.put(this.results);
                }
            }
            catch (InterruptedException ex) {
                Logger.getLogger(DBSCAN.class.getName()).log(Level.SEVERE, null, ex);
            }
        }
    }

    private class StatsWorker
    implements Callable<OnLineStatistics> {
        final BlockingQueue<DataPoint> queue;
        final VectorCollection<VecPaired<Vec, Integer>> vc;
        final int minPts;

        public StatsWorker(BlockingQueue<DataPoint> queue, VectorCollection<VecPaired<Vec, Integer>> vc, int minPts) {
            this.queue = queue;
            this.vc = vc;
            this.minPts = minPts;
        }

        @Override
        public OnLineStatistics call() throws Exception {
            DataPoint dp;
            OnLineStatistics stats = new OnLineStatistics();
            while ((dp = this.queue.take()).numNumericalValues() != 0 || dp.numCategoricalValues() != 0) {
                stats.add(this.vc.search(dp.getNumericalValues(), this.minPts + 1).get(this.minPts).getPair());
            }
            return stats;
        }
    }
}

