/*
 * Decompiled with CFR 0.152.
 */
package jsat.clustering.kmeans;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLongArray;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.DataSet;
import jsat.clustering.SeedSelectionMethods;
import jsat.clustering.kmeans.KMeans;
import jsat.linear.DenseVector;
import jsat.linear.Vec;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.EuclideanDistance;
import jsat.linear.distancemetrics.TrainableDistanceMetric;
import jsat.utils.DoubleList;
import jsat.utils.FakeExecutor;
import jsat.utils.SystemInfo;
import jsat.utils.random.XORWOW;

public class HamerlyKMeans
extends KMeans {
    private static final long serialVersionUID = -4960453870335145091L;

    public HamerlyKMeans(DistanceMetric dm, SeedSelectionMethods.SeedSelection seedSelection, Random rand) {
        super(dm, seedSelection, rand);
    }

    public HamerlyKMeans(DistanceMetric dm, SeedSelectionMethods.SeedSelection seedSelection) {
        this(dm, seedSelection, new XORWOW());
    }

    public HamerlyKMeans() {
        this(new EuclideanDistance(), SeedSelectionMethods.SeedSelection.KPP);
    }

    public HamerlyKMeans(HamerlyKMeans toCopy) {
        super(toCopy);
    }

    @Override
    protected double cluster(final DataSet dataSet, List<Double> accelCache, final int k, final List<Vec> means, final int[] assignment, boolean exactTotal, ExecutorService threadpool, boolean returnError) {
        final int N = dataSet.getSampleSize();
        final int D = dataSet.getNumNumericalVars();
        TrainableDistanceMetric.trainIfNeeded(this.dm, dataSet, threadpool);
        final List<Vec> X = dataSet.getDataVectors();
        final List<Double> distAccel = accelCache == null ? (threadpool == null || threadpool instanceof FakeExecutor ? this.dm.getAccelerationCache(X) : this.dm.getAccelerationCache(X, threadpool)) : accelCache;
        final ArrayList<List<Double>> meanQI = new ArrayList<List<Double>>(k);
        if (means.size() != k) {
            means.clear();
            if (threadpool == null || threadpool instanceof FakeExecutor) {
                means.addAll(SeedSelectionMethods.selectIntialPoints(dataSet, k, this.dm, distAccel, this.rand, this.seedSelection));
            } else {
                means.addAll(SeedSelectionMethods.selectIntialPoints(dataSet, k, this.dm, distAccel, this.rand, this.seedSelection, threadpool));
            }
        }
        for (int i = 0; i < means.size(); ++i) {
            if (!means.get(i).isSparse()) continue;
            means.set(i, new DenseVector(means.get(i)));
        }
        final Vec[] cP = new Vec[k];
        Vec[] tmpVecs = new Vec[k];
        final AtomicLongArray q = new AtomicLongArray(k);
        double[] p = new double[k];
        final double[] s = new double[k];
        final double[] u = new double[N];
        final double[] l = new double[N];
        final ThreadLocal<Vec[]> localDeltas = new ThreadLocal<Vec[]>(){

            @Override
            protected Vec[] initialValue() {
                Vec[] toRet = new Vec[means.size()];
                for (int i = 0; i < k; ++i) {
                    toRet[i] = new DenseVector(D);
                }
                return toRet;
            }
        };
        this.Initialize(dataSet, q, means, tmpVecs, cP, u, l, assignment, threadpool, localDeltas, X, distAccel, meanQI);
        for (int i = 0; i < means.size(); ++i) {
            if (!means.get(i).isSparse()) continue;
            means.set(i, new DenseVector(means.get(i)));
        }
        final AtomicInteger updates = new AtomicInteger(N);
        while (updates.get() > 0) {
            this.moveCenters(means, tmpVecs, cP, q, p, meanQI);
            this.UpdateBounds(p, assignment, u, l);
            updates.set(0);
            this.updateS(s, means, threadpool, meanQI);
            if (threadpool == null) {
                int localUpdates = 0;
                for (int i = 0; i < N; ++i) {
                    localUpdates += this.mainLoopWork(dataSet, i, s, assignment, u, l, q, cP, X, distAccel, means, meanQI);
                }
                updates.set(localUpdates);
                continue;
            }
            final CountDownLatch latch = new CountDownLatch(SystemInfo.LogicalCores);
            int id = 0;
            while (id < SystemInfo.LogicalCores) {
                final int ID = id++;
                threadpool.submit(new Runnable(){

                    /*
                     * WARNING - Removed try catching itself - possible behaviour change.
                     */
                    @Override
                    public void run() {
                        int i;
                        Vec[] deltas = (Vec[])localDeltas.get();
                        int localUpdates = 0;
                        for (i = ID; i < N; i += SystemInfo.LogicalCores) {
                            localUpdates += HamerlyKMeans.this.mainLoopWork(dataSet, i, s, assignment, u, l, q, deltas, X, distAccel, means, meanQI);
                        }
                        if (localUpdates > 0) {
                            updates.getAndAdd(localUpdates);
                            for (i = 0; i < cP.length; ++i) {
                                Vec vec = cP[i];
                                synchronized (vec) {
                                    cP[i].mutableAdd(deltas[i]);
                                }
                                deltas[i].zeroOut();
                            }
                        }
                        latch.countDown();
                    }
                });
            }
            try {
                latch.await();
            }
            catch (InterruptedException ex) {
                Logger.getLogger(HamerlyKMeans.class.getName()).log(Level.SEVERE, null, ex);
            }
        }
        if (returnError) {
            int i;
            double totalDistance = 0.0;
            this.nearestCentroidDist = (double[])(this.saveCentroidDistance ? new double[N] : null);
            if (exactTotal) {
                for (i = 0; i < N; ++i) {
                    double dist = this.dm.dist(i, means.get(assignment[i]), (List)meanQI.get(assignment[i]), X, distAccel);
                    totalDistance += Math.pow(dist, 2.0);
                    if (!this.saveCentroidDistance) continue;
                    this.nearestCentroidDist[i] = dist;
                }
            } else {
                for (i = 0; i < N; ++i) {
                    totalDistance += Math.pow(u[i], 2.0);
                    if (!this.saveCentroidDistance) continue;
                    this.nearestCentroidDist[i] = u[i];
                }
            }
            return totalDistance;
        }
        return 0.0;
    }

    private int mainLoopWork(DataSet dataSet, int i, double[] s, int[] assignment, double[] u, double[] l, AtomicLongArray q, Vec[] deltas, List<Vec> X, List<Double> distAccel, List<Vec> means, List<List<Double>> meanQI) {
        int a_i = assignment[i];
        double m = Math.max(s[a_i] / 2.0, l[i]);
        if (u[i] > m) {
            int new_a_i;
            Vec x = X.get(i);
            u[i] = this.dm.dist(i, means.get(a_i), meanQI.get(a_i), X, distAccel);
            if (u[i] > m && a_i != (new_a_i = this.PointAllCtrs(x, i, means, assignment, u, l, X, distAccel, meanQI))) {
                q.decrementAndGet(a_i);
                q.incrementAndGet(new_a_i);
                deltas[a_i].mutableSubtract(x);
                deltas[new_a_i].mutableAdd(x);
                return 1;
            }
        }
        return 0;
    }

    private void updateS(final double[] s, final List<Vec> means, ExecutorService threadpool, List<List<Double>> meanQIs) {
        DoubleList meanCache;
        int tasks = means.size();
        final CountDownLatch latch = new CountDownLatch(tasks);
        Arrays.fill(s, Double.MAX_VALUE);
        DoubleList doubleList = meanCache = meanQIs.get(0).isEmpty() ? null : new DoubleList(meanQIs.size());
        if (meanCache != null) {
            for (List<Double> qi : meanQIs) {
                meanCache.addAll(qi);
            }
        }
        for (int j = 0; j < means.size(); ++j) {
            if (threadpool == null) {
                double min = Double.POSITIVE_INFINITY;
                int otherIndx = Integer.MAX_VALUE;
                for (int jp = j + 1; jp < means.size(); ++jp) {
                    double d;
                    double tmp = this.dm.dist(j, jp, means, (List<Double>)meanCache);
                    if (!(d < min)) continue;
                    min = tmp;
                    otherIndx = jp;
                }
                s[j] = Math.min(min, s[j]);
                if (otherIndx >= s.length) continue;
                s[otherIndx] = Math.min(s[otherIndx], s[j]);
                continue;
            }
            final int J = j;
            threadpool.submit(new Runnable(){

                /*
                 * WARNING - Removed try catching itself - possible behaviour change.
                 */
                @Override
                public void run() {
                    double min = Double.POSITIVE_INFINITY;
                    int otherIndx = Integer.MAX_VALUE;
                    for (int jp = J + 1; jp < means.size(); ++jp) {
                        double d;
                        double tmp = HamerlyKMeans.this.dm.dist(J, jp, (List<? extends Vec>)means, (List<Double>)meanCache);
                        if (!(d < min)) continue;
                        min = tmp;
                        otherIndx = jp;
                    }
                    double[] dArray = s;
                    synchronized (s) {
                        min = s[J] = Math.min(min, s[J]);
                        if (otherIndx < s.length) {
                            s[otherIndx] = Math.min(min, s[otherIndx]);
                        }
                        // ** MonitorExit[var6_4] (shouldn't be in output)
                        latch.countDown();
                        return;
                    }
                }
            });
        }
        if (threadpool != null) {
            try {
                latch.await();
            }
            catch (InterruptedException ex) {
                Logger.getLogger(HamerlyKMeans.class.getName()).log(Level.SEVERE, null, ex);
            }
        }
    }

    private void Initialize(DataSet d, final AtomicLongArray q, final List<Vec> means, Vec[] tmp, final Vec[] cP, final double[] u, final double[] l, final int[] a, ExecutorService threadpool, final ThreadLocal<Vec[]> localDeltas, final List<Vec> X, final List<Double> distAccel, final List<List<Double>> meanQI) {
        for (int j = 0; j < means.size(); ++j) {
            cP[j] = new DenseVector(means.get(0).length());
            tmp[j] = cP[j].clone();
            if (this.dm.supportsAcceleration()) {
                meanQI.add(this.dm.getQueryInfo(means.get(j)));
                continue;
            }
            meanQI.add(Collections.EMPTY_LIST);
        }
        if (threadpool == null) {
            for (int i = 0; i < u.length; ++i) {
                Vec x = X.get(i);
                int j = this.PointAllCtrs(x, i, means, a, u, l, X, distAccel, meanQI);
                q.incrementAndGet(j);
                cP[j].mutableAdd(x);
            }
        } else {
            final CountDownLatch latch = new CountDownLatch(SystemInfo.LogicalCores);
            int id = 0;
            while (id < SystemInfo.LogicalCores) {
                final int ID = id++;
                threadpool.submit(new Runnable(){

                    /*
                     * WARNING - Removed try catching itself - possible behaviour change.
                     */
                    @Override
                    public void run() {
                        int i;
                        Vec[] deltas = (Vec[])localDeltas.get();
                        for (i = ID; i < u.length; i += SystemInfo.LogicalCores) {
                            Vec x = (Vec)X.get(i);
                            int j = HamerlyKMeans.this.PointAllCtrs(x, i, means, a, u, l, X, distAccel, meanQI);
                            q.incrementAndGet(j);
                            deltas[j].mutableAdd(x);
                        }
                        for (i = 0; i < cP.length; ++i) {
                            Vec vec = cP[i];
                            synchronized (vec) {
                                cP[i].mutableAdd(deltas[i]);
                            }
                            deltas[i].zeroOut();
                        }
                        latch.countDown();
                    }
                });
            }
            try {
                latch.await();
            }
            catch (InterruptedException ex) {
                Logger.getLogger(HamerlyKMeans.class.getName()).log(Level.SEVERE, null, ex);
            }
        }
    }

    private int PointAllCtrs(Vec x, int i, List<Vec> means, int[] a, double[] u, double[] l, List<Vec> X, List<Double> distAccel, List<List<Double>> meanQI) {
        double secondLowest = Double.POSITIVE_INFINITY;
        int slIndex = -1;
        double lowest = Double.MAX_VALUE;
        int lIndex = -1;
        for (int j = 0; j < means.size(); ++j) {
            double dist = this.dm.dist(i, means.get(j), meanQI.get(j), X, distAccel);
            if (!(dist < secondLowest)) continue;
            if (dist < lowest) {
                secondLowest = lowest;
                slIndex = lIndex;
                lowest = dist;
                lIndex = j;
                continue;
            }
            secondLowest = dist;
            slIndex = j;
        }
        a[i] = lIndex;
        u[i] = lowest;
        l[i] = secondLowest;
        return lIndex;
    }

    private void moveCenters(List<Vec> means, Vec[] tmpSpace, Vec[] cP, AtomicLongArray q, double[] p, List<List<Double>> meanQI) {
        for (int j = 0; j < means.size(); ++j) {
            long count = q.get(j);
            if (count > 0L) {
                cP[j].copyTo(tmpSpace[j]);
                tmpSpace[j].mutableDivide(q.get(j));
            } else {
                cP[j].zeroOut();
                tmpSpace[j].zeroOut();
            }
            p[j] = this.dm.dist(means.get(j), tmpSpace[j]);
            tmpSpace[j].copyTo(means.get(j));
            if (!this.dm.supportsAcceleration()) continue;
            meanQI.set(j, this.dm.getQueryInfo(means.get(j)));
        }
    }

    private void UpdateBounds(double[] p, int[] a, double[] u, double[] l) {
        double secondHighest = Double.NEGATIVE_INFINITY;
        int shIndex = -1;
        double highest = -1.7976931348623157E308;
        int hIndex = -1;
        for (int j = 0; j < p.length; ++j) {
            double dist = p[j];
            if (!(dist > secondHighest)) continue;
            if (dist > highest) {
                secondHighest = highest;
                shIndex = hIndex;
                highest = dist;
                hIndex = j;
                continue;
            }
            secondHighest = dist;
            shIndex = j;
        }
        int r = hIndex;
        int rP = shIndex;
        for (int i = 0; i < u.length; ++i) {
            int j = a[i];
            int n = i;
            u[n] = u[n] + p[j];
            if (r == j) {
                int n2 = i;
                l[n2] = l[n2] - p[rP];
                continue;
            }
            int n3 = i;
            l[n3] = l[n3] - p[r];
        }
    }

    @Override
    public HamerlyKMeans clone() {
        return new HamerlyKMeans(this);
    }
}

