diff --git a/test/stablehlo/test_export_fx_passes.py b/test/stablehlo/test_export_fx_passes.py index d1e731abd6e..82650997316 100644 --- a/test/stablehlo/test_export_fx_passes.py +++ b/test/stablehlo/test_export_fx_passes.py @@ -18,7 +18,7 @@ class ExportFxPassTest(unittest.TestCase): def test_decompose_dynamic_shape_select(self): args = (torch.rand((10, 197, 768)), 1, 0) - dynamic_shapes = ([{0: Dim("bs")}, None, None],) + dynamic_shapes = (({0: Dim("bs")}, None, None),) m = wrap_func_as_nn_module(torch.ops.aten.select.int) ep = export(m, args, dynamic_shapes=dynamic_shapes) out1 = ep.module()(*args) @@ -55,7 +55,7 @@ def forward(self, x): def test_embedding_indices_flatten(self): args = (torch.rand((20, 768)), torch.randint(0, 15, (3, 10)).to(torch.int64)) - dynamic_shapes = ([None, {0: Dim("bs")}],) + dynamic_shapes = ((None, {0: Dim("bs")}),) m = wrap_func_as_nn_module(torch.ops.aten.embedding.default) ep = export(m, args, dynamic_shapes=dynamic_shapes) print(ep)