Skip to content

Commit

Permalink
[LoweringContext] SPMD propagation (#8471)
Browse files Browse the repository at this point in the history
  • Loading branch information
rpsilva-aws authored Dec 10, 2024
1 parent 1812817 commit b068cab
Show file tree
Hide file tree
Showing 8 changed files with 239 additions and 88 deletions.
21 changes: 7 additions & 14 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ MAX_GRAPH_SIZE=500
GRAPH_CHECK_FREQUENCY=100
VERBOSITY=2

# Utils file
source "${CDIR}/utils/run_tests_utils.sh"

# Note [Keep Going]
#
# Set the `CONTINUE_ON_ERROR` flag to `true` to make the CI tests continue on error.
Expand Down Expand Up @@ -112,16 +115,6 @@ function run_eager_debug {
XLA_USE_EAGER_DEBUG_MODE=1 run_test "$@"
}

function run_save_tensor_ir {
echo "Running in save tensor file mode: $@"
XLA_SAVE_TENSORS_FILE="/tmp/xla_test_save_ir.txt" XLA_SAVE_TENSORS_FMT="text" run_test "$@"
}

function run_save_tensor_hlo {
echo "Running in save tensor file mode: $@"
XLA_SAVE_TENSORS_FILE="/tmp/xla_test_save_ir.txt" XLA_SAVE_TENSORS_FMT="hlo" run_test "$@"
}

function run_pt_xla_debug {
echo "Running in save tensor file mode: $@"
PT_XLA_DEBUG=1 PT_XLA_DEBUG_FILE="/tmp/pt_xla_debug.txt" run_test "$@"
Expand Down Expand Up @@ -193,16 +186,16 @@ function run_xla_op_tests1 {
run_test "$CDIR/dynamo/test_num_output.py"
run_test "$CDIR/dynamo/test_graph_input_matcher.py"
run_test "$CDIR/dynamo/test_dynamo_config.py"
run_save_tensor_ir "$CDIR/dynamo/test_dynamo_graph_dump.py"
run_save_tensor_ir run_test "$CDIR/dynamo/test_dynamo_graph_dump.py"
run_test "$CDIR/test_data_type.py"
run_use_bf16 "$CDIR/test_data_type.py"
run_downcast_bf16 "$CDIR/test_data_type.py"
run_test "$CDIR/test_fp8.py"
run_xla_ir_debug "$CDIR/test_env_var_mapper.py"
run_xla_hlo_debug "$CDIR/test_env_var_mapper.py"
run_xla_hlo_debug "$CDIR/stablehlo/test_stablehlo_save_load.py"
run_save_tensor_ir "$CDIR/spmd/test_spmd_graph_dump.py"
run_save_tensor_hlo "$CDIR/spmd/test_spmd_graph_dump.py"
run_save_tensor_ir run_test "$CDIR/spmd/test_spmd_graph_dump.py"
run_save_tensor_hlo run_test "$CDIR/spmd/test_spmd_graph_dump.py"
}

function run_xla_op_tests2 {
Expand Down Expand Up @@ -248,7 +241,7 @@ function run_xla_op_tests3 {
run_test "$CDIR/spmd/test_xla_auto_sharding.py"
run_test "$CDIR/spmd/test_spmd_parameter_wrapping.py"
run_test "$CDIR/spmd/test_mp_input_sharding.py"
run_test "$CDIR/spmd/test_spmd_lowering_context.py"
run_save_tensor_hlo run_test "$CDIR/spmd/test_spmd_lowering_context.py"
run_test "$CDIR/test_operations_hlo.py" "$@" --verbosity=$VERBOSITY
run_test "$CDIR/test_input_output_aliases.py"
run_test "$CDIR/test_torch_distributed_xla_backend.py"
Expand Down
9 changes: 3 additions & 6 deletions test/spmd/test_spmd_graph_dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ def setUpClass(cls):
def test_dump_with_output_sharding(self):
save_file = os.getenv('XLA_SAVE_TENSORS_FILE')
save_format = os.getenv('XLA_SAVE_TENSORS_FMT')
if not save_file:
assert False, "This test should be run with XLA_SAVE_TENSORS_FILE"
assert save_file, "This test should be run with XLA_SAVE_TENSORS_FILE"
should_dump_output_sharding = (save_format == 'hlo')
save_file += '.0'
device = xm.xla_device()
Expand All @@ -35,12 +34,10 @@ def test_dump_with_output_sharding(self):
xla_sharded_x = xs.mark_sharding(xla_x, self._get_mesh((1, self.n_devices)),
partition_spec)
xla_res = xla_x + xla_y
xm.mark_step()
with open(save_file, 'rb') as f:
current_line = sum(1 for line in f)
with open(save_file, 'rb') as f:
xm.mark_step()
lines = f.readlines()
self.assertGreater(len(lines), current_line)
self.assertGreater(len(lines), 0)
if should_dump_output_sharding:
self.assertIn('OUTPUT_SHARDING_END', str(lines[-2]))
else:
Expand Down
88 changes: 86 additions & 2 deletions test/spmd/test_spmd_lowering_context.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
import os
import re
import sys
from pathlib import Path

import unittest

import test_xla_sharding_base

import torch
import torch_xla
import torch_xla.core.xla_builder as xb
import torch_xla.debug.metrics as met
import torch_xla.distributed.spmd as xs
import torch_xla.core.xla_model as xm
import contextlib


class TestSPMDLoweringContext(test_xla_sharding_base.XlaShardingTest):
Expand All @@ -18,10 +21,91 @@ class TestSPMDLoweringContext(test_xla_sharding_base.XlaShardingTest):
def setUpClass(cls):
super().setUpClass()

def _get_computation_hlo_txt(self, ctx):
hlo = ctx.hlo()
comp = xb.computation_from_module_proto("my_custom_comp", hlo)
return xb.get_computation_hlo(comp)

def test_basic(self):
save_file = os.getenv('XLA_SAVE_TENSORS_FILE')
save_format = os.getenv('XLA_SAVE_TENSORS_FMT')
assert save_file, "This test should be run with XLA_SAVE_TENSORS_FILE"
save_file += '.0' # Identify a single device
assert save_format == 'hlo', "This test should be run with XLA_SAVE_TENSORS_FMT=hlo"

model_axis = max(1, self.n_devices // 2)
data_axis = self.n_devices // model_axis
mesh_shape = (data_axis, model_axis)
spmd_mesh = self._get_mesh(mesh_shape, axis_names=('x', 'y'))

device = xm.xla_device()
a = torch.zeros(2048, device=device, requires_grad=True)
xs.mark_sharding(a, spmd_mesh, ('x',))
b = torch.randn([32, 2048], device=device, requires_grad=True)
xs.mark_sharding(b, spmd_mesh, (None, 'y'))

def fn(x, y):
x = x + 1
return x, y * 2

result = fn(a, b)
ctx = torch_xla._XLAC.lowering.LoweringContext("MyCustomName")
ctx.build(list(result))
torch_xla.sync()

# Sanity HLO check.
hlo_text = ctx.hlo_text()
self.assertIn('MyCustomName', hlo_text)
self.assertIn('opcode: "parameter"', hlo_text)
self.assertIn('opcode: "add"', hlo_text)
self.assertIn('sharding', hlo_text)

# Ensure that the corresponding input parameters contain the expected sharding.
hlo_comp_txt = self._get_computation_hlo_txt(ctx)
a_sharding_spec = torch_xla._XLAC._get_xla_sharding_spec(a)
self.assertRegex(
hlo_comp_txt,
rf'%custom-call.*.*f32[2048]{{0}}.*sharding={re.escape(a_sharding_spec)}'
)
b_sharding_spec = torch_xla._XLAC._get_xla_sharding_spec(b)
self.assertRegex(
hlo_comp_txt,
rf'%custom-call.*f32[32,2048]{{0}}.*sharding={re.escape(b_sharding_spec)}'
)

# Ensure that the results retain the same sharding specs.
result_a, result_b = result
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(result_a), a_sharding_spec)
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(result_b), b_sharding_spec)

hlo_content = Path(save_file).read_text()
assert len(re.findall('END_GRAPH',
hlo_content)) == 1, "There is a single graph"

# Extract the content between OUTPUT_SHARDING_BEGIN and OUTPUT_SHARDING_END
pattern = r'#OUTPUT_SHARDING_BEGIN\n(.*?)\n#OUTPUT_SHARDING_END'
match = re.search(pattern, hlo_content, re.DOTALL)
assert match is not None, "#OUTPUT_SHARDING not found in the file"
assert len(match.groups()
) == 1, f"Expected 1 group, but found {len(match.groups())}"
expected_output = match.group(1).strip().split('\n')

# Assert that the output sharding match our expectation.
assert len(expected_output
) == 4, f"Expected 4 lines, but found {len(expected_output)}"
assert expected_output[0] == f"f32[2048] {a_sharding_spec}"
assert expected_output[1] == f"f32[32,2048] {b_sharding_spec}"
assert expected_output[2] == f"f32[2048] {a_sharding_spec}"
assert expected_output[3] == f"f32[32,2048] {b_sharding_spec}"
self.assertTrue(met.counter_value("ExecuteReplicated") == 1)
self.assertTrue(met.counter_value("ExecuteComputation") is None)

def test_device_parameter_id_tensor_mapping(self):
met.clear_all()

model_axis = min(8, self.n_devices)
model_axis = max(1, self.n_devices // 2)
data_axis = self.n_devices // model_axis
mesh_shape = (data_axis, model_axis)
spmd_mesh = self._get_mesh(mesh_shape, axis_names=('x', 'y'))
Expand Down
102 changes: 53 additions & 49 deletions test/tpu/run_tests.sh
Original file line number Diff line number Diff line change
@@ -1,66 +1,70 @@
#!/bin/bash
set -xue
CDIR="$(cd "$(dirname "$0")" ; pwd -P)"
TEST_CDIR="$(dirname "$CDIR")"

source "${TEST_CDIR}/utils/run_tests_utils.sh"

# TODO: merge with other run_tests
python3 test/test_operations.py -v
python3 test/pjrt/test_runtime_tpu.py
python3 test/pjrt/test_collective_ops_tpu.py
python3 test/spmd/test_mp_input_sharding.py
python3 test/spmd/test_spmd_lowering_context.py
python3 test/spmd/test_xla_sharding.py
python3 test/spmd/test_xla_virtual_device.py
python3 test/spmd/test_xla_distributed_checkpoint.py
python3 test/spmd/test_train_spmd_linear_model.py
python3 test/spmd/test_xla_spmd_python_api_interaction.py
python3 test/spmd/test_xla_auto_sharding.py
python3 test/spmd/test_fsdp_v2.py
XLA_EXPERIMENTAL=nonzero:masked_select:nms python3 test/ds/test_dynamic_shape_models.py -v
python3 test/test_autocast.py
python3 test/test_fp8.py
python3 test/test_grad_checkpoint.py
python3 test/dynamo/test_dynamo.py
python3 test/dynamo/test_dynamo_dynamic_shape.py
python3 test/spmd/test_spmd_debugging.py
XLA_PARAMETER_WRAPPING_THREADSHOLD=1 python test/spmd/test_spmd_parameter_wrapping.py
python3 test/pjrt/test_dtypes.py
python3 test/pjrt/test_dynamic_plugin_tpu.py
python3 test/test_while_loop.py
python3 test/scan/test_scan.py
python3 test/scan/test_scan_spmd.py
python3 test/scan/test_scan_layers.py
python3 test/test_pallas.py -v
python3 test/test_pallas_spmd.py
python3 test/test_tpu_paged_attention_kernel.py
python3 test/test_input_output_aliases.py
python3 test/test_gmm.py
python3 test/eager/test_eager_spmd.py
python3 test/torch_distributed/test_torch_distributed_all_gather_xla_backend.py
python3 test/torch_distributed/test_torch_distributed_all_reduce_xla_backend.py
python3 test/torch_distributed/test_torch_distributed_multi_all_reduce_xla_backend.py
python3 test/torch_distributed/test_torch_distributed_reduce_scatter_xla_backend.py
python3 test/quantized_ops/test_dot_general.py
python3 "$TEST_CDIR/test_operations.py" -v
python3 "$TEST_CDIR/pjrt/test_runtime_tpu.py"
python3 "$TEST_CDIR/pjrt/test_collective_ops_tpu.py"
python3 "$TEST_CDIR/spmd/test_mp_input_sharding.py"
run_save_tensor_hlo python3 "$TEST_CDIR/spmd/test_spmd_lowering_context.py"
python3 "$TEST_CDIR/spmd/test_xla_sharding.py"
python3 "$TEST_CDIR/spmd/test_xla_virtual_device.py"
python3 "$TEST_CDIR/spmd/test_xla_distributed_checkpoint.py"
python3 "$TEST_CDIR/spmd/test_train_spmd_linear_model.py"
python3 "$TEST_CDIR/spmd/test_xla_spmd_python_api_interaction.py"
python3 "$TEST_CDIR/spmd/test_xla_auto_sharding.py"
python3 "$TEST_CDIR/spmd/test_fsdp_v2.py"
XLA_EXPERIMENTAL=nonzero:masked_select:nms python3 "$TEST_CDIR/ds/test_dynamic_shape_models.py" -v
python3 "$TEST_CDIR/test_autocast.py"
python3 "$TEST_CDIR/test_fp8.py"
python3 "$TEST_CDIR/test_grad_checkpoint.py"
python3 "$TEST_CDIR/dynamo/test_dynamo.py"
python3 "$TEST_CDIR/dynamo/test_dynamo_dynamic_shape.py"
python3 "$TEST_CDIR/spmd/test_spmd_debugging.py"
XLA_PARAMETER_WRAPPING_THREADSHOLD=1 python3 "$TEST_CDIR/spmd/test_spmd_parameter_wrapping.py"
python3 "$TEST_CDIR/pjrt/test_dtypes.py"
python3 "$TEST_CDIR/pjrt/test_dynamic_plugin_tpu.py"
python3 "$TEST_CDIR/test_while_loop.py"
python3 "$TEST_CDIR/scan/test_scan.py"
python3 "$TEST_CDIR/scan/test_scan_spmd.py"
python3 "$TEST_CDIR/scan/test_scan_layers.py"
python3 "$TEST_CDIR/test_pallas.py" -v
python3 "$TEST_CDIR/test_pallas_spmd.py"
python3 "$TEST_CDIR/test_tpu_paged_attention_kernel.py"
python3 "$TEST_CDIR/test_input_output_aliases.py"
python3 "$TEST_CDIR/test_gmm.py"
python3 "$TEST_CDIR/eager/test_eager_spmd.py"
python3 "$TEST_CDIR/torch_distributed/test_torch_distributed_all_gather_xla_backend.py"
python3 "$TEST_CDIR/torch_distributed/test_torch_distributed_all_reduce_xla_backend.py"
python3 "$TEST_CDIR/torch_distributed/test_torch_distributed_multi_all_reduce_xla_backend.py"
python3 "$TEST_CDIR/torch_distributed/test_torch_distributed_reduce_scatter_xla_backend.py"
python3 "$TEST_CDIR/quantized_ops/test_dot_general.py"

# run examples, each test should takes <2 minutes
python3 examples/data_parallel/train_resnet_spmd_data_parallel.py
python3 examples/fsdp/train_decoder_only_fsdp_v2.py
python3 examples/train_resnet_amp.py
python3 "$TEST_CDIR/../examples/data_parallel/train_resnet_spmd_data_parallel.py"
python3 "$TEST_CDIR/../examples/fsdp/train_decoder_only_fsdp_v2.py"
python3 "$TEST_CDIR/../examples/train_resnet_amp.py"

# HACK: don't confuse local `torch_xla` folder with installed package
# Python 3.11 has the permanent fix: https://stackoverflow.com/a/73636559
# Egaer tests will take more HBM, only run them on TPU v4 CI
TPU_VERSION=$(python -c "import sys; sys.path.remove(''); import torch_xla; print(torch_xla._internal.tpu.version())")
if [[ -n "$TPU_VERSION" && "$TPU_VERSION" == "4" ]]; then
python3 test/dynamo/test_traceable_collectives.py
python3 examples/data_parallel/train_resnet_xla_ddp.py
python3 examples/fsdp/train_resnet_fsdp_auto_wrap.py
python3 examples/eager/train_decoder_only_eager.py
python3 examples/eager/train_decoder_only_eager_spmd_data_parallel.py
python3 examples/eager/train_decoder_only_eager_with_compile.py
python3 examples/eager/train_decoder_only_eager_multi_process.py
XLA_EXPERIMENTAL=nonzero:masked_select:nms python3 test/ds/test_dynamic_shapes.py -v
python3 "$TEST_CDIR/dynamo/test_traceable_collectives.py"
python3 "$TEST_CDIR/../examples/data_parallel/train_resnet_xla_ddp.py"
python3 "$TEST_CDIR/../examples/fsdp/train_resnet_fsdp_auto_wrap.py"
python3 "$TEST_CDIR/../examples/eager/train_decoder_only_eager.py"
python3 "$TEST_CDIR/../examples/eager/train_decoder_only_eager_spmd_data_parallel.py"
python3 "$TEST_CDIR/../examples/eager/train_decoder_only_eager_with_compile.py"
python3 "$TEST_CDIR/../examples/eager/train_decoder_only_eager_multi_process.py"
XLA_EXPERIMENTAL=nonzero:masked_select:nms python3 "$TEST_CDIR/ds/test_dynamic_shapes.py" -v
fi

if [[ -n "$TPU_VERSION" && "$TPU_VERSION" != "6" ]]; then
# Test `tpu-info` CLI compatibility
python3 test/tpu/tpu_info/test_cli.py
python3 "$CDIR/tpu_info/test_cli.py"
fi
56 changes: 56 additions & 0 deletions test/utils/run_tests_utils.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#!/bin/bash
set -exo pipefail

# Run a test with tensor saving enabled, using a specified graph format. The
# graph dump files are cleaned after the test. In case the test crashes, the
# file is retained.
#
# Usage: run_save_tensor <exec> <format> [test arguments...]
#
# Arguments:
# exec: The executable or function to run the test (python3 or any function)
# format: The graph format to use with XLA_SAVE_TENSORS_FMT
# test arguments: Arguments to pass to the test
#
# Environment:
# Sets XLA_SAVE_TENSORS_FILE and XLA_SAVE_TENSORS_FMT
function run_save_tensor {
local run_test_func="$1" ; local file_graph_format="$2" ; shift 2

echo "Running in save tensor file mode: $@"
local base_file="/tmp/xla_test_save_ir.txt"

# Check if the file already exists, for any device ordinal number.
if ls "${base_file}"* 1> /dev/null 2>&1; then
echo "Error: File ${base_file} or a numbered version already exists. Please remove it before running the test."
return 1
fi

XLA_SAVE_TENSORS_FILE="$base_file" XLA_SAVE_TENSORS_FMT="$file_graph_format" $run_test_func "$@"
local test_status=$?

# Clean up the file once the test finalizes.
local actual_file
actual_file=$(ls "${base_file}"* 2>/dev/null | head -n1)
if [ -f "$actual_file" ]; then
echo "Cleaning up temporary file: $actual_file"
rm "$actual_file"
else
echo "Warning: Expected output file not found"
fi
return $test_status
}

function run_save_tensor_ir {
local run_test_func="$1"
shift
echo "Running in save tensor file mode: $@"
run_save_tensor "$run_test_func" "text" "$@"
}

function run_save_tensor_hlo {
local run_test_func="$1"
shift
echo "Running in save tensor file mode: $@"
run_save_tensor "$run_test_func" "hlo" "$@"
}
Loading

0 comments on commit b068cab

Please sign in to comment.