/*
 * Decompiled with CFR 0.152.
 */
package boofcv.alg.flow;

import boofcv.abst.filter.derivative.ImageGradient;
import boofcv.alg.flow.DenseFlowPyramidBase;
import boofcv.alg.interpolate.InterpolatePixelS;
import boofcv.alg.misc.ImageMiscOps;
import boofcv.factory.filter.derivative.FactoryDerivative;
import boofcv.factory.flow.ConfigHornSchunckPyramid;
import boofcv.struct.image.GrayF32;
import boofcv.struct.image.ImageGray;
import boofcv.struct.pyramid.ImagePyramid;

public class HornSchunckPyramid<T extends ImageGray<T>>
extends DenseFlowPyramidBase<T> {
    private final float alpha2;
    private final float SOR_RELAXATION;
    private final int numWarps;
    private final int maxInnerIterations;
    private final float convergeTolerance;
    private final ImageGradient<GrayF32, GrayF32> gradient = FactoryDerivative.three(GrayF32.class, GrayF32.class);
    private final GrayF32 deriv2X = new GrayF32(1, 1);
    private final GrayF32 deriv2Y = new GrayF32(1, 1);
    protected GrayF32 flowX = new GrayF32(1, 1);
    protected GrayF32 flowY = new GrayF32(1, 1);
    protected GrayF32 initFlowX = new GrayF32(1, 1);
    protected GrayF32 initFlowY = new GrayF32(1, 1);
    protected GrayF32 warpImage2 = new GrayF32(1, 1);
    protected GrayF32 warpDeriv2X = new GrayF32(1, 1);
    protected GrayF32 warpDeriv2Y = new GrayF32(1, 1);

    public HornSchunckPyramid(ConfigHornSchunckPyramid config, InterpolatePixelS<GrayF32> interp) {
        super(config.pyrScale, config.pyrSigma, config.pyrMaxLayers, interp);
        this.alpha2 = config.alpha * config.alpha;
        this.SOR_RELAXATION = config.SOR_RELAXATION;
        this.numWarps = config.numWarps;
        this.maxInnerIterations = config.maxInnerIterations;
        this.interp = interp;
        this.convergeTolerance = config.convergeTolerance;
    }

    @Override
    public void process(ImagePyramid<GrayF32> image1, ImagePyramid<GrayF32> image2) {
        boolean first = true;
        for (int i = image1.getNumLayers() - 1; i >= 0; --i) {
            GrayF32 layer1 = image1.getLayer(i);
            GrayF32 layer2 = image2.getLayer(i);
            this.deriv2X.reshape(layer1.width, layer1.height);
            this.deriv2Y.reshape(layer1.width, layer1.height);
            this.warpDeriv2X.reshape(layer1.width, layer1.height);
            this.warpDeriv2Y.reshape(layer1.width, layer1.height);
            this.warpImage2.reshape(layer1.width, layer1.height);
            this.gradient.process(layer2, this.deriv2X, this.deriv2Y);
            if (!first) {
                this.interpolateFlowScale(layer1.width, layer1.height);
            } else {
                first = false;
                this.initFlowX.reshape(layer1.width, layer1.height);
                this.initFlowY.reshape(layer1.width, layer1.height);
                this.flowX.reshape(layer1.width, layer1.height);
                this.flowY.reshape(layer1.width, layer1.height);
                ImageMiscOps.fill(this.flowX, 0.0f);
                ImageMiscOps.fill(this.flowY, 0.0f);
                ImageMiscOps.fill(this.initFlowX, 0.0f);
                ImageMiscOps.fill(this.initFlowY, 0.0f);
            }
            this.processLayer(layer1, layer2, this.deriv2X, this.deriv2Y);
        }
    }

    protected void interpolateFlowScale(int widthNew, int heightNew) {
        this.initFlowX.reshape(widthNew, heightNew);
        this.initFlowY.reshape(widthNew, heightNew);
        this.interpolateFlowScale(this.flowX, this.initFlowX);
        this.interpolateFlowScale(this.flowY, this.initFlowY);
        this.flowX.reshape(widthNew, heightNew);
        this.flowY.reshape(widthNew, heightNew);
        this.flowX.setTo(this.initFlowX);
        this.flowY.setTo(this.initFlowY);
    }

    @Override
    protected void interpolateFlowScale(GrayF32 prev, GrayF32 curr) {
        this.interp.setImage(prev);
        float scaleX = (float)(prev.width - 1) / (float)(curr.width - 1) * 0.999f;
        float scaleY = (float)(prev.height - 1) / (float)(curr.height - 1) * 0.999f;
        float scale = (float)prev.width / (float)curr.width;
        int indexCurr = 0;
        for (int y = 0; y < curr.height; ++y) {
            for (int x = 0; x < curr.width; ++x) {
                curr.data[indexCurr++] = this.interp.get((float)x * scaleX, (float)y * scaleY) / scale;
            }
        }
    }

    @Override
    protected void warpImageTaylor(GrayF32 before, GrayF32 flowX, GrayF32 flowY, GrayF32 after) {
        this.interp.setImage(before);
        for (int y = 0; y < before.height; ++y) {
            int pixelIndex = y * before.width;
            int x = 0;
            while (x < before.width) {
                float u = flowX.data[pixelIndex];
                float v = flowY.data[pixelIndex];
                float wx = (float)x + u;
                float wy = (float)y + v;
                after.data[pixelIndex] = wx < 0.0f || wx > (float)(before.width - 1) || wy < 0.0f || wy > (float)(before.height - 1) ? 0.0f : this.interp.get(wx, wy);
                ++x;
                ++pixelIndex;
            }
        }
    }

    protected void processLayer(GrayF32 image1, GrayF32 image2, GrayF32 derivX2, GrayF32 derivY2) {
        float w = this.SOR_RELAXATION;
        for (int warp = 0; warp < this.numWarps; ++warp) {
            float error;
            this.initFlowX.setTo(this.flowX);
            this.initFlowY.setTo(this.flowY);
            this.warpImageTaylor(derivX2, this.initFlowX, this.initFlowY, this.warpDeriv2X);
            this.warpImageTaylor(derivY2, this.initFlowX, this.initFlowY, this.warpDeriv2Y);
            this.warpImageTaylor(image2, this.initFlowX, this.initFlowY, this.warpImage2);
            int iter = 0;
            do {
                int x;
                error = 0.0f;
                for (int y = 1; y < image1.height - 1; ++y) {
                    int pixelIndex = y * image1.width + 1;
                    x = 1;
                    while (x < image1.width - 1) {
                        float vf;
                        float uf;
                        float ui = this.initFlowX.data[pixelIndex];
                        float vi = this.initFlowY.data[pixelIndex];
                        float u = this.flowX.data[pixelIndex];
                        float v = this.flowY.data[pixelIndex];
                        float I1 = image1.data[pixelIndex];
                        float I2 = this.warpImage2.data[pixelIndex];
                        float I2x = this.warpDeriv2X.data[pixelIndex];
                        float I2y = this.warpDeriv2Y.data[pixelIndex];
                        float AU = HornSchunckPyramid.A(x, y, this.flowX);
                        float AV = HornSchunckPyramid.A(x, y, this.flowY);
                        this.flowX.data[pixelIndex] = uf = (1.0f - w) * u + w * ((I1 - I2 + I2x * ui - I2y * (v - vi)) * I2x + this.alpha2 * AU) / (I2x * I2x + this.alpha2);
                        this.flowY.data[pixelIndex] = vf = (1.0f - w) * v + w * ((I1 - I2 + I2y * vi - I2x * (uf - ui)) * I2y + this.alpha2 * AV) / (I2y * I2y + this.alpha2);
                        error += (uf - u) * (uf - u) + (vf - v) * (vf - v);
                        ++x;
                        ++pixelIndex;
                    }
                }
                int pixelIndex0 = 0;
                int pixelIndex1 = (image1.height - 1) * image1.width;
                for (x = 0; x < image1.width; ++x) {
                    error += this.iterationSorSafe(image1, x, 0, pixelIndex0++);
                    error += this.iterationSorSafe(image1, x, image1.height - 1, pixelIndex1++);
                }
                pixelIndex0 = image1.width;
                pixelIndex1 = image1.width + image1.width - 1;
                for (int y = 1; y < image1.height - 1; ++y) {
                    error += this.iterationSorSafe(image1, 0, y, pixelIndex0);
                    error += this.iterationSorSafe(image1, image1.width - 1, y, pixelIndex1);
                    pixelIndex0 += image1.width;
                    pixelIndex1 += image1.width;
                }
            } while (error > this.convergeTolerance * (float)image1.width * (float)image1.height && ++iter < this.maxInnerIterations);
        }
    }

    private float iterationSorSafe(GrayF32 image1, int x, int y, int pixelIndex) {
        float vf;
        float uf;
        float w = this.SOR_RELAXATION;
        float ui = this.initFlowX.data[pixelIndex];
        float vi = this.initFlowY.data[pixelIndex];
        float u = this.flowX.data[pixelIndex];
        float v = this.flowY.data[pixelIndex];
        float I1 = image1.data[pixelIndex];
        float I2 = this.warpImage2.data[pixelIndex];
        float I2x = this.warpDeriv2X.data[pixelIndex];
        float I2y = this.warpDeriv2Y.data[pixelIndex];
        float AU = HornSchunckPyramid.A_safe(x, y, this.flowX);
        float AV = HornSchunckPyramid.A_safe(x, y, this.flowY);
        this.flowX.data[pixelIndex] = uf = (1.0f - w) * u + w * ((I1 - I2 + I2x * ui - I2y * (v - vi)) * I2x + this.alpha2 * AU) / (I2x * I2x + this.alpha2);
        this.flowY.data[pixelIndex] = vf = (1.0f - w) * v + w * ((I1 - I2 + I2y * vi - I2x * (uf - ui)) * I2y + this.alpha2 * AV) / (I2y * I2y + this.alpha2);
        return (uf - u) * (uf - u) + (vf - v) * (vf - v);
    }

    protected static float A_safe(int x, int y, GrayF32 flow) {
        float u0 = HornSchunckPyramid.safe(x - 1, y, flow);
        float u1 = HornSchunckPyramid.safe(x + 1, y, flow);
        float u2 = HornSchunckPyramid.safe(x, y - 1, flow);
        float u3 = HornSchunckPyramid.safe(x, y + 1, flow);
        float u4 = HornSchunckPyramid.safe(x - 1, y - 1, flow);
        float u5 = HornSchunckPyramid.safe(x + 1, y - 1, flow);
        float u6 = HornSchunckPyramid.safe(x - 1, y + 1, flow);
        float u7 = HornSchunckPyramid.safe(x + 1, y + 1, flow);
        return 0.16666667f * (u0 + u1 + u2 + u3) + 0.083333336f * (u4 + u5 + u6 + u7);
    }

    protected static float A(int x, int y, GrayF32 flow) {
        int index = flow.getIndex(x, y);
        float u0 = flow.data[index - 1];
        float u1 = flow.data[index + 1];
        float u2 = flow.data[index - flow.stride];
        float u3 = flow.data[index + flow.stride];
        float u4 = flow.data[index - 1 - flow.stride];
        float u5 = flow.data[index + 1 - flow.stride];
        float u6 = flow.data[index - 1 + flow.stride];
        float u7 = flow.data[index + 1 + flow.stride];
        return 0.16666667f * (u0 + u1 + u2 + u3) + 0.083333336f * (u4 + u5 + u6 + u7);
    }

    protected static float safe(int x, int y, GrayF32 image) {
        if (x < 0) {
            x = 0;
        } else if (x >= image.width) {
            x = image.width - 1;
        }
        if (y < 0) {
            y = 0;
        } else if (y >= image.height) {
            y = image.height - 1;
        }
        return image.unsafe_get(x, y);
    }

    public GrayF32 getFlowX() {
        return this.flowX;
    }

    public GrayF32 getFlowY() {
        return this.flowY;
    }
}

