Skip to content

Commit

Permalink
Move where clear pending IR is called to avoid crash (#5552)
Browse files Browse the repository at this point in the history
* Move where clear pending IR is called to avoid crash

* fix CI

* fix CI and add some debugging messages
  • Loading branch information
JackCaoG committed Sep 15, 2023
1 parent c55ef0b commit ef34a31
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 21 deletions.
11 changes: 6 additions & 5 deletions test/dynamo/test_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,21 +273,21 @@ def fn_fallback(t):
xla_dynamo_res = dynamo_fn(t_xla)
self.assertTrue(torch.allclose(cpu_res, xla_dynamo_res.cpu()))
self.assertEqual(met.metric_data('CompileTime')[0], 3)
self.assertEqual(met.metric_data('ExecuteTime')[0], 10)
self.assertEqual(met.metric_data('ExecuteTime')[0], 11)

# Second tracing
met.clear_counters()
xla_dynamo_res_2 = dynamo_fn(t_xla)
self.assertTrue(torch.allclose(cpu_res, xla_dynamo_res_2.cpu()))
self.assertEqual(met.metric_data('CompileTime')[0], 3)
self.assertEqual(met.metric_data('ExecuteTime')[0], 12)
self.assertEqual(met.metric_data('ExecuteTime')[0], 13)

# Verify that dynamo can handle different inputs
xla_dynamo_res_3 = dynamo_fn(t_xla * 3)
cpu_res_3 = fn_fallback(t * 3)
self.assertTrue(torch.allclose(cpu_res_3, xla_dynamo_res_3.cpu()))
self.assertEqual(met.metric_data('CompileTime')[0], 4)
self.assertEqual(met.metric_data('ExecuteTime')[0], 15)
self.assertEqual(met.metric_data('ExecuteTime')[0], 16)


class DynamoTrainingBasicTest(unittest.TestCase):
Expand Down Expand Up @@ -539,9 +539,10 @@ def test_all_cpu_tensor(self):
# there should be 18 paramters + 1 input
self.assertGreater(len(w), 15)
self.assertIn('Found tensor with shape torch.Size', str(w[0].message))
# no XLA operation should happens. Partitioner should offload all CPU
# no XLA operation should happens except a empty mark_step. Partitioner should offload all CPU
# ops to CPU.
self.assertEqual(len(met.counter_names()), 0)
self.assertEqual(len(met.counter_names()), 1)
self.assertIn('MarkStep', met.counter_names())


if __name__ == '__main__':
Expand Down
38 changes: 22 additions & 16 deletions torch_xla/core/dynamo_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torch_xla.runtime as xr
import torch_xla.utils.utils as xu

debug = os.environ.get("TORCH_XLA_DEBUG") == "1"
debug = os.environ.get("XLA_DYNAMO_DEBUG") == "1"


@dataclasses.dataclass
Expand Down Expand Up @@ -322,6 +322,10 @@ def extract_graph_helper(xla_model: torch.fx.GraphModule):


def extract_internal(xla_model: torch.fx.GraphModule):
if debug:
for xla_arg in xla_model.xla_args:
print(torch_xla._XLAC._get_xla_tensor_debug_info(xla_arg))
xm.mark_step()
(xla_args_sharding_spec, args_and_out, graph_hash,
arg_index_to_need_update_index, none_remover, graph_input_matcher,
dumb_return_handler, xla_args_need_update) = extract_graph_helper(xla_model)
Expand Down Expand Up @@ -471,6 +475,23 @@ def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args):
collector = FallBackNodeCollector(xla_model)
collector.run(*xla_args)
fallback_ops = collector.get_fallback_ops()
if debug and len(fallback_ops) > 0:
print('fallback ops are' + str(fallback_ops))

# This logic, needed for supporting in-place operations, is a duplicate of
# the one in the main `extract_internal` function above. We need to do this
# check for fetching fallback ops as well.
# TODO (@wonjoo): Make this duplicate code a bit cleaner.
args_need_update_bool = torch_xla._XLAC._check_tensor_need_materialization(
all_xla_args)

# Again, same logic in the `extract_internal` above to support in-place operations.
# TODO (@wonjoo): Make this duplicate code a bit cleaner.
for i, need_update in enumerate(args_need_update_bool):
if need_update and isinstance(all_xla_args[i], torch.Tensor):
all_xla_args[i].copy_(cloned_args[i])

torch_xla._XLAC._clear_pending_irs(str(xm.xla_device()))

class XlaOperatorSupport(torch.fx.passes.operator_support.OperatorSupport):

Expand All @@ -493,21 +514,6 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
partitioned_graph = partitioner.fuse_partitions(partitions)
InputCollector(partitioned_graph).run(*xla_args)

# This logic, needed for supporting in-place operations, is a duplicate of
# the one in the main `extract_internal` function above. We need to do this
# check for fetching fallback ops as well.
# TODO (@wonjoo): Make this duplicate code a bit cleaner.
args_need_update_bool = torch_xla._XLAC._check_tensor_need_materialization(
all_xla_args)

# Again, same logic in the `extract_internal` above to support in-place operations.
# TODO (@wonjoo): Make this duplicate code a bit cleaner.
for i, need_update in enumerate(args_need_update_bool):
if need_update and isinstance(all_xla_args[i], torch.Tensor):
all_xla_args[i].copy_(cloned_args[i])

torch_xla._XLAC._clear_pending_irs(str(xm.xla_device()))

# compile each submodule and replace it with a call
for node in partitioned_graph.graph.nodes:
if node.op == "call_module" and "fused_" in node.name:
Expand Down
3 changes: 3 additions & 0 deletions torch_xla/csrc/xla_graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,9 @@ void XLAGraphExecutor::ClearPendingIrs(
runtime::GetComputationClient()->CreateDataPlaceholder(
device.toString(), std::move(shape)));
tensors[i]->data()->handle = handle;
TF_VLOG(4) << "Replacing the IR " << ir_value.node.get()->ToString()
<< " of Tensor with ID " << tensors[i]->GetUniqueId()
<< " with placeholder";
}
tensors[i]->AssignIrValue(torch::lazy::Value());
tensors[i]->data()->view = nullptr;
Expand Down

0 comments on commit ef34a31

Please sign in to comment.