/*
 * Decompiled with CFR 0.152.
 */
package jsat.clustering;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.concurrent.Callable;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.DataSet;
import jsat.SimpleDataSet;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.DataPoint;
import jsat.clustering.SeedSelectionMethods;
import jsat.clustering.kmeans.ElkanKMeans;
import jsat.distributions.multivariate.MultivariateDistribution;
import jsat.distributions.multivariate.NormalM;
import jsat.linear.DenseMatrix;
import jsat.linear.DenseVector;
import jsat.linear.Matrix;
import jsat.linear.Vec;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.utils.ListUtils;
import jsat.utils.SystemInfo;

public class EMGaussianMixture
extends ElkanKMeans
implements MultivariateDistribution {
    private static final long serialVersionUID = 2606159910420221662L;
    private List<NormalM> gaussians;
    private double[] a_k;
    private double tolerance = 0.001;

    public EMGaussianMixture(DistanceMetric dm, Random rand, SeedSelectionMethods.SeedSelection seedSelection) {
        super(dm, rand, seedSelection);
    }

    public EMGaussianMixture(DistanceMetric dm, Random rand) {
        super(dm, rand);
    }

    public EMGaussianMixture(DistanceMetric dm) {
        super(dm);
    }

    public EMGaussianMixture() {
    }

    public EMGaussianMixture(EMGaussianMixture gm) {
        if (gm.gaussians != null && !gm.gaussians.isEmpty()) {
            this.gaussians = new ArrayList<NormalM>(gm.gaussians.size());
            for (NormalM gaussian : gm.gaussians) {
                this.gaussians.add(gaussian.clone());
            }
        }
        if (gm.a_k != null) {
            this.a_k = Arrays.copyOf(gm.a_k, gm.a_k.length);
        }
        this.MaxIterLimit = gm.MaxIterLimit;
        this.tolerance = gm.tolerance;
    }

    private EMGaussianMixture(List<NormalM> gaussians, double[] a_k, double tolerance) {
        this.gaussians = new ArrayList<NormalM>(a_k.length);
        this.a_k = new double[a_k.length];
        for (int i = 0; i < a_k.length; ++i) {
            this.gaussians.add(gaussians.get(i).clone());
            this.a_k[i] = a_k[i];
        }
    }

    @Override
    protected double cluster(DataSet dataSet, List<Double> accelCache, int K, List<Vec> means, int[] assignment, boolean exactTotal, ExecutorService threadpool, boolean returnError) {
        super.cluster(dataSet, accelCache, K, means, assignment, exactTotal, threadpool, false);
        ArrayList<Matrix> covariances = new ArrayList<Matrix>(K);
        int dimension = dataSet.getNumNumericalVars();
        for (int k = 0; k < means.size(); ++k) {
            covariances.add(new DenseMatrix(dimension, dimension));
        }
        this.a_k = new double[K];
        double sum = dataSet.getSampleSize();
        DenseVector scratch = new DenseVector(dimension);
        for (int i = 0; i < dataSet.getSampleSize(); ++i) {
            int k;
            Vec x = dataSet.getDataPoint(i).getNumericalValues();
            int n = k = assignment[i];
            this.a_k[n] = this.a_k[n] + 1.0;
            x.copyTo(scratch);
            scratch.mutableSubtract(means.get(k));
            Matrix.OuterProductUpdate((Matrix)covariances.get(k), scratch, scratch, 1.0);
        }
        int k = 0;
        while (k < means.size()) {
            ((Matrix)covariances.get(k)).mutableMultiply(1.0 / this.a_k[k]);
            int n = k++;
            this.a_k[n] = this.a_k[n] / sum;
        }
        return this.clusterCompute(K, dataSet, assignment, means, covariances, threadpool);
    }

    protected double clusterCompute(int K, DataSet dataSet, int[] assignment, List<Vec> means, List<Matrix> covs, ExecutorService execServ) {
        List<DataPoint> dataPoints = dataSet.getDataPoints();
        int N = dataPoints.size();
        double currentLogLike = -1.7976931348623157E308;
        this.gaussians = new ArrayList<NormalM>(K);
        for (int k = 0; k < means.size(); ++k) {
            this.gaussians.add(new NormalM(means.get(k), covs.get(k)));
        }
        double[][] p_ik = new double[dataPoints.size()][K];
        while (true) {
            try {
                double logLike;
                double logDifference;
                while (!((logDifference = Math.abs(currentLogLike - (logLike = this.eStep(N, dataPoints, K, p_ik, execServ)))) < this.tolerance)) {
                    currentLogLike = logLike;
                    this.mStep(means, N, dataPoints, K, p_ik, covs, execServ);
                }
            }
            catch (ExecutionException ex) {
                Logger.getLogger(EMGaussianMixture.class.getName()).log(Level.SEVERE, null, ex);
                continue;
            }
            catch (InterruptedException ex) {
                Logger.getLogger(EMGaussianMixture.class.getName()).log(Level.SEVERE, null, ex);
                continue;
            }
            break;
        }
        for (int i = 0; i < p_ik.length; ++i) {
            for (int k = 0; k < K; ++k) {
                if (!(p_ik[i][k] > p_ik[i][assignment[i]])) continue;
                assignment[i] = k;
            }
        }
        return -currentLogLike;
    }

    private void mStep(final List<Vec> means, int N, final List<DataPoint> dataPoints, final int K, final double[][] p_ik, final List<Matrix> covs, ExecutorService execServ) throws InterruptedException {
        int Start2;
        int to;
        int remainder;
        int step;
        final int D = means.get(0).length();
        for (Vec mean : means) {
            mean.zeroOut();
        }
        Arrays.fill(this.a_k, 0.0);
        if (execServ == null) {
            for (int i = 0; i < N; ++i) {
                Vec x_i = dataPoints.get(i).getNumericalValues();
                for (int k = 0; k < K; ++k) {
                    int n = k;
                    this.a_k[n] = this.a_k[n] + p_ik[i][k];
                    means.get(k).mutableAdd(p_ik[i][k], x_i);
                }
            }
        } else {
            final CountDownLatch latch = new CountDownLatch(SystemInfo.LogicalCores);
            int start = 0;
            step = N / SystemInfo.LogicalCores;
            remainder = N % SystemInfo.LogicalCores;
            while (start < N) {
                to = Math.min((remainder-- > 0 ? 1 : 0) + start + step, N);
                Start2 = start;
                start = to;
                execServ.submit(new Runnable(){

                    /*
                     * WARNING - Removed try catching itself - possible behaviour change.
                     */
                    @Override
                    public void run() {
                        Vec[] partialMean = new Vec[means.size()];
                        for (int i = 0; i < partialMean.length; ++i) {
                            partialMean[i] = new DenseVector(((Vec)means.get(i)).length());
                        }
                        double[] partial_a_k = new double[EMGaussianMixture.this.a_k.length];
                        for (int i = Start2; i < to; ++i) {
                            Vec x_i = ((DataPoint)dataPoints.get(i)).getNumericalValues();
                            for (int k = 0; k < K; ++k) {
                                int n = k;
                                partial_a_k[n] = partial_a_k[n] + p_ik[i][k];
                                partialMean[k].mutableAdd(p_ik[i][k], x_i);
                            }
                        }
                        List list = means;
                        synchronized (list) {
                            for (int k = 0; k < EMGaussianMixture.this.a_k.length; ++k) {
                                double[] dArray = EMGaussianMixture.this.a_k;
                                int n = k;
                                dArray[n] = dArray[n] + partial_a_k[k];
                                ((Vec)means.get(k)).mutableAdd(partialMean[k]);
                            }
                        }
                        latch.countDown();
                    }
                });
            }
            latch.await();
        }
        for (int k = 0; k < this.a_k.length; ++k) {
            means.get(k).mutableDivide(this.a_k[k]);
        }
        for (Matrix cov : covs) {
            cov.zeroOut();
        }
        if (execServ == null) {
            for (int k = 0; k < K; ++k) {
                Matrix covariance = covs.get(k);
                Vec mean = means.get(k);
                DenseVector scratch = new DenseVector(mean.length());
                for (int i = 0; i < dataPoints.size(); ++i) {
                    DataPoint dp = dataPoints.get(i);
                    Vec x = dp.getNumericalValues();
                    x.copyTo(scratch);
                    scratch.mutableSubtract(mean);
                    Matrix.OuterProductUpdate(covariance, scratch, scratch, p_ik[i][k]);
                }
                covariance.mutableMultiply(1.0 / this.a_k[k]);
            }
        } else {
            final CountDownLatch latch = new CountDownLatch(SystemInfo.LogicalCores);
            int start = 0;
            step = N / SystemInfo.LogicalCores;
            remainder = N % SystemInfo.LogicalCores;
            while (start < N) {
                to = Math.min((remainder-- > 0 ? 1 : 0) + start + step, N);
                Start2 = start;
                start = to;
                execServ.submit(new Runnable(){

                    /*
                     * WARNING - Removed try catching itself - possible behaviour change.
                     */
                    @Override
                    public void run() {
                        int i;
                        Matrix[] partialCovs = new Matrix[K];
                        for (i = 0; i < partialCovs.length; ++i) {
                            partialCovs[i] = new DenseMatrix(D, D);
                        }
                        for (i = Start2; i < to; ++i) {
                            DataPoint dp = (DataPoint)dataPoints.get(i);
                            Vec x = dp.getNumericalValues();
                            DenseVector scratch = new DenseVector(x.length());
                            for (int k = 0; k < K; ++k) {
                                Matrix covariance = partialCovs[k];
                                Vec mean = (Vec)means.get(k);
                                x.copyTo(scratch);
                                scratch.mutableSubtract(mean);
                                Matrix.OuterProductUpdate(covariance, scratch, scratch, p_ik[i][k]);
                            }
                        }
                        List list = covs;
                        synchronized (list) {
                            for (int k = 0; k < K; ++k) {
                                ((Matrix)covs.get(k)).mutableAdd(partialCovs[k]);
                            }
                        }
                        latch.countDown();
                    }
                });
            }
            latch.await();
            for (int k = 0; k < K; ++k) {
                covs.get(k).mutableMultiply(1.0 / this.a_k[k]);
            }
        }
        int k = 0;
        while (k < K) {
            int n = k++;
            this.a_k[n] = this.a_k[n] / (double)N;
        }
        for (k = 0; k < means.size(); ++k) {
            this.gaussians.get(k).setMeanCovariance(means.get(k), covs.get(k));
        }
    }

    private double eStep(int N, final List<DataPoint> dataPoints, final int K, final double[][] p_ik, ExecutorService execServ) throws InterruptedException, ExecutionException {
        double logLike = 0.0;
        if (execServ == null) {
            for (int i = 0; i < N; ++i) {
                int k;
                Vec x_i = dataPoints.get(i).getNumericalValues();
                double p_ikNormalizer = 0.0;
                for (k = 0; k < K; ++k) {
                    double tmp;
                    p_ik[i][k] = tmp = this.a_k[k] * this.gaussians.get(k).pdf(x_i);
                    p_ikNormalizer += tmp;
                }
                k = 0;
                while (k < K) {
                    double[] dArray = p_ik[i];
                    int n = k++;
                    dArray[n] = dArray[n] / p_ikNormalizer;
                }
                logLike += Math.log(p_ikNormalizer);
            }
        } else {
            ArrayList partialLogLikes = new ArrayList(SystemInfo.LogicalCores);
            int start = 0;
            int step = N / SystemInfo.LogicalCores;
            int remainder = N % SystemInfo.LogicalCores;
            while (start < N) {
                final int to = Math.min((remainder-- > 0 ? 1 : 0) + start + step, N);
                final int Start2 = start;
                start = to;
                partialLogLikes.add(execServ.submit(new Callable<Double>(){

                    @Override
                    public Double call() throws Exception {
                        double partialLog = 0.0;
                        for (int i = Start2; i < to; ++i) {
                            int k;
                            Vec x_i = ((DataPoint)dataPoints.get(i)).getNumericalValues();
                            double p_ikNormalizer = 0.0;
                            for (k = 0; k < K; ++k) {
                                double tmp;
                                p_ik[i][k] = tmp = EMGaussianMixture.this.a_k[k] * ((NormalM)EMGaussianMixture.this.gaussians.get(k)).pdf(x_i);
                                p_ikNormalizer += tmp;
                            }
                            k = 0;
                            while (k < K) {
                                double[] dArray = p_ik[i];
                                int n = k++;
                                dArray[n] = dArray[n] / p_ikNormalizer;
                            }
                            partialLog += Math.log(p_ikNormalizer);
                        }
                        return partialLog;
                    }
                }));
            }
            Iterator i$ = ListUtils.collectFutures(partialLogLikes).iterator();
            while (i$.hasNext()) {
                double partialLogLike = (Double)i$.next();
                logLike += partialLogLike;
            }
        }
        return logLike;
    }

    @Override
    public double logPdf(double ... x) {
        return this.logPdf(DenseVector.toDenseVec(x));
    }

    @Override
    public double logPdf(Vec x) {
        double pdf = this.pdf(x);
        if (pdf == 0.0) {
            return -1.7976931348623157E308;
        }
        return Math.log(pdf);
    }

    @Override
    public double pdf(double ... x) {
        return this.pdf(DenseVector.toDenseVec(x));
    }

    @Override
    public double pdf(Vec x) {
        double PDF = 0.0;
        for (int i = 0; i < this.a_k.length; ++i) {
            PDF += this.a_k[i] * this.gaussians.get(i).pdf(x);
        }
        return PDF;
    }

    @Override
    public <V extends Vec> boolean setUsingData(List<V> dataSet) {
        ArrayList<DataPoint> dataPoints = new ArrayList<DataPoint>(dataSet.size());
        for (Vec x : dataSet) {
            dataPoints.add(new DataPoint(x, new int[0], new CategoricalData[0]));
        }
        return this.setUsingDataList(dataPoints);
    }

    @Override
    public boolean setUsingDataList(List<DataPoint> dataPoint) {
        return this.setUsingData(new SimpleDataSet(dataPoint));
    }

    @Override
    public boolean setUsingData(DataSet dataSet) {
        try {
            this.cluster(dataSet);
            return true;
        }
        catch (ArithmeticException ex) {
            return false;
        }
    }

    @Override
    public boolean setUsingData(DataSet dataSet, ExecutorService threadpool) {
        return this.setUsingData(dataSet);
    }

    @Override
    public <V extends Vec> boolean setUsingData(List<V> dataSet, ExecutorService threadpool) {
        return this.setUsingData(dataSet);
    }

    @Override
    public boolean setUsingDataList(List<DataPoint> dataPoints, ExecutorService threadpool) {
        return this.setUsingDataList(dataPoints);
    }

    @Override
    public EMGaussianMixture clone() {
        return new EMGaussianMixture(this);
    }

    @Override
    public List<Vec> sample(int count, Random rand) {
        ArrayList<Vec> samples = new ArrayList<Vec>(count);
        double[] priorTargets = new double[count];
        for (int i = 0; i < count; ++i) {
            priorTargets[i] = rand.nextDouble();
        }
        Arrays.sort(priorTargets);
        int subSampleSize = 0;
        int currentGaussian = 0;
        int pos = 0;
        double a_kSum = 0.0;
        while (currentGaussian < this.a_k.length) {
            a_kSum += this.a_k[currentGaussian];
            while (pos < count) {
                int n = pos++;
                if (!(priorTargets[n] < a_kSum)) break;
                ++subSampleSize;
            }
            samples.addAll(this.gaussians.get(currentGaussian++).sample(subSampleSize, rand));
        }
        return samples;
    }
}

