diff --git a/test/dynamo/test_bridge.py b/test/dynamo/test_bridge.py index 778d77591e4..0d65e2ce430 100644 --- a/test/dynamo/test_bridge.py +++ b/test/dynamo/test_bridge.py @@ -207,6 +207,29 @@ class TorchXLAReuseGraphTest(torch._dynamo.test_case.TestCase): test_training_linear = make_training_test(LinearModule) test_training_maxpool = make_training_test(MaxPoolModule) + def test_non_tensor_args_for_partition(self): + class Emb(torch.nn.Embedding): + def __init__(self): + super().__init__(num_embeddings=10, embedding_dim=10, padding_idx=0) + + class Main(torch.nn.Module): + def __init__(self): + super().__init__() + self.embedding = Emb() + + def forward(self, x): + return self.embedding(x) + + device = xm.xla_device() + module = Main() + module.to(device) + + @torch.compile(backend="openxla_eval") + def foo(x): + return module(x) + + x = torch.randint(0, 10, (10,), device=device) + foo(x) if __name__ == "__main__": from torch._dynamo.test_case import run_tests