Skip to content

Commit

Permalink
Add test.
Browse files Browse the repository at this point in the history
  • Loading branch information
ysiraichi committed Oct 11, 2023
1 parent f67a922 commit c4c701b
Showing 1 changed file with 23 additions and 0 deletions.
23 changes: 23 additions & 0 deletions test/dynamo/test_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,29 @@ class TorchXLAReuseGraphTest(torch._dynamo.test_case.TestCase):
test_training_linear = make_training_test(LinearModule)
test_training_maxpool = make_training_test(MaxPoolModule)

def test_non_tensor_args_for_partition(self):
class Emb(torch.nn.Embedding):
def __init__(self):
super().__init__(num_embeddings=10, embedding_dim=10, padding_idx=0)

class Main(torch.nn.Module):
def __init__(self):
super().__init__()
self.embedding = Emb()

def forward(self, x):
return self.embedding(x)

device = xm.xla_device()
module = Main()
module.to(device)

@torch.compile(backend="openxla_eval")
def foo(x):
return module(x)

x = torch.randint(0, 10, (10,), device=device)
foo(x)

if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
Expand Down

0 comments on commit c4c701b

Please sign in to comment.