/*
 * Decompiled with CFR 0.152.
 */
package deepboof.impl.backward.standard;

import deepboof.backward.DSpatialBatchNorm;
import deepboof.impl.backward.standard.BaseDBatchNorm_F64;
import deepboof.misc.TensorOps;
import deepboof.tensors.Tensor_F64;
import java.util.List;

public class DSpatialBatchNorm_F64
extends BaseDBatchNorm_F64
implements DSpatialBatchNorm<Tensor_F64> {
    int numChannels;
    int numPixels;
    double M;
    double M_var;

    public DSpatialBatchNorm_F64(boolean requiresGammaBeta) {
        super(requiresGammaBeta);
    }

    @Override
    protected int[] createShapeVariables(int[] shapeInput) {
        return new int[]{shapeInput[0]};
    }

    @Override
    public void _forward(Tensor_F64 input, Tensor_F64 output) {
        if (input.length(0) <= 1) {
            throw new IllegalArgumentException("There must be more than 1 minibatch");
        }
        this.tensorDiffX.reshape(input.shape);
        this.tensorXhat.reshape(input.shape);
        this.numChannels = input.length(1);
        this.numPixels = TensorOps.outerLength(input.shape, 2);
        this.M = this.miniBatchSize * this.numPixels;
        this.M_var = this.M - 1.0;
        if (this.learningMode) {
            this.forwardLearning(input, output);
        } else {
            this.forwardEvaluate(input, output);
        }
    }

    private void forwardLearning(Tensor_F64 input, Tensor_F64 output) {
        this.computeStatisticsAndNormalize(input);
        if (this.requiresGammaBeta) {
            this.applyGammaBeta(output);
        } else {
            output.setTo(this.tensorXhat);
        }
    }

    public void forwardEvaluate(Tensor_F64 input, Tensor_F64 output) {
        int C = input.length(1);
        int W = input.length(2);
        int H = input.length(3);
        int D2 = W * H;
        int indexIn = input.startIndex;
        int indexOut = output.startIndex;
        if (this.hasGammaBeta()) {
            for (int batch = 0; batch < this.miniBatchSize; ++batch) {
                int indexP = this.params.startIndex;
                for (int channel = 0; channel < C; ++channel) {
                    double mean = this.tensorMean.d[channel];
                    double stdev_eps = this.tensorStd.d[channel];
                    double gamma = this.params.d[indexP++];
                    double beta = this.params.d[indexP++];
                    int end = indexIn + D2;
                    while (indexIn < end) {
                        output.d[indexOut++] = (input.d[indexIn++] - mean) * (gamma / stdev_eps) + beta;
                    }
                }
            }
        } else {
            for (int batch = 0; batch < this.miniBatchSize; ++batch) {
                for (int channel = 0; channel < C; ++channel) {
                    double mean = this.tensorMean.d[channel];
                    double stdev_eps = this.tensorStd.d[channel];
                    int end = indexIn + D2;
                    while (indexIn < end) {
                        output.d[indexOut++] = (input.d[indexIn++] - mean) / stdev_eps;
                    }
                }
            }
        }
    }

    private void applyGammaBeta(Tensor_F64 output) {
        int indexOut = output.startIndex;
        int indexTensor = 0;
        for (int stack = 0; stack < this.miniBatchSize; ++stack) {
            for (int channel = 0; channel < this.numChannels; ++channel) {
                double gamma = this.params.d[channel * 2];
                double beta = this.params.d[channel * 2 + 1];
                for (int pixel = 0; pixel < this.numPixels; ++pixel) {
                    output.d[indexOut++] = gamma * this.tensorXhat.d[indexTensor++] + beta;
                }
            }
        }
    }

    private void computeStatisticsAndNormalize(Tensor_F64 input) {
        int stack;
        this.tensorMean.zero();
        this.tensorStd.zero();
        this.tensorXhat.zero();
        int indexIn = input.startIndex;
        for (stack = 0; stack < this.miniBatchSize; ++stack) {
            int channel = 0;
            while (channel < this.numChannels) {
                double sum = 0.0;
                for (int pixel = 0; pixel < this.numPixels; ++pixel) {
                    sum += input.d[indexIn++];
                }
                int n = channel++;
                this.tensorMean.d[n] = this.tensorMean.d[n] + sum;
            }
        }
        int channel = 0;
        while (channel < this.numChannels) {
            int n = channel++;
            this.tensorMean.d[n] = this.tensorMean.d[n] / this.M;
        }
        indexIn = input.startIndex;
        int indexTensor = 0;
        for (stack = 0; stack < this.miniBatchSize; ++stack) {
            int channel2 = 0;
            while (channel2 < this.numChannels) {
                double sum = 0.0;
                double channelMean = this.tensorMean.d[channel2];
                int pixel = 0;
                while (pixel < this.numPixels) {
                    double d;
                    this.tensorDiffX.d[indexTensor] = d = input.d[indexIn++] - channelMean;
                    sum += d * d;
                    ++pixel;
                    ++indexTensor;
                }
                int n = channel2++;
                this.tensorStd.d[n] = this.tensorStd.d[n] + sum;
            }
        }
        for (channel = 0; channel < this.numChannels; ++channel) {
            this.tensorStd.d[channel] = Math.sqrt(this.tensorStd.d[channel] / this.M_var + this.EPS);
        }
        indexTensor = 0;
        for (stack = 0; stack < this.miniBatchSize; ++stack) {
            for (int channel3 = 0; channel3 < this.numChannels; ++channel3) {
                double channelStd = this.tensorStd.d[channel3];
                int pixel = 0;
                while (pixel < this.numPixels) {
                    this.tensorXhat.d[indexTensor] = this.tensorDiffX.d[indexTensor] / channelStd;
                    ++pixel;
                    ++indexTensor;
                }
            }
        }
    }

    @Override
    protected void _backwards(Tensor_F64 input, Tensor_F64 dout, Tensor_F64 gradientInput, List<Tensor_F64> gradientParameters) {
        this.tensorDXhat.reshape(input.shape);
        if (this.requiresGammaBeta) {
            this.partialXHat(dout);
        } else {
            this.tensorDXhat.setTo(dout);
        }
        this.partialVariance();
        this.partialMean();
        this.partialX(gradientInput);
        if (this.requiresGammaBeta) {
            this.partialParameters(gradientParameters.get(0), dout);
        }
    }

    private void partialParameters(Tensor_F64 tensorDParam, Tensor_F64 dout) {
        tensorDParam.zero();
        int indexDOut = dout.startIndex;
        int indexTensor = 0;
        for (int stack = 0; stack < this.miniBatchSize; ++stack) {
            int indexDParam = 0;
            for (int channel = 0; channel < this.numChannels; ++channel) {
                double sumDGamma = 0.0;
                double sumDBeta = 0.0;
                int pixel = 0;
                while (pixel < this.numPixels) {
                    double d = dout.d[indexDOut];
                    sumDGamma += d * this.tensorXhat.d[indexTensor];
                    sumDBeta += d;
                    ++pixel;
                    ++indexTensor;
                    ++indexDOut;
                }
                int n = indexDParam++;
                tensorDParam.d[n] = tensorDParam.d[n] + sumDGamma;
                int n2 = indexDParam++;
                tensorDParam.d[n2] = tensorDParam.d[n2] + sumDBeta;
            }
        }
    }

    private void partialXHat(Tensor_F64 dout) {
        int indexDOut = dout.startIndex;
        int indexTensor = 0;
        for (int stack = 0; stack < this.miniBatchSize; ++stack) {
            for (int channel = 0; channel < this.numChannels; ++channel) {
                double gamma = this.params.d[channel * 2];
                for (int pixel = 0; pixel < this.numPixels; ++pixel) {
                    this.tensorDXhat.d[indexTensor++] = dout.d[indexDOut++] * gamma;
                }
            }
        }
    }

    private void partialX(Tensor_F64 tensorDX) {
        int indexDX = tensorDX.startIndex;
        int indexTensor = 0;
        for (int stack = 0; stack < this.miniBatchSize; ++stack) {
            for (int channel = 0; channel < this.numChannels; ++channel) {
                double stdev = this.tensorStd.d[channel];
                double dvar = this.tensorDVar.d[channel];
                double dmean = this.tensorDMean.d[channel];
                int pixel = 0;
                while (pixel < this.numPixels) {
                    double val = this.tensorDXhat.d[indexTensor] / stdev;
                    tensorDX.d[indexDX] = val += dvar * 2.0 * this.tensorDiffX.d[indexTensor] / this.M_var + dmean / this.M;
                    ++pixel;
                    ++indexTensor;
                    ++indexDX;
                }
            }
        }
    }

    private void partialMean() {
        this.tensorDMean.zero();
        this.tensorTmp.zero();
        int indexTensor = 0;
        for (int stack = 0; stack < this.miniBatchSize; ++stack) {
            int channel = 0;
            while (channel < this.numChannels) {
                double sumTmp = 0.0;
                double sumDMean = 0.0;
                int pixel = 0;
                while (pixel < this.numPixels) {
                    sumTmp += this.tensorDiffX.d[indexTensor];
                    sumDMean -= this.tensorDXhat.d[indexTensor];
                    ++pixel;
                    ++indexTensor;
                }
                int n = channel;
                this.tensorTmp.d[n] = this.tensorTmp.d[n] + sumTmp;
                int n2 = channel++;
                this.tensorDMean.d[n2] = this.tensorDMean.d[n2] + sumDMean;
            }
        }
        for (int channel = 0; channel < this.numChannels; ++channel) {
            int n = channel;
            this.tensorDMean.d[n] = this.tensorDMean.d[n] / this.tensorStd.d[channel];
            int n3 = channel;
            this.tensorDMean.d[n3] = this.tensorDMean.d[n3] - 2.0 * this.tensorDVar.d[channel] * this.tensorTmp.d[channel] / this.M_var;
        }
    }

    private void partialVariance() {
        this.tensorDVar.zero();
        int indexTensor = 0;
        for (int stack = 0; stack < this.miniBatchSize; ++stack) {
            int channel = 0;
            while (channel < this.numChannels) {
                double sumDVar = 0.0;
                int pixel = 0;
                while (pixel < this.numPixels) {
                    sumDVar += this.tensorDXhat.d[indexTensor] * this.tensorDiffX.d[indexTensor];
                    ++pixel;
                    ++indexTensor;
                }
                int n = channel++;
                this.tensorDVar.d[n] = this.tensorDVar.d[n] + sumDVar;
            }
        }
        int channel = 0;
        while (channel < this.numChannels) {
            double sigmaPow3 = this.tensorStd.d[channel];
            sigmaPow3 = sigmaPow3 * sigmaPow3 * sigmaPow3;
            int n = channel++;
            this.tensorDVar.d[n] = this.tensorDVar.d[n] / (-2.0 * sigmaPow3);
        }
    }
}

