Skip to content

Commit

Permalink
Add graph hash and num input/output to PT_XLA_DEBUG (pytorch#5947)
Browse files Browse the repository at this point in the history
* Add graph hash and num input/output to PT_XLA_DEBUG

* Remove unnecessary checks

* fix typo

* static const
  • Loading branch information
JackCaoG authored and ManfeiBai committed Dec 1, 2023
1 parent 1ab6895 commit 07e75fd
Show file tree
Hide file tree
Showing 4 changed files with 181 additions and 34 deletions.
147 changes: 127 additions & 20 deletions test/debug_tool/test_pt_xla_debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch_xla.utils.utils as xu
import torch_xla.distributed.parallel_loader as pl
import unittest
from typing import NamedTuple


def check_env_flag(name, default=''):
Expand All @@ -22,6 +23,34 @@ def extract_execution_cause(lines):
return causes


def extract_compilation_cause(lines):
causes = []
for i in range(len(lines)):
if 'Compilation Cause' in lines[i].decode():
causes.append(lines[i + 1].decode())
return causes


class GraphInfo(NamedTuple):
hash: str
num_input: int
num_output: int


def extract_graph_infos(lines):
infos = []
for i in range(len(lines)):
if 'Graph Info' in lines[i].decode():
hash = lines[i + 1].decode().split('Graph Hash: ')[1].strip()
num_input = lines[i +
2].decode().split('Number of Graph Inputs:')[1].strip()
num_output = lines[i + 3].decode().split(
'Number of Graph Outputs:')[1].strip()
infos.append(GraphInfo(hash, int(num_input), int(num_output)))

return infos


class PtXLADebugTest(unittest.TestCase):

@classmethod
Expand All @@ -39,40 +68,88 @@ def test_user_mark_step(self):
xm.mark_step()
with open(self.debug_file_name, 'rb') as f:
lines = f.readlines()
causes = extract_execution_cause(lines)
self.assertEqual(len(causes), 1)
self.assertIn('user mark_step', causes[0])
executation_causes = extract_execution_cause(lines)
compilation_causes = extract_compilation_cause(lines)
graph_infos = extract_graph_infos(lines)

self.assertEqual(len(executation_causes), 1)
self.assertIn('user mark_step', executation_causes[0])

self.assertEqual(len(compilation_causes), 1)
self.assertIn('user mark_step', compilation_causes[0])

self.assertEqual(len(graph_infos), 2)
# one graph info from compilation, one from execution, hash should match
self.assertEqual(graph_infos[0].hash, graph_infos[1].hash)
# this graph has one input(random seed) and one output(t1)
self.assertEqual(graph_infos[1].num_input, 1)
self.assertEqual(graph_infos[1].num_output, 1)
open(self.debug_file_name, 'w').close()

def test_step_trace(self):
device = xm.xla_device()
with xp.StepTrace('train_pt_xla_debug'):
t1 = torch.randn(2, 2, device=device)
t1 = torch.randn(3, 3, device=device)
with open(self.debug_file_name, 'rb') as f:
lines = f.readlines()
causes = extract_execution_cause(lines)
compilation_causes = extract_compilation_cause(lines)
graph_infos = extract_graph_infos(lines)

self.assertEqual(len(causes), 1)
self.assertIn('mark_step when exiting a profiler StepTrace region',
causes[0])

self.assertEqual(len(compilation_causes), 1)
self.assertIn('mark_step when exiting a profiler StepTrace region',
compilation_causes[0])

self.assertEqual(len(graph_infos), 2)
# one graph info from compilation, one from execution, hash should match
self.assertEqual(graph_infos[0].hash, graph_infos[1].hash)
# this graph has one input(random seed) and one output(t1)
self.assertEqual(graph_infos[1].num_input, 1)
self.assertEqual(graph_infos[1].num_output, 1)
open(self.debug_file_name, 'w').close()

def test_dynamo(self):
device = xm.xla_device()
t1 = torch.randn(2, 2, device=device)
t1 = torch.randn(4, 4, device=device)

def toy_program(t1):
return t1 + t1
return t1 * 100

compiled = torch.compile(toy_program, backend="openxla")
res = compiled(t1)
with open(self.debug_file_name, 'rb') as f:
lines = f.readlines()
causes = extract_execution_cause(lines)
self.assertEqual(len(causes), 4)
self.assertIn('mark_step when dynamo processing input graphs', causes[0])
self.assertIn('mark_step when dynamo processing input graphs', causes[1])
self.assertIn('dynamo is compiling a FX graph to HLO', causes[2])
self.assertIn('dynamo is executing a compiled program', causes[3])
executation_causes = extract_execution_cause(lines)
compilation_causes = extract_compilation_cause(lines)
graph_infos = extract_graph_infos(lines)

self.assertEqual(len(executation_causes), 2)
self.assertIn('mark_step when dynamo processing input graphs',
executation_causes[0])
self.assertIn('dynamo is executing a compiled program',
executation_causes[1])

self.assertEqual(len(compilation_causes), 2)
self.assertIn('mark_step when dynamo processing input graphs',
compilation_causes[0])
self.assertIn('dynamo is compiling a FX graph to HLO',
compilation_causes[1])

# one graph info from compilation, one from execution, hash should match
self.assertEqual(graph_infos[0].hash, graph_infos[1].hash)
# this graph has one input(random seed) and one output(t1)
self.assertEqual(graph_infos[1].num_input, 1)
self.assertEqual(graph_infos[1].num_output, 1)

# one graph info from dynamo compilation, one from dynamo execution, hash should match
self.assertEqual(graph_infos[2].hash, graph_infos[3].hash)
# this graph has two input(t1, 100) and one output
self.assertEqual(graph_infos[3].num_input, 2)
self.assertEqual(graph_infos[3].num_output, 1)
open(self.debug_file_name, 'w').close()

def test_parallel_loader(self):
Expand All @@ -93,25 +170,55 @@ def test_parallel_loader(self):
host_to_device_transfer_threads=1)

