/*
 * Decompiled with CFR 0.152.
 */
package dataMining.classifiers.neuralnet.optimization.backpropagation;

import arrayTiTi.ArrayArithmetic;
import dataMining.classifiers.neuralnet.NeuralNetwork;
import dataMining.classifiers.neuralnet.functions.LossFunction;
import dataMining.classifiers.neuralnet.layers.BasicLayer;
import dataMining.classifiers.neuralnet.layers.InputLayer;
import dataMining.classifiers.neuralnet.layers.Layer;
import dataMining.classifiers.neuralnet.layers.convolutional.ConvolutionalLayer;
import dataMining.classifiers.neuralnet.layers.fullyconnected.FullyConnectedLayer;
import dataMining.classifiers.neuralnet.layers.fullyconnected.OutputLayer;
import dataMining.classifiers.neuralnet.layers.pooling.PoolingLayer;
import dataMining.classifiers.neuralnet.layers.pooling.SubSamplingLayer;
import dataMining.classifiers.neuralnet.neurons.BasicNeuron;
import dataMining.classifiers.neuralnet.neurons.Neuron;
import dataMining.classifiers.neuralnet.neurons.NeuronBP;
import dataMining.classifiers.neuralnet.neurons.NeuronBackProp;
import dataMining.classifiers.neuralnet.neurons.NeuronBackPropWeightsSharing;
import dataMining.classifiers.neuralnet.neurons.SimpleNeuron;
import dataMining.classifiers.neuralnet.optimization.NeuralNetworkTrainer;
import dataMining.classifiers.neuralnet.optimization.backpropagation.GradientBP;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import mathematics.functions.Differentiable;

