Skip to content

Commit

Permalink
Move where clear pending IR is called to avoid crash
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG committed Sep 9, 2023
1 parent e51d28b commit e293596
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 15 deletions.
31 changes: 16 additions & 15 deletions torch_xla/core/dynamo_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ def extract_graph_helper(xla_model: torch.fx.GraphModule):


def extract_internal(xla_model: torch.fx.GraphModule):
xm.mark_step()
(xla_args_sharding_spec, args_and_out, graph_hash,
arg_index_to_need_update_index, none_remover, graph_input_matcher,
dumb_return_handler, xla_args_need_update) = extract_graph_helper(xla_model)
Expand Down Expand Up @@ -473,6 +474,21 @@ def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args):
collector.run(*xla_args)
fallback_ops = collector.get_fallback_ops()

# This logic, needed for supporting in-place operations, is a duplicate of
# the one in the main `extract_internal` function above. We need to do this
# check for fetching fallback ops as well.
# TODO (@wonjoo): Make this duplicate code a bit cleaner.
args_need_update_bool = torch_xla._XLAC._check_tensor_need_materialization(
all_xla_args)

# Again, same logic in the `extract_internal` above to support in-place operations.
# TODO (@wonjoo): Make this duplicate code a bit cleaner.
for i, need_update in enumerate(args_need_update_bool):
if need_update and isinstance(all_xla_args[i], torch.Tensor):
all_xla_args[i].copy_(cloned_args[i])

torch_xla._XLAC._clear_pending_irs(str(xm.xla_device()))

class XlaOperatorSupport(torch.fx.passes.operator_support.OperatorSupport):

def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
Expand All @@ -494,21 +510,6 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
partitioned_graph = partitioner.fuse_partitions(partitions)
InputCollector(partitioned_graph).run(*xla_args)

# This logic, needed for supporting in-place operations, is a duplicate of
# the one in the main `extract_internal` function above. We need to do this
# check for fetching fallback ops as well.
# TODO (@wonjoo): Make this duplicate code a bit cleaner.
args_need_update_bool = torch_xla._XLAC._check_tensor_need_materialization(
all_xla_args)

# Again, same logic in the `extract_internal` above to support in-place operations.
# TODO (@wonjoo): Make this duplicate code a bit cleaner.
for i, need_update in enumerate(args_need_update_bool):
if need_update and isinstance(all_xla_args[i], torch.Tensor):
all_xla_args[i].copy_(cloned_args[i])

torch_xla._XLAC._clear_pending_irs(str(xm.xla_device()))

# compile each submodule and replace it with a call
for node in partitioned_graph.graph.nodes:
if node.op == "call_module" and "fused_" in node.name:
Expand Down
3 changes: 3 additions & 0 deletions torch_xla/csrc/xla_graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,9 @@ void XLAGraphExecutor::ClearPendingIrs(
runtime::GetComputationClient()->CreateDataPlaceholder(
device.toString(), std::move(shape)));
tensors[i]->data()->handle = handle;
TF_VLOG(4) << "Replacing the IR " << ir_value.node.get()->ToString()
<< " of Tensor with ID " << tensors[i]->GetUniqueId()
<< " with placeholder";
}
tensors[i]->AssignIrValue(torch::lazy::Value());
tensors[i]->data()->view = nullptr;
Expand Down

0 comments on commit e293596

Please sign in to comment.