Skip to content

Commit

Permalink
feat: support context parallel for dpo
Browse files Browse the repository at this point in the history
  • Loading branch information
yezhenan.yza committed Jan 16, 2025
1 parent 968ba12 commit ad4bebb
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 5 deletions.
44 changes: 43 additions & 1 deletion nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ def gather_and_split_rewards(self, pi_logprobs, ref_logprobs, labels, cu_seqlens
pi_logprobs - ref_logprobs, labels, cu_seqlens, average_log_probs=average_log_probs
)

if parallel_state.get_context_parallel_world_size() > 1:
torch.distributed.all_reduce(batch_logs, group=parallel_state.get_context_parallel_group())

num_examples_on_this_rank = torch.tensor(batch_logs.size(), device=torch.cuda.current_device())
num_examples = [torch.zeros_like(num_examples_on_this_rank) for _ in range(dp_group.size())]
torch.distributed.all_gather(num_examples, num_examples_on_this_rank, group=dp_group)
Expand All @@ -97,6 +100,37 @@ def gather_and_split_rewards(self, pi_logprobs, ref_logprobs, labels, cu_seqlens

return out_chosen.flatten(), out_rejected.flatten()

def get_batch_on_this_context_parallel_rank(self, batch: dict):
cp_size = parallel_state.get_context_parallel_world_size()
if cp_size > 1:
ref_logs = {}
if "ref_policy_log_probs_chosen" in batch:
ref_logs["ref_policy_log_probs_chosen"] = batch.pop("ref_policy_log_probs_chosen")

if "ref_policy_log_probs_rejected" in batch:
ref_logs["ref_policy_log_probs_rejected"] = batch.pop("ref_policy_log_probs_rejected")

cp_rank = parallel_state.get_context_parallel_rank()
for key, val in batch.items():
if val is not None:
seq_dim = 1 if key != 'attention_mask' else 2
val = val.view(
*val.shape[0:seq_dim],
2 * cp_size,
val.shape[seq_dim] // (2 * cp_size),
*val.shape[(seq_dim + 1) :],
)
index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device="cpu", pin_memory=True).cuda(
non_blocking=True
)
val = val.index_select(seq_dim, index)
val = val.view(*val.shape[0:seq_dim], -1, *val.shape[(seq_dim + 2) :])
batch[key] = val
batch.update(ref_logs)


return batch

def get_forward_output_and_loss_func(self, validation_step=False, logprobs_only=False):
def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_layers=None):
batch = next(dataloader_iter)
Expand Down Expand Up @@ -138,6 +172,8 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_

batch = {key: val.cuda(non_blocking=True) if key in required_keys else None for key, val in batch.items()}

batch = self.get_batch_on_this_context_parallel_rank(batch)

tokens, labels, ref_logprobs, gt_rewards, cu_seqlens = None, None, None, None, None
if packed: ## packed sequence
tokens = batch["input_ids"]
Expand Down Expand Up @@ -275,8 +311,10 @@ def loss_func(output_tensor):
average_log_probs=self.preference_avg_log_probs,
)

cp_size = parallel_state.get_context_parallel_world_size()

return (
loss,
loss * cp_size,
{
"avg": reduced_loss,
"avg_sft_loss": reduced_sft_loss,
Expand Down Expand Up @@ -348,6 +386,10 @@ def loss_func(self, pi_logprobs, ref_logprobs, labels, gt_rewards, cu_seqlens=No
rewards = self.get_reduced_masked_logps(
pi_logprobs - ref_logprobs, labels, cu_seqlens=cu_seqlens, average_log_probs=average_log_probs,
)

if parallel_state.get_context_parallel_world_size() > 1:
torch.distributed.all_reduce(rewards, group=parallel_state.get_context_parallel_group())

chosen_rewards, reject_rewards = self.split_output_tensor(rewards)
rewards_delta = chosen_rewards - reject_rewards

Expand Down
7 changes: 3 additions & 4 deletions nemo_aligner/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,12 +328,11 @@ def forward(ctx, vocab_parallel_logits, target, inference_only=False, higher_sta
@staticmethod
def backward(ctx, grad_output):
softmax, target_mask, masked_target = ctx.saved_tensors
partition_vocab_size = softmax.size(-1)

# 1 if it's the chosen log prob, 0 otherwise
is_chosen = (~target_mask).unsqueeze(-1) * torch.nn.functional.one_hot(
masked_target, num_classes=partition_vocab_size
)
is_chosen = torch.zeros_like(softmax)
is_chosen.scatter_(-1, masked_target.unsqueeze(-1), 1.0)
is_chosen = is_chosen.masked_fill(target_mask.unsqueeze(-1), 0.0)

grad_input = is_chosen.float().sub_(softmax)

Expand Down

0 comments on commit ad4bebb

Please sign in to comment.