package deepboof.impl.backward.standard;

import deepboof.DeepBoofConstants;
import deepboof.backward.DBatchNorm;
import deepboof.misc.TensorOps;
import deepboof.tensors.Tensor_F64;
import java.util.List;

/* loaded from: classes3.dex */
public abstract class BaseDBatchNorm_F64 extends BaseDFunction<Tensor_F64> implements DBatchNorm<Tensor_F64> {
    protected int D;
    protected boolean requiresGammaBeta;
    protected int[] shapeVariables;
    protected Tensor_F64 tensorMean = new Tensor_F64();
    protected Tensor_F64 tensorStd = new Tensor_F64();
    protected Tensor_F64 tensorXhat = new Tensor_F64();
    protected Tensor_F64 tensorDVar = new Tensor_F64();
    protected Tensor_F64 tensorDMean = new Tensor_F64();
    protected Tensor_F64 tensorDXhat = new Tensor_F64();
    protected Tensor_F64 tensorDiffX = new Tensor_F64();
    protected Tensor_F64 tensorTmp = new Tensor_F64();
    protected Tensor_F64 params = new Tensor_F64(0);
    protected double EPS = DeepBoofConstants.TEST_TOL_F64 * 0.1d;

    public BaseDBatchNorm_F64(boolean z) {
        this.requiresGammaBeta = z;
    }

    @Override // deepboof.impl.forward.standard.BaseFunction
    public void _initialize() {
        this.shapeVariables = createShapeVariables(this.shapeInput);
        this.tensorMean.reshape(this.shapeVariables);
        this.tensorStd.reshape(this.shapeVariables);
        this.tensorDVar.reshape(this.shapeVariables);
        this.tensorDMean.reshape(this.shapeVariables);
        this.tensorTmp.reshape(this.shapeVariables);
        this.shapeOutput = (int[]) this.shapeInput.clone();
        if (this.requiresGammaBeta) {
            int[] WI = TensorOps.WI(this.shapeVariables, 2);
            this.shapeParameters.add(WI);
            this.params.reshape(WI);
        }
        this.D = TensorOps.tensorLength(this.shapeVariables);
    }

    @Override // deepboof.impl.forward.standard.BaseFunction
    public void _setParameters(List<Tensor_F64> list) {
        if (this.requiresGammaBeta) {
            this.params.setTo(list.get(0));
        } else if (list.size() != 0) {
            throw new IllegalArgumentException("There are no parameters since gamma and beta have been turned off");
        }
    }

    protected abstract int[] createShapeVariables(int[] iArr);

    @Override // deepboof.forward.BatchNorm
    public double getEPS() {
        return this.EPS;
    }

    @Override // deepboof.backward.DBatchNorm
    public Tensor_F64 getMean(Tensor_F64 tensor_F64) {
        if (tensor_F64 == null) {
            tensor_F64 = this.tensorMean.createLike();
        }
        tensor_F64.setTo(this.tensorMean);
        return tensor_F64;
    }

    @Override // deepboof.Function
    public Class<Tensor_F64> getTensorType() {
        return Tensor_F64.class;
    }

    @Override // deepboof.backward.DBatchNorm
    public Tensor_F64 getVariance(Tensor_F64 tensor_F64) {
        if (tensor_F64 == null) {
            tensor_F64 = this.tensorStd.createLike();
        }
        tensor_F64.reshape(this.tensorStd.getShape());
        int i = tensor_F64.startIndex;
        int length = this.tensorStd.length();
        int i2 = 0;
        int i3 = i;
        int i4 = 0;
        while (i2 < length) {
            double d = this.tensorStd.d[i4];
            tensor_F64.d[i3] = (d * d) - this.EPS;
            i2++;
            i3++;
            i4++;
        }
        return tensor_F64;
    }

    @Override // deepboof.forward.BatchNorm
    public boolean hasGammaBeta() {
        return this.requiresGammaBeta;
    }

    @Override // deepboof.forward.BatchNorm
    public void setEPS(double d) {
        this.EPS = d;
    }
}
