From dc5e3693d456198039d0e95745d3b1db6cc22b3c Mon Sep 17 00:00:00 2001 From: DesmonDay <908660116@qq.com> Date: Wed, 13 Nov 2024 11:27:40 +0800 Subject: [PATCH] update condition --- paddlenlp/trainer/trainer.py | 8 +++++++- paddlenlp/trainer/trainer_utils.py | 8 ++++++++ paddlenlp/trainer/training_args.py | 8 +------- paddlenlp/trainer/unified_checkpoint/check_completion.py | 3 ++- paddlenlp/trainer/unified_checkpoint/utils.py | 2 +- 5 files changed, 19 insertions(+), 10 deletions(-) diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index 3b0bcc1fd301..a5301e290d08 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -141,6 +141,7 @@ set_seed, should_skip_data, speed_metrics, + split_parallel_config, ) from .training_args import TrainingArguments from .unified_checkpoint import UnifiedCheckpointHandler @@ -2057,6 +2058,7 @@ def get_expected_keys(inputs, keys): hasattr(self.args, "enable_sharding_comm_overlap") and self.args.enable_sharding_comm_overlap and self.args.unified_checkpoint + and "split_param" in split_parallel_config(self.args.sharding_parallel_config) ): model.register_sharding_comm_overlap_hook(self.optimizer) @@ -2848,7 +2850,11 @@ def _load_optimizer_and_scheduler(self, checkpoint): opt_state_dict = None else: model = self.model - if hasattr(self.args, "enable_sharding_comm_overlap") and self.args.enable_sharding_comm_overlap: + if ( + hasattr(self.args, "enable_sharding_comm_overlap") + and self.args.enable_sharding_comm_overlap + and "split_param" in split_parallel_config(self.args.sharding_parallel_config) + ): model = self.model_wrapped opt_state_dict = self.unified_checkpoint_handler.load_unified_optimizer( model=model, diff --git a/paddlenlp/trainer/trainer_utils.py b/paddlenlp/trainer/trainer_utils.py index ca816b585e3b..30488e960f14 100644 --- a/paddlenlp/trainer/trainer_utils.py +++ b/paddlenlp/trainer/trainer_utils.py @@ -1126,3 +1126,11 @@ def should_skip_data(global_step, skip_data_intervals): skip_flag = True break return skip_flag + + +def split_parallel_config(parallel_config): + if "," in parallel_config: + parallel_config = set(parallel_config.split(",")) + else: + parallel_config = set(parallel_config.split(" ")) + return parallel_config diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index 0359e28bac1a..25cf62309983 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -37,6 +37,7 @@ OptimizerNames, SchedulerType, ShardingOption, + split_parallel_config, ) try: @@ -1096,13 +1097,6 @@ def __post_init__(self): logger.warning("set amp_master_grad to false since amp is disabled.") self.amp_master_grad = False - def split_parallel_config(parallel_config): - if "," in parallel_config: - parallel_config = set(parallel_config.split(",")) - else: - parallel_config = set(parallel_config.split(" ")) - return parallel_config - # use_hybrid_parallel if self.use_hybrid_parallel: diff --git a/paddlenlp/trainer/unified_checkpoint/check_completion.py b/paddlenlp/trainer/unified_checkpoint/check_completion.py index 8165a4542820..626d25875740 100644 --- a/paddlenlp/trainer/unified_checkpoint/check_completion.py +++ b/paddlenlp/trainer/unified_checkpoint/check_completion.py @@ -150,7 +150,8 @@ def check_unified_optimizer(args, model, optimizer, resume_from_checkpoint, safe sharding_group = hcg.get_sharding_parallel_group() sharding_rank = sharding_group.rank dp_rank = dp_group.rank if dp_group.nranks > 1 else 0 - struct2static_name_mappings = {k: v.name for k, v in model.state_dict().items()} + model_state_dict = get_expected_state_dict(model) + struct2static_name_mappings = {k: v.name for k, v in model_state_dict.items()} if is_sharding_split_param_mode(args): # We do not check optimizer files completion for split_param, since it is very complicated. Directly support local resume. diff --git a/paddlenlp/trainer/unified_checkpoint/utils.py b/paddlenlp/trainer/unified_checkpoint/utils.py index 3a61ee912ed3..9bd9fdcc65b7 100644 --- a/paddlenlp/trainer/unified_checkpoint/utils.py +++ b/paddlenlp/trainer/unified_checkpoint/utils.py @@ -227,7 +227,7 @@ def get_expected_keys(args, sharded_metadata, model, optimizer, is_master_weight params2rank = optimizer._param2rank model_state_dict = get_expected_state_dict(model) - struct2static_name_mappings = {k: v.name for k, v in get_expected_state_dict(model).items()} + struct2static_name_mappings = {k: v.name for k, v in model_state_dict.items()} expected_keys = [] for key in list(sharded_metadata["all_optimizer_keys"]):