/*
 * Decompiled with CFR 0.152.
 */
package dataMining.classifiers.neuralnet.optimization.backpropagation;

import arrayTiTi.ArrayArithmetic;
import dataMining.classifiers.neuralnet.NeuralNetwork;
import dataMining.classifiers.neuralnet.functions.LossFunction;
import dataMining.classifiers.neuralnet.layers.BasicLayer;
import dataMining.classifiers.neuralnet.layers.Layer;
import dataMining.classifiers.neuralnet.layers.fullyconnected.FullyConnectedLayer;
import dataMining.classifiers.neuralnet.layers.fullyconnected.OutputLayer;
import dataMining.classifiers.neuralnet.neurons.BasicNeuron;
import dataMining.classifiers.neuralnet.neurons.Neuron;
import dataMining.classifiers.neuralnet.neurons.NeuronBP;
import dataMining.classifiers.neuralnet.neurons.NeuronBackProp;
import dataMining.classifiers.neuralnet.optimization.backpropagation.BackPropagation;
import dataMining.classifiers.neuralnet.optimization.backpropagation.GradientBP;
import mathematics.functions.Differentiable;
import mathematics.functions.Function1D;

public class BackPropDropConnect
extends BackPropagation {
    protected float fcl = 0.5f;
    protected boolean[][] dropout = null;
    protected float[][] memweights = null;

    public BackPropDropConnect(LossFunction Loss, GradientBP gradbp, float LearningRate, float safety, float ConvergenceLimit, int MaxEpoch) {
        super(Loss, gradbp, LearningRate, safety, ConvergenceLimit, MaxEpoch, -4);
    }

    public BackPropDropConnect(LossFunction Loss, GradientBP gradbp, float LearningRate, float safety, float ConvergenceLimit, int MaxEpoch, float fcl) {
        super(Loss, gradbp, LearningRate, safety, ConvergenceLimit, MaxEpoch, -4);
        if (Float.compare(fcl, 0.0f) <= 0 || 0 < Float.compare(fcl, 100.0f)) {
            throw new IllegalArgumentException("The ilp coefficient must be into the range ]0,100[.");
        }
        this.fcl = fcl / 100.0f;
    }

    @Override
    public void EpochOver(boolean verbose) {
        int j;
        int i2;
        this.EpochOverBasic(verbose);
        if (0.0f < this.CurrentWeightsError - this.WeightsError || this.LastWeightsError < this.CurrentWeightsError) {
            BasicNeuron[] neurons = null;
            for (i2 = 1; i2 < this.Layers.size(); ++i2) {
                neurons = ((BasicLayer)this.Layers.get(i2)).Neurons();
                if (neurons[0] instanceof NeuronBackProp) {
                    for (j = 0; j < neurons.length; ++j) {
                        ((NeuronBackProp)((Object)neurons[j])).MultiplyLearningRates(this.safety);
                    }
                }
                neurons = null;
            }
        }
        ++this.EpochCounter;
        if (Math.abs(this.WeightsConvergence) < this.ConvergenceLimit || this.MaxEpoch <= this.EpochCounter) {
            this.isTraining = false;
        }
        if (!this.isTraining) {
            float c = 1.0f - this.fcl;
            BasicNeuron[] neurons = null;
            for (i2 = 1; i2 < this.Layers.size() - 1; ++i2) {
                if (!(this.Layers.get(i2) instanceof FullyConnectedLayer)) continue;
                neurons = ((BasicLayer)this.Layers.get(i2)).Neurons();
                for (j = 0; j < neurons.length; ++j) {
                    float[] weights = ((Neuron)neurons[j]).Combination.getCoefficients((float[])null);
                    ArrayArithmetic.Multiply((float[])weights, (float)c, (float[])weights);
                    weights = null;
                }
                neurons = null;
            }
        }
    }

    @Override
    protected void FeedForward() {
        int length = this.Layers.size();
        for (int i2 = 1; i2 < length; ++i2) {
            Layer layer = (Layer)this.Layers.get(i2);
            boolean[] drop = this.dropout[i2];
            float[] mem = this.memweights[i2];
            if (drop != null) {
                Neuron[] neurons = (Neuron[])layer.Neurons();
                int pos = 0;
                for (int x = 0; x < neurons.length; ++x) {
                    float[] weights = neurons[x].Combination.getCoefficients((float[])null);
                    int w = 0;
                    while (w < weights.length) {
                        if (Math.random() < (double)this.fcl) {
                            drop[pos] = true;
                            mem[pos] = weights[w];
                            weights[w] = 0.0f;
                        } else {
                            drop[pos] = false;
                        }
                        ++w;
                        ++pos;
                    }
                    weights = null;
                }
                Object var2_4 = null;
            }
            layer.Compute();
            layer = null;
            drop = null;
            mem = null;
        }
    }

    @Override
    protected void FullConnLERP(int layernum, boolean last) throws Exception {
        Layer layer = (Layer)this.Layers.get(layernum);
        Neuron[] neurons = (Neuron[])layer.Neurons();
        Function1D derivative = ((Differentiable)layer.Activation()).Derivative();
        if (layer.Activation().hasConstant()) {
            throw new Exception("Constant in the activation function not supported (yet).");
        }
        int length = neurons[0].Combination.Dimension();
        if (layer.hasVirtualBias()) {
            --length;
        }
        boolean[] drop = this.dropout[layernum];
        int pos = 0;
        for (int n = 0; n < neurons.length; ++n) {
            NeuronBP neuron = (NeuronBP)neurons[n];
            neuron.Error *= derivative.Compute(neuron.Combined);
            if (!last) {
                int ne;
                float[] weights = neuron.Combination.getCoefficients((float[])null);
                BasicNeuron[] connexions = neuron.Synapses();
                if (drop == null) {
                    ne = 0;
                    while (ne < length) {
                        ((NeuronBackProp)((Object)connexions[ne])).AddError(neuron.Error * weights[ne]);
                        ++ne;
                        ++pos;
                    }
                } else {
                    ne = 0;
                    while (ne < length) {
                        if (!drop[pos]) {
                            ((NeuronBackProp)((Object)connexions[ne])).AddError(neuron.Error * weights[ne]);
                        }
                        ++ne;
                        ++pos;
                    }
                }
                weights = null;
                connexions = null;
            }
            neuron = null;
        }
        layer = null;
        neurons = null;
        derivative = null;
        drop = null;
    }

    @Override
    protected void FullConnLWU(int layernum) {
        Layer layer = (Layer)this.Layers.get(layernum);
        Neuron[] neurons = (Neuron[])layer.Neurons();
        float[] mem = this.memweights[layernum];
        boolean[] drop = this.dropout[layernum];
        int pos = 0;
        for (int n = 0; n < neurons.length; ++n) {
            NeuronBP neuron = (NeuronBP)neurons[n];
            BasicNeuron[] connexions = neuron.Synapses();
            float[] weights = neuron.Combination.getCoefficients((float[])null);
            int ne = 0;
            while (ne < weights.length) {
                if (drop == null || !drop[pos]) {
                    float grad = this.gradbp.Compute(neuron.Error * connexions[ne].Output, ne, neuron.LearningRates, neuron.GradientsK1, neuron.GradientsK2);
                    int n2 = ne;
                    weights[n2] = weights[n2] - grad;
                    this.CurrentWeightsError += Math.abs(grad);
                } else {
                    weights[ne] = mem[pos];
                }
                ++ne;
                ++pos;
            }
            weights = null;
            neuron = null;
            connexions = null;
        }
        layer = null;
        neurons = null;
        mem = null;
        drop = null;
    }

    @Override
    public void LinkNeuralNetwork(NeuralNetwork nn) {
        this.Check(nn);
        this.neuralnet = nn;
        this.Layers = this.neuralnet.Layers();
        this.Allocate();
    }

    protected void Allocate() {
        this.dropout = new boolean[this.Layers.size()][];
        this.memweights = new float[this.dropout.length][];
        for (int i2 = 0; i2 < this.Layers.size(); ++i2) {
            if (!(this.Layers.get(i2) instanceof FullyConnectedLayer) || this.Layers.get(i2) instanceof OutputLayer) continue;
            this.dropout[i2] = new boolean[((BasicLayer)this.Layers.get(i2)).nbWeights()];
            this.memweights[i2] = new float[this.dropout[i2].length];
        }
    }
}

