diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index a23955f069ec..32aa8cfdefe3 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -1136,6 +1136,15 @@ def split_parallel_config(parallel_config): raise ValueError( "If `enable_sharding_comm_overlap` in pipeline_parallel_configs, `amp_master_grad` must be True." ) + if ( + enable_sharding_comm_overlap + and self.unified_checkpoint + and "split_param" in split_parallel_config(self.sharding_parallel_config) + ): + logger.warning( + "Currently unified checkpoint do not support using `sharding_comm_overlap` and `split_param` at the same time, delete `sharding_comm_overlap`." + ) + enable_sharding_comm_overlap = False dygraph_pp_configs = { "delay_scale_loss": True if "enable_delay_scale_loss" in pipeline_parallel_config else False, diff --git a/paddlenlp/trl/dpo_criterion.py b/paddlenlp/trl/dpo_criterion.py index 2af2a8ef2096..be454e2ce4d1 100644 --- a/paddlenlp/trl/dpo_criterion.py +++ b/paddlenlp/trl/dpo_criterion.py @@ -287,10 +287,10 @@ def forward( ) loss = dpo_loss + sft_loss if self.use_infohub: - infohub.policy_chosen_logps.append(policy_chosen_logps) - infohub.policy_rejected_logps.append(policy_rejected_logps) - infohub.sft_loss.append(sft_loss) - infohub.dpo_loss.append(dpo_loss) + infohub.policy_chosen_logps.append(policy_chosen_logps.detach()) + infohub.policy_rejected_logps.append(policy_rejected_logps.detach()) + infohub.sft_loss.append(sft_loss.detach()) + infohub.dpo_loss.append(dpo_loss.detach()) return loss else: return policy_chosen_logps, policy_rejected_logps, sft_loss, dpo_loss, loss diff --git a/paddlenlp/trl/kto_criterion.py b/paddlenlp/trl/kto_criterion.py index 52745e996999..a6ca6c4c837a 100644 --- a/paddlenlp/trl/kto_criterion.py +++ b/paddlenlp/trl/kto_criterion.py @@ -247,10 +247,10 @@ def forward( reference_kl_logps, ) if self.use_infohub: - infohub.policy_chosen_logps.append(policy_chosen_logps) - infohub.policy_rejected_logps.append(policy_rejected_logps) - infohub.policy_kl_logps.append(policy_kl_logps) - infohub.kl.append(kl) + infohub.policy_chosen_logps.append(policy_chosen_logps.detach()) + infohub.policy_rejected_logps.append(policy_rejected_logps.detach()) + infohub.policy_kl_logps.append(policy_kl_logps.detach()) + infohub.kl.append(kl.detach()) return loss else: return (