From 86b0290f72773581be08f58d0d2429d0c350a8a0 Mon Sep 17 00:00:00 2001 From: JackCaoG <59073027+JackCaoG@users.noreply.github.com> Date: Wed, 29 Nov 2023 10:54:28 -0800 Subject: [PATCH] Truncate python stack when outputting frame that cause the graph executation (#5933) * Truncate python stack when outputting frame that cause the graph execution * add mp tests * move tests to a new dir --------- Co-authored-by: root --- test/debug_tool/test_mp_pt_xla_debug.py | 67 ++++++++++++++++++++++ test/{ => debug_tool}/test_pt_xla_debug.py | 0 test/run_tests.sh | 3 +- torch_xla/csrc/debug_util.cpp | 13 ++++- 4 files changed, 79 insertions(+), 4 deletions(-) create mode 100644 test/debug_tool/test_mp_pt_xla_debug.py rename test/{ => debug_tool}/test_pt_xla_debug.py (100%) diff --git a/test/debug_tool/test_mp_pt_xla_debug.py b/test/debug_tool/test_mp_pt_xla_debug.py new file mode 100644 index 00000000000..3b9547bbdee --- /dev/null +++ b/test/debug_tool/test_mp_pt_xla_debug.py @@ -0,0 +1,67 @@ +import os + +import torch +import torch_xla.core.xla_model as xm +import torch_xla.distributed.xla_multiprocessing as xmp + + +def check_env_flag(name, default=''): + return os.getenv(name, default).upper() in ['ON', '1', 'YES', 'TRUE', 'Y'] + + +def extract_execution_cause(lines): + causes = [] + for i in range(len(lines)): + if 'Execution Cause' in lines[i].decode(): + causes.append(lines[i + 1].decode()) + return causes + + +def extract_python_frames(lines): + frames = [] + current_frame = '' + record_frame = False + for i in range(len(lines)): + if 'Python Frame Triggered Execution' in lines[i].decode(): + record_frame = True + elif 'Execution Analysis: ----------------' in lines[i].decode(): + record_frame = False + frames.append(current_frame) + current_frame = '' + if record_frame: + current_frame += lines[i].decode() + return frames + + +def _mp_fn(index): + if not check_env_flag('PT_XLA_DEBUG'): + assert False, "This test should be run with PT_XLA_DEBUG" + debug_file_name = os.getenv('PT_XLA_DEBUG_FILE') + if not debug_file_name: + assert False, "This test should be run with PT_XLA_DEBUG_FILE" + if index == 0: + open(debug_file_name, 'w').close() + device = xm.xla_device() + t1 = torch.randn(10, 10, device=device) + t2 = t1 * 100 + xm.mark_step() + xm.wait_device_ops() + + if index == 0: + # All of the process will write to the same PT_XLA_DEBUG_FILE, but the + # no need to check this on all processes. + with open(debug_file_name, 'rb') as f: + lines = f.readlines() + causes = extract_execution_cause(lines) + frames = extract_python_frames(lines) + # only the local master process should dump the executation analysis + assert (len(causes) == 1) + assert ('user mark_step' in causes[0]) + # make sure that frame that spawn up process is skipped + assert (len(frames) == 1) + assert ('....' in frames[0]) + assert ('_internal/pjrt.py' not in frames[0]) + + +if __name__ == '__main__': + xmp.spawn(_mp_fn, args=()) diff --git a/test/test_pt_xla_debug.py b/test/debug_tool/test_pt_xla_debug.py similarity index 100% rename from test/test_pt_xla_debug.py rename to test/debug_tool/test_pt_xla_debug.py diff --git a/test/run_tests.sh b/test/run_tests.sh index a4c82a6d4c7..453abb5e469 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -161,7 +161,7 @@ function run_xla_op_tests1 { run_test "$CDIR/test_grad_checkpoint.py" run_test "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY run_test_without_functionalization "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY - run_pt_xla_debug "$CDIR/test_pt_xla_debug.py" + run_pt_xla_debug "$CDIR/debug_tool/test_pt_xla_debug.py" run_test "$CDIR/test_async_closures.py" run_test "$CDIR/test_hlo_metadata.py" run_test "$CDIR/test_profiler.py" @@ -232,6 +232,7 @@ function run_mp_op_tests { run_test "$CDIR/test_mp_save.py" run_test "$CDIR/test_mp_mesh_reduce.py" run_test "$CDIR/test_mp_sync_batch_norm.py" + run_pt_xla_debug "$CDIR/debug_tool/test_mp_pt_xla_debug.py" run_xla_backend_mp "$CDIR/test_torch_distributed_all_gather_xla_backend.py" run_xla_backend_mp "$CDIR/test_torch_distributed_all_reduce_xla_backend.py" run_xla_backend_mp "$CDIR/test_torch_distributed_multi_all_reduce_xla_backend.py" diff --git a/torch_xla/csrc/debug_util.cpp b/torch_xla/csrc/debug_util.cpp index 9959d46f8a2..06d839a5e28 100644 --- a/torch_xla/csrc/debug_util.cpp +++ b/torch_xla/csrc/debug_util.cpp @@ -272,11 +272,18 @@ void DebugUtil::analyze_graph_execution_python_frame( "mark_step\n"; } - // TODO(JackCaoG): make number of frames printed configurable ss << debug_output_prefix << "Python Frame Triggered Execution: \n"; for (auto& location : frames) { - ss << debug_output_prefix << " " << location.function << " (" - << location.file << ":" << location.line << ")\n"; + // if current frame `__call__` at pjrt.py, bleow stack will be python + // code to spawn up process, not very useful to the user. + if (location.function == "__call__" && + endsWith(location.file, "_internal/pjrt.py")) { + ss << debug_output_prefix << " ..........\n"; + break; + } else { + ss << debug_output_prefix << " " << location.function << " (" + << location.file << ":" << location.line << ")\n"; + } } ss << debug_output_prefix << "----------------------------------------------------------------------"