/*
 * Decompiled with CFR 0.152.
 */
package jsat.clustering.kmeans;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.ExecutorService;
import jsat.DataSet;
import jsat.SimpleDataSet;
import jsat.classifiers.DataPoint;
import jsat.clustering.SeedSelectionMethods;
import jsat.clustering.kmeans.HamerlyKMeans;
import jsat.clustering.kmeans.KMeans;
import jsat.linear.MatrixStatistics;
import jsat.linear.Vec;

public class XMeans
extends KMeans {
    private static final long serialVersionUID = -2577160317892141870L;
    private boolean stopAfterFail = false;
    private boolean iterativeRefine = true;
    private int minClusterSize = 25;
    private KMeans kmeans;

    public XMeans() {
        this(new HamerlyKMeans());
    }

    public XMeans(KMeans kmeans) {
        super(kmeans.dm, kmeans.seedSelection, kmeans.rand);
        this.kmeans = kmeans;
        this.kmeans.saveCentroidDistance = true;
        this.kmeans.setStoreMeans(true);
    }

    public XMeans(XMeans toCopy) {
        super(toCopy);
        this.kmeans = toCopy.kmeans.clone();
        this.stopAfterFail = toCopy.stopAfterFail;
        this.iterativeRefine = toCopy.iterativeRefine;
        this.minClusterSize = toCopy.minClusterSize;
    }

    public void setStopAfterFail(boolean stopAfterFail) {
        this.stopAfterFail = stopAfterFail;
    }

    public boolean isStopAfterFail() {
        return this.stopAfterFail;
    }

    public void setMinClusterSize(int minClusterSize) {
        if (minClusterSize < 2) {
            throw new IllegalArgumentException("min cluster size that could be split is 2, not " + minClusterSize);
        }
        this.minClusterSize = minClusterSize;
    }

    public int getMinClusterSize() {
        return this.minClusterSize;
    }

    public void setIterativeRefine(boolean refineCenters) {
        this.iterativeRefine = refineCenters;
    }

    public boolean getIterativeRefine() {
        return this.iterativeRefine;
    }

    @Override
    public int[] cluster(DataSet dataSet, int[] designations) {
        return this.cluster(dataSet, 2, Math.max(dataSet.getSampleSize() / 20, 10), designations);
    }

    @Override
    public int[] cluster(DataSet dataSet, ExecutorService threadpool, int[] designations) {
        return this.cluster(dataSet, 2, Math.max(dataSet.getSampleSize() / 20, 10), threadpool, designations);
    }

    private static int freeParameters(int K, int D) {
        return K - 1 + D * K + 1;
    }

    @Override
    public int[] cluster(DataSet dataSet, int lowK, int highK, ExecutorService threadpool, int[] designations) {
        int origMeans;
        int N = dataSet.getSampleSize();
        int D = dataSet.getNumNumericalVars();
        if (designations == null || designations.length < dataSet.getSampleSize()) {
            designations = new int[N];
        }
        List<Vec> data = dataSet.getDataVectors();
        List<Double> accelCache = this.dm.getAccelerationCache(data, threadpool);
        double[] localVar = new double[highK];
        int[] localOwned = new int[highK];
        if (lowK >= 2) {
            this.means = new ArrayList();
            this.kmeans.cluster(dataSet, accelCache, lowK, this.means, designations, true, threadpool, true);
            for (int i = 0; i < data.size(); ++i) {
                int n = designations[i];
                localVar[n] = localVar[n] + Math.pow(this.kmeans.nearestCentroidDist[i], 2.0);
                int n2 = designations[i];
                localOwned[n2] = localOwned[n2] + 1;
            }
        } else {
            if (designations == null || designations.length < N) {
                designations = new int[N];
            } else {
                Arrays.fill(designations, 0);
            }
            this.means = new ArrayList<Vec>(Arrays.asList(MatrixStatistics.meanVector(dataSet)));
            localOwned[0] = N;
            List<Double> qi = this.dm.getQueryInfo((Vec)this.means.get(0));
            for (int i = 0; i < data.size(); ++i) {
                localVar[0] = localVar[0] + Math.pow(this.dm.dist(i, (Vec)this.means.get(0), qi, data, accelCache), 2.0);
            }
        }
        int[] subS = new int[designations.length];
        int[] subC = new int[designations.length];
        ArrayList<Boolean> dontRedo = new ArrayList<Boolean>(Collections.nCopies(this.means.size(), false));
        do {
            origMeans = this.means.size();
            for (int c = 0; c < origMeans; ++c) {
                if (((Boolean)dontRedo.get(c)).booleanValue()) continue;
                List<DataPoint> X = XMeans.getDatapointsFromCluster(c, designations, dataSet, subS);
                int n = X.size();
                if (X.size() < this.minClusterSize || this.means.size() == highK) continue;
                subC = this.kmeans.cluster((DataSet)new SimpleDataSet(X), 2, threadpool, subC);
                ArrayList<Vec> subMean = new ArrayList<Vec>(2);
                this.kmeans.cluster(new SimpleDataSet(X), null, 2, subMean, subC, true, threadpool, true);
                double[] nearDist = this.kmeans.nearestCentroidDist;
                Vec c1 = (Vec)subMean.get(0);
                Vec c2 = (Vec)subMean.get(1);
                double newSigma = 0.0;
                int size_c1 = 0;
                for (int i = 0; i < X.size(); ++i) {
                    newSigma += Math.pow(nearDist[i], 2.0);
                    if (subC[i] != 0) continue;
                    ++size_c1;
                }
                int size_c2 = n - size_c1;
                double localNewBic = (double)size_c1 * Math.log(size_c1) + (double)size_c2 * Math.log(size_c2) - (double)n * Math.log(n) - (double)(n * D) / 2.0 * Math.log(Math.PI * 2 * (newSigma /= (double)(D * (n - 2)))) - (double)D / 2.0 * (double)(n - 2) - (double)XMeans.freeParameters(2, D) / 2.0 * Math.log(n);
                double localOldBic = (double)(-n * D) / 2.0 * Math.log(Math.PI * 2 * localVar[c] / (double)(D * (n - 1))) - (double)D / 2.0 * (double)(n - 1) - (double)XMeans.freeParameters(1, D) / 2.0 * Math.log(n);
                if (localOldBic > localNewBic) {
                    if (!this.stopAfterFail) continue;
                    dontRedo.set(c, true);
                    continue;
                }
                for (int i = 0; i < X.size(); ++i) {
                    if (subC[i] != 1) continue;
                    designations[subS[i]] = this.means.size();
                }
                this.means.set(c, c1.clone());
                this.means.add(c2.clone());
                dontRedo.add(false);
            }
            if (!this.iterativeRefine || this.means.size() <= 1) continue;
            this.kmeans.cluster(dataSet, accelCache, this.means.size(), this.means, designations, true, threadpool, true);
            Arrays.fill(localVar, 0.0);
            Arrays.fill(localOwned, 0);
            for (int i = 0; i < data.size(); ++i) {
                int n = designations[i];
                localVar[n] = localVar[n] + Math.pow(this.kmeans.nearestCentroidDist[i], 2.0);
                int n3 = designations[i];
                localOwned[n3] = localOwned[n3] + 1;
            }
        } while (origMeans < this.means.size());
        if (!this.iterativeRefine) {
            this.kmeans.cluster(dataSet, accelCache, this.means.size(), this.means, designations, false, threadpool, false);
        }
        return designations;
    }

    @Override
    public int[] cluster(DataSet dataSet, int lowK, int highK, int[] designations) {
        return this.cluster(dataSet, lowK, highK, null, designations);
    }

    @Override
    public int getIterationLimit() {
        return this.kmeans.getIterationLimit();
    }

    @Override
    public void setIterationLimit(int iterLimit) {
        this.kmeans.setIterationLimit(iterLimit);
    }

    @Override
    public void setSeedSelection(SeedSelectionMethods.SeedSelection seedSelection) {
        if (this.kmeans != null) {
            this.kmeans.setSeedSelection(seedSelection);
        }
    }

    @Override
    public SeedSelectionMethods.SeedSelection getSeedSelection() {
        return this.kmeans.getSeedSelection();
    }

    @Override
    protected double cluster(DataSet dataSet, List<Double> accelCache, int k, List<Vec> means, int[] assignment, boolean exactTotal, ExecutorService threadpool, boolean returnError) {
        return this.kmeans.cluster(dataSet, accelCache, k, means, assignment, exactTotal, threadpool, returnError);
    }

    @Override
    public XMeans clone() {
        return new XMeans(this);
    }
}

