diff --git a/tests/trainer/test_fsdp_checkpoint.py b/tests/trainer/test_fsdp_checkpoint.py index e0f62ddc75..9c107231ba 100644 --- a/tests/trainer/test_fsdp_checkpoint.py +++ b/tests/trainer/test_fsdp_checkpoint.py @@ -515,6 +515,8 @@ def test_fsdp_mixed_with_sync( '0.24.0', '0.25.0', '0.26.0', + '0.27.0', + '0.28.0', ], ) @pytest.mark.filterwarnings(r'ignore:.*metrics are not saved with sharded state dict.*:UserWarning') @@ -534,9 +536,12 @@ def test_fsdp_load_old_checkpoint( if composer_version == '0.18.1' and state_dict_type == 'full' and precision == 'amp_bf16' and sharding_strategy == 'FULL_SHARD': pytest.skip('TODO: This checkpoint is missing') - if (composer_version in ['0.22.0', '0.23.0'] and version.parse(torch.__version__) < version.parse('2.3.0')) or ( - composer_version in ['0.24.0', '0.25.0'] and version.parse(torch.__version__) < version.parse('2.4.0') - ) or (composer_version in '0.26.0' and version.parse(torch.__version__) < version.parse('2.5.0')): + if (composer_version in ['0.22.0', '0.23.0'] and version.parse(torch.__version__) < version.parse('2.3.0') + ) or (composer_version in ['0.24.0', '0.25.0'] and + version.parse(torch.__version__) < version.parse('2.4.0')) or ( + composer_version in ['0.26.0', '0.27.0', '0.28.0'] and + version.parse(torch.__version__) < version.parse('2.5.0') + ): pytest.skip('Current torch version is older than torch version that checkpoint was written with.') if composer_version in ['0.13.5', '0.14.0', '0.14.1', '0.15.1']: