/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.neuralnetwork;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ExecutorService;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.DataPointPair;
import jsat.classifiers.PriorClassifier;
import jsat.classifiers.bayesian.MultivariateNormals;
import jsat.classifiers.neuralnetwork.LVQ;
import jsat.linear.Vec;
import jsat.linear.VecPaired;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.math.decayrates.DecayRate;

public class LVQLLC
extends LVQ {
    private static final long serialVersionUID = 3602640001545233744L;
    private Classifier localClassifier;
    private Classifier[] localClassifeirs;

    public LVQLLC(DistanceMetric dm, int iterations) {
        this(dm, iterations, new MultivariateNormals(true));
    }

    public LVQLLC(DistanceMetric dm, int iterations, Classifier localClasifier) {
        super(dm, iterations);
        this.setLocalClassifier(localClasifier);
    }

    public LVQLLC(DistanceMetric dm, int iterations, Classifier localClasifier, double learningRate, int representativesPerClass) {
        super(dm, iterations, learningRate, representativesPerClass);
        this.setLocalClassifier(localClasifier);
    }

    public LVQLLC(DistanceMetric dm, int iterations, Classifier localClasifier, double learningRate, int representativesPerClass, LVQ.LVQVersion lvqVersion, DecayRate learningDecay) {
        super(dm, iterations, learningRate, representativesPerClass, lvqVersion, learningDecay);
        this.setLocalClassifier(localClasifier);
    }

    protected LVQLLC(LVQLLC toCopy) {
        super(toCopy);
        if (toCopy.localClassifier != null) {
            this.localClassifier = toCopy.localClassifier.clone();
        }
        if (toCopy.localClassifeirs != null) {
            this.localClassifeirs = new Classifier[toCopy.localClassifeirs.length];
            for (int i = 0; i < this.localClassifeirs.length; ++i) {
                this.localClassifeirs[i] = toCopy.localClassifeirs[i].clone();
            }
        }
    }

    public void setLocalClassifier(Classifier localClassifier) {
        this.localClassifier = localClassifier;
    }

    public Classifier getLocalClassifier() {
        return this.localClassifier;
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        List nns = this.vc.search(data.getNumericalValues(), 2);
        double d1 = nns.get(0).getPair();
        int index1 = (Integer)((VecPaired)nns.get(0).getVector()).getPair();
        double d2 = nns.get(1).getPair();
        int index2 = (Integer)((VecPaired)nns.get(1).getVector()).getPair();
        CategoricalResults r1 = this.localClassifeirs[index1].classify(data);
        if (this.getLVQMethod().ordinal() >= LVQ.LVQVersion.LVQ2.ordinal() && this.epsClose(d1, d2)) {
            CategoricalResults result = new CategoricalResults(r1.size());
            CategoricalResults r2 = this.localClassifeirs[index2].classify(data);
            double distSum = d1 + d2;
            for (int i = 0; i < r1.size(); ++i) {
                result.incProb(i, r1.getProb(i) * (distSum - d1));
                result.incProb(i, r2.getProb(i) * (distSum - d2));
            }
            result.normalize();
            return result;
        }
        return r1;
    }

    @Override
    public void trainC(ClassificationDataSet dataSet, ExecutorService threadPool) {
        super.trainC(dataSet, threadPool);
        ArrayList listOfLocalPoints = new ArrayList(this.weights.length);
        for (int i = 0; i < this.weights.length; ++i) {
            listOfLocalPoints.add(new ArrayList(this.wins[i] * 3 / 2));
        }
        for (DataPointPair<Integer> dpp : dataSet.getAsDPPList()) {
            Vec x = dpp.getVector();
            int minDistIndx = 0;
            int minDistIndx2 = 0;
            double minDist = Double.POSITIVE_INFINITY;
            double minDist2 = Double.POSITIVE_INFINITY;
            List closestWeightVecs = this.vc.search(x, 2);
            VecPaired closest = closestWeightVecs.get(0);
            minDistIndx = (Integer)((VecPaired)closest.getVector()).getPair();
            minDist = closest.getPair();
            VecPaired closest2nd = closestWeightVecs.get(0);
            minDistIndx2 = (Integer)((VecPaired)closest2nd.getVector()).getPair();
            minDist2 = closest2nd.getPair();
            ((List)listOfLocalPoints.get(minDistIndx)).add(dpp);
            double tmpEps = this.getEpsilonDistance();
            if (!(Math.min(minDist / minDist2, minDist2 / minDist) > 1.0 - tmpEps) || !(Math.max(minDist / minDist2, minDist2 / minDist) < 1.0 + tmpEps)) continue;
            ((List)listOfLocalPoints.get(minDistIndx2)).add(dpp);
        }
        this.localClassifeirs = new Classifier[this.weights.length];
        for (int i = 0; i < this.weights.length; ++i) {
            if (this.wins[i] == 0) continue;
            ClassificationDataSet localSet = new ClassificationDataSet((List)listOfLocalPoints.get(i), dataSet.getPredicting());
            if (this.wins[i] < 10) {
                CategoricalResults cr = new CategoricalResults(dataSet.getPredicting().getNumOfCategories());
                cr.setProb(this.weightClass[i], 1.0);
                this.localClassifeirs[i] = new PriorClassifier(cr);
                continue;
            }
            this.localClassifeirs[i] = this.localClassifier.clone();
            this.localClassifeirs[i].trainC(localSet);
        }
    }

    @Override
    public void trainC(ClassificationDataSet dataSet) {
        this.trainC(dataSet, null);
    }

    @Override
    public LVQLLC clone() {
        return new LVQLLC(this);
    }
}