for step, (data, target) in enumerate(train_device_loader):
pass
dummy = data + 100

with open(self.debug_file_name, 'rb') as f:
lines = f.readlines()
causes = extract_execution_cause(lines)
self.assertEqual(len(causes), batch_size + 2)
for cause in causes:
executation_causes = extract_execution_cause(lines)
compilation_causes = extract_compilation_cause(lines)
graph_infos = extract_graph_infos(lines)

self.assertEqual(len(executation_causes), batch_size)
for cause in executation_causes:
self.assertIn('mark_step in parallel loader at step end', cause)

# We should only compile once.
self.assertEqual(len(compilation_causes), 1)
self.assertIn('mark_step in parallel loader at step end',
compilation_causes[0])

self.assertEqual(len(graph_infos), batch_size + 1)
# one graph info from compilation, batch size from execution, hash should match
for i in range(batch_size + 1):
self.assertEqual(graph_infos[0].hash, graph_infos[i].hash)
# this graph has two input(data, 100) and one output(dummy)
self.assertEqual(graph_infos[i].num_input, 2)
self.assertEqual(graph_infos[i].num_output, 1)
open(self.debug_file_name, 'w').close()

def test_print(self):
device = xm.xla_device()
t1 = torch.randn(2, 2, device=device)
t1 = torch.randn(5, 5, device=device)
print(t1)
with open(self.debug_file_name, 'rb') as f:
lines = f.readlines()
causes = extract_execution_cause(lines)
self.assertEqual(len(causes), 1)
self.assertIn('user code trying to access tensor value', causes[0])
executation_causes = extract_execution_cause(lines)
compilation_causes = extract_compilation_cause(lines)
graph_infos = extract_graph_infos(lines)

self.assertEqual(len(executation_causes), 1)
self.assertIn('user code trying to access tensor value',
executation_causes[0])

self.assertEqual(len(compilation_causes), 1)
self.assertIn('user code trying to access tensor value',
compilation_causes[0])

# one graph info from compilation, one from execution, hash should match
self.assertEqual(graph_infos[0].hash, graph_infos[1].hash)
# this graph has one input(random seed) and one output(t1)
self.assertEqual(graph_infos[1].num_input, 1)
self.assertEqual(graph_infos[1].num_output, 1)
open(self.debug_file_name, 'w').close()


Expand Down
39 changes: 33 additions & 6 deletions torch_xla/csrc/debug_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,12 +217,25 @@ static bool endsWith(const std::string& str, const std::string& suffix) {
}

