Skip to content

Commit

Permalink
add missing aten op
Browse files Browse the repository at this point in the history
  • Loading branch information
zpcore committed May 17, 2024
1 parent aeed89e commit 74c47e8
Showing 1 changed file with 45 additions and 0 deletions.
45 changes: 45 additions & 0 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
torch.ops.aten.eq_: torch.ops.aten.eq,
torch.ops.aten.ne_: torch.ops.aten.ne,
torch.ops.aten.uniform_: torch.ops.aten.uniform,
torch.ops.aten.relu_: torch.ops.aten.relu,
}


Expand Down Expand Up @@ -1948,3 +1949,47 @@ def _aten_outer(a, b):
def _aten_allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False):
return jnp.allclose(input, other, rtol, atol, equal_nan)

@op(torch.ops.aten.native_batch_norm)
def _aten_native_batch_norm(input, weight, bias, running_mean, running_var, training=False, momentum=0.1, eps=1e-5):
"""JAX implementation of batch normalization.
Args:
input: Input data (N, C, ...)
running_mean: Running mean of input (C,)
running_var: Running variance of input (C,)
weight: Optional scaling factor (gamma) (C,)
bias: Optional shift factor (beta) (C,)
training: Whether to perform training-time or inference-time batch norm
momentum: Momentum factor for updating running mean and variance
eps: Small constant added to the variance to avoid division by zero
Returns:
Output data, updated running mean, updated running var
"""

if training:
# Training-time batch norm: compute statistics across the batch
mean = jnp.mean(input, axis=(0, 2, 3))
var = jnp.var(input, axis=(0, 2, 3))

# Update running statistics
running_mean = momentum * mean + (1 - momentum) * running_mean
running_var = momentum * var + (1 - momentum) * running_var

else:
# Inference-time batch norm: use pre-computed running statistics
mean = running_mean
var = running_var

# Normalize
xmu = input - mean.reshape(1, -1, 1, 1) # Broadcast mean across batch
ivar = lax.rsqrt(var + eps).reshape(1, -1, 1, 1) # Reciprocal of square root

# Scale and shift
out = xmu * ivar
if weight is not None:
out *= weight.reshape(1, -1, 1, 1)
if bias is not None:
out += bias.reshape(1, -1, 1, 1)

return out, running_mean, running_var

0 comments on commit 74c47e8

Please sign in to comment.