From cef02c13f637c543fa2cb60807f17b1ee0317b06 Mon Sep 17 00:00:00 2001 From: zpcore Date: Fri, 17 May 2024 22:25:10 +0000 Subject: [PATCH] fix lax dependency --- experimental/torch_xla2/torch_xla2/ops/jaten.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index e0acf954d88..6dd94c8e42d 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -1985,7 +1985,7 @@ def _aten_native_batch_norm(input, weight, bias, running_mean, running_var, trai # Normalize xmu = input - mean.reshape(1, -1, 1, 1) # Broadcast mean across batch - ivar = lax.rsqrt(var + eps).reshape(1, -1, 1, 1) # Reciprocal of square root + ivar = jax.lax.rsqrt(var + eps).reshape(1, -1, 1, 1) # Reciprocal of square root # Scale and shift out = xmu * ivar