diff --git a/test/dynamo/test_bridge.py b/test/dynamo/test_bridge.py index 0d65e2ce430..3bc9772dfd2 100644 --- a/test/dynamo/test_bridge.py +++ b/test/dynamo/test_bridge.py @@ -208,11 +208,14 @@ class TorchXLAReuseGraphTest(torch._dynamo.test_case.TestCase): test_training_maxpool = make_training_test(MaxPoolModule) def test_non_tensor_args_for_partition(self): + class Emb(torch.nn.Embedding): + def __init__(self): super().__init__(num_embeddings=10, embedding_dim=10, padding_idx=0) class Main(torch.nn.Module): + def __init__(self): super().__init__() self.embedding = Emb() @@ -231,6 +234,7 @@ def foo(x): x = torch.randint(0, 10, (10,), device=device) foo(x) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch_xla/core/dynamo_bridge.py b/torch_xla/core/dynamo_bridge.py index 2eacfc684f2..f29547d7a32 100644 --- a/torch_xla/core/dynamo_bridge.py +++ b/torch_xla/core/dynamo_bridge.py @@ -230,13 +230,12 @@ def extract_graph_helper(xla_model: torch.fx.GraphModule): for xla_arg in xla_args ] - xla_tensor_args = [ - (i, xla_arg) for i, xla_arg in enumerate(xla_args) if isinstance(xla_arg, torch.Tensor) - ] + xla_tensor_args = [(i, xla_arg) + for i, xla_arg in enumerate(xla_args) + if isinstance(xla_arg, torch.Tensor)] - args_tensor_ids = [ - (index, torch_xla._XLAC._xla_get_tensor_id(xla_arg)) for index, xla_arg in xla_tensor_args - ] + args_tensor_ids = [(index, torch_xla._XLAC._xla_get_tensor_id(xla_arg)) + for index, xla_arg in xla_tensor_args] if dynamo_debug: print(f"Graph module:\n{xla_model.code}") @@ -262,8 +261,7 @@ def extract_graph_helper(xla_model: torch.fx.GraphModule): # If a arg is being in place updated by model, we need to include arg as part of the graph result. xla_args_need_update_bool = torch_xla._XLAC._check_tensor_need_materialization( - [tensor for _, tensor in xla_tensor_args] - ) + [tensor for _, tensor in xla_tensor_args]) xla_args_need_update = [] arg_index_to_need_update_index = {} for i, need_update in enumerate(xla_args_need_update_bool): @@ -354,7 +352,9 @@ def optimized_mod(*args): # mark_step needs to be blocking since we want to access args's XLADatas # and they can't be placeholder. - if any(torch_xla._XLAC._check_tensor_need_materialization([a for a in args if isinstance(a, torch.Tensor)])): + if any( + torch_xla._XLAC._check_tensor_need_materialization( + [a for a in args if isinstance(a, torch.Tensor)])): xm.mark_step(wait=True) # If input sharding has changed from the previous program, dynamo current can