Skip to content

Commit

Permalink
[Trainer] Enable parallel_config to use commas as delimiters. (Paddle…
Browse files Browse the repository at this point in the history
  • Loading branch information
Difers authored Jul 17, 2024
1 parent e89399a commit 62a18c4
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 50 deletions.
10 changes: 5 additions & 5 deletions legacy/examples/RLHF/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,11 +340,11 @@ def full_training_step(self: Trainer, inputs: Dict[str, paddle.Tensor], **kwargs
fused_allreduce_gradients(list(model.parameters()), None)

# Pipeline parallel mode, handle gradient reduce here to overlap
pipeline_parallel_config = (
set(args.pipeline_parallel_config.split(" ")) if args.pipeline_parallel_degree > 1 else set()
)
enable_dp_comm_overlap = "enable_dp_comm_overlap" in pipeline_parallel_config
enable_release_grads = "enable_release_grads" in pipeline_parallel_config
enable_dp_comm_overlap = False
enable_release_grads = False
if args.pipeline_parallel_degree > 1:
enable_dp_comm_overlap = "enable_dp_comm_overlap" in args.pipeline_parallel_config
enable_release_grads = "enable_release_grads" in args.pipeline_parallel_config

# Case 3: Pipeline parallel mode, overlap with dp
if isinstance(self.optimizer, HybridParallelOptimizer) and not self.do_grad_scaling:
Expand Down
10 changes: 5 additions & 5 deletions llm/alignment/ppo/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,11 +340,11 @@ def full_training_step(self: Trainer, inputs: Dict[str, paddle.Tensor], **kwargs
fused_allreduce_gradients(list(model.parameters()), None)

# Pipeline parallel mode, handle gradient reduce here to overlap
pipeline_parallel_config = (
set(args.pipeline_parallel_config.split(" ")) if args.pipeline_parallel_degree > 1 else set()
)
enable_dp_comm_overlap = "enable_dp_comm_overlap" in pipeline_parallel_config
enable_release_grads = "enable_release_grads" in pipeline_parallel_config
enable_dp_comm_overlap = False
enable_release_grads = False
if args.pipeline_parallel_degree > 1:
enable_dp_comm_overlap = "enable_dp_comm_overlap" in args.pipeline_parallel_config
enable_release_grads = "enable_release_grads" in args.pipeline_parallel_config

# Case 3: Pipeline parallel mode, overlap with dp
if isinstance(self.optimizer, HybridParallelOptimizer) and not self.do_grad_scaling:
Expand Down
25 changes: 10 additions & 15 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1083,17 +1083,13 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg):
fused_allreduce_gradients_no_sync(list(model.parameters()), None)

# Pipeline parallel mode, handle gradient reduce here to overlap
pipeline_parallel_config = (
set(args.pipeline_parallel_config.split(" ")) if args.pipeline_parallel_degree > 1 else set()
)
sharding_parallel_config = (
set(args.sharding_parallel_config.split(" ")) if args.sharding_parallel_degree > 1 else set()
)
enable_dp_comm_overlap = "enable_dp_comm_overlap" in pipeline_parallel_config
enable_release_grads = (
"enable_release_grads" in pipeline_parallel_config
or "enable_release_grads" in sharding_parallel_config
)
enable_dp_comm_overlap = "enable_dp_comm_overlap" in args.pipeline_parallel_config

enable_release_grads = False
if args.sharding_parallel_degree > 1:
enable_release_grads = "enable_release_grads" in args.sharding_parallel_config
if not enable_release_grads and args.pipeline_parallel_degree > 1:
enable_release_grads = "enable_release_grads" in args.pipeline_parallel_config

# Case 3: Pipeline parallel mode, overlap with dp
if isinstance(self.optimizer, HybridParallelOptimizer) and not self.do_grad_scaling:
Expand Down Expand Up @@ -1992,8 +1988,7 @@ def get_expected_keys(inputs, keys):
"please upgrade your paddle (using nightly version)."
)

sharding_parallel_config = set(self.args.sharding_parallel_config.split(" "))
if level == "os_g" and "enable_stage2_overlap" in sharding_parallel_config:
if level == "os_g" and "enable_stage2_overlap" in self.args.sharding_parallel_config:
model._set_reduce_overlap(True)
optimizer._set_broadcast_overlap(True, model)

Expand Down Expand Up @@ -2133,9 +2128,9 @@ def compute_loss(self, model, inputs, return_outputs=False):
def _enable_delay_scale_loss(self):
key = "enable_delay_scale_loss"
if self.args.pipeline_parallel_degree > 1:
return key in self.args.pipeline_parallel_config.split(" ")
return key in self.args.pipeline_parallel_config
elif self.args.tensor_parallel_degree > 1:
return key in self.args.tensor_parallel_config.split(" ")
return key in self.args.tensor_parallel_config
else:
return False

Expand Down
39 changes: 14 additions & 25 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1022,6 +1022,13 @@ def __post_init__(self):
logger.warning("set amp_master_grad to false since amp is disabled.")
self.amp_master_grad = False

def split_parallel_config(parallel_config):
if "," in parallel_config:
parallel_config = set(parallel_config.split(","))
else:
parallel_config = set(parallel_config.split(" "))
return parallel_config

# use_hybrid_parallel
if self.use_hybrid_parallel:

Expand All @@ -1039,10 +1046,7 @@ def __post_init__(self):
strategy = fleet.DistributedStrategy()
assert self.data_parallel_config == "", "data_parallle_config is not supported in hybrid parallel"
if self.pipeline_parallel_degree > 1:
if " " in self.pipeline_parallel_config:
pipeline_parallel_config = set(self.pipeline_parallel_config.split(" "))
else:
pipeline_parallel_config = set(self.pipeline_parallel_config.split(","))
pipeline_parallel_config = split_parallel_config(self.pipeline_parallel_config)
for x in pipeline_parallel_config:
if len(x) > 0:
if x not in [
Expand Down Expand Up @@ -1116,10 +1120,7 @@ def __post_init__(self):
if self.tensor_parallel_degree > 1:
strategy.tensor_parallel_configs = {"tensor_init_seed": self.seed}

if " " in self.tensor_parallel_config:
mp_config = set(self.tensor_parallel_config.split(" "))
else:
mp_config = set(self.tensor_parallel_config.split(","))
mp_config = split_parallel_config(self.tensor_parallel_config)

for x in mp_config:
if len(x) > 0:
Expand Down Expand Up @@ -1225,10 +1226,8 @@ def is_segment_parallel_supported():
strategy.hybrid_configs = hybrid_configs

if self.sharding_parallel_degree > 1:
if " " in self.sharding_parallel_config:
sharding_parallel_config = set(self.sharding_parallel_config.split(" "))
else:
sharding_parallel_config = set(self.sharding_parallel_config.split(","))
sharding_parallel_config = split_parallel_config(self.sharding_parallel_config)

for x in sharding_parallel_config:
if len(x) > 0:
if x not in [
Expand Down Expand Up @@ -1384,10 +1383,7 @@ def is_segment_parallel_supported():

# navie-pp: pipeline_parallel_degree > 1 and gradient_accumulation_steps == 1
if self.pipeline_parallel_degree > 1 and self.gradient_accumulation_steps > 1:
if " " in self.pipeline_parallel_config:
pipeline_parallel_config = set(self.pipeline_parallel_config.split(" "))
else:
pipeline_parallel_config = set(self.pipeline_parallel_config.split(","))
pipeline_parallel_config = split_parallel_config(self.pipeline_parallel_config)
for x in pipeline_parallel_config:
if len(x) > 0:
if x not in [
Expand Down Expand Up @@ -1436,11 +1432,7 @@ def is_segment_parallel_supported():

if self.tensor_parallel_degree > 1:
mp_optimization = strategy.mp_optimization

if " " in self.tensor_parallel_config:
mp_config = set(self.tensor_parallel_config.split(" "))
else:
mp_config = set(self.tensor_parallel_config.split(","))
mp_config = split_parallel_config(self.tensor_parallel_config)

for x in mp_config:
if len(x) > 0:
Expand Down Expand Up @@ -1473,10 +1465,7 @@ def is_segment_parallel_supported():
elif ShardingOption.FULL_SHARD in self.sharding:
sharding.stage = 3

if " " in self.sharding_parallel_config:
sharding_parallel_config = set(self.sharding_parallel_config.split(" "))
else:
sharding_parallel_config = set(self.sharding_parallel_config.split(","))
sharding_parallel_config = split_parallel_config(self.sharding_parallel_config)
for x in sharding_parallel_config:
if len(x) > 0:
if x not in [
Expand Down

0 comments on commit 62a18c4

Please sign in to comment.