void DebugUtil::analyze_graph_execution_python_frame(
bool from_dynamo_executation) {
static bool is_master_process =
GraphAnalysisSource source, torch::lazy::hash_t graph_hash,
const xla::ProgramShape* program_shape) {
static const bool pt_xla_debug_enabled =
runtime::sys_util::GetEnvBool("PT_XLA_DEBUG", false);
static const bool is_master_process =
(runtime::sys_util::GetEnvInt("PJRT_LOCAL_PROCESS_RANK", 0) == 0);
static std::string debug_file_name =
static const std::string debug_file_name =
runtime::sys_util::GetEnvString("PT_XLA_DEBUG_FILE", "");
static std::string debug_output_prefix = "Execution Analysis: ";

static const std::string executation_output_prefix = "Execution Analysis: ";
static const std::string compilation_output_prefix = "Compilation Analysis: ";

if (!pt_xla_debug_enabled) {
return;
}

std::string debug_output_prefix = (source == GraphAnalysisSource::Compilation)
? compilation_output_prefix
: executation_output_prefix;
// TODO: Make this configurable.
if (!is_master_process) {
return;
Expand All @@ -237,8 +250,10 @@ void DebugUtil::analyze_graph_execution_python_frame(
<< "======================================================================"
"=========="
<< "\n";
ss << debug_output_prefix << "Execution Cause\n";
if (from_dynamo_executation) {
ss << debug_output_prefix
<< ((source == GraphAnalysisSource::Compilation) ? "Compilation Cause\n"
: "Execution Cause\n");
if (source == GraphAnalysisSource::DynamoExecution) {
// when executation is from dynamo compiled graph, the python stack will not
// show any dynamo related python file since frame is already replaced. We
// can either analyze the C++ call stack or rely on caller to pass a boolean
Expand Down Expand Up @@ -272,6 +287,18 @@ void DebugUtil::analyze_graph_execution_python_frame(
"mark_step\n";
}

ss << debug_output_prefix << "Graph Info: \n";
ss << debug_output_prefix
<< " Graph Hash: " << torch::lazy::HashToString(graph_hash) << "\n";
ss << debug_output_prefix
<< " Number of Graph Inputs: " << program_shape->parameters().size()
<< "\n";
ss << debug_output_prefix << " Number of Graph Outputs: "
<< (program_shape->result().IsTuple()
? program_shape->result().tuple_shapes_size()
: 1)
<< "\n";

ss << debug_output_prefix << "Python Frame Triggered Execution: \n";
for (auto& location : frames) {
// if current frame `__call__` at pjrt.py, bleow stack will be python
Expand Down
9 changes: 8 additions & 1 deletion torch_xla/csrc/debug_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ class DebugUtil {
kStableHlo,
};

enum GraphAnalysisSource {
Compilation,
Execution,
DynamoExecution,
};

static GraphFormat GetDefaultGraphFormat();

// Return HLO/StableHLO gragh of the index selected tensors in string format.
Expand Down Expand Up @@ -50,7 +56,8 @@ class DebugUtil {
// warning, this function should only be called when a graph execution is
// about to happen.
static void analyze_graph_execution_python_frame(
bool from_dynamo_executation = false);
GraphAnalysisSource source, torch::lazy::hash_t graph_hash = 0,
const xla::ProgramShape* program_shape = nullptr);
};

} // namespace torch_xla
Expand Down
20 changes: 13 additions & 7 deletions torch_xla/csrc/xla_graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -611,10 +611,6 @@ XLAGraphExecutor::ExecuteComputationWithBarrier(
tsl::profiler::TraceMe activity("ExecuteComputationWithBarrier",
tsl::profiler::TraceMeLevel::kInfo);
MaybeDumpGraph("dynamo", hash);
if (runtime::sys_util::GetEnvBool("PT_XLA_DEBUG", false)) {
DebugUtil::analyze_graph_execution_python_frame(
/*from_dynamo_executation=*/true);
}
auto cachedComputation =
XLAGraphExecutor::Get()->GetComputationCache()->Get(hash);
TF_VLOG(5) << "Cached computation (hash: " << torch::lazy::HashToString(hash)
Expand All @@ -627,6 +623,11 @@ XLAGraphExecutor::ExecuteComputationWithBarrier(
<< ". Maybe the entry get "
"kicked out of the LRU cache";

DebugUtil::analyze_graph_execution_python_frame(
DebugUtil::GraphAnalysisSource::DynamoExecution,
/*graph_hash=*/hash,
/*program_shape=*/&(cachedComputation->computation->program_shape()));

// Create DataPlaceHolder that will get filled in async executions.
std::vector<xla::Shape>* output_shapes =
DeviceContextArena::Get()->GetOutputShapesByHash(hash);
Expand Down Expand Up @@ -956,6 +957,10 @@ XLAGraphExecutor::ScheduleSyncTensorsGraph(
std::vector<torch::lazy::BackendDataPtr> tensors_data,
std::vector<XLATensor::ShardingSpecPtr> sharding_specs,
ComputationCache::TypePtr cached_computation) {
DebugUtil::analyze_graph_execution_python_frame(
DebugUtil::GraphAnalysisSource::Execution,
/*graph_hash=*/coll->hash,
/*program_shape=*/&(cached_computation->computation->program_shape()));
tsl::profiler::TraceMe activity("ScheduleSyncTensorsGraph",
tsl::profiler::TraceMeLevel::kInfo);
TensorCollectionBarrier(coll);
Expand Down Expand Up @@ -1261,6 +1266,10 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile(
coll.device.toString(), devices),
&shape, should_wrap_parameter, is_sharded});

DebugUtil::analyze_graph_execution_python_frame(
DebugUtil::GraphAnalysisSource::Compilation,
/*graph_hash=*/coll.hash, /*program_shape=*/&program_shape);

TF_VLOG(3) << "Compiling IR graph hash "
<< torch::lazy::HashToString(coll.hash) << " on device "
<< coll.device << " ...";
Expand Down Expand Up @@ -1300,9 +1309,6 @@ XLAGraphExecutor::SyncTensorsGraphInternal(
const SyncTensorsConfig& config, bool warm_up_cache_only) {
tsl::profiler::TraceMe activity("SyncTensorsGraphInternal",
tsl::profiler::TraceMeLevel::kInfo);
if (runtime::sys_util::GetEnvBool("PT_XLA_DEBUG", false)) {
DebugUtil::analyze_graph_execution_python_frame();
}
SyncTensorCollection coll = CollectSyncTensors(*tensors, config);
if (coll.indices.empty()) {
// Enure previous execution is complete before exiting this
Expand Down

0 comments on commit 07e75fd

Please sign in to comment.