Skip to content

Commit

Permalink
Improve IR dump graph test cleanups
Browse files Browse the repository at this point in the history
  • Loading branch information
rpsilva-aws committed Dec 9, 2024
1 parent b18ba99 commit 9a98a36
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 12 deletions.
43 changes: 41 additions & 2 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -112,14 +112,53 @@ function run_eager_debug {
XLA_USE_EAGER_DEBUG_MODE=1 run_test "$@"
}

# Run a test with tensor saving enabled, using a specified graph format. The
# graph dump files are cleaned after the test. In case the test crashes, the
# file is retained.
#
# Usage: run_save_tensor <format> [test arguments...]
#
# Arguments:
# format: The graph format to use with XLA_SAVE_TENSORS_FMT
# test arguments: Arguments to pass to the test
#
# Environment:
# Sets XLA_SAVE_TENSORS_FILE and XLA_SAVE_TENSORS_FMT
function run_save_tensor {
local file_graph_format="$1" ; shift

echo "Running in save tensor file mode: $@"
local base_file="/tmp/xla_test_save_ir.txt"

# Check if the file already exists, for any device ordinal number.
if ls "${base_file}"* 1> /dev/null 2>&1; then
echo "Error: File ${base_file} or a numbered version already exists. Please remove it before running the test."
return 1
fi

XLA_SAVE_TENSORS_FILE="$base_file" XLA_SAVE_TENSORS_FMT="$file_graph_format" run_test "$@"
local test_status=$?

# Clean up the file once the test finalizes.
local actual_file
actual_file=$(ls "${base_file}"* 2>/dev/null | head -n1)
if [ -f "$actual_file" ]; then
echo "Cleaning up temporary file: $actual_file"
rm "$actual_file"
else
echo "Warning: Expected output file not found"
fi
return $test_status
}

function run_save_tensor_ir {
echo "Running in save tensor file mode: $@"
XLA_SAVE_TENSORS_FILE="/tmp/xla_test_save_ir.txt" XLA_SAVE_TENSORS_FMT="text" run_test "$@"
run_save_tensor "text" "$@"
}

function run_save_tensor_hlo {
echo "Running in save tensor file mode: $@"
XLA_SAVE_TENSORS_FILE="/tmp/xla_test_save_ir.txt" XLA_SAVE_TENSORS_FMT="hlo" run_test "$@"
run_save_tensor "hlo" "$@"
}

function run_pt_xla_debug {
Expand Down
10 changes: 0 additions & 10 deletions test/spmd/test_spmd_lowering_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,6 @@ def test_basic(self):
save_file += '.0' # Identify a single device
assert save_format == 'hlo', "This test should be run with XLA_SAVE_TENSORS_FMT=hlo"

# Ensure that there is no existing file to begin with.
try:
os.remove(save_file)
except:
pass

model_axis = min(8, self.n_devices)
data_axis = self.n_devices // model_axis
mesh_shape = (data_axis, model_axis)
Expand Down Expand Up @@ -108,10 +102,6 @@ def fn(x, y):
self.assertTrue(met.counter_value("ExecuteReplicated") == 1)
self.assertTrue(met.counter_value("ExecuteComputation") is None)

# Remove the file once the test is complete.
# TODO(rpsilva-aws): Add a proper cleanup wrapper to avoid lingering files.
os.remove(save_file)

def test_device_parameter_id_tensor_mapping(self):
met.clear_all()

Expand Down

0 comments on commit 9a98a36

Please sign in to comment.