From ad4bebbf1a034f07d45c1dab64512fbd6ade2405 Mon Sep 17 00:00:00 2001 From: "yezhenan.yza" Date: Thu, 16 Jan 2025 16:44:13 +0800 Subject: [PATCH] feat: support context parallel for dpo --- .../models/nlp/gpt/megatron_gpt_dpo_model.py | 44 ++++++++++++++++++- nemo_aligner/utils/distributed.py | 7 ++- 2 files changed, 46 insertions(+), 5 deletions(-) diff --git a/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py b/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py index c5404ac7b..9353d6059 100644 --- a/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py +++ b/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py @@ -84,6 +84,9 @@ def gather_and_split_rewards(self, pi_logprobs, ref_logprobs, labels, cu_seqlens pi_logprobs - ref_logprobs, labels, cu_seqlens, average_log_probs=average_log_probs ) + if parallel_state.get_context_parallel_world_size() > 1: + torch.distributed.all_reduce(batch_logs, group=parallel_state.get_context_parallel_group()) + num_examples_on_this_rank = torch.tensor(batch_logs.size(), device=torch.cuda.current_device()) num_examples = [torch.zeros_like(num_examples_on_this_rank) for _ in range(dp_group.size())] torch.distributed.all_gather(num_examples, num_examples_on_this_rank, group=dp_group) @@ -97,6 +100,37 @@ def gather_and_split_rewards(self, pi_logprobs, ref_logprobs, labels, cu_seqlens return out_chosen.flatten(), out_rejected.flatten() + def get_batch_on_this_context_parallel_rank(self, batch: dict): + cp_size = parallel_state.get_context_parallel_world_size() + if cp_size > 1: + ref_logs = {} + if "ref_policy_log_probs_chosen" in batch: + ref_logs["ref_policy_log_probs_chosen"] = batch.pop("ref_policy_log_probs_chosen") + + if "ref_policy_log_probs_rejected" in batch: + ref_logs["ref_policy_log_probs_rejected"] = batch.pop("ref_policy_log_probs_rejected") + + cp_rank = parallel_state.get_context_parallel_rank() + for key, val in batch.items(): + if val is not None: + seq_dim = 1 if key != 'attention_mask' else 2 + val = val.view( + *val.shape[0:seq_dim], + 2 * cp_size, + val.shape[seq_dim] // (2 * cp_size), + *val.shape[(seq_dim + 1) :], + ) + index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device="cpu", pin_memory=True).cuda( + non_blocking=True + ) + val = val.index_select(seq_dim, index) + val = val.view(*val.shape[0:seq_dim], -1, *val.shape[(seq_dim + 2) :]) + batch[key] = val + batch.update(ref_logs) + + + return batch + def get_forward_output_and_loss_func(self, validation_step=False, logprobs_only=False): def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_layers=None): batch = next(dataloader_iter) @@ -138,6 +172,8 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_ batch = {key: val.cuda(non_blocking=True) if key in required_keys else None for key, val in batch.items()} + batch = self.get_batch_on_this_context_parallel_rank(batch) + tokens, labels, ref_logprobs, gt_rewards, cu_seqlens = None, None, None, None, None if packed: ## packed sequence tokens = batch["input_ids"] @@ -275,8 +311,10 @@ def loss_func(output_tensor): average_log_probs=self.preference_avg_log_probs, ) + cp_size = parallel_state.get_context_parallel_world_size() + return ( - loss, + loss * cp_size, { "avg": reduced_loss, "avg_sft_loss": reduced_sft_loss, @@ -348,6 +386,10 @@ def loss_func(self, pi_logprobs, ref_logprobs, labels, gt_rewards, cu_seqlens=No rewards = self.get_reduced_masked_logps( pi_logprobs - ref_logprobs, labels, cu_seqlens=cu_seqlens, average_log_probs=average_log_probs, ) + + if parallel_state.get_context_parallel_world_size() > 1: + torch.distributed.all_reduce(rewards, group=parallel_state.get_context_parallel_group()) + chosen_rewards, reject_rewards = self.split_output_tensor(rewards) rewards_delta = chosen_rewards - reject_rewards diff --git a/nemo_aligner/utils/distributed.py b/nemo_aligner/utils/distributed.py index 654502ae4..d7705e41b 100755 --- a/nemo_aligner/utils/distributed.py +++ b/nemo_aligner/utils/distributed.py @@ -328,12 +328,11 @@ def forward(ctx, vocab_parallel_logits, target, inference_only=False, higher_sta @staticmethod def backward(ctx, grad_output): softmax, target_mask, masked_target = ctx.saved_tensors - partition_vocab_size = softmax.size(-1) # 1 if it's the chosen log prob, 0 otherwise - is_chosen = (~target_mask).unsqueeze(-1) * torch.nn.functional.one_hot( - masked_target, num_classes=partition_vocab_size - ) + is_chosen = torch.zeros_like(softmax) + is_chosen.scatter_(-1, masked_target.unsqueeze(-1), 1.0) + is_chosen = is_chosen.masked_fill(target_mask.unsqueeze(-1), 0.0) grad_input = is_chosen.float().sub_(softmax)