Skip to content

Commit

Permalink
Filter tensor arguments from traced model. (pytorch#5689)
Browse files Browse the repository at this point in the history
* Filter tensor arguments from traced model.

This PR filters tensor arguments from the list of arguments that would be given to the
model.

**Problem:** dynamo bridge assumed all arguments were tensors.
**Solution:** filter tensor arguments so that we correctly collect tensor information.

* Add test.

* Fix lint issues.

* Simplified test.

* Use `openxla` instead of `openxla_eval` backend.

* Rename variables for readability.

* Use `openxla_eval` instead of `openxla`.
  • Loading branch information
ysiraichi authored and ghpvnist committed Oct 31, 2023
1 parent 38dca6f commit 407ee76
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 11 deletions.
18 changes: 18 additions & 0 deletions test/dynamo/test_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,24 @@ class TorchXLAReuseGraphTest(torch._dynamo.test_case.TestCase):
test_training_linear = make_training_test(LinearModule)
test_training_maxpool = make_training_test(MaxPoolModule)

def test_non_tensor_args_for_partition(self):

class Emb(torch.nn.Embedding):

def __init__(self):
super().__init__(num_embeddings=10, embedding_dim=10, padding_idx=0)

device = xm.xla_device()
module = Emb()
module.to(device)

@torch.compile(backend="openxla_eval")
def foo(x):
return module(x)

x = torch.randint(0, 10, (10,), device=device)
foo(x)


if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
Expand Down
29 changes: 18 additions & 11 deletions torch_xla/core/dynamo_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,16 +230,19 @@ def extract_graph_helper(xla_model: torch.fx.GraphModule):
for xla_arg in xla_args
]

args_tensor_ids = [
torch_xla._XLAC._xla_get_tensor_id(xla_arg) for xla_arg in xla_args
]
index_and_xla_tensor_args = [(i, xla_arg)
for i, xla_arg in enumerate(xla_args)
if isinstance(xla_arg, torch.Tensor)]

index_and_tensor_ids = [(index, torch_xla._XLAC._xla_get_tensor_id(xla_arg))
for index, xla_arg in index_and_xla_tensor_args]

if dynamo_debug:
print(f"Graph module:\n{xla_model.code}")
print(f"args_tensor_ids {args_tensor_ids}")
print(f"args_tensor_ids {index_and_tensor_ids}")

tensor_id_to_arg_idx = {
tensor_id: i for i, tensor_id in enumerate(args_tensor_ids)
tensor_id: index for index, tensor_id in index_and_tensor_ids
}

if xr.is_spmd():
Expand All @@ -258,15 +261,16 @@ def extract_graph_helper(xla_model: torch.fx.GraphModule):

# If a arg is being in place updated by model, we need to include arg as part of the graph result.
xla_args_need_update_bool = torch_xla._XLAC._check_tensor_need_materialization(
xla_args)
[tensor for _, tensor in index_and_xla_tensor_args])
xla_args_need_update = []
arg_index_to_need_update_index = {}
for i, need_update in enumerate(xla_args_need_update_bool):
# Don't add inplace updated argument to the list if it's already
# being returned
if need_update and id(xla_args[i]) not in xla_out_ids:
arg_index_to_need_update_index[i] = len(xla_args_need_update)
xla_args_need_update.append(xla_args[i])
index, tensor = index_and_xla_tensor_args[i]
if need_update and id(tensor) not in xla_out_ids:
arg_index_to_need_update_index[index] = len(xla_args_need_update)
xla_args_need_update.append(tensor)

args_and_out = tuple(xla_args_need_update) + tuple(xla_out)

Expand Down Expand Up @@ -325,7 +329,8 @@ def extract_graph_helper(xla_model: torch.fx.GraphModule):
def extract_internal(xla_model: torch.fx.GraphModule):
if dynamo_debug:
for xla_arg in xla_model.xla_args:
print(torch_xla._XLAC._get_xla_tensor_debug_info(xla_arg))
if isinstance(xla_arg, torch.Tensor):
print(torch_xla._XLAC._get_xla_tensor_debug_info(xla_arg))
xm.mark_step()
(xla_args_sharding_spec, args_and_out, graph_hash,
arg_index_to_need_update_index, none_remover, graph_input_matcher,
Expand All @@ -347,7 +352,9 @@ def optimized_mod(*args):

# mark_step needs to be blocking since we want to access args's XLADatas
# and they can't be placeholder.
if any(torch_xla._XLAC._check_tensor_need_materialization(args)):
if any(
torch_xla._XLAC._check_tensor_need_materialization(
[a for a in args if isinstance(a, torch.Tensor)])):
xm.mark_step(wait=True)

# If input sharding has changed from the previous program, dynamo current can
Expand Down

0 comments on commit 407ee76

Please sign in to comment.