Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add missing aten op #7078

Merged
merged 9 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions experimental/torch_xla2/test/test_core_aten_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2697,6 +2697,117 @@ def test_aten_native_layer_norm_0(self):
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.native_layer_norm, args, kwargs)

def test_aten_native_batch_norm_legit(self):
batch = 3
channel = 2
args = (
torch.randn((batch,channel,2,2)).to(torch.float32),
torch.ones(channel),
torch.zeros(channel),
torch.zeros(channel),
torch.ones(channel),
False,
0.5,
1,
)
kwargs = dict()
run_export_and_compare(self, torch.ops.aten._native_batch_norm_legit, args, kwargs)

def test_aten_native_batch_norm_legit_none(self):
batch = 3
channel = 2
args = (
torch.randn((batch,channel,4,4)).to(torch.float32),
None,
None,
torch.ones(channel),
torch.zeros(channel),
False,
0.5,
1,
)
kwargs = dict()
run_export_and_compare(self, torch.ops.aten._native_batch_norm_legit, args, kwargs)

def test_aten_native_batch_norm_legit_training_none(self):
batch = 3
channel = 2
args = (
torch.randn((batch,channel,4,3)).to(torch.float32),
None,
None,
torch.zeros(channel),
torch.ones(channel),
True,
0.2,
2e-5,
)
kwargs = dict()
run_export_and_compare(self, torch.ops.aten._native_batch_norm_legit, args, kwargs)

def test_aten_native_batch_norm_legit_no_training(self):
batch = 3
channel = 2
args = (
torch.randn((batch,channel,4,3)).to(torch.float32),
torch.ones(channel),
torch.zeros(channel),
torch.zeros(channel),
torch.ones(channel),
0.2,
2e-5,
)
kwargs = dict()
run_export_and_compare(self, torch.ops.aten._native_batch_norm_legit_no_training, args, kwargs)

def test_aten_native_batch_norm_training(self):
batch = 3
channel = 2
args = (
torch.randn((batch,channel,4,3)).to(torch.float32),
torch.ones(channel),
torch.zeros(channel),
torch.zeros(channel),
torch.ones(channel),
True,
0.1,
1e-5,
)
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.native_batch_norm, args, kwargs)

def test_aten_native_batch_norm_training_none(self):
batch = 3
channel = 2
args = (
torch.randn((batch,channel,4,3)).to(torch.float32),
None,
None,
torch.zeros(channel),
torch.ones(channel),
True,
0.1,
1e-5,
)
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.native_batch_norm, args, kwargs)

def test_aten_native_batch_norm_eval(self):
batch = 3
channel = 2
args = (
torch.randn((batch,channel,4,3)).to(torch.float32),
torch.ones(channel),
torch.zeros(channel),
torch.zeros(channel),
torch.ones(channel),
False,
0.2,
2e-5,
)
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.native_batch_norm, args, kwargs)

def test_aten_ne_Scalar_0(self):
args = (
torch.randint(0, 10, (10, 10)).to(torch.int32),
Expand Down
2 changes: 0 additions & 2 deletions experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
"__getitem__",
"__rmatmul__",
"__rpow__",
"_native_batch_norm_legit",
"_segment_reduce",
"_upsample_bilinear2d_aa",
"argsort",
Expand Down Expand Up @@ -198,7 +197,6 @@
"nansum",
"narrow_copy",
"narrow",
"native_batch_norm",
"native_layer_norm",
"new_empty",
"new_empty_strided",
Expand Down
91 changes: 68 additions & 23 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
torch.ops.aten.eq_: torch.ops.aten.eq,
torch.ops.aten.ne_: torch.ops.aten.ne,
torch.ops.aten.uniform_: torch.ops.aten.uniform,
torch.ops.aten.relu_: torch.ops.aten.relu,
}


Expand Down Expand Up @@ -545,35 +546,67 @@ def create_default_conv_dimension_numbers(num_spatial_dims):
def _aten__native_batch_norm_legit(
input, weight, bias, running_mean, running_var, training, momentum, eps
):
return _aten__native_batch_norm_legit_no_training(
input, weight, bias, running_mean, running_var, momentum, eps
)
"""JAX implementation of batch normalization with optional parameters.
Refers to https://github.com/pytorch/pytorch/blob/cd3a71f754a2248bcfe500de7c9860bd7d2002bf/torch/_decomp/decompositions.py#L1713.

Args:
input (DeviceArray): Input data (N, C, H, W).
running_mean ([DeviceArray]): Running mean of input (C,).
running_var ([DeviceArray]): Running variance of input (C,).
weight (Optional[DeviceArray]): Scaling factor (gamma) (C,). Can be None.
bias (Optional[DeviceArray]): Shift factor (beta) (C,). Can be None.
training (bool): If True, use batch statistics for normalization.
If False, use running statistics.
momentum (float): Momentum factor for updating running statistics.
eps (float): Small constant for numerical stability.

Returns:
DeviceArray: Normalized output
DeviceArray: Batch mean (C,) or empty if training is False
DeviceArray: Reversed batch variance (C,) or empty if training is False
"""
reduction_dims = [0] + list(range(2, input.ndim))
reshape_dims = [1, -1] + [1]*(input.ndim-2)

if training:
# Calculate batch mean and variance
mean = jnp.mean(input, axis=reduction_dims, keepdims=True)
saved_mean = jnp.squeeze(mean, reduction_dims)
var = jnp.var(input, axis=reduction_dims)
rstd = jax.lax.rsqrt(var.reshape(reshape_dims) + eps)
# Update running statistics using momentum
running_mean = (1 - momentum) * running_mean + momentum * saved_mean
running_var = (1 - momentum) * running_var + momentum * var
saved_rstd = jnp.squeeze(rstd, reduction_dims)
else:
rstd = jax.lax.rsqrt(running_var.reshape(reshape_dims) + eps)
saved_mean = jnp.array([]) # No need to calculate batch statistics in inference mode
saved_rstd = jnp.array([])

# Normalize
if training:
# use batch statistics if training
x_hat = (input - mean) * rstd
else:
# Use running statistics in inference mode
x_hat = (input - running_mean.reshape(reshape_dims)) * rstd

# Scale and shift
if weight is not None:
x_hat *= weight.reshape(reshape_dims) # Reshape weight for broadcasting
if bias is not None:
x_hat += bias.reshape(reshape_dims) # Reshape bias for broadcasting

return x_hat, saved_mean, saved_rstd



@op(torch.ops.aten._native_batch_norm_legit_no_training)
def _aten__native_batch_norm_legit_no_training(
input, weight, bias, running_mean, running_var, momentum, eps
):
if weight is None:
weight = jnp.ones_like(running_mean)
if bias is None:
bias = jnp.zeros_like(running_mean)

def broadcast(t):
return jax.lax.broadcast_in_dim(t, input.shape, broadcast_dimensions=(1,))

if running_mean is not None:
a = input - broadcast(running_mean)
else:
a = input
if running_var is not None:
b = broadcast(jnp.sqrt(running_var + eps))
else:
b = broadcast(jnp.sqrt(eps))
return (
a / b * broadcast(weight) + broadcast(bias),
jnp.array([]),
jnp.array([]),
return _aten__native_batch_norm_legit(
input, weight, bias, running_mean, running_var, False, momentum, eps
)


Expand Down Expand Up @@ -1950,3 +1983,15 @@ def _aten_outer(a, b):
def _aten_allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False):
return jnp.allclose(input, other, rtol, atol, equal_nan)

@op(torch.ops.aten.native_batch_norm)
def _aten_native_batch_norm(input, weight, bias, running_mean, running_var, training=False, momentum=0.1, eps=1e-5):

if running_mean is None:
running_mean = jnp.zeros(input.shape[1]) # Initialize running mean if None
if running_var is None:
running_var = jnp.ones(input.shape[1]) # Initialize running variance if None

if training:
return torch.ops.aten._native_batch_norm_legit(input, weight, bias, running_mean, running_var, training, momentum, eps)
else:
return torch.ops.aten._native_batch_norm_legit_no_training(input, weight, bias, running_mean, running_var, momentum, eps)
Loading