Skip to content

Commit

Permalink
Update test_export_fx_passes.py
Browse files Browse the repository at this point in the history
pytorch/pytorch#124898 makes some changes to how `args` and `dynamic_shapes` are matched in `torch.export`, which will fail some XLA tests. This PR avoids those failures.
  • Loading branch information
avikchaudhuri authored Apr 25, 2024
1 parent abe090a commit 2fdf551
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions test/stablehlo/test_export_fx_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class ExportFxPassTest(unittest.TestCase):

def test_decompose_dynamic_shape_select(self):
args = (torch.rand((10, 197, 768)), 1, 0)
dynamic_shapes = ([{0: Dim("bs")}, None, None],)
dynamic_shapes = (({0: Dim("bs")}, None, None),)
m = wrap_func_as_nn_module(torch.ops.aten.select.int)
ep = export(m, args, dynamic_shapes=dynamic_shapes)
out1 = ep.module()(*args)
Expand Down Expand Up @@ -55,7 +55,7 @@ def forward(self, x):
def test_embedding_indices_flatten(self):
args = (torch.rand((20, 768)), torch.randint(0, 15,
(3, 10)).to(torch.int64))
dynamic_shapes = ([None, {0: Dim("bs")}],)
dynamic_shapes = ((None, {0: Dim("bs")}),)
m = wrap_func_as_nn_module(torch.ops.aten.embedding.default)
ep = export(m, args, dynamic_shapes=dynamic_shapes)
print(ep)
Expand Down

0 comments on commit 2fdf551

Please sign in to comment.