public class BackPropagation
implements NeuralNetworkTrainer {
    public static final int OUTPUTS_CONVERGENCE = -1;
    public static final int WEIGHTS_CONVERGENCE = -2;
    public static final int BOTH_CONVERGENCES = -3;
    public static final int REDEFINED_CONVERGENCE = -4;
    protected int Convergence = -3;
    protected NeuralNetwork neuralnet = null;
    protected List<BasicLayer> Layers = null;
    protected InputLayer inputlayer = null;
    protected OutputLayer outputlayer = null;
    protected LossFunction Loss = null;
    protected int EpochCounter = -1;
    protected int MaxEpoch;
    protected float LearningRate = 0.1f;
    protected float safety = 0.75f;
    protected float[] convgrad = null;
    protected float CurrentOutputError;
    protected float OutputError;
    protected float LastOutputError;
    protected float OutputConvergence;
    protected float CurrentWeightsError;
    protected float WeightsError;
    protected float LastWeightsError;
    protected float WeightsConvergence;
    protected float ConvergenceLimit;
    protected boolean isTraining = false;
    GradientBP gradbp = null;

    public BackPropagation(LossFunction Loss, GradientBP gradbp) {
        this.Loss = Loss;
        this.gradbp = gradbp;
    }

    public BackPropagation(LossFunction Loss, GradientBP gradbp, float LearningRate, float safety, float ConvergenceLimit, int MaxEpoch, int ConvergenceMode) {
        this(Loss, gradbp);
        this.LearningRate(LearningRate);
        this.Safety(safety);
        this.ConvergenceLimit(ConvergenceLimit);
        this.MaximumEpoch(MaxEpoch);
        this.Convergence = ConvergenceMode;
    }

    @Override
    public void StartTraining() {
        this.EpochCounter = 0;
        this.CurrentOutputError = 0.0f;
        this.LastOutputError = Float.MAX_VALUE;
        this.OutputError = Float.MAX_VALUE;
        this.CurrentWeightsError = 0.0f;
        this.LastWeightsError = Float.MAX_VALUE;
        this.WeightsError = Float.MAX_VALUE;
        this.isTraining = true;
        NeuronBackProp neuron = null;
        BasicNeuron[] neurons = null;
        for (int i2 = 1; i2 < this.Layers.size(); ++i2) {
            BasicLayer layer = this.Layers.get(i2);
            neurons = layer.Neurons();
            if (neurons[0] instanceof NeuronBackProp) {
                if (layer instanceof FullyConnectedLayer) {
                    for (int j = 0; j < neurons.length; ++j) {
                        neuron = (NeuronBackProp)((Object)neurons[j]);
                        neuron.Allocate();
                        Arrays.fill(neuron.LearningRates(), this.LearningRate);
                        Arrays.fill(neuron.GradientsK1(), 0.0f);
                        Arrays.fill(neuron.GradientsK2(), 0.0f);
                        neuron = null;
                    }
                } else if (layer instanceof ConvolutionalLayer) {
                    int nbfm = layer.nbFeaturesMaps();
                    int width = layer.Width();
                    int height = layer.Height();
                    int length = width * height;
                    int pos = 0;
                    for (int z = 0; z < nbfm; ++z) {
                        neuron = (NeuronBackProp)((Object)neurons[pos++]);
                        neuron.Allocate();
                        Arrays.fill(neuron.LearningRates(), this.LearningRate);
                        Arrays.fill(neuron.GradientsK1(), 0.0f);
                        Arrays.fill(neuron.GradientsK2(), 0.0f);
                        for (int j = 1; j < length; ++j) {
                            NeuronBP n = (NeuronBP)neurons[pos++];
                            n.LearningRates = neuron.LearningRates();
                            n.GradientsK1 = neuron.GradientsK1();
                            n.GradientsK2 = neuron.GradientsK2();
                            n = null;
                        }
                    }
                }
            }
            neurons = null;
        }
    }

    @Override
    public void EpochOver(boolean verbose) {
        this.EpochOverBasic(verbose);
        if (this.EpochOverActivateSafety()) {
            this.EpochOverSafety();
        }
        ++this.EpochCounter;
        if (Math.abs(this.OutputConvergence) + Math.abs(this.WeightsConvergence) < this.ConvergenceLimit || this.MaxEpoch <= this.EpochCounter) {
            this.isTraining = false;
        }
    }

    protected void EpochOverBasic(boolean verbose) {
        this.LastOutputError = this.OutputError;
        this.OutputError = this.CurrentOutputError;
        this.OutputConvergence = this.OutputError - this.LastOutputError;
        this.LastWeightsError = this.WeightsError;
        this.WeightsError = this.CurrentWeightsError;
        this.WeightsConvergence = this.WeightsError - this.LastWeightsError;
        this.CurrentWeightsError = 0.0f;
        this.CurrentOutputError = 0.0f;
        if (verbose) {
            System.out.println("Epoch " + this.EpochCounter + " over.");
            System.out.println(" - Output:  error=" + this.OutputError + ",\t convergence=" + this.OutputConvergence + ";");
            System.out.println(" - Weights: error=" + this.WeightsError + ",\t convergence=" + this.WeightsConvergence + ";");
        }
    }

    protected boolean EpochOverActivateSafety() {
        boolean safetyon = false;
        switch (this.Convergence) {
            case -1: {
                safetyon |= 0.0f < this.OutputConvergence;
                break;
            }
            case -2: {
                safetyon |= 0.0f < this.WeightsConvergence;
                break;
            }
            case -3: {
                safetyon |= 0.0f < this.OutputConvergence;
                safetyon |= 0.0f < this.WeightsConvergence;
                break;
            }
            case -4: {
                throw new IllegalArgumentException("The EpochOver method must be redefined into the extended class.");
            }
            default: {
                throw new IllegalArgumentException("Unknow convergence mode: " + this.Convergence);
            }
        }
        return safetyon;
    }

    protected void EpochOverSafety() {
        for (int i2 = 1; i2 < this.Layers.size(); ++i2) {
            BasicNeuron[] neurons = this.Layers.get(i2).Neurons();
            if (neurons[0] instanceof NeuronBackProp) {
                for (int j = 0; j < neurons.length; ++j) {
                    ((NeuronBackProp)((Object)neurons[j])).MultiplyLearningRates(this.safety);
                }
            }
            Object var2_2 = null;
        }
        this.gradbp.MultipyMomentums(this.safety);
    }

    @Override
    public void InitializeWeights() {
        Random rand = new Random();
        for (int l = 1; l < this.Layers.size(); ++l) {
            float[] weights;
            double sd;
            Layer layer = (Layer)this.Layers.get(l);
            SimpleNeuron[] neurons = layer.Neurons();
            int width = layer.Width();
            int height = layer.Height();
            if (layer instanceof FullyConnectedLayer) {
                sd = 0.5 / Math.sqrt(((Neuron)neurons[0]).Combination.Dimension());
                for (int x = 0; x < neurons.length; ++x) {
                    weights = ((Neuron)neurons[x]).Combination.getCoefficients((float[])null);
                    for (int w = 0; w < weights.length; ++w) {
                        weights[w] = (float)(rand.nextGaussian() * sd);
                    }
                    weights = null;
                }
            } else if (layer instanceof ConvolutionalLayer) {
                int pos = 0;
                int map = 0;
                while (map < layer.nbFeaturesMaps()) {
                    sd = 0.5 / Math.sqrt(((Neuron)neurons[pos]).Combination.Dimension());
                    weights = ((Neuron)neurons[pos]).Combination.getCoefficients((float[])null);
                    for (int ne = 0; ne < weights.length; ++ne) {
                        weights[ne] = (float)(rand.nextGaussian() * sd);
                    }
                    ++map;
                    pos += width * height;
                }
            } else if (!(layer instanceof PoolingLayer) && !(layer instanceof SubSamplingLayer)) {
                throw new IllegalArgumentException("Layer " + l + ", type not supported (yet).");
            }
            layer = null;
            neurons = null;
        }
        rand = null;
    }

    @Override
    public void OptimizeWeights(float[] input, int nbfeaturesmap, int width, int height, float[] output, float weight) throws Exception {
        Layer layer;
        int l;
        if (!this.isTraining) {
            throw new Exception("The NN is not currently training.");
        }
        if (weight <= 0.0f) {
            throw new IllegalArgumentException("weight <= 0f.");
        }
        if (this.outputlayer.Neurons().length != output.length) {
            throw new IllegalArgumentException("The given output does not match with the NN output layer: " + this.outputlayer.Neurons().length + " (NN) vs " + output.length + " (output).");
        }
        this.inputlayer.SetInputValues(input, nbfeaturesmap, width, height);
        this.FeedForward();
        NeuronBackProp[] neurons = null;
        for (int i2 = 1; i2 < this.Layers.size(); ++i2) {
            neurons = (NeuronBackProp[])this.Layers.get(i2).Neurons();
            for (int j = 0; j < neurons.length; ++j) {
                neurons[j].ResetError();
            }
            neurons = null;
        }
        this.CurrentOutputError += this.Loss.Compute((Neuron[])this.outputlayer.Neurons(), output, weight);
        boolean last = false;
        for (l = this.Layers.size() - 1; l > 0; --l) {
            layer = (Layer)this.Layers.get(l);
            if (l == 1) {
                last = true;
            }
            if (layer instanceof FullyConnectedLayer) {
                this.FullConnLERP(l, last);
            } else if (layer instanceof ConvolutionalLayer) {
                this.ConvLERP(l, last);
            } else if (layer instanceof PoolingLayer) {
                this.PooLERP(l, last);
            } else if (layer instanceof SubSamplingLayer) {
                this.SubSampLERP(l, last);
            } else {
                throw new IllegalArgumentException("Layer " + l + ": type not supported (yet).");
            }
            layer = null;
        }
        for (l = this.Layers.size() - 1; l > 0; --l) {
            layer = (Layer)this.Layers.get(l);
            if (layer instanceof FullyConnectedLayer) {
                this.FullConnLWU(l);
            } else if (layer instanceof ConvolutionalLayer) {
                this.ConvLWU(l);
            } else if (!(layer instanceof PoolingLayer) && !(layer instanceof SubSamplingLayer)) {
                throw new IllegalArgumentException("Layer " + l + ": type not supported (yet).");
            }
            layer = null;
        }
    }

    protected void FeedForward() {
        for (int i2 = 1; i2 < this.Layers.size(); ++i2) {
            ((Layer)this.Layers.get(i2)).Compute();
        }
    }

    protected void FullConnLERP(int layernum, boolean last) throws Exception {
        Layer layer = (Layer)this.Layers.get(layernum);
        NeuronBackProp[] neurons = (NeuronBackProp[])layer.Neurons();
        for (int ne = 0; ne < neurons.length; ++ne) {
            neurons[ne].RetropropagateError(last);
        }
        layer = null;
        neurons = null;
    }

    protected void ConvLERP(int layernum, boolean last) throws Exception {
        Layer layer = (Layer)this.Layers.get(layernum);
        NeuronBackProp[] neurons = (NeuronBackProp[])layer.Neurons();
        for (int ne = 0; ne < neurons.length; ++ne) {
            neurons[ne].RetropropagateError(last);
        }
        layer = null;
        neurons = null;
    }

    protected void PooLERP(int layernum, boolean last) throws Exception {
        if (last) {
            return;
        }
        Layer layer = (Layer)this.Layers.get(layernum);
        NeuronBackProp[] neurons = (NeuronBackProp[])layer.Neurons();
        for (int n = 0; n < neurons.length; ++n) {
            neurons[n].RetropropagateError(last);
        }
        neurons = null;
        layer = null;
    }

    protected void SubSampLERP(int layernum, boolean last) throws Exception {
        throw new Error("Method not implemented (yet)");
    }

    protected void FullConnLWU(int layernum) {
        Layer layer = (Layer)this.Layers.get(layernum);
        NeuronBackProp[] neurons = (NeuronBackProp[])layer.Neurons();
        for (int n = 0; n < neurons.length; ++n) {
            this.CurrentWeightsError += neurons[n].UpdateWeights(this.gradbp);
        }
        neurons = null;
        layer = null;
    }

    protected void ConvLWU(int layernum) {
        Layer layer = (Layer)this.Layers.get(layernum);
        NeuronBackPropWeightsSharing[] neurons = (NeuronBackPropWeightsSharing[])layer.Neurons();
        int nbfeaturesmap = layer.nbFeaturesMaps();
        int width = layer.Width();
        int height = layer.Height();
        int length = width * height;
        int pos = 0;
        for (int z = 0; z < nbfeaturesmap; ++z) {
            Arrays.fill(this.convgrad, 0.0f);
            for (int y = 0; y < height; ++y) {
                int x = 0;
                while (x < width) {
                    neurons[pos].UpdateWeights(this.convgrad);
                    ++x;
                    ++pos;
                }
            }
            ArrayArithmetic.Divide((float[])this.convgrad, (float)length, (float[])this.convgrad);
            neurons[pos - 1].UpdateWeights(this.convgrad, this.gradbp);
        }
        neurons = null;
        layer = null;
    }

    @Override
    public boolean Check(NeuralNetwork nn) throws IllegalArgumentException {
        int max = 0;
        int numlayer = 0;
        List<BasicLayer> layerss = nn.Layers();
        Iterator<BasicLayer> iter = layerss.iterator();
        this.inputlayer = (InputLayer)iter.next();
        BasicLayer layer = null;
        while (iter.hasNext()) {
            layer = iter.next();
            SimpleNeuron[] neurons = ((Layer)layer).Neurons();
            ++numlayer;
            if (!(layer.NeuronSample() instanceof NeuronBackProp)) {
                throw new IllegalArgumentException("The layer " + numlayer + " requires NeuronBackProp neurons during training.");
            }
            if (layer instanceof FullyConnectedLayer || layer instanceof ConvolutionalLayer) {
                if (((Layer)layer).Activation() == null) {
                    throw new IllegalArgumentException("The layer " + numlayer + " has no activation function defined.");
                }
                if (((Layer)layer).Activation().hasConstant()) {
                    throw new IllegalArgumentException("The layer " + numlayer + ": constant in the activation function not supported (yet).");
                }
                if (!(((Layer)layer).Activation() instanceof Differentiable)) {
                    throw new IllegalArgumentException("The layer " + numlayer + " has a non differentiable activation function.");
                }
            }
            if (!(layer instanceof ConvolutionalLayer)) continue;
            if (((Layer)layer).Activation() == null) {
                throw new IllegalArgumentException("The layer " + numlayer + " has no activation function defined.");
            }
            if (((Layer)layer).Activation().hasConstant()) {
                throw new IllegalArgumentException("The layer " + numlayer + ": constant in the activation function not supported (yet).");
            }
            int pos = 0;
            int z = 0;
            while (z < layer.nbFeaturesMaps()) {
                if (max < ((Neuron)neurons[pos]).Combination.Dimension()) {
                    max = ((Neuron)neurons[pos]).Combination.Dimension();
                }
                ++z;
                pos += layer.Width() * layer.Height();
            }
        }
        this.outputlayer = (OutputLayer)nn.LastAddedLayer();
        if (max != 0) {
            this.convgrad = new float[max];
        }
        if (!this.Loss.Compatible(nn)) {
            throw new IllegalArgumentException("The loss function does not support the output layer neurons.");
        }
        return true;
    }

    @Override
    public void LinkNeuralNetwork(NeuralNetwork nn) {
        if (!this.Check(nn)) {
            throw new IllegalArgumentException("NeuralNetwork not compatible.");
        }
        this.neuralnet = nn;
        this.Layers = this.neuralnet.Layers();
    }

    @Override
    public void MaximumEpoch(int limitmax) {
        if (limitmax < 1) {
            throw new IllegalArgumentException("limitmax < 1.");
        }
        this.MaxEpoch = limitmax;
    }

    @Override
    public int MaximumEpoch() {
        return this.MaxEpoch;
    }

    @Override
    public int Epoch() {
        return this.EpochCounter;
    }

    @Override
    public void LearningRate(float rate) {
        if (Float.compare(this.LearningRate, 0.0f) <= 0 || 1.0f < this.LearningRate) {
            throw new IllegalArgumentException("The learning rate must be in the range ]0,1].");
        }
        this.LearningRate = rate;
    }

    @Override
    public float LearningRate() {
        return this.LearningRate;
    }

    @Override
    public float[] Momentums() {
        return this.gradbp.Momentums();
    }

    @Override
    public void ConvergenceLimit(float limit) {
        if (Float.compare(limit, 0.0f) <= 0) {
            throw new IllegalArgumentException("The convergence limit must be positive.");
        }
        this.ConvergenceLimit = limit;
    }

    @Override
    public float ConvergenceLimit() {
        return this.ConvergenceLimit;
    }

    @Override
    public void Safety(float safety) {
        if (Float.compare(safety, 0.0f) <= 0 || 0 < Float.compare(safety, 1.0f)) {
            throw new IllegalArgumentException("The safety coefficient must be into the range ]0,1[.");
        }
        this.safety = safety;
    }

    @Override
    public float Safety() {
        return this.safety;
    }

    @Override
    public float OutputConvergence() {
        return this.OutputConvergence;
    }

    @Override
    public float OutputError() {
        return this.OutputError;
    }

    @Override
    public float WeightsConvergence() {
        return this.WeightsConvergence;
    }

    @Override
    public float WeightsError() {
        return this.WeightsError;
    }

    @Override
    public boolean isTraining() {
        return this.isTraining;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder(1013);
        sb.append(this.getClass().getName()).append(":\n");
        sb.append(" - Loss function = ").append(this.Loss.getClass().getName()).append(":\n");
        sb.append(" - MaxEpoch = ").append(String.valueOf(this.MaxEpoch)).append("\n");
        sb.append(" - LearningRate = ").append(String.valueOf(this.LearningRate)).append("\n");
        sb.append(" - Momentums = ");
        float[] momentums = this.gradbp.Momentums();
        for (int i2 = 0; i2 < momentums.length - 1; ++i2) {
            sb.append(String.valueOf(momentums[i2])).append(", ");
        }
        sb.append(String.valueOf(momentums[momentums.length - 1])).append("\n");
        momentums = null;
        sb.append(" - ConvergenceLimit = ").append(String.valueOf(this.ConvergenceLimit)).append("\n");
        return sb.toString();
    }
}

