Skip to content

Commit

Permalink
Truncate python stack when outputting frame that cause the graph exec…
Browse files Browse the repository at this point in the history
…utation (#5933)

* Truncate python stack when outputting frame that cause the graph execution

* add mp tests

* move tests to a new dir

---------

Co-authored-by: root <[email protected]>
  • Loading branch information
2 people authored and bhavya01 committed Apr 22, 2024
1 parent b65f8ec commit 86b0290
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 4 deletions.
67 changes: 67 additions & 0 deletions test/debug_tool/test_mp_pt_xla_debug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import os

import torch
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp


def check_env_flag(name, default=''):
return os.getenv(name, default).upper() in ['ON', '1', 'YES', 'TRUE', 'Y']


def extract_execution_cause(lines):
causes = []
for i in range(len(lines)):
if 'Execution Cause' in lines[i].decode():
causes.append(lines[i + 1].decode())
return causes


def extract_python_frames(lines):
frames = []
current_frame = ''
record_frame = False
for i in range(len(lines)):
if 'Python Frame Triggered Execution' in lines[i].decode():
record_frame = True
elif 'Execution Analysis: ----------------' in lines[i].decode():
record_frame = False
frames.append(current_frame)
current_frame = ''
if record_frame:
current_frame += lines[i].decode()
return frames


def _mp_fn(index):
if not check_env_flag('PT_XLA_DEBUG'):
assert False, "This test should be run with PT_XLA_DEBUG"
debug_file_name = os.getenv('PT_XLA_DEBUG_FILE')
if not debug_file_name:
assert False, "This test should be run with PT_XLA_DEBUG_FILE"
if index == 0:
open(debug_file_name, 'w').close()
device = xm.xla_device()
t1 = torch.randn(10, 10, device=device)
t2 = t1 * 100
xm.mark_step()
xm.wait_device_ops()

if index == 0:
# All of the process will write to the same PT_XLA_DEBUG_FILE, but the
# no need to check this on all processes.
with open(debug_file_name, 'rb') as f:
lines = f.readlines()
causes = extract_execution_cause(lines)
frames = extract_python_frames(lines)
# only the local master process should dump the executation analysis
assert (len(causes) == 1)
assert ('user mark_step' in causes[0])
# make sure that frame that spawn up process is skipped
assert (len(frames) == 1)
assert ('....' in frames[0])
assert ('_internal/pjrt.py' not in frames[0])


if __name__ == '__main__':
xmp.spawn(_mp_fn, args=())
File renamed without changes.
3 changes: 2 additions & 1 deletion test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ function run_xla_op_tests1 {
run_test "$CDIR/test_grad_checkpoint.py"
run_test "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY
run_test_without_functionalization "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY
run_pt_xla_debug "$CDIR/test_pt_xla_debug.py"
run_pt_xla_debug "$CDIR/debug_tool/test_pt_xla_debug.py"
run_test "$CDIR/test_async_closures.py"
run_test "$CDIR/test_hlo_metadata.py"
run_test "$CDIR/test_profiler.py"
Expand Down Expand Up @@ -232,6 +232,7 @@ function run_mp_op_tests {
run_test "$CDIR/test_mp_save.py"
run_test "$CDIR/test_mp_mesh_reduce.py"
run_test "$CDIR/test_mp_sync_batch_norm.py"
run_pt_xla_debug "$CDIR/debug_tool/test_mp_pt_xla_debug.py"
run_xla_backend_mp "$CDIR/test_torch_distributed_all_gather_xla_backend.py"
run_xla_backend_mp "$CDIR/test_torch_distributed_all_reduce_xla_backend.py"
run_xla_backend_mp "$CDIR/test_torch_distributed_multi_all_reduce_xla_backend.py"
Expand Down
13 changes: 10 additions & 3 deletions torch_xla/csrc/debug_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,11 +272,18 @@ void DebugUtil::analyze_graph_execution_python_frame(
"mark_step\n";
}

// TODO(JackCaoG): make number of frames printed configurable
ss << debug_output_prefix << "Python Frame Triggered Execution: \n";
for (auto& location : frames) {
ss << debug_output_prefix << " " << location.function << " ("
<< location.file << ":" << location.line << ")\n";
// if current frame `__call__` at pjrt.py, bleow stack will be python
// code to spawn up process, not very useful to the user.
if (location.function == "__call__" &&
endsWith(location.file, "_internal/pjrt.py")) {
ss << debug_output_prefix << " ..........\n";
break;
} else {
ss << debug_output_prefix << " " << location.function << " ("
<< location.file << ":" << location.line << ")\n";
}
}
ss << debug_output_prefix
<< "----------------------------------------------------------------------"
Expand Down

0 comments on commit 86b0290

Please sign in to comment.