/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.neuralnetwork;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicIntegerArray;
import jsat.DataSet;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.svm.DCDs;
import jsat.clustering.SeedSelectionMethods;
import jsat.clustering.kmeans.HamerlyKMeans;
import jsat.datatransform.DataTransform;
import jsat.exceptions.FailedToFitException;
import jsat.linear.DenseVector;
import jsat.linear.SparseVector;
import jsat.linear.Vec;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.EuclideanDistance;
import jsat.math.OnLineStatistics;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.regression.RegressionDataSet;
import jsat.regression.Regressor;
import jsat.utils.BoundedSortedList;
import jsat.utils.DoubleList;
import jsat.utils.FakeExecutor;
import jsat.utils.IntList;
import jsat.utils.IntSet;
import jsat.utils.ListUtils;
import jsat.utils.SystemInfo;
import jsat.utils.random.XORWOW;

public class RBFNet
implements Classifier,
Regressor,
DataTransform,
Parameterized {
    private static final long serialVersionUID = 5418896646203518062L;
    private int numCentroids;
    private Phase1Learner p1l;
    private Phase2Learner p2l;
    private double alpha;
    private int p;
    private DistanceMetric dm;
    private boolean normalize = true;
    private Classifier baseClassifier;
    private Regressor baseRegressor;
    private List<Double> centroidDistCache;
    private List<Vec> centroids;
    private double[] bandwidths;

    public RBFNet(int numCentroids) {
        this(numCentroids, Phase1Learner.K_MEANS, Phase2Learner.NEAREST_OTHER_CENTROID_AVERAGE, 3.0, 3, (DistanceMetric)new EuclideanDistance(), new DCDs());
    }

    public RBFNet(int numCentroids, Phase1Learner cl, Phase2Learner bl, double alpha, int p, DistanceMetric dm, Classifier baseClassifier) {
        this.setNumCentroids(numCentroids);
        this.setPhase1Learner(cl);
        this.setPhase2Learner(bl);
        this.setAlpha(alpha);
        this.setP(p);
        this.setDistanceMetric(dm);
        this.baseClassifier = baseClassifier;
        if (baseClassifier instanceof Regressor) {
            this.baseRegressor = (Regressor)((Object)baseClassifier);
        }
    }

    public RBFNet(int numCentroids, Phase1Learner cl, Phase2Learner bl, double alpha, int p, DistanceMetric dm, Regressor baseRegressor) {
        this.setNumCentroids(numCentroids);
        this.setPhase1Learner(cl);
        this.setPhase2Learner(bl);
        this.setAlpha(alpha);
        this.setP(p);
        this.setDistanceMetric(dm);
        this.baseRegressor = baseRegressor;
        if (baseRegressor instanceof Classifier) {
            this.baseClassifier = (Classifier)((Object)baseRegressor);
        }
    }

    public RBFNet(RBFNet toCopy) {
        this.setNumCentroids(toCopy.getNumCentroids());
        this.setPhase1Learner(toCopy.getPhase1Learner());
        this.setPhase2Learner(toCopy.getPhase2Learner());
        this.setAlpha(toCopy.getAlpha());
        this.setP(toCopy.getP());
        this.setDistanceMetric(toCopy.getDistanceMetric().clone());
        if (toCopy.baseRegressor != null) {
            this.baseRegressor = toCopy.baseRegressor.clone();
            if (this.baseRegressor instanceof Classifier) {
                this.baseClassifier = (Classifier)((Object)this.baseRegressor);
            }
        } else if (toCopy.baseClassifier != null) {
            this.baseClassifier = toCopy.baseClassifier.clone();
            if (this.baseClassifier instanceof Regressor) {
                this.baseRegressor = (Regressor)((Object)this.baseClassifier);
            }
        }
        if (toCopy.centroids != null) {
            this.centroids = new ArrayList<Vec>(toCopy.centroids.size());
            for (Vec v : toCopy.centroids) {
                this.centroids.add(v.clone());
            }
            if (toCopy.centroidDistCache != null) {
                this.centroidDistCache = new DoubleList(toCopy.centroidDistCache);
            }
        }
        if (toCopy.bandwidths != null) {
            this.bandwidths = Arrays.copyOf(toCopy.bandwidths, toCopy.bandwidths.length);
        }
    }

    @Override
    public DataPoint transform(DataPoint dp) {
        Vec x = dp.getNumericalValues();
        List<Double> qi = this.dm.getQueryInfo(x);
        Vec sv = new SparseVector(this.numCentroids);
        double sum = 0.0;
        double maxActivation = Double.NEGATIVE_INFINITY;
        int highestNeuron = -1;
        for (int i = 0; i < this.centroids.size(); ++i) {
            double sig;
            double dist = this.dm.dist(i, x, qi, this.centroids, this.centroidDistCache);
            double activation = Math.exp(-(dist * dist) / ((sig = this.bandwidths[i]) * sig * 2.0));
            if (activation > maxActivation) {
                maxActivation = activation;
                highestNeuron = i;
            }
            if (!(activation > 1.0E-16)) continue;
            sv.set(i, activation);
            sum += activation;
        }
        if (sv.nnz() == 0) {
            sv.set(highestNeuron, maxActivation);
            sum = maxActivation;
        }
        if (this.normalize && sum != 0.0) {
            sv.mutableDivide(sum);
        }
        if (sv.nnz() > sv.length() / 2) {
            sv = new DenseVector(sv);
        }
        return new DataPoint(sv, dp.getCategoricalValues(), dp.getCategoricalData(), dp.getWeight());
    }

    public void setAlpha(double alpha) {
        if (alpha < 0.0 || Double.isInfinite(alpha) || Double.isNaN(alpha)) {
            throw new IllegalArgumentException("Alpha must be a positive value, not " + alpha);
        }
        this.alpha = alpha;
    }

    public double getAlpha() {
        return this.alpha;
    }

    public void setP(int p) {
        if (p < 1) {
            throw new IllegalArgumentException("neighbors parameter must be positive, not " + p);
        }
        this.p = p;
    }

    public int getP() {
        return this.p;
    }

    public void setNumCentroids(int numCentroids) {
        if (numCentroids < 1) {
            throw new IllegalArgumentException("Number of centroids must be positive, not " + numCentroids);
        }
        this.numCentroids = numCentroids;
    }

    public int getNumCentroids() {
        return this.numCentroids;
    }

    public void setDistanceMetric(DistanceMetric dm) {
        this.dm = dm;
    }

    public DistanceMetric getDistanceMetric() {
        return this.dm;
    }

    public void setPhase1Learner(Phase1Learner p1l) {
        this.p1l = p1l;
    }

    public Phase1Learner getPhase1Learner() {
        return this.p1l;
    }

    public void setPhase2Learner(Phase2Learner p2l) {
        this.p2l = p2l;
    }

    public Phase2Learner getPhase2Learner() {
        return this.p2l;
    }

    public void setNormalize(boolean normalize) {
        this.normalize = normalize;
    }

    public boolean isNormalize() {
        return this.normalize;
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        return this.baseClassifier.classify(this.transform(data));
    }

    @Override
    public void trainC(ClassificationDataSet dataSet, ExecutorService threadPool) {
        if (this.baseClassifier == null) {
            throw new FailedToFitException("RBFNet was not given a base classifier");
        }
        if (threadPool == null) {
            threadPool = new FakeExecutor();
        }
        this.centroids = this.p1l.getCentroids(dataSet, this.numCentroids, this.dm, threadPool);
        this.centroidDistCache = this.dm.getAccelerationCache(this.centroids, threadPool);
        this.bandwidths = this.p2l.estimateBandwidths(this.alpha, this.p, dataSet, this.centroids, this.centroidDistCache, this.dm, threadPool);
        ClassificationDataSet transformedData = dataSet.shallowClone();
        transformedData.applyTransform((DataTransform)this, threadPool);
        this.baseClassifier.trainC(transformedData, threadPool);
    }

    @Override
    public void trainC(ClassificationDataSet dataSet) {
        this.trainC(dataSet, null);
    }

    @Override
    public boolean supportsWeightedData() {
        if (this.baseClassifier != null) {
            return this.baseClassifier.supportsWeightedData();
        }
        return this.baseRegressor.supportsWeightedData();
    }

    @Override
    public double regress(DataPoint data) {
        return this.baseRegressor.regress(this.transform(data));
    }

    @Override
    public void train(RegressionDataSet dataSet, ExecutorService threadPool) {
        if (this.baseRegressor == null) {
            throw new FailedToFitException("RBFNet was not given a base classifier");
        }
        if (threadPool == null) {
            threadPool = new FakeExecutor();
        }
        this.centroids = this.p1l.getCentroids(dataSet, this.numCentroids, this.dm, threadPool);
        this.centroidDistCache = this.dm.getAccelerationCache(this.centroids, threadPool);
        this.bandwidths = this.p2l.estimateBandwidths(this.alpha, this.p, dataSet, this.centroids, this.centroidDistCache, this.dm, threadPool);
        RegressionDataSet transformedData = dataSet.shallowClone();
        transformedData.applyTransform((DataTransform)this, threadPool);
        this.baseRegressor.train(transformedData, threadPool);
    }

    @Override
    public void train(RegressionDataSet dataSet) {
        this.train(dataSet, null);
    }

    @Override
    public RBFNet clone() {
        return new RBFNet(this);
    }

    @Override
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override
    public Parameter getParameter(String paramName) {
        return Parameter.toParameterMap(this.getParameters()).get(paramName);
    }

    public static enum Phase2Learner {
        CENTROID_DISTANCE{

            @Override
            protected double[] estimateBandwidths(double alpha, int p, DataSet data, final List<Vec> centroids, final List<Double> centroidDistCache, final DistanceMetric dm, ExecutorService threadpool) {
                final double[] bandwidths = new double[centroids.size()];
                OnLineStatistics[] averages = new OnLineStatistics[bandwidths.length];
                for (int i = 0; i < averages.length; ++i) {
                    averages[i] = new OnLineStatistics();
                }
                ArrayList futures = new ArrayList(SystemInfo.LogicalCores);
                for (final List<Vec> subList : ListUtils.splitList(data.getDataVectors(), SystemInfo.LogicalCores)) {
                    Future<OnLineStatistics[]> future = threadpool.submit(new Callable<OnLineStatistics[]>(){

                        @Override
                        public OnLineStatistics[] call() {
                            OnLineStatistics[] localAverages = new OnLineStatistics[bandwidths.length];
                            for (int i = 0; i < localAverages.length; ++i) {
                                localAverages[i] = new OnLineStatistics();
                            }
                            for (Vec x : subList) {
                                double minDist = Double.POSITIVE_INFINITY;
                                int minI = 0;
                                for (int i = 0; i < centroids.size(); ++i) {
                                    double dist = dm.dist(i, x, (List<? extends Vec>)centroids, (List<Double>)centroidDistCache);
                                    if (!(dist < minDist)) continue;
                                    minDist = dist;
                                    minI = i;
                                }
                                localAverages[minI].add(minDist);
                            }
                            return localAverages;
                        }
                    });
                    futures.add(future);
                }
                try {
                    for (OnLineStatistics[] localAverages : ListUtils.collectFutures(futures)) {
                        for (int i = 0; i < localAverages.length; ++i) {
                            if (localAverages[i].getSumOfWeights() == 0.0) continue;
                            averages[i] = OnLineStatistics.add(averages[i], localAverages[i]);
                        }
                    }
                    for (int i = 0; i < bandwidths.length; ++i) {
                        bandwidths[i] = averages[i].getMean() + averages[i].getStandardDeviation() * alpha;
                    }
                }
                catch (InterruptedException ex) {
                    throw new FailedToFitException(ex);
                }
                catch (ExecutionException ex) {
                    throw new FailedToFitException(ex);
                }
                return bandwidths;
            }
        }
        ,
        CLOSEST_OPPOSITE_CENTROID{

            @Override
            protected double[] estimateBandwidths(final double alpha, int p, DataSet data, final List<Vec> centroids, final List<Double> centroidDistCache, final DistanceMetric dm, ExecutorService threadpool) {
                if (!(data instanceof ClassificationDataSet)) {
                    throw new FailedToFitException("CLOSEST_OPPOSITE_CENTROID only works for classification data sets");
                }
                final ClassificationDataSet cds = (ClassificationDataSet)data;
                final double[] bandwidths = new double[centroids.size()];
                final CountDownLatch latch0 = new CountDownLatch(SystemInfo.LogicalCores);
                final AtomicIntegerArray[] classLabels = new AtomicIntegerArray[centroids.size()];
                for (int i = 0; i < classLabels.length; ++i) {
                    classLabels[i] = new AtomicIntegerArray(cds.getClassSize());
                }
                IntList indices = new IntList(data.getSampleSize());
                ListUtils.addRange(indices, 0, data.getSampleSize(), 1);
                for (final List<Integer> subList : ListUtils.splitList(indices, SystemInfo.LogicalCores)) {
                    threadpool.submit(new Runnable(){

                        @Override
                        public void run() {
                            Iterator i$ = subList.iterator();
                            while (i$.hasNext()) {
                                int id = (Integer)i$.next();
                                Vec x = cds.getDataPoint(id).getNumericalValues();
                                double minDist = Double.POSITIVE_INFINITY;
                                int minI = 0;
                                for (int i = 0; i < centroids.size(); ++i) {
                                    double dist = dm.dist(i, x, (List<? extends Vec>)centroids, (List<Double>)centroidDistCache);
                                    if (!(dist < minDist)) continue;
                                    minDist = dist;
                                    minI = i;
                                }
                                classLabels[minI].incrementAndGet(cds.getDataPointCategory(id));
                            }
                            latch0.countDown();
                        }
                    });
                }
                try {
                    latch0.await();
                }
                catch (InterruptedException ex) {
                    throw new FailedToFitException(ex);
                }
                final int[] neuronClass = new int[centroids.size()];
                for (int i = 0; i < neuronClass.length; ++i) {
                    int maxVal = -1;
                    int maxClass = 0;
                    for (int j = 0; j < classLabels[i].length(); ++j) {
                        if (classLabels[i].get(j) <= maxVal) continue;
                        maxClass = j;
                        maxVal = classLabels[i].get(j);
                    }
                    neuronClass[i] = maxClass;
                }
                final CountDownLatch latch1 = new CountDownLatch(centroids.size());
                int i = 0;
                while (i < centroids.size()) {
                    final int center = i++;
                    threadpool.submit(new Runnable(){

                        @Override
                        public void run() {
                            int i;
                            double minDist = Double.POSITIVE_INFINITY;
                            for (i = 0; i < centroids.size(); ++i) {
                                if (neuronClass[center] == neuronClass[i]) continue;
                                minDist = Math.min(minDist, dm.dist(i, center, (List<? extends Vec>)centroids, (List<Double>)centroidDistCache));
                            }
                            if (Double.isInfinite(minDist)) {
                                for (i = 0; i < centroids.size(); ++i) {
                                    if (center == i) continue;
                                    minDist = Math.min(minDist, dm.dist(i, center, (List<? extends Vec>)centroids, (List<Double>)centroidDistCache));
                                }
                            }
                            bandwidths[center] = alpha * minDist;
                            latch1.countDown();
                        }
                    });
                }
                try {
                    latch1.await();
                }
                catch (InterruptedException ex) {
                    throw new FailedToFitException(ex);
                }
                return bandwidths;
            }
        }
        ,
        NEAREST_OTHER_CENTROID_AVERAGE{

            @Override
            protected double[] estimateBandwidths(final double alpha, final int p, DataSet data, final List<Vec> centroids, final List<Double> centroidDistCache, final DistanceMetric dm, ExecutorService threadpool) {
                final double[] bandwidths = new double[centroids.size()];
                final CountDownLatch latch = new CountDownLatch(centroids.size());
                int i = 0;
                while (i < centroids.size()) {
                    final int center = i++;
                    threadpool.submit(new Runnable(){

                        @Override
                        public void run() {
                            BoundedSortedList<Double> closestDistances = new BoundedSortedList<Double>(p);
                            for (int i = 0; i < centroids.size(); ++i) {
                                if (i == center) continue;
                                closestDistances.add(dm.dist(i, center, (List<? extends Vec>)centroids, (List<Double>)centroidDistCache));
                            }
                            OnLineStatistics stats = new OnLineStatistics();
                            Iterator i$ = closestDistances.iterator();
                            while (i$.hasNext()) {
                                double dist = (Double)i$.next();
                                stats.add(dist);
                            }
                            bandwidths[center] = stats.getMean() + alpha * stats.getStandardDeviation();
                            latch.countDown();
                        }
                    });
                }
                return bandwidths;
            }
        };


        protected abstract double[] estimateBandwidths(double var1, int var3, DataSet var4, List<Vec> var5, List<Double> var6, DistanceMetric var7, ExecutorService var8);
    }

    public static enum Phase1Learner {
        RANDOM{

            @Override
            protected List<Vec> getCentroids(DataSet data, int centroids, DistanceMetric dm, ExecutorService ex) {
                XORWOW rand = new XORWOW();
                ArrayList<Vec> toRet = new ArrayList<Vec>();
                IntSet points = new IntSet();
                while (points.size() < centroids) {
                    points.add(rand.nextInt(data.getSampleSize()));
                }
                Iterator i$ = points.iterator();
                while (i$.hasNext()) {
                    int i = (Integer)i$.next();
                    toRet.add(data.getDataPoint(i).getNumericalValues());
                }
                return toRet;
            }
        }
        ,
        K_MEANS{

            @Override
            protected List<Vec> getCentroids(DataSet data, int centroids, DistanceMetric dm, ExecutorService ex) {
                HamerlyKMeans kmeans = new HamerlyKMeans(dm, SeedSelectionMethods.SeedSelection.KPP);
                if (ex == null || ex instanceof FakeExecutor) {
                    kmeans.cluster(data, centroids);
                } else {
                    kmeans.cluster(data, centroids, ex);
                }
                return kmeans.getMeans();
            }
        };


        protected abstract List<Vec> getCentroids(DataSet var1, int var2, DistanceMetric var3, ExecutorService var4);
    }
}

