Skip to content

Commit

Permalink
Re-enable AMP test on GPU CI. (#6790)
Browse files Browse the repository at this point in the history
  • Loading branch information
vanbasten23 authored Mar 29, 2024
1 parent 8f095fc commit 5f858e7
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 38 deletions.
37 changes: 0 additions & 37 deletions .circleci/common.sh
Original file line number Diff line number Diff line change
Expand Up @@ -132,43 +132,6 @@ function run_torch_xla_python_tests() {
chmod -R 755 ~/htmlcov
else
./test/run_tests.sh

# CUDA tests
if [ -x "$(command -v nvidia-smi)" ]; then
# single-host-single-process
PJRT_DEVICE=CUDA python3 test/test_train_mp_imagenet.py --fake_data --batch_size=16 --num_epochs=1 --num_cores=1 --num_steps=25 --model=resnet18
PJRT_DEVICE=CUDA torchrun --nnodes=1 --node_rank=0 --nproc_per_node=1 test/test_train_mp_imagenet.py --fake_data --pjrt_distributed --batch_size=16 --num_epochs=1 --num_steps=25 --model=resnet18

# single-host-multi-process
num_devices=$(nvidia-smi --list-gpus | wc -l)
PJRT_DEVICE=CUDA GPU_NUM_DEVICES=$GPU_NUM_DEVICES python3 test/test_train_mp_imagenet.py --fake_data --batch_size=16 --num_epochs=1 --num_steps=25 --model=resnet18
PJRT_DEVICE=CUDA torchrun --nnodes=1 --node_rank=0 --nproc_per_node=$num_devices test/test_train_mp_imagenet.py --fake_data --pjrt_distributed --batch_size=16 --num_epochs=1 --num_steps=25 --model=resnet18

# single-host-SPMD
# TODO: Reduce BS due to GPU test OOM in CI after pin update to 03/05/2024 (#6677)
XLA_USE_SPMD=1 PJRT_DEVICE=CUDA torchrun --nnodes=1 --node_rank=0 --nproc_per_node=1 test/spmd/test_train_spmd_imagenet.py --fake_data --batch_size 8 --model=resnet50 --sharding=batch --num_epochs=1 --num_steps=25 --model=resnet18

# TODO: Reduce BS due to GPU test OOM in CI after pin update to 03/05/2024 (#6677)
PJRT_DEVICE=CUDA python test/test_train_mp_imagenet_fsdp.py --fake_data --use_nested_fsdp --use_small_fake_sample --num_epochs=1 --batch_size 32 --test_set_batch_size 32
# TODO: Reduce BS due to GPU test OOM in CI after pin update to 03/05/2024 (#6677)
PJRT_DEVICE=CUDA python test/test_train_mp_imagenet_fsdp.py --fake_data --auto_wrap_policy type_based --use_small_fake_sample --num_epochs=1 --batch_size 32 --test_set_batch_size 32
XLA_DISABLE_FUNCTIONALIZATION=1 PJRT_DEVICE=CUDA python test/test_train_mp_imagenet_fsdp.py --fake_data --use_nested_fsdp --use_small_fake_sample --num_epochs=1
# Syncfree SGD optimizer tests
if [ -d ./torch_xla/amp/syncfree ]; then
echo "Running Syncfree Optimizer Test"
PJRT_DEVICE=CUDA python test/test_syncfree_optimizers.py

# Following test scripts are mainly useful for
# performance evaluation & comparison among different
# amp optimizers.
# echo "Running ImageNet Test"
# python test/test_train_mp_imagenet_amp.py --fake_data --num_epochs=1

# disabled per https://github.com/pytorch/xla/pull/2809
# echo "Running MNIST Test"
# python test/test_train_mp_mnist_amp.py --fake_data --num_epochs=1
fi
fi
fi
popd
}
Expand Down
37 changes: 37 additions & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,43 @@ function run_xla_op_tests3 {
# NOTE: this line below is testing export and don't care about GPU
PJRT_DEVICE=CPU CPU_NUM_DEVICES=1 run_coverage "$CDIR/test_core_aten_ops.py"
run_test "$CDIR/test_pallas.py"

# CUDA tests
if [ -x "$(command -v nvidia-smi)" ]; then
# Please keep PJRT_DEVICE and GPU_NUM_DEVICES explicit in the following test commands.
echo "single-host-single-process"
PJRT_DEVICE=CUDA GPU_NUM_DEVICES=1 python3 test/test_train_mp_imagenet.py --fake_data --batch_size=16 --num_epochs=1 --num_cores=1 --num_steps=25 --model=resnet18
PJRT_DEVICE=CUDA torchrun --nnodes=1 --node_rank=0 --nproc_per_node=1 test/test_train_mp_imagenet.py --fake_data --pjrt_distributed --batch_size=16 --num_epochs=1 --num_steps=25 --model=resnet18

echo "single-host-multi-process"
num_devices=$(nvidia-smi --list-gpus | wc -l)
PJRT_DEVICE=CUDA GPU_NUM_DEVICES=$num_devices python3 test/test_train_mp_imagenet.py --fake_data --batch_size=16 --num_epochs=1 --num_steps=25 --model=resnet18
PJRT_DEVICE=CUDA torchrun --nnodes=1 --node_rank=0 --nproc_per_node=$num_devices test/test_train_mp_imagenet.py --fake_data --pjrt_distributed --batch_size=16 --num_epochs=1 --num_steps=25 --model=resnet18

echo "single-host-SPMD"
# TODO: Reduce BS due to GPU test OOM in CI after pin update to 03/05/2024 (#6677)
XLA_USE_SPMD=1 PJRT_DEVICE=CUDA torchrun --nnodes=1 --node_rank=0 --nproc_per_node=1 test/spmd/test_train_spmd_imagenet.py --fake_data --batch_size 8 --sharding=batch --num_epochs=1 --num_steps=25 --model=resnet18

# TODO: Reduce BS due to GPU test OOM in CI after pin update to 03/05/2024 (#6677)
PJRT_DEVICE=CUDA python test/test_train_mp_imagenet_fsdp.py --fake_data --use_nested_fsdp --use_small_fake_sample --num_epochs=1 --batch_size 32 --test_set_batch_size 32
# TODO: Reduce BS due to GPU test OOM in CI after pin update to 03/05/2024 (#6677)
PJRT_DEVICE=CUDA python test/test_train_mp_imagenet_fsdp.py --fake_data --auto_wrap_policy type_based --use_small_fake_sample --num_epochs=1 --batch_size 32 --test_set_batch_size 32
XLA_DISABLE_FUNCTIONALIZATION=1 PJRT_DEVICE=CUDA python test/test_train_mp_imagenet_fsdp.py --fake_data --use_nested_fsdp --use_small_fake_sample --num_epochs=1
# Syncfree SGD optimizer tests
if [ -d ./torch_xla/amp/syncfree ]; then
echo "Running Syncfree Optimizer Test"
PJRT_DEVICE=CUDA python test/test_syncfree_optimizers.py

# Following test scripts are mainly useful for
# performance evaluation & comparison among different
# amp optimizers.
echo "Running ImageNet Test"
PJRT_DEVICE=CUDA GPU_NUM_DEVICES=$num_devices python test/test_train_mp_imagenet_amp.py --fake_data --num_epochs=1 --batch_size 64 --num_steps=25 --model=resnet18

echo "Running MNIST Test"
PJRT_DEVICE=CUDA GPU_NUM_DEVICES=$num_devices python test/test_train_mp_mnist_amp.py --fake_data --num_epochs=1 --batch_size 64 --num_steps=25
fi
fi
}

#######################################################################################
Expand Down
4 changes: 4 additions & 0 deletions test/test_train_mp_imagenet_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,8 @@ def train_loop_fn(loader, epoch):
if step % FLAGS.log_steps == 0:
xm.add_step_closure(
_train_update, args=(device, step, loss, tracker, epoch, writer))
if FLAGS.num_steps and FLAGS.num_steps == step:
break

def test_loop_fn(loader, epoch):
total_samples, correct = 0, 0
Expand All @@ -266,6 +268,8 @@ def test_loop_fn(loader, epoch):
if step % FLAGS.log_steps == 0:
xm.add_step_closure(
test_utils.print_test_update, args=(device, None, epoch, step))
if FLAGS.num_steps and FLAGS.num_steps == step:
break
accuracy = 100.0 * correct.item() / total_samples
accuracy = xm.mesh_reduce('test_accuracy', accuracy, np.mean)
return accuracy
Expand Down
9 changes: 8 additions & 1 deletion test/test_train_mp_mnist_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,16 +170,23 @@ def train_loop_fn(loader):
if step % flags.log_steps == 0:
xm.add_step_closure(
_train_update, args=(device, step, loss, tracker, writer))
if FLAGS.num_steps and FLAGS.num_steps == step:
break

def test_loop_fn(loader):
total_samples = 0
correct = 0
model.eval()
for data, target in loader:
for step, (data, target) in enumerate(loader):
output = model(data)
pred = output.max(1, keepdim=True)[1]
correct += pred.eq(target.view_as(pred)).sum()
total_samples += data.size()[0]
if step % FLAGS.log_steps == 0:
xm.add_step_closure(
test_utils.print_test_update, args=(device, None, epoch, step))
if FLAGS.num_steps and FLAGS.num_steps == step:
break

accuracy = 100.0 * correct.item() / total_samples
accuracy = xm.mesh_reduce('test_accuracy', accuracy, np.mean)
Expand Down

0 comments on commit 5f858e7

Please sign in to comment.