diff --git a/.circleci/common.sh b/.circleci/common.sh index 22a7fcb692a..7dfb771b798 100755 --- a/.circleci/common.sh +++ b/.circleci/common.sh @@ -137,7 +137,7 @@ function run_torch_xla_python_tests() { # single-host-multi-process num_devices=$(nvidia-smi --list-gpus | wc -l) - test/test_train_mp_imagenet.py --fake_data --batch_size=16 --num_epochs=1 --num_steps=25 --model=resnet18 + python3 test/test_train_mp_imagenet.py --fake_data --batch_size=16 --num_epochs=1 --num_steps=25 --model=resnet18 torchrun --nnodes=1 --node_rank=0 --nproc_per_node=$num_devices test/test_train_mp_imagenet.py --fake_data --pjrt_distributed --batch_size=16 --num_epochs=1 --num_steps=25 --model=resnet18 # single-host-SPMD