Skip to content

Commit

Permalink
fix core aten ops
Browse files Browse the repository at this point in the history
  • Loading branch information
zpcore committed May 22, 2024
1 parent e60a646 commit 7d0495e
Show file tree
Hide file tree
Showing 3 changed files with 177 additions and 64 deletions.
114 changes: 114 additions & 0 deletions experimental/torch_xla2/test/test_core_aten_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ def run_export_and_compare(testcase,
with testcase.env:
res2 = func(*args2, **kwargs2)
res2 = pytree.tree_map_only(tensor.XLATensor2, lambda t: t.torch(), res2)
print(res)
print(res2)
# import pdb; pdb.set_trace()
with testcase.subTest("torch_xla2_diff:" + str(atol)):
if ignore_indices and isinstance(res, tuple) and len(res) == 2:
Expand Down Expand Up @@ -2697,6 +2699,118 @@ def test_aten_native_layer_norm_0(self):
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.native_layer_norm, args, kwargs)

def test_aten_native_batch_norm_legit(self):
batch = 3
channel = 2
args = (
torch.randn((batch,channel,2,2)).to(torch.float32),
torch.ones(channel),
torch.zeros(channel),
torch.zeros(channel),
torch.ones(channel),
False,
0.5,
1,
)
kwargs = dict()
run_export_and_compare(self, torch.ops.aten._native_batch_norm_legit, args, kwargs)

def test_aten_native_batch_norm_legit_none(self):
batch = 3
channel = 2
args = (
torch.randn((batch,channel,4,4)).to(torch.float32),
None,
None,
torch.ones(channel),
torch.zeros(channel),
False,
0.5,
1,
)
kwargs = dict()
run_export_and_compare(self, torch.ops.aten._native_batch_norm_legit, args, kwargs)

def test_aten_native_batch_norm_legit_training_none(self):
batch = 3
channel = 2
args = (
torch.randn((batch,channel,4,3)).to(torch.float32),
None,
None,
torch.zeros(channel),
torch.ones(channel),
True,
0.2,
2e-5,
)
kwargs = dict()
run_export_and_compare(self, torch.ops.aten._native_batch_norm_legit, args, kwargs)

def test_aten_native_batch_norm_legit_no_training(self):
batch = 3
channel = 2
args = (
torch.randn((batch,channel,4,3)).to(torch.float32),
torch.ones(channel),
torch.zeros(channel),
torch.zeros(channel),
torch.ones(channel),
True,
0.2,
2e-5,
)
kwargs = dict()
run_export_and_compare(self, torch.ops.aten._native_batch_norm_legit_no_training, args, kwargs)

def test_aten_native_batch_norm_training(self):
batch = 3
channel = 2
args = (
torch.randn((batch,channel,4,3)).to(torch.float32),
torch.ones(channel),
torch.zeros(channel),
torch.zeros(channel),
torch.ones(channel),
True,
0.1,
1e-5,
)
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.native_batch_norm, args, kwargs)

def test_aten_native_batch_norm_training_none(self):
batch = 3
channel = 2
args = (
torch.randn((batch,channel,4,3)).to(torch.float32),
None,
None,
torch.zeros(channel),
torch.ones(channel),
True,
0.1,
1e-5,
)
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.native_batch_norm, args, kwargs)

def test_aten_native_batch_norm_eval(self):
batch = 3
channel = 2
args = (
torch.randn((batch,channel,4,3)).to(torch.float32),
torch.ones(channel),
torch.zeros(channel),
torch.zeros(channel),
torch.ones(channel),
False,
0.2,
2e-5,
)
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.native_batch_norm, args, kwargs)

