diff --git a/test/debug_tool/test_pt_xla_debug.py b/test/debug_tool/test_pt_xla_debug.py index 14f0817a4c08..4c806d28036c 100644 --- a/test/debug_tool/test_pt_xla_debug.py +++ b/test/debug_tool/test_pt_xla_debug.py @@ -8,6 +8,7 @@ import torch_xla.utils.utils as xu import torch_xla.distributed.parallel_loader as pl import unittest +from typing import NamedTuple def check_env_flag(name, default=''): @@ -22,6 +23,34 @@ def extract_execution_cause(lines): return causes +def extract_compilation_cause(lines): + causes = [] + for i in range(len(lines)): + if 'Compilation Cause' in lines[i].decode(): + causes.append(lines[i + 1].decode()) + return causes + + +class GraphInfo(NamedTuple): + hash: str + num_input: int + num_output: int + + +def extract_graph_infos(lines): + infos = [] + for i in range(len(lines)): + if 'Graph Info' in lines[i].decode(): + hash = lines[i + 1].decode().split('Graph Hash: ')[1].strip() + num_input = lines[i + + 2].decode().split('Number of Graph Inputs:')[1].strip() + num_output = lines[i + 3].decode().split( + 'Number of Graph Outputs:')[1].strip() + infos.append(GraphInfo(hash, int(num_input), int(num_output))) + + return infos + + class PtXLADebugTest(unittest.TestCase): @classmethod @@ -39,40 +68,88 @@ def test_user_mark_step(self): xm.mark_step() with open(self.debug_file_name, 'rb') as f: lines = f.readlines() - causes = extract_execution_cause(lines) - self.assertEqual(len(causes), 1) - self.assertIn('user mark_step', causes[0]) + executation_causes = extract_execution_cause(lines) + compilation_causes = extract_compilation_cause(lines) + graph_infos = extract_graph_infos(lines) + + self.assertEqual(len(executation_causes), 1) + self.assertIn('user mark_step', executation_causes[0]) + + self.assertEqual(len(compilation_causes), 1) + self.assertIn('user mark_step', compilation_causes[0]) + + self.assertEqual(len(graph_infos), 2) + # one graph info from compilation, one from execution, hash should match + self.assertEqual(graph_infos[0].hash, graph_infos[1].hash) + # this graph has one input(random seed) and one output(t1) + self.assertEqual(graph_infos[1].num_input, 1) + self.assertEqual(graph_infos[1].num_output, 1) open(self.debug_file_name, 'w').close() def test_step_trace(self): device = xm.xla_device() with xp.StepTrace('train_pt_xla_debug'): - t1 = torch.randn(2, 2, device=device) + t1 = torch.randn(3, 3, device=device) with open(self.debug_file_name, 'rb') as f: lines = f.readlines() causes = extract_execution_cause(lines) + compilation_causes = extract_compilation_cause(lines) + graph_infos = extract_graph_infos(lines) + self.assertEqual(len(causes), 1) self.assertIn('mark_step when exiting a profiler StepTrace region', causes[0]) + + self.assertEqual(len(compilation_causes), 1) + self.assertIn('mark_step when exiting a profiler StepTrace region', + compilation_causes[0]) + + self.assertEqual(len(graph_infos), 2) + # one graph info from compilation, one from execution, hash should match + self.assertEqual(graph_infos[0].hash, graph_infos[1].hash) + # this graph has one input(random seed) and one output(t1) + self.assertEqual(graph_infos[1].num_input, 1) + self.assertEqual(graph_infos[1].num_output, 1) open(self.debug_file_name, 'w').close() def test_dynamo(self): device = xm.xla_device() - t1 = torch.randn(2, 2, device=device) + t1 = torch.randn(4, 4, device=device) def toy_program(t1): - return t1 + t1 + return t1 * 100 compiled = torch.compile(toy_program, backend="openxla") res = compiled(t1) with open(self.debug_file_name, 'rb') as f: lines = f.readlines() - causes = extract_execution_cause(lines) - self.assertEqual(len(causes), 4) - self.assertIn('mark_step when dynamo processing input graphs', causes[0]) - self.assertIn('mark_step when dynamo processing input graphs', causes[1]) - self.assertIn('dynamo is compiling a FX graph to HLO', causes[2]) - self.assertIn('dynamo is executing a compiled program', causes[3]) + executation_causes = extract_execution_cause(lines) + compilation_causes = extract_compilation_cause(lines) + graph_infos = extract_graph_infos(lines) + + self.assertEqual(len(executation_causes), 2) + self.assertIn('mark_step when dynamo processing input graphs', + executation_causes[0]) + self.assertIn('dynamo is executing a compiled program', + executation_causes[1]) + + self.assertEqual(len(compilation_causes), 2) + self.assertIn('mark_step when dynamo processing input graphs', + compilation_causes[0]) + self.assertIn('dynamo is compiling a FX graph to HLO', + compilation_causes[1]) + + # one graph info from compilation, one from execution, hash should match + self.assertEqual(graph_infos[0].hash, graph_infos[1].hash) + # this graph has one input(random seed) and one output(t1) + self.assertEqual(graph_infos[1].num_input, 1) + self.assertEqual(graph_infos[1].num_output, 1) + + # one graph info from dynamo compilation, one from dynamo execution, hash should match + self.assertEqual(graph_infos[2].hash, graph_infos[3].hash) + # this graph has two input(t1, 100) and one output + self.assertEqual(graph_infos[3].num_input, 2) + self.assertEqual(graph_infos[3].num_output, 1) open(self.debug_file_name, 'w').close() def test_parallel_loader(self): @@ -93,25 +170,55 @@ def test_parallel_loader(self): host_to_device_transfer_threads=1) for step, (data, target) in enumerate(train_device_loader): - pass + dummy = data + 100 with open(self.debug_file_name, 'rb') as f: lines = f.readlines() - causes = extract_execution_cause(lines) - self.assertEqual(len(causes), batch_size + 2) - for cause in causes: + executation_causes = extract_execution_cause(lines) + compilation_causes = extract_compilation_cause(lines) + graph_infos = extract_graph_infos(lines) + + self.assertEqual(len(executation_causes), batch_size) + for cause in executation_causes: self.assertIn('mark_step in parallel loader at step end', cause) + + # We should only compile once. + self.assertEqual(len(compilation_causes), 1) + self.assertIn('mark_step in parallel loader at step end', + compilation_causes[0]) + + self.assertEqual(len(graph_infos), batch_size + 1) + # one graph info from compilation, batch size from execution, hash should match + for i in range(batch_size + 1): + self.assertEqual(graph_infos[0].hash, graph_infos[i].hash) + # this graph has two input(data, 100) and one output(dummy) + self.assertEqual(graph_infos[i].num_input, 2) + self.assertEqual(graph_infos[i].num_output, 1) open(self.debug_file_name, 'w').close() def test_print(self): device = xm.xla_device() - t1 = torch.randn(2, 2, device=device) + t1 = torch.randn(5, 5, device=device) print(t1) with open(self.debug_file_name, 'rb') as f: lines = f.readlines() - causes = extract_execution_cause(lines) - self.assertEqual(len(causes), 1) - self.assertIn('user code trying to access tensor value', causes[0]) + executation_causes = extract_execution_cause(lines) + compilation_causes = extract_compilation_cause(lines) + graph_infos = extract_graph_infos(lines) + + self.assertEqual(len(executation_causes), 1) + self.assertIn('user code trying to access tensor value', + executation_causes[0]) + + self.assertEqual(len(compilation_causes), 1) + self.assertIn('user code trying to access tensor value', + compilation_causes[0]) + + # one graph info from compilation, one from execution, hash should match + self.assertEqual(graph_infos[0].hash, graph_infos[1].hash) + # this graph has one input(random seed) and one output(t1) + self.assertEqual(graph_infos[1].num_input, 1) + self.assertEqual(graph_infos[1].num_output, 1) open(self.debug_file_name, 'w').close() diff --git a/torch_xla/csrc/debug_util.cpp b/torch_xla/csrc/debug_util.cpp index 06d839a5e28a..6af3c1b36084 100644 --- a/torch_xla/csrc/debug_util.cpp +++ b/torch_xla/csrc/debug_util.cpp @@ -217,12 +217,25 @@ static bool endsWith(const std::string& str, const std::string& suffix) { } void DebugUtil::analyze_graph_execution_python_frame( - bool from_dynamo_executation) { - static bool is_master_process = + GraphAnalysisSource source, torch::lazy::hash_t graph_hash, + const xla::ProgramShape* program_shape) { + static const bool pt_xla_debug_enabled = + runtime::sys_util::GetEnvBool("PT_XLA_DEBUG", false); + static const bool is_master_process = (runtime::sys_util::GetEnvInt("PJRT_LOCAL_PROCESS_RANK", 0) == 0); - static std::string debug_file_name = + static const std::string debug_file_name = runtime::sys_util::GetEnvString("PT_XLA_DEBUG_FILE", ""); - static std::string debug_output_prefix = "Execution Analysis: "; + + static const std::string executation_output_prefix = "Execution Analysis: "; + static const std::string compilation_output_prefix = "Compilation Analysis: "; + + if (!pt_xla_debug_enabled) { + return; + } + + std::string debug_output_prefix = (source == GraphAnalysisSource::Compilation) + ? compilation_output_prefix + : executation_output_prefix; // TODO: Make this configurable. if (!is_master_process) { return; @@ -237,8 +250,10 @@ void DebugUtil::analyze_graph_execution_python_frame( << "======================================================================" "==========" << "\n"; - ss << debug_output_prefix << "Execution Cause\n"; - if (from_dynamo_executation) { + ss << debug_output_prefix + << ((source == GraphAnalysisSource::Compilation) ? "Compilation Cause\n" + : "Execution Cause\n"); + if (source == GraphAnalysisSource::DynamoExecution) { // when executation is from dynamo compiled graph, the python stack will not // show any dynamo related python file since frame is already replaced. We // can either analyze the C++ call stack or rely on caller to pass a boolean @@ -272,6 +287,18 @@ void DebugUtil::analyze_graph_execution_python_frame( "mark_step\n"; } + ss << debug_output_prefix << "Graph Info: \n"; + ss << debug_output_prefix + << " Graph Hash: " << torch::lazy::HashToString(graph_hash) << "\n"; + ss << debug_output_prefix + << " Number of Graph Inputs: " << program_shape->parameters().size() + << "\n"; + ss << debug_output_prefix << " Number of Graph Outputs: " + << (program_shape->result().IsTuple() + ? program_shape->result().tuple_shapes_size() + : 1) + << "\n"; + ss << debug_output_prefix << "Python Frame Triggered Execution: \n"; for (auto& location : frames) { // if current frame `__call__` at pjrt.py, bleow stack will be python diff --git a/torch_xla/csrc/debug_util.h b/torch_xla/csrc/debug_util.h index 530a45fc83a7..74d2ffa3ffcf 100644 --- a/torch_xla/csrc/debug_util.h +++ b/torch_xla/csrc/debug_util.h @@ -19,6 +19,12 @@ class DebugUtil { kStableHlo, }; + enum GraphAnalysisSource { + Compilation, + Execution, + DynamoExecution, + }; + static GraphFormat GetDefaultGraphFormat(); // Return HLO/StableHLO gragh of the index selected tensors in string format. @@ -50,7 +56,8 @@ class DebugUtil { // warning, this function should only be called when a graph execution is // about to happen. static void analyze_graph_execution_python_frame( - bool from_dynamo_executation = false); + GraphAnalysisSource source, torch::lazy::hash_t graph_hash = 0, + const xla::ProgramShape* program_shape = nullptr); }; } // namespace torch_xla diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index 2b193bdbb960..49f4639ff9e1 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -611,10 +611,6 @@ XLAGraphExecutor::ExecuteComputationWithBarrier( tsl::profiler::TraceMe activity("ExecuteComputationWithBarrier", tsl::profiler::TraceMeLevel::kInfo); MaybeDumpGraph("dynamo", hash); - if (runtime::sys_util::GetEnvBool("PT_XLA_DEBUG", false)) { - DebugUtil::analyze_graph_execution_python_frame( - /*from_dynamo_executation=*/true); - } auto cachedComputation = XLAGraphExecutor::Get()->GetComputationCache()->Get(hash); TF_VLOG(5) << "Cached computation (hash: " << torch::lazy::HashToString(hash) @@ -627,6 +623,11 @@ XLAGraphExecutor::ExecuteComputationWithBarrier( << ". Maybe the entry get " "kicked out of the LRU cache"; + DebugUtil::analyze_graph_execution_python_frame( + DebugUtil::GraphAnalysisSource::DynamoExecution, + /*graph_hash=*/hash, + /*program_shape=*/&(cachedComputation->computation->program_shape())); + // Create DataPlaceHolder that will get filled in async executions. std::vector* output_shapes = DeviceContextArena::Get()->GetOutputShapesByHash(hash); @@ -956,6 +957,10 @@ XLAGraphExecutor::ScheduleSyncTensorsGraph( std::vector tensors_data, std::vector sharding_specs, ComputationCache::TypePtr cached_computation) { + DebugUtil::analyze_graph_execution_python_frame( + DebugUtil::GraphAnalysisSource::Execution, + /*graph_hash=*/coll->hash, + /*program_shape=*/&(cached_computation->computation->program_shape())); tsl::profiler::TraceMe activity("ScheduleSyncTensorsGraph", tsl::profiler::TraceMeLevel::kInfo); TensorCollectionBarrier(coll); @@ -1261,6 +1266,10 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile( coll.device.toString(), devices), &shape, should_wrap_parameter, is_sharded}); + DebugUtil::analyze_graph_execution_python_frame( + DebugUtil::GraphAnalysisSource::Compilation, + /*graph_hash=*/coll.hash, /*program_shape=*/&program_shape); + TF_VLOG(3) << "Compiling IR graph hash " << torch::lazy::HashToString(coll.hash) << " on device " << coll.device << " ..."; @@ -1300,9 +1309,6 @@ XLAGraphExecutor::SyncTensorsGraphInternal( const SyncTensorsConfig& config, bool warm_up_cache_only) { tsl::profiler::TraceMe activity("SyncTensorsGraphInternal", tsl::profiler::TraceMeLevel::kInfo); - if (runtime::sys_util::GetEnvBool("PT_XLA_DEBUG", false)) { - DebugUtil::analyze_graph_execution_python_frame(); - } SyncTensorCollection coll = CollectSyncTensors(*tensors, config); if (coll.indices.empty()) { // Enure previous execution is complete before exiting this