From 74c47e8bb2b7a6f2c2ce705b682a9036a61d00b7 Mon Sep 17 00:00:00 2001 From: zpcore Date: Fri, 17 May 2024 22:10:35 +0000 Subject: [PATCH] add missing aten op --- .../torch_xla2/torch_xla2/ops/jaten.py | 45 +++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 3c92e5e290d4..3b1cd3b1800c 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -30,6 +30,7 @@ torch.ops.aten.eq_: torch.ops.aten.eq, torch.ops.aten.ne_: torch.ops.aten.ne, torch.ops.aten.uniform_: torch.ops.aten.uniform, + torch.ops.aten.relu_: torch.ops.aten.relu, } @@ -1948,3 +1949,47 @@ def _aten_outer(a, b): def _aten_allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False): return jnp.allclose(input, other, rtol, atol, equal_nan) +@op(torch.ops.aten.native_batch_norm) +def _aten_native_batch_norm(input, weight, bias, running_mean, running_var, training=False, momentum=0.1, eps=1e-5): + """JAX implementation of batch normalization. + + Args: + input: Input data (N, C, ...) + running_mean: Running mean of input (C,) + running_var: Running variance of input (C,) + weight: Optional scaling factor (gamma) (C,) + bias: Optional shift factor (beta) (C,) + training: Whether to perform training-time or inference-time batch norm + momentum: Momentum factor for updating running mean and variance + eps: Small constant added to the variance to avoid division by zero + + Returns: + Output data, updated running mean, updated running var + """ + + if training: + # Training-time batch norm: compute statistics across the batch + mean = jnp.mean(input, axis=(0, 2, 3)) + var = jnp.var(input, axis=(0, 2, 3)) + + # Update running statistics + running_mean = momentum * mean + (1 - momentum) * running_mean + running_var = momentum * var + (1 - momentum) * running_var + + else: + # Inference-time batch norm: use pre-computed running statistics + mean = running_mean + var = running_var + + # Normalize + xmu = input - mean.reshape(1, -1, 1, 1) # Broadcast mean across batch + ivar = lax.rsqrt(var + eps).reshape(1, -1, 1, 1) # Reciprocal of square root + + # Scale and shift + out = xmu * ivar + if weight is not None: + out *= weight.reshape(1, -1, 1, 1) + if bias is not None: + out += bias.reshape(1, -1, 1, 1) + + return out, running_mean, running_var \ No newline at end of file