Skip to content

Commit

Permalink
Filter tensor arguments from traced model.
Browse files Browse the repository at this point in the history
This PR filters tensor arguments from the list of arguments that would be given to the
model.

**Problem:** dynamo bridge assumed all arguments were tensors.
**Solution:** filter tensor arguments so that we correctly collect tensor information.
  • Loading branch information
ysiraichi committed Oct 9, 2023
1 parent c9a1324 commit f67a922
Showing 1 changed file with 15 additions and 8 deletions.
23 changes: 15 additions & 8 deletions torch_xla/core/dynamo_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,16 +230,20 @@ def extract_graph_helper(xla_model: torch.fx.GraphModule):
for xla_arg in xla_args
]

xla_tensor_args = [
(i, xla_arg) for i, xla_arg in enumerate(xla_args) if isinstance(xla_arg, torch.Tensor)
]

args_tensor_ids = [
torch_xla._XLAC._xla_get_tensor_id(xla_arg) for xla_arg in xla_args
(index, torch_xla._XLAC._xla_get_tensor_id(xla_arg)) for index, xla_arg in xla_tensor_args
]

if dynamo_debug:
print(f"Graph module:\n{xla_model.code}")
print(f"args_tensor_ids {args_tensor_ids}")

tensor_id_to_arg_idx = {
tensor_id: i for i, tensor_id in enumerate(args_tensor_ids)
tensor_id: index for index, tensor_id in args_tensor_ids
}

if xr.is_spmd():
Expand All @@ -258,15 +262,17 @@ def extract_graph_helper(xla_model: torch.fx.GraphModule):

# If a arg is being in place updated by model, we need to include arg as part of the graph result.
xla_args_need_update_bool = torch_xla._XLAC._check_tensor_need_materialization(
xla_args)
[tensor for _, tensor in xla_tensor_args]
)
xla_args_need_update = []
arg_index_to_need_update_index = {}
for i, need_update in enumerate(xla_args_need_update_bool):
# Don't add inplace updated argument to the list if it's already
# being returned
if need_update and id(xla_args[i]) not in xla_out_ids:
arg_index_to_need_update_index[i] = len(xla_args_need_update)
xla_args_need_update.append(xla_args[i])
index, tensor = xla_tensor_args[i]
if need_update and id(tensor) not in xla_out_ids:
arg_index_to_need_update_index[index] = len(xla_args_need_update)
xla_args_need_update.append(tensor)

args_and_out = tuple(xla_args_need_update) + tuple(xla_out)

Expand Down Expand Up @@ -325,7 +331,8 @@ def extract_graph_helper(xla_model: torch.fx.GraphModule):
def extract_internal(xla_model: torch.fx.GraphModule):
if dynamo_debug:
for xla_arg in xla_model.xla_args:
print(torch_xla._XLAC._get_xla_tensor_debug_info(xla_arg))
if isinstance(xla_arg, torch.Tensor):
print(torch_xla._XLAC._get_xla_tensor_debug_info(xla_arg))
xm.mark_step()
(xla_args_sharding_spec, args_and_out, graph_hash,
arg_index_to_need_update_index, none_remover, graph_input_matcher,
Expand All @@ -347,7 +354,7 @@ def optimized_mod(*args):

# mark_step needs to be blocking since we want to access args's XLADatas
# and they can't be placeholder.
if any(torch_xla._XLAC._check_tensor_need_materialization(args)):
if any(torch_xla._XLAC._check_tensor_need_materialization([a for a in args if isinstance(a, torch.Tensor)])):
xm.mark_step(wait=True)

# If input sharding has changed from the previous program, dynamo current can
Expand Down

0 comments on commit f67a922

Please sign in to comment.