From 62a18c4ad8ee48608ede1623b3cae11029dd66c4 Mon Sep 17 00:00:00 2001 From: Difer <707065510@qq.com> Date: Wed, 17 Jul 2024 20:09:25 +0800 Subject: [PATCH] [Trainer] Enable parallel_config to use commas as delimiters. (#8677) --- legacy/examples/RLHF/trainer_utils.py | 10 +++---- llm/alignment/ppo/trainer_utils.py | 10 +++---- paddlenlp/trainer/trainer.py | 25 +++++++---------- paddlenlp/trainer/training_args.py | 39 ++++++++++----------------- 4 files changed, 34 insertions(+), 50 deletions(-) diff --git a/legacy/examples/RLHF/trainer_utils.py b/legacy/examples/RLHF/trainer_utils.py index 865d34cea653..e10a339851fe 100644 --- a/legacy/examples/RLHF/trainer_utils.py +++ b/legacy/examples/RLHF/trainer_utils.py @@ -340,11 +340,11 @@ def full_training_step(self: Trainer, inputs: Dict[str, paddle.Tensor], **kwargs fused_allreduce_gradients(list(model.parameters()), None) # Pipeline parallel mode, handle gradient reduce here to overlap - pipeline_parallel_config = ( - set(args.pipeline_parallel_config.split(" ")) if args.pipeline_parallel_degree > 1 else set() - ) - enable_dp_comm_overlap = "enable_dp_comm_overlap" in pipeline_parallel_config - enable_release_grads = "enable_release_grads" in pipeline_parallel_config + enable_dp_comm_overlap = False + enable_release_grads = False + if args.pipeline_parallel_degree > 1: + enable_dp_comm_overlap = "enable_dp_comm_overlap" in args.pipeline_parallel_config + enable_release_grads = "enable_release_grads" in args.pipeline_parallel_config # Case 3: Pipeline parallel mode, overlap with dp if isinstance(self.optimizer, HybridParallelOptimizer) and not self.do_grad_scaling: diff --git a/llm/alignment/ppo/trainer_utils.py b/llm/alignment/ppo/trainer_utils.py index 865d34cea653..e10a339851fe 100644 --- a/llm/alignment/ppo/trainer_utils.py +++ b/llm/alignment/ppo/trainer_utils.py @@ -340,11 +340,11 @@ def full_training_step(self: Trainer, inputs: Dict[str, paddle.Tensor], **kwargs fused_allreduce_gradients(list(model.parameters()), None) # Pipeline parallel mode, handle gradient reduce here to overlap - pipeline_parallel_config = ( - set(args.pipeline_parallel_config.split(" ")) if args.pipeline_parallel_degree > 1 else set() - ) - enable_dp_comm_overlap = "enable_dp_comm_overlap" in pipeline_parallel_config - enable_release_grads = "enable_release_grads" in pipeline_parallel_config + enable_dp_comm_overlap = False + enable_release_grads = False + if args.pipeline_parallel_degree > 1: + enable_dp_comm_overlap = "enable_dp_comm_overlap" in args.pipeline_parallel_config + enable_release_grads = "enable_release_grads" in args.pipeline_parallel_config # Case 3: Pipeline parallel mode, overlap with dp if isinstance(self.optimizer, HybridParallelOptimizer) and not self.do_grad_scaling: diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index 01e5fccbc02e..85e9386bfc5e 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -1083,17 +1083,13 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg): fused_allreduce_gradients_no_sync(list(model.parameters()), None) # Pipeline parallel mode, handle gradient reduce here to overlap - pipeline_parallel_config = ( - set(args.pipeline_parallel_config.split(" ")) if args.pipeline_parallel_degree > 1 else set() - ) - sharding_parallel_config = ( - set(args.sharding_parallel_config.split(" ")) if args.sharding_parallel_degree > 1 else set() - ) - enable_dp_comm_overlap = "enable_dp_comm_overlap" in pipeline_parallel_config - enable_release_grads = ( - "enable_release_grads" in pipeline_parallel_config - or "enable_release_grads" in sharding_parallel_config - ) + enable_dp_comm_overlap = "enable_dp_comm_overlap" in args.pipeline_parallel_config + + enable_release_grads = False + if args.sharding_parallel_degree > 1: + enable_release_grads = "enable_release_grads" in args.sharding_parallel_config + if not enable_release_grads and args.pipeline_parallel_degree > 1: + enable_release_grads = "enable_release_grads" in args.pipeline_parallel_config # Case 3: Pipeline parallel mode, overlap with dp if isinstance(self.optimizer, HybridParallelOptimizer) and not self.do_grad_scaling: @@ -1992,8 +1988,7 @@ def get_expected_keys(inputs, keys): "please upgrade your paddle (using nightly version)." ) - sharding_parallel_config = set(self.args.sharding_parallel_config.split(" ")) - if level == "os_g" and "enable_stage2_overlap" in sharding_parallel_config: + if level == "os_g" and "enable_stage2_overlap" in self.args.sharding_parallel_config: model._set_reduce_overlap(True) optimizer._set_broadcast_overlap(True, model) @@ -2133,9 +2128,9 @@ def compute_loss(self, model, inputs, return_outputs=False): def _enable_delay_scale_loss(self): key = "enable_delay_scale_loss" if self.args.pipeline_parallel_degree > 1: - return key in self.args.pipeline_parallel_config.split(" ") + return key in self.args.pipeline_parallel_config elif self.args.tensor_parallel_degree > 1: - return key in self.args.tensor_parallel_config.split(" ") + return key in self.args.tensor_parallel_config else: return False diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index b027b1178247..2aa7b632ba40 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -1022,6 +1022,13 @@ def __post_init__(self): logger.warning("set amp_master_grad to false since amp is disabled.") self.amp_master_grad = False + def split_parallel_config(parallel_config): + if "," in parallel_config: + parallel_config = set(parallel_config.split(",")) + else: + parallel_config = set(parallel_config.split(" ")) + return parallel_config + # use_hybrid_parallel if self.use_hybrid_parallel: @@ -1039,10 +1046,7 @@ def __post_init__(self): strategy = fleet.DistributedStrategy() assert self.data_parallel_config == "", "data_parallle_config is not supported in hybrid parallel" if self.pipeline_parallel_degree > 1: - if " " in self.pipeline_parallel_config: - pipeline_parallel_config = set(self.pipeline_parallel_config.split(" ")) - else: - pipeline_parallel_config = set(self.pipeline_parallel_config.split(",")) + pipeline_parallel_config = split_parallel_config(self.pipeline_parallel_config) for x in pipeline_parallel_config: if len(x) > 0: if x not in [ @@ -1116,10 +1120,7 @@ def __post_init__(self): if self.tensor_parallel_degree > 1: strategy.tensor_parallel_configs = {"tensor_init_seed": self.seed} - if " " in self.tensor_parallel_config: - mp_config = set(self.tensor_parallel_config.split(" ")) - else: - mp_config = set(self.tensor_parallel_config.split(",")) + mp_config = split_parallel_config(self.tensor_parallel_config) for x in mp_config: if len(x) > 0: @@ -1225,10 +1226,8 @@ def is_segment_parallel_supported(): strategy.hybrid_configs = hybrid_configs if self.sharding_parallel_degree > 1: - if " " in self.sharding_parallel_config: - sharding_parallel_config = set(self.sharding_parallel_config.split(" ")) - else: - sharding_parallel_config = set(self.sharding_parallel_config.split(",")) + sharding_parallel_config = split_parallel_config(self.sharding_parallel_config) + for x in sharding_parallel_config: if len(x) > 0: if x not in [ @@ -1384,10 +1383,7 @@ def is_segment_parallel_supported(): # navie-pp: pipeline_parallel_degree > 1 and gradient_accumulation_steps == 1 if self.pipeline_parallel_degree > 1 and self.gradient_accumulation_steps > 1: - if " " in self.pipeline_parallel_config: - pipeline_parallel_config = set(self.pipeline_parallel_config.split(" ")) - else: - pipeline_parallel_config = set(self.pipeline_parallel_config.split(",")) + pipeline_parallel_config = split_parallel_config(self.pipeline_parallel_config) for x in pipeline_parallel_config: if len(x) > 0: if x not in [ @@ -1436,11 +1432,7 @@ def is_segment_parallel_supported(): if self.tensor_parallel_degree > 1: mp_optimization = strategy.mp_optimization - - if " " in self.tensor_parallel_config: - mp_config = set(self.tensor_parallel_config.split(" ")) - else: - mp_config = set(self.tensor_parallel_config.split(",")) + mp_config = split_parallel_config(self.tensor_parallel_config) for x in mp_config: if len(x) > 0: @@ -1473,10 +1465,7 @@ def is_segment_parallel_supported(): elif ShardingOption.FULL_SHARD in self.sharding: sharding.stage = 3 - if " " in self.sharding_parallel_config: - sharding_parallel_config = set(self.sharding_parallel_config.split(" ")) - else: - sharding_parallel_config = set(self.sharding_parallel_config.split(",")) + sharding_parallel_config = split_parallel_config(self.sharding_parallel_config) for x in sharding_parallel_config: if len(x) > 0: if x not in [