/*
 * Decompiled with CFR 0.152.
 */
package jsat.datatransform.kernel;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import jsat.DataSet;
import jsat.classifiers.DataPoint;
import jsat.clustering.SeedSelectionMethods;
import jsat.clustering.kmeans.HamerlyKMeans;
import jsat.datatransform.DataTransform;
import jsat.datatransform.DataTransformFactoryParm;
import jsat.distributions.kernels.KernelTrick;
import jsat.linear.DenseMatrix;
import jsat.linear.DenseVector;
import jsat.linear.EigenValueDecomposition;
import jsat.linear.Matrix;
import jsat.linear.Vec;
import jsat.linear.distancemetrics.EuclideanDistance;
import jsat.parameters.Parameter;
import jsat.utils.DoubleList;
import jsat.utils.IntSet;
import jsat.utils.random.XOR96;

public class Nystrom
implements DataTransform {
    private static final long serialVersionUID = -3227844260130709773L;
    private KernelTrick k;
    private List<Vec> basisVecs;
    private List<Double> accelCache;
    private Matrix transform;

    public Nystrom(KernelTrick k, DataSet dataset, int basisSize, SamplingMethod method) {
        this(k, dataset, basisSize, method, 0.0, false);
    }

    public Nystrom(KernelTrick k, DataSet dataset, int basisSize, SamplingMethod method, double ridge, boolean sampleWithReplacment) {
        XOR96 rand = new XOR96();
        if (ridge < 0.0) {
            throw new IllegalArgumentException("ridge must be positive, not " + ridge);
        }
        int N = dataset.getSampleSize();
        int D = dataset.getNumNumericalVars();
        List<Vec> X = dataset.getDataVectors();
        this.basisVecs = Nystrom.sampleBasisVectors(k, dataset, X, method, basisSize, sampleWithReplacment, rand);
        this.k = k;
        this.accelCache = k.getAccelerationCache(this.basisVecs);
        DenseMatrix K = new DenseMatrix(basisSize, basisSize);
        for (int i = 0; i < basisSize; ++i) {
            ((Matrix)K).set(i, i, k.eval(i, i, this.basisVecs, this.accelCache));
            for (int j = i + 1; j < basisSize; ++j) {
                double val = k.eval(i, j, this.basisVecs, this.accelCache);
                ((Matrix)K).set(i, j, val);
                ((Matrix)K).set(j, i, val);
            }
        }
        EigenValueDecomposition eig = new EigenValueDecomposition(K);
        double[] eigenVals = eig.getRealEigenvalues();
        DenseVector eigNorm = new DenseVector(eigenVals.length);
        for (int i = 0; i < eigenVals.length; ++i) {
            eigNorm.set(i, 1.0 / Math.sqrt(Math.max(1.0E-7, eigenVals[i] + ridge)));
        }
        Matrix U = eig.getV();
        Matrix.diagMult(U, eigNorm);
        this.transform = U.multiply(eig.getVRaw());
        this.transform.mutableTranspose();
    }

    protected Nystrom(Nystrom toCopy) {
        this.k = toCopy.k.clone();
        this.basisVecs = new ArrayList<Vec>(toCopy.basisVecs.size());
        for (Vec v : toCopy.basisVecs) {
            this.basisVecs.add(v.clone());
        }
        if (toCopy.accelCache != null) {
            this.accelCache = new DoubleList(toCopy.accelCache);
        }
        this.transform = toCopy.transform.clone();
    }

    public static List<Vec> sampleBasisVectors(KernelTrick k, DataSet dataset, List<Vec> X, SamplingMethod method, int basisSize, boolean sampleWithReplacment, Random rand) {
        ArrayList<Vec> basisVecs = new ArrayList<Vec>(basisSize);
        int N = dataset.getSampleSize();
        switch (method) {
            case DIAGONAL: {
                double[] diags = new double[N];
                diags[0] = k.eval(X.get(0), X.get(0));
                for (int i = 1; i < N; ++i) {
                    diags[i] = diags[i - 1] + k.eval(X.get(i), X.get(i));
                }
                Nystrom.sample(basisSize, rand, diags, X, sampleWithReplacment, basisVecs);
                break;
            }
            case NORM: {
                int i;
                double[] norms = new double[N];
                ArrayList<DenseVector> gramVecs = new ArrayList<DenseVector>();
                for (int i2 = 0; i2 < N; ++i2) {
                    gramVecs.add(new DenseVector(N));
                }
                List<Double> tmpCache = k.getAccelerationCache(X);
                for (i = 0; i < N; ++i) {
                    ((Vec)gramVecs.get(i)).set(i, k.eval(i, i, X, tmpCache));
                    for (int j = i + 1; j < N; ++j) {
                        double val = k.eval(i, j, X, tmpCache);
                        ((Vec)gramVecs.get(i)).set(j, val);
                        ((Vec)gramVecs.get(j)).set(i, val);
                    }
                }
                norms[0] = ((Vec)gramVecs.get(0)).pNorm(2.0);
                for (i = 1; i < gramVecs.size(); ++i) {
                    norms[i] = norms[i - 1] + ((Vec)gramVecs.get(i)).pNorm(2.0);
                }
                Nystrom.sample(basisSize, rand, norms, X, sampleWithReplacment, basisVecs);
                break;
            }
            case KMEANS: {
                HamerlyKMeans kMeans = new HamerlyKMeans(new EuclideanDistance(), SeedSelectionMethods.SeedSelection.KPP);
                kMeans.setStoreMeans(true);
                kMeans.cluster(dataset, basisSize);
                basisVecs.addAll(kMeans.getMeans());
                break;
            }
            default: {
                if (sampleWithReplacment) {
                    IntSet sampled = new IntSet(basisSize);
                    while (sampled.size() < basisSize) {
                        sampled.add(rand.nextInt(N));
                    }
                    Iterator i$ = sampled.iterator();
                    while (i$.hasNext()) {
                        int indx = (Integer)i$.next();
                        basisVecs.add(X.get(indx));
                    }
                } else {
                    for (int i = 0; i < basisSize; ++i) {
                        basisVecs.add(X.get(rand.nextInt(N)));
                    }
                }
                break;
            }
        }
        return basisVecs;
    }

    private static void sample(int basisSize, Random rand, double[] weightSume, List<Vec> X, boolean sampleWithReplacment, List<Vec> basisVecs) {
        IntSet sampled = new IntSet(basisSize);
        double max = weightSume[weightSume.length - 1];
        for (int i = 0; i < basisSize; ++i) {
            double rndVal = rand.nextDouble() * max;
            int indx = Arrays.binarySearch(weightSume, rndVal);
            if (indx < 0) {
                indx = -indx - 1;
            }
            if (sampleWithReplacment) {
                basisVecs.add(X.get(indx));
                continue;
            }
            int size = sampled.size();
            sampled.add(indx);
            if (sampled.size() == size) {
                --i;
                continue;
            }
            basisVecs.add(X.get(indx));
        }
    }

    @Override
    public DataPoint transform(DataPoint dp) {
        Vec x = dp.getNumericalValues();
        List<Double> qi = this.k.getQueryInfo(x);
        DenseVector kVec = new DenseVector(this.basisVecs.size());
        for (int i = 0; i < this.basisVecs.size(); ++i) {
            ((Vec)kVec).set(i, this.k.eval(i, x, qi, this.basisVecs, this.accelCache));
        }
        return new DataPoint(kVec.multiply(this.transform), dp.getCategoricalValues(), dp.getCategoricalData(), dp.getWeight());
    }

    @Override
    public Nystrom clone() {
        return new Nystrom(this);
    }

    public static class NystromTransformFactory
    extends DataTransformFactoryParm {
        private double ridge;
        @Parameter.ParameterHolder
        private KernelTrick k;
        private int dimension;
        private SamplingMethod method;
        private boolean sampleWithReplacment;

        public NystromTransformFactory(KernelTrick k, int dimension, SamplingMethod method, double ridge, boolean sampleWithReplacment) {
            this.k = k;
            this.setDimension(dimension);
            this.setBasisSamplingMethod(method);
            this.setRidge(ridge);
            this.sampleWithReplacment = sampleWithReplacment;
        }

        public NystromTransformFactory(NystromTransformFactory toCopy) {
            this(toCopy.k.clone(), toCopy.dimension, toCopy.method, toCopy.ridge, toCopy.sampleWithReplacment);
        }

        public void setRidge(double ridge) {
            if (ridge < 0.0 || Double.isNaN(ridge) || Double.isInfinite(ridge)) {
                throw new IllegalArgumentException("Ridge must be non negative, not " + ridge);
            }
            this.ridge = ridge;
        }

        public double getRidge() {
            return this.ridge;
        }

        public void setDimension(int dimension) {
            if (dimension < 1) {
                throw new IllegalArgumentException("The number of dimensions must be positive, not " + dimension);
            }
            this.dimension = dimension;
        }

        public int getDimension() {
            return this.dimension;
        }

        public void setBasisSamplingMethod(SamplingMethod method) {
            this.method = method;
        }

        public SamplingMethod getBasisSamplingMethod() {
            return this.method;
        }

        @Override
        public DataTransform getTransform(DataSet dataset) {
            return new Nystrom(this.k, dataset, this.dimension, this.method, this.ridge, this.sampleWithReplacment);
        }

        @Override
        public NystromTransformFactory clone() {
            return new NystromTransformFactory(this.k.clone(), this.dimension, this.method, this.ridge, this.sampleWithReplacment);
        }
    }

    public static enum SamplingMethod {
        UNIFORM,
        DIAGONAL,
        NORM,
        KMEANS;

    }
}