def test_aten_ne_Scalar_0(self):
args = (
torch.randint(0, 10, (10, 10)).to(torch.int32),
Expand Down
1 change: 0 additions & 1 deletion experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
"__getitem__",
"__rmatmul__",
"__rpow__",
"_native_batch_norm_legit",
"_segment_reduce",
"_upsample_bilinear2d_aa",
"argsort",
Expand Down
126 changes: 63 additions & 63 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,35 +546,67 @@ def create_default_conv_dimension_numbers(num_spatial_dims):
def _aten__native_batch_norm_legit(
input, weight, bias, running_mean, running_var, training, momentum, eps
):
return _aten__native_batch_norm_legit_no_training(
input, weight, bias, running_mean, running_var, momentum, eps
)
"""JAX implementation of batch normalization with optional parameters.
Refers to https://github.com/pytorch/pytorch/blob/cd3a71f754a2248bcfe500de7c9860bd7d2002bf/torch/_decomp/decompositions.py#L1713.
Args:
input (DeviceArray): Input data (N, C, H, W).
running_mean ([DeviceArray]): Running mean of input (C,).
running_var ([DeviceArray]): Running variance of input (C,).
weight (Optional[DeviceArray]): Scaling factor (gamma) (C,). Can be None.
bias (Optional[DeviceArray]): Shift factor (beta) (C,). Can be None.
training (bool): If True, use batch statistics for normalization.
If False, use running statistics.
momentum (float): Momentum factor for updating running statistics.
eps (float): Small constant for numerical stability.
Returns:
DeviceArray: Normalized output
DeviceArray: Batch mean (C,) or empty if training is False
DeviceArray: Reversed batch variance (C,) or empty if training is False
"""
reduction_dims = [0] + list(range(2, input.ndim))
reshape_dims = [1, -1] + [1]*(input.ndim-2)

if training:
# Calculate batch mean and variance
mean = jnp.mean(input, axis=reduction_dims, keepdims=True)
saved_mean = jnp.squeeze(mean, reduction_dims)
var = jnp.var(input, axis=reduction_dims)
rstd = jax.lax.rsqrt(var.reshape(reshape_dims) + eps)
# Update running statistics using momentum
running_mean = (1 - momentum) * running_mean + momentum * saved_mean
running_var = (1 - momentum) * running_var + momentum * var
saved_rstd = jnp.squeeze(rstd, reduction_dims)
else:
rstd = jax.lax.rsqrt(running_var.reshape(reshape_dims) + eps)
saved_mean = jnp.array([]) # No need to calculate batch statistics in inference mode
saved_rstd = jnp.array([])

# Normalize
if training:
# use batch statistics if training
x_hat = (input - mean) * rstd
else:
# Use running statistics in inference mode
x_hat = (input - running_mean.reshape(reshape_dims)) * rstd

# Scale and shift
if weight is not None:
x_hat *= weight.reshape(reshape_dims) # Reshape weight for broadcasting
if bias is not None:
x_hat += bias.reshape(reshape_dims) # Reshape bias for broadcasting

return x_hat, saved_mean, saved_rstd



@op(torch.ops.aten._native_batch_norm_legit_no_training)
def _aten__native_batch_norm_legit_no_training(
input, weight, bias, running_mean, running_var, momentum, eps
):
if weight is None:
weight = jnp.ones_like(running_mean)
if bias is None:
bias = jnp.zeros_like(running_mean)

def broadcast(t):
return jax.lax.broadcast_in_dim(t, input.shape, broadcast_dimensions=(1,))

if running_mean is not None:
a = input - broadcast(running_mean)
else:
a = input
if running_var is not None:
b = broadcast(jnp.sqrt(running_var + eps))
else:
b = broadcast(jnp.sqrt(eps))
return (
a / b * broadcast(weight) + broadcast(bias),
jnp.array([]),
jnp.array([]),
return _aten__native_batch_norm_legit(
input, weight, bias, running_mean, running_var, False, momentum, eps
)


Expand Down Expand Up @@ -1953,45 +1985,13 @@ def _aten_allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False):

@op(torch.ops.aten.native_batch_norm)
def _aten_native_batch_norm(input, weight, bias, running_mean, running_var, training=False, momentum=0.1, eps=1e-5):
"""JAX implementation of batch normalization.
Args:
input: Input data (N, C, ...)
weight: Scaling factor (gamma) (C,), can be None
bias: Shift factor (beta) (C,), can be None
running_mean: Running mean of input (C,)
running_var: Running variance of input (C,)
training: Whether to perform training-time or inference-time batch norm
momentum: Momentum factor for updating running mean and variance
eps: Small constant added to the variance to avoid division by zero
Returns:
Output data, updated running mean, updated running var
"""


if running_mean is None:
running_mean = jnp.zeros(input.shape[1]) # Initialize running mean if None
if running_var is None:
running_var = jnp.ones(input.shape[1]) # Initialize running variance if None

if training:
# Training-time batch norm: compute statistics across the batch
mean = jnp.mean(input, axis=(0, 2, 3))
var = jnp.var(input, axis=(0, 2, 3))

# Update running statistics
running_mean = momentum * mean + (1 - momentum) * running_mean
running_var = momentum * var + (1 - momentum) * running_var

return torch.ops.aten._native_batch_norm_legit(input, weight, bias, running_mean, running_var, training, momentum, eps)
else:
# Inference-time batch norm: use pre-computed running statistics
mean = running_mean
var = running_var

# Normalize
xmu = input - mean.reshape(1, -1, 1, 1) # Broadcast mean across batch
ivar = jax.lax.rsqrt(var + eps).reshape(1, -1, 1, 1) # Reciprocal of square root

# Scale and shift
out = xmu * ivar
if weight is not None:
out *= weight.reshape(1, -1, 1, 1)
if bias is not None:
out += bias.reshape(1, -1, 1, 1)

return out, running_mean, running_var
return torch.ops.aten._native_batch_norm_legit_no_training(input, weight, bias, running_mean, running_var, momentum, eps)

0 comments on commit 7d0495e

Please sign in to comment.