diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index e0acf954d88..6dd94c8e42d 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -1985,7 +1985,7 @@ def _aten_native_batch_norm(input, weight, bias, running_mean, running_var, trai # Normalize xmu = input - mean.reshape(1, -1, 1, 1) # Broadcast mean across batch - ivar = lax.rsqrt(var + eps).reshape(1, -1, 1, 1) # Reciprocal of square root + ivar = jax.lax.rsqrt(var + eps).reshape(1, -1, 1, 1) # Reciprocal of square root # Scale and shift out = xmu * ivar