From 6af550aa69a24f5361c881f8d8d0792fe5a8bf09 Mon Sep 17 00:00:00 2001 From: David Huang Date: Mon, 23 Sep 2024 09:43:53 -0700 Subject: [PATCH] [torch_xla2] Fix nn.functional.conv2d and conv3d (#8048) --- experimental/torch_xla2/test/test_ops.py | 2 -- experimental/torch_xla2/torch_xla2/ops/jaten.py | 7 +++++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 0e30d181c0d..8ddccccdcaf 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -93,8 +93,6 @@ "nn.functional.avg_pool2d", "nn.functional.avg_pool3d", "nn.functional.bilinear", - "nn.functional.conv2d", - "nn.functional.conv3d", "nn.functional.conv_transpose1d", "nn.functional.conv_transpose2d", "nn.functional.conv_transpose3d", diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 291587329bb..54adbd30e65 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -799,7 +799,10 @@ def _aten_convolution( if transposed: raise NotImplementedError("Transposed convolution is not implemented.") - def make_padding(padding): + def make_padding(padding, num_spatial_dims): + # Expand single padding to pairs expected by jax + if len(padding) == 1 and len(padding) < num_spatial_dims: + padding *= num_spatial_dims return ((p, p) for p in padding) def create_default_conv_dimension_numbers(num_spatial_dims): @@ -822,7 +825,7 @@ def create_default_conv_dimension_numbers(num_spatial_dims): input, weight, stride, - make_padding(padding), + make_padding(padding, len(stride)), lhs_dilation=(1,) * len(stride), rhs_dilation=dilation, dimension_numbers=create_default_conv_dimension_numbers(len(stride)),