From 934adc9d50694240d56b5a8c6f3e9bf61d27259c Mon Sep 17 00:00:00 2001 From: iefgnoix Date: Sat, 7 Oct 2023 08:34:25 -0700 Subject: [PATCH] Add fsdp tests back to GPU CI (#5674) * use manfei's change * use the new flag * reduce data size * reduce batch size * keep reducing batch size to 64 * remove comments --- .circleci/common.sh | 6 +++--- test/test_train_mp_imagenet_fsdp.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.circleci/common.sh b/.circleci/common.sh index 1642480a6b40..88c5fed8efcb 100755 --- a/.circleci/common.sh +++ b/.circleci/common.sh @@ -151,9 +151,9 @@ function run_torch_xla_python_tests() { # GPU tests if [ -x "$(command -v nvidia-smi)" ]; then # These tests fail on GPU with 03/30 TF-pin update (https://github.com/pytorch/xla/pull/4840) - # PJRT_DEVICE=GPU python test/test_train_mp_imagenet_fsdp.py --fake_data --use_nested_fsdp --use_small_fake_sample --num_epochs=1 - # PJRT_DEVICE=GPU python test/test_train_mp_imagenet_fsdp.py --fake_data --auto_wrap_policy type_based --use_small_fake_sample --num_epochs=1 - # XLA_DISABLE_FUNCTIONALIZATION=1 PJRT_DEVICE=GPU python test/test_train_mp_imagenet_fsdp.py --fake_data --use_nested_fsdp --use_small_fake_sample --num_epochs=1 + PJRT_DEVICE=GPU python test/test_train_mp_imagenet_fsdp.py --fake_data --use_nested_fsdp --use_small_fake_sample --num_epochs=1 + PJRT_DEVICE=GPU python test/test_train_mp_imagenet_fsdp.py --fake_data --auto_wrap_policy type_based --use_small_fake_sample --num_epochs=1 + XLA_DISABLE_FUNCTIONALIZATION=1 PJRT_DEVICE=GPU python test/test_train_mp_imagenet_fsdp.py --fake_data --use_nested_fsdp --use_small_fake_sample --num_epochs=1 # Syncfree SGD optimizer tests if [ -d ./torch_xla/amp/syncfree ]; then echo "Running Syncfree Optimizer Test" diff --git a/test/test_train_mp_imagenet_fsdp.py b/test/test_train_mp_imagenet_fsdp.py index a40f3bef74fa..fdfdc8a698c1 100644 --- a/test/test_train_mp_imagenet_fsdp.py +++ b/test/test_train_mp_imagenet_fsdp.py @@ -110,7 +110,7 @@ transformer_auto_wrap_policy) DEFAULT_KWARGS = dict( - batch_size=128, + batch_size=64, test_set_batch_size=64, num_epochs=18, momentum=0.9,