From 238d2e3c44366aba9dc5c770c95475765a6725cb Mon Sep 17 00:00:00 2001 From: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Sat, 16 Dec 2023 19:41:43 +0530 Subject: [PATCH] fix resuming from ckpt when using FSDP with FULL_STATE_DICT (#27891) * fix resuming from ckpt when suing FSDP with FULL_STATE_DICT * update tests * fix tests --- src/transformers/trainer.py | 13 +++++++++---- tests/fsdp/test_fsdp.py | 14 ++++++++++++++ 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index ffe5f5c0d1556b..9cd0bf0685e6c9 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2033,10 +2033,15 @@ def _load_from_checkpoint(self, resume_from_checkpoint, model=None): weights_index_file = os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME) safe_weights_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_NAME) safe_weights_index_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_INDEX_NAME) - is_fsdp_ckpt = os.path.isdir(resume_from_checkpoint) and any( - FSDP_MODEL_NAME in folder_name - for folder_name in os.listdir(resume_from_checkpoint) - if os.path.isdir(os.path.join(resume_from_checkpoint, folder_name)) + is_fsdp_ckpt = os.path.isdir(resume_from_checkpoint) and ( + # this checks the FSDP state dict when `SHARDED_STATE_DICT` is used + any( + FSDP_MODEL_NAME in folder_name + for folder_name in os.listdir(resume_from_checkpoint) + if os.path.isdir(os.path.join(resume_from_checkpoint, folder_name)) + ) + # this checks the FSDP state dict when `FULL_STATE_DICT` is used + or os.path.isfile(os.path.join(resume_from_checkpoint, f"{FSDP_MODEL_NAME}.bin")) ) if is_fsdp_ckpt and not self.is_fsdp_enabled: diff --git a/tests/fsdp/test_fsdp.py b/tests/fsdp/test_fsdp.py index 2a9473c862ffa9..d883f29ed3698c 100644 --- a/tests/fsdp/test_fsdp.py +++ b/tests/fsdp/test_fsdp.py @@ -41,6 +41,7 @@ if is_torch_available(): from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_1 + from transformers.trainer import FSDP_MODEL_NAME else: is_torch_greater_or_equal_than_2_1 = False @@ -211,6 +212,19 @@ def test_training_and_can_resume_normally(self, state_dict_type): # resume from ckpt checkpoint = os.path.join(output_dir, "checkpoint-115") resume_args = args + f"--resume_from_checkpoint {checkpoint}".split() + + is_fsdp_ckpt = os.path.isdir(checkpoint) and ( + # this checks the FSDP state dict when `SHARDED_STATE_DICT` is used + any( + FSDP_MODEL_NAME in folder_name + for folder_name in os.listdir(checkpoint) + if os.path.isdir(os.path.join(checkpoint, folder_name)) + ) + # this checks the FSDP state dict when `FULL_STATE_DICT` is used + or os.path.isfile(os.path.join(checkpoint, f"{FSDP_MODEL_NAME}.bin")) + ) + self.assertTrue(is_fsdp_ckpt) + logs_resume = self.run_cmd_and_get_logs( use_accelerate, sharding_strategy, launcher, script, resume_args, output_dir )