-
Notifications
You must be signed in to change notification settings - Fork 486
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[LoweringContext] SPMD propagation (#8471)
- Loading branch information
1 parent
1812817
commit b068cab
Showing
8 changed files
with
239 additions
and
88 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,66 +1,70 @@ | ||
#!/bin/bash | ||
set -xue | ||
CDIR="$(cd "$(dirname "$0")" ; pwd -P)" | ||
TEST_CDIR="$(dirname "$CDIR")" | ||
|
||
source "${TEST_CDIR}/utils/run_tests_utils.sh" | ||
|
||
# TODO: merge with other run_tests | ||
python3 test/test_operations.py -v | ||
python3 test/pjrt/test_runtime_tpu.py | ||
python3 test/pjrt/test_collective_ops_tpu.py | ||
python3 test/spmd/test_mp_input_sharding.py | ||
python3 test/spmd/test_spmd_lowering_context.py | ||
python3 test/spmd/test_xla_sharding.py | ||
python3 test/spmd/test_xla_virtual_device.py | ||
python3 test/spmd/test_xla_distributed_checkpoint.py | ||
python3 test/spmd/test_train_spmd_linear_model.py | ||
python3 test/spmd/test_xla_spmd_python_api_interaction.py | ||
python3 test/spmd/test_xla_auto_sharding.py | ||
python3 test/spmd/test_fsdp_v2.py | ||
XLA_EXPERIMENTAL=nonzero:masked_select:nms python3 test/ds/test_dynamic_shape_models.py -v | ||
python3 test/test_autocast.py | ||
python3 test/test_fp8.py | ||
python3 test/test_grad_checkpoint.py | ||
python3 test/dynamo/test_dynamo.py | ||
python3 test/dynamo/test_dynamo_dynamic_shape.py | ||
python3 test/spmd/test_spmd_debugging.py | ||
XLA_PARAMETER_WRAPPING_THREADSHOLD=1 python test/spmd/test_spmd_parameter_wrapping.py | ||
python3 test/pjrt/test_dtypes.py | ||
python3 test/pjrt/test_dynamic_plugin_tpu.py | ||
python3 test/test_while_loop.py | ||
python3 test/scan/test_scan.py | ||
python3 test/scan/test_scan_spmd.py | ||
python3 test/scan/test_scan_layers.py | ||
python3 test/test_pallas.py -v | ||
python3 test/test_pallas_spmd.py | ||
python3 test/test_tpu_paged_attention_kernel.py | ||
python3 test/test_input_output_aliases.py | ||
python3 test/test_gmm.py | ||
python3 test/eager/test_eager_spmd.py | ||
python3 test/torch_distributed/test_torch_distributed_all_gather_xla_backend.py | ||
python3 test/torch_distributed/test_torch_distributed_all_reduce_xla_backend.py | ||
python3 test/torch_distributed/test_torch_distributed_multi_all_reduce_xla_backend.py | ||
python3 test/torch_distributed/test_torch_distributed_reduce_scatter_xla_backend.py | ||
python3 test/quantized_ops/test_dot_general.py | ||
python3 "$TEST_CDIR/test_operations.py" -v | ||
python3 "$TEST_CDIR/pjrt/test_runtime_tpu.py" | ||
python3 "$TEST_CDIR/pjrt/test_collective_ops_tpu.py" | ||
python3 "$TEST_CDIR/spmd/test_mp_input_sharding.py" | ||
run_save_tensor_hlo python3 "$TEST_CDIR/spmd/test_spmd_lowering_context.py" | ||
python3 "$TEST_CDIR/spmd/test_xla_sharding.py" | ||
python3 "$TEST_CDIR/spmd/test_xla_virtual_device.py" | ||
python3 "$TEST_CDIR/spmd/test_xla_distributed_checkpoint.py" | ||
python3 "$TEST_CDIR/spmd/test_train_spmd_linear_model.py" | ||
python3 "$TEST_CDIR/spmd/test_xla_spmd_python_api_interaction.py" | ||
python3 "$TEST_CDIR/spmd/test_xla_auto_sharding.py" | ||
python3 "$TEST_CDIR/spmd/test_fsdp_v2.py" | ||
XLA_EXPERIMENTAL=nonzero:masked_select:nms python3 "$TEST_CDIR/ds/test_dynamic_shape_models.py" -v | ||
python3 "$TEST_CDIR/test_autocast.py" | ||
python3 "$TEST_CDIR/test_fp8.py" | ||
python3 "$TEST_CDIR/test_grad_checkpoint.py" | ||
python3 "$TEST_CDIR/dynamo/test_dynamo.py" | ||
python3 "$TEST_CDIR/dynamo/test_dynamo_dynamic_shape.py" | ||
python3 "$TEST_CDIR/spmd/test_spmd_debugging.py" | ||
XLA_PARAMETER_WRAPPING_THREADSHOLD=1 python3 "$TEST_CDIR/spmd/test_spmd_parameter_wrapping.py" | ||
python3 "$TEST_CDIR/pjrt/test_dtypes.py" | ||
python3 "$TEST_CDIR/pjrt/test_dynamic_plugin_tpu.py" | ||
python3 "$TEST_CDIR/test_while_loop.py" | ||
python3 "$TEST_CDIR/scan/test_scan.py" | ||
python3 "$TEST_CDIR/scan/test_scan_spmd.py" | ||
python3 "$TEST_CDIR/scan/test_scan_layers.py" | ||
python3 "$TEST_CDIR/test_pallas.py" -v | ||
python3 "$TEST_CDIR/test_pallas_spmd.py" | ||
python3 "$TEST_CDIR/test_tpu_paged_attention_kernel.py" | ||
python3 "$TEST_CDIR/test_input_output_aliases.py" | ||
python3 "$TEST_CDIR/test_gmm.py" | ||
python3 "$TEST_CDIR/eager/test_eager_spmd.py" | ||
python3 "$TEST_CDIR/torch_distributed/test_torch_distributed_all_gather_xla_backend.py" | ||
python3 "$TEST_CDIR/torch_distributed/test_torch_distributed_all_reduce_xla_backend.py" | ||
python3 "$TEST_CDIR/torch_distributed/test_torch_distributed_multi_all_reduce_xla_backend.py" | ||
python3 "$TEST_CDIR/torch_distributed/test_torch_distributed_reduce_scatter_xla_backend.py" | ||
python3 "$TEST_CDIR/quantized_ops/test_dot_general.py" | ||
|
||
# run examples, each test should takes <2 minutes | ||
python3 examples/data_parallel/train_resnet_spmd_data_parallel.py | ||
python3 examples/fsdp/train_decoder_only_fsdp_v2.py | ||
python3 examples/train_resnet_amp.py | ||
python3 "$TEST_CDIR/../examples/data_parallel/train_resnet_spmd_data_parallel.py" | ||
python3 "$TEST_CDIR/../examples/fsdp/train_decoder_only_fsdp_v2.py" | ||
python3 "$TEST_CDIR/../examples/train_resnet_amp.py" | ||
|
||
# HACK: don't confuse local `torch_xla` folder with installed package | ||
# Python 3.11 has the permanent fix: https://stackoverflow.com/a/73636559 | ||
# Egaer tests will take more HBM, only run them on TPU v4 CI | ||
TPU_VERSION=$(python -c "import sys; sys.path.remove(''); import torch_xla; print(torch_xla._internal.tpu.version())") | ||
if [[ -n "$TPU_VERSION" && "$TPU_VERSION" == "4" ]]; then | ||
python3 test/dynamo/test_traceable_collectives.py | ||
python3 examples/data_parallel/train_resnet_xla_ddp.py | ||
python3 examples/fsdp/train_resnet_fsdp_auto_wrap.py | ||
python3 examples/eager/train_decoder_only_eager.py | ||
python3 examples/eager/train_decoder_only_eager_spmd_data_parallel.py | ||
python3 examples/eager/train_decoder_only_eager_with_compile.py | ||
python3 examples/eager/train_decoder_only_eager_multi_process.py | ||
XLA_EXPERIMENTAL=nonzero:masked_select:nms python3 test/ds/test_dynamic_shapes.py -v | ||
python3 "$TEST_CDIR/dynamo/test_traceable_collectives.py" | ||
python3 "$TEST_CDIR/../examples/data_parallel/train_resnet_xla_ddp.py" | ||
python3 "$TEST_CDIR/../examples/fsdp/train_resnet_fsdp_auto_wrap.py" | ||
python3 "$TEST_CDIR/../examples/eager/train_decoder_only_eager.py" | ||
python3 "$TEST_CDIR/../examples/eager/train_decoder_only_eager_spmd_data_parallel.py" | ||
python3 "$TEST_CDIR/../examples/eager/train_decoder_only_eager_with_compile.py" | ||
python3 "$TEST_CDIR/../examples/eager/train_decoder_only_eager_multi_process.py" | ||
XLA_EXPERIMENTAL=nonzero:masked_select:nms python3 "$TEST_CDIR/ds/test_dynamic_shapes.py" -v | ||
fi | ||
|
||
if [[ -n "$TPU_VERSION" && "$TPU_VERSION" != "6" ]]; then | ||
# Test `tpu-info` CLI compatibility | ||
python3 test/tpu/tpu_info/test_cli.py | ||
python3 "$CDIR/tpu_info/test_cli.py" | ||
fi |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
#!/bin/bash | ||
set -exo pipefail | ||
|
||
# Run a test with tensor saving enabled, using a specified graph format. The | ||
# graph dump files are cleaned after the test. In case the test crashes, the | ||
# file is retained. | ||
# | ||
# Usage: run_save_tensor <exec> <format> [test arguments...] | ||
# | ||
# Arguments: | ||
# exec: The executable or function to run the test (python3 or any function) | ||
# format: The graph format to use with XLA_SAVE_TENSORS_FMT | ||
# test arguments: Arguments to pass to the test | ||
# | ||
# Environment: | ||
# Sets XLA_SAVE_TENSORS_FILE and XLA_SAVE_TENSORS_FMT | ||
function run_save_tensor { | ||
local run_test_func="$1" ; local file_graph_format="$2" ; shift 2 | ||
|
||
echo "Running in save tensor file mode: $@" | ||
local base_file="/tmp/xla_test_save_ir.txt" | ||
|
||
# Check if the file already exists, for any device ordinal number. | ||
if ls "${base_file}"* 1> /dev/null 2>&1; then | ||
echo "Error: File ${base_file} or a numbered version already exists. Please remove it before running the test." | ||
return 1 | ||
fi | ||
|
||
XLA_SAVE_TENSORS_FILE="$base_file" XLA_SAVE_TENSORS_FMT="$file_graph_format" $run_test_func "$@" | ||
local test_status=$? | ||
|
||
# Clean up the file once the test finalizes. | ||
local actual_file | ||
actual_file=$(ls "${base_file}"* 2>/dev/null | head -n1) | ||
if [ -f "$actual_file" ]; then | ||
echo "Cleaning up temporary file: $actual_file" | ||
rm "$actual_file" | ||
else | ||
echo "Warning: Expected output file not found" | ||
fi | ||
return $test_status | ||
} | ||
|
||
function run_save_tensor_ir { | ||
local run_test_func="$1" | ||
shift | ||
echo "Running in save tensor file mode: $@" | ||
run_save_tensor "$run_test_func" "text" "$@" | ||
} | ||
|
||
function run_save_tensor_hlo { | ||
local run_test_func="$1" | ||
shift | ||
echo "Running in save tensor file mode: $@" | ||
run_save_tensor "$run_test_func" "hlo" "$@" | ||
} |
Oops, something went wrong.