From 239641c3f584433e13f5f60ba53730a31a6a9ff6 Mon Sep 17 00:00:00 2001 From: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Thu, 7 Dec 2023 18:34:39 +0530 Subject: [PATCH 1/3] fix resuming from ckpt when suing FSDP with FULL_STATE_DICT --- src/transformers/trainer.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 3c9e4420124012..2049298da67c6e 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2023,10 +2023,15 @@ def _load_from_checkpoint(self, resume_from_checkpoint, model=None): weights_index_file = os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME) safe_weights_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_NAME) safe_weights_index_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_INDEX_NAME) - is_fsdp_ckpt = os.path.isdir(resume_from_checkpoint) and any( - FSDP_MODEL_NAME in folder_name - for folder_name in os.listdir(resume_from_checkpoint) - if os.path.isdir(os.path.join(resume_from_checkpoint, folder_name)) + is_fsdp_ckpt = os.path.isdir(resume_from_checkpoint) and ( + # this checks the FSDP state dict when `SHARDED_STATE_DICT` is used + any( + FSDP_MODEL_NAME in folder_name + for folder_name in os.listdir(resume_from_checkpoint) + if os.path.isdir(os.path.join(resume_from_checkpoint, folder_name)) + ) + # this checks the FSDP state dict when `FULL_STATE_DICT` is used + or os.path.isfile(os.path.join(resume_from_checkpoint, f"{FSDP_MODEL_NAME}.bin")) ) if is_fsdp_ckpt and not self.is_fsdp_enabled: From 8b6b8013dbcdc3f893daef9231605312950f7340 Mon Sep 17 00:00:00 2001 From: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Fri, 15 Dec 2023 23:41:36 +0530 Subject: [PATCH 2/3] update tests --- tests/fsdp/test_fsdp.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/fsdp/test_fsdp.py b/tests/fsdp/test_fsdp.py index 2a9473c862ffa9..c2d539b3e99e75 100644 --- a/tests/fsdp/test_fsdp.py +++ b/tests/fsdp/test_fsdp.py @@ -34,6 +34,7 @@ slow, torch_device, ) +from transformers.trainer import FSDP_MODEL_NAME from transformers.trainer_callback import TrainerState from transformers.trainer_utils import FSDPOption, set_seed from transformers.utils import is_accelerate_available, is_torch_bf16_available_on_device @@ -211,6 +212,19 @@ def test_training_and_can_resume_normally(self, state_dict_type): # resume from ckpt checkpoint = os.path.join(output_dir, "checkpoint-115") resume_args = args + f"--resume_from_checkpoint {checkpoint}".split() + + is_fsdp_ckpt = os.path.isdir(checkpoint) and ( + # this checks the FSDP state dict when `SHARDED_STATE_DICT` is used + any( + FSDP_MODEL_NAME in folder_name + for folder_name in os.listdir(checkpoint) + if os.path.isdir(os.path.join(checkpoint, folder_name)) + ) + # this checks the FSDP state dict when `FULL_STATE_DICT` is used + or os.path.isfile(os.path.join(checkpoint, f"{FSDP_MODEL_NAME}.bin")) + ) + self.assertTrue(is_fsdp_ckpt) + logs_resume = self.run_cmd_and_get_logs( use_accelerate, sharding_strategy, launcher, script, resume_args, output_dir ) From e8b3cc5112fd34928bfbc095f077478455550d03 Mon Sep 17 00:00:00 2001 From: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Fri, 15 Dec 2023 23:50:09 +0530 Subject: [PATCH 3/3] fix tests --- tests/fsdp/test_fsdp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/fsdp/test_fsdp.py b/tests/fsdp/test_fsdp.py index c2d539b3e99e75..d883f29ed3698c 100644 --- a/tests/fsdp/test_fsdp.py +++ b/tests/fsdp/test_fsdp.py @@ -34,7 +34,6 @@ slow, torch_device, ) -from transformers.trainer import FSDP_MODEL_NAME from transformers.trainer_callback import TrainerState from transformers.trainer_utils import FSDPOption, set_seed from transformers.utils import is_accelerate_available, is_torch_bf16_available_on_device @@ -42,6 +41,7 @@ if is_torch_available(): from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_1 + from transformers.trainer import FSDP_MODEL_NAME else: is_torch_greater_or_equal_than_2_1 = False