Skip to content

Commit

Permalink
Fix lint issues.
Browse files Browse the repository at this point in the history
  • Loading branch information
ysiraichi committed Oct 11, 2023
1 parent c4c701b commit 8b00209
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 9 deletions.
4 changes: 4 additions & 0 deletions test/dynamo/test_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,11 +208,14 @@ class TorchXLAReuseGraphTest(torch._dynamo.test_case.TestCase):
test_training_maxpool = make_training_test(MaxPoolModule)

def test_non_tensor_args_for_partition(self):

class Emb(torch.nn.Embedding):

def __init__(self):
super().__init__(num_embeddings=10, embedding_dim=10, padding_idx=0)

class Main(torch.nn.Module):

def __init__(self):
super().__init__()
self.embedding = Emb()
Expand All @@ -231,6 +234,7 @@ def foo(x):
x = torch.randint(0, 10, (10,), device=device)
foo(x)


if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

Expand Down
18 changes: 9 additions & 9 deletions torch_xla/core/dynamo_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,13 +230,12 @@ def extract_graph_helper(xla_model: torch.fx.GraphModule):
for xla_arg in xla_args
]

xla_tensor_args = [
(i, xla_arg) for i, xla_arg in enumerate(xla_args) if isinstance(xla_arg, torch.Tensor)
]
xla_tensor_args = [(i, xla_arg)
for i, xla_arg in enumerate(xla_args)
if isinstance(xla_arg, torch.Tensor)]

args_tensor_ids = [
(index, torch_xla._XLAC._xla_get_tensor_id(xla_arg)) for index, xla_arg in xla_tensor_args
]
args_tensor_ids = [(index, torch_xla._XLAC._xla_get_tensor_id(xla_arg))
for index, xla_arg in xla_tensor_args]

if dynamo_debug:
print(f"Graph module:\n{xla_model.code}")
Expand All @@ -262,8 +261,7 @@ def extract_graph_helper(xla_model: torch.fx.GraphModule):

# If a arg is being in place updated by model, we need to include arg as part of the graph result.
xla_args_need_update_bool = torch_xla._XLAC._check_tensor_need_materialization(
[tensor for _, tensor in xla_tensor_args]
)
[tensor for _, tensor in xla_tensor_args])
xla_args_need_update = []
arg_index_to_need_update_index = {}
for i, need_update in enumerate(xla_args_need_update_bool):
Expand Down Expand Up @@ -354,7 +352,9 @@ def optimized_mod(*args):

# mark_step needs to be blocking since we want to access args's XLADatas
# and they can't be placeholder.
if any(torch_xla._XLAC._check_tensor_need_materialization([a for a in args if isinstance(a, torch.Tensor)])):
if any(
torch_xla._XLAC._check_tensor_need_materialization(
[a for a in args if isinstance(a, torch.Tensor)])):
xm.mark_step(wait=True)

# If input sharding has changed from the previous program, dynamo current can
Expand Down

0 comments on commit 8b00209

Please sign in to comment.