diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index fcb757cf6d56..6fd77f0a1738 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -2053,6 +2053,9 @@ def get_expected_keys(inputs, keys): self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer) self.optimizer = fleet.distributed_optimizer(self.optimizer) + if self.args.enable_sharding_comm_overlap: + model.register_sharding_comm_overlap_hook(self.optimizer) + # No pipeline mode, sharding only if not in_pipeline_parallel_mode and in_sharding_parallel_mode: # Sharded DDP! @@ -2840,8 +2843,11 @@ def _load_optimizer_and_scheduler(self, checkpoint): else: opt_state_dict = None else: + model = self.model + if hasattr(self.args, "enable_sharding_comm_overlap") and self.args.enable_sharding_comm_overlap: + model = self.model_wrapped opt_state_dict = self.unified_checkpoint_handler.load_unified_optimizer( - model=self.model, + model=model, optimizer=self.optimizer, resume_from_checkpoint=checkpoint, ) diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index ed81f055bd80..0359e28bac1a 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -1155,29 +1155,20 @@ def split_parallel_config(parallel_config): or "enable_dp_comm_overlap" in pipeline_parallel_config ) enable_dp_comm_overlap = using_comm_overlap and self.data_parallel_degree > 1 - enable_sharding_comm_overlap = using_comm_overlap and self.sharding_parallel_degree > 1 + self.enable_sharding_comm_overlap = using_comm_overlap and self.sharding_parallel_degree > 1 assert not ( - enable_dp_comm_overlap and enable_sharding_comm_overlap + enable_dp_comm_overlap and self.enable_sharding_comm_overlap ), "dp_comm_overlap and sharding_comm_overlap cannot be enabled at the same time" - if enable_sharding_comm_overlap and not self.amp_master_grad: + if self.enable_sharding_comm_overlap and not self.amp_master_grad: raise ValueError( "If `enable_sharding_comm_overlap` in pipeline_parallel_configs, `amp_master_grad` must be True." ) - if ( - enable_sharding_comm_overlap - and self.unified_checkpoint - and "split_param" in split_parallel_config(self.sharding_parallel_config) - ): - logger.warning( - "Currently unified checkpoint do not support using `sharding_comm_overlap` and `split_param` at the same time, delete `sharding_comm_overlap`." - ) - enable_sharding_comm_overlap = False dygraph_pp_configs = { "delay_scale_loss": True if "enable_delay_scale_loss" in pipeline_parallel_config else False, "dp_comm_overlap": enable_dp_comm_overlap, - "sharding_comm_overlap": enable_sharding_comm_overlap, + "sharding_comm_overlap": self.enable_sharding_comm_overlap, "enable_timer": "enable_timer" in pipeline_parallel_config, "release_gradients": "enable_release_grads" in pipeline_parallel_config or self.release_grads, "overlap_p2p_comm": "enable_overlap_p2p_comm" in pipeline_parallel_config, diff --git a/paddlenlp/trainer/unified_checkpoint/load_local.py b/paddlenlp/trainer/unified_checkpoint/load_local.py index 5d16fd4ef966..459eff7185d1 100644 --- a/paddlenlp/trainer/unified_checkpoint/load_local.py +++ b/paddlenlp/trainer/unified_checkpoint/load_local.py @@ -150,7 +150,7 @@ def _remove_unused_keys( def load_unified_optimizer_locally(args, model, optimizer, resume_from_checkpoint, safe_serialization=False): # Special process with split param. if is_sharding_split_param_mode(args): - returned_optim_state_dict = load_unified_optimizer_split_param(model, optimizer, resume_from_checkpoint) + returned_optim_state_dict = load_unified_optimizer_split_param(args, model, optimizer, resume_from_checkpoint) return returned_optim_state_dict # init and get optimizer LR_Scheduler diff --git a/paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py b/paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py index f337b1a8186b..17a9f0782221 100644 --- a/paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py +++ b/paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py @@ -15,6 +15,7 @@ import gc import os +from itertools import chain import paddle import paddle.distributed as dist @@ -22,7 +23,7 @@ from tqdm.auto import tqdm from paddlenlp.peft import LoRAModel, PrefixModelForCausalLM -from paddlenlp.transformers.model_utils import load_state_dict +from paddlenlp.transformers.model_utils import load_state_dict, unwrap_model from paddlenlp.utils.env import ( SAFE_MASTER_WEIGHTS_INDEX_NAME, SAFE_OPTIMIZER_INDEX_NAME, @@ -97,6 +98,7 @@ def gather_splited_param_for_optimizer(optimizer): global_rank = dist.get_rank() param_slice_info = {} param_shape_info = {} + for buffer in optimizer._inner_opt._comm_buffer_list: for key in buffer._sharding_param_grad_view.keys(): param_slice_info[key] = ( @@ -153,7 +155,7 @@ def gather_splited_param_for_optimizer(optimizer): return optim_state_dict, master_weights -def load_unified_optimizer_split_param(model, optimizer, resume_from_checkpoint): +def load_unified_optimizer_split_param(args, model, optimizer, resume_from_checkpoint): returned_optim_state_dict = nested_copy(optimizer.state_dict()) index_filename, index_filename_master_weights = SAFE_OPTIMIZER_INDEX_NAME, SAFE_MASTER_WEIGHTS_INDEX_NAME @@ -177,7 +179,13 @@ def load_unified_optimizer_split_param(model, optimizer, resume_from_checkpoint) expected_keys = [] param_slice_info = {} param_shape_info = {} - for buffer in optimizer._inner_opt._comm_buffer_list: + + comm_buffer_list = optimizer._inner_opt._comm_buffer_list + if args.enable_sharding_comm_overlap: + comm_buffer_list = list(chain(*model._chunk_2_comm_buffers.values())) + model = unwrap_model(model) + + for buffer in comm_buffer_list: for key in buffer._sharding_param_grad_view.keys(): begin = buffer._sharding_param_grad_view[key]._param_begin end = buffer._sharding_param_grad_view[key]._param_end diff --git a/paddlenlp/trainer/unified_checkpoint/utils.py b/paddlenlp/trainer/unified_checkpoint/utils.py index 74db0e20e184..3a61ee912ed3 100644 --- a/paddlenlp/trainer/unified_checkpoint/utils.py +++ b/paddlenlp/trainer/unified_checkpoint/utils.py @@ -24,7 +24,11 @@ from paddlenlp.peft import LoRAModel, PrefixModelForCausalLM from paddlenlp.trainer.trainer_utils import ExplicitEnum, ShardingOption from paddlenlp.trainer.utils.helper import distributed_isfile -from paddlenlp.transformers.model_utils import PretrainedModel, get_parameter_dtype +from paddlenlp.transformers.model_utils import ( + PretrainedModel, + get_parameter_dtype, + unwrap_model, +) from paddlenlp.transformers.utils import dtype_byte_size from paddlenlp.utils.distributed import distributed_allgather, distributed_gather from paddlenlp.utils.env import ( @@ -193,6 +197,8 @@ def get_expected_state_dict(model_to_save): """ Get trainable state_dict of model_to_save. """ + model_to_save = unwrap_model(model_to_save) + if isinstance(model_to_save, PretrainedModel): state_dict = model_to_save.state_dict() if (