diff --git a/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py b/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py index 71d486818..ece8e34da 100644 --- a/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py +++ b/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py @@ -46,6 +46,7 @@ "kto": "kto_loss_func", } + class MegatronGPTDPOModel(MegatronGPTModel, SupervisedInterface): """ Megatron GPT DPO Model Training. @@ -204,14 +205,16 @@ def get_reduced_masked_logps(self, logps, labels, average_log_probs=False): return (logps * loss_mask).sum(-1) def loss_func(self, pi_logprobs, ref_logprobs, labels, average_log_probs=False): - loss, chosen_rewards, reject_rewards = self.preference_loss(pi_logprobs, ref_logprobs, labels, average_log_probs) + loss, chosen_rewards, reject_rewards = self.preference_loss( + pi_logprobs, ref_logprobs, labels, average_log_probs + ) with torch.no_grad(): comp = chosen_rewards > reject_rewards acc_chosen = comp.float().mean() return loss, acc_chosen - + def dpo_loss_func(self, pi_logprobs, ref_logprobs, labels, average_log_probs=False): rewards = self.get_reduced_masked_logps( pi_logprobs - ref_logprobs, labels, average_log_probs=average_log_probs @@ -245,9 +248,7 @@ def kto_loss_func(self, pi_logprobs, ref_logprobs, labels, average_log_probs=Fal chosen_rewards, reject_rewards = self.split_output_tensor(rewards) - rewards_kl = self.get_reduced_masked_logps( - pi_logprobs - ref_logprobs, labels, average_log_probs=True - ) + rewards_kl = self.get_reduced_masked_logps(pi_logprobs - ref_logprobs, labels, average_log_probs=True) chosen_kl, reject_kl = self.split_output_tensor(rewards_kl) loss = (