Skip to content

Commit

Permalink
Optimized KTO loss
Browse files Browse the repository at this point in the history
  • Loading branch information
ertkonuk committed Jan 12, 2024
2 parents 4d3bdc8 + 9b20120 commit f7dafe0
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
"kto": "kto_loss_func",
}


class MegatronGPTDPOModel(MegatronGPTModel, SupervisedInterface):
"""
Megatron GPT DPO Model Training.
Expand Down Expand Up @@ -204,14 +205,16 @@ def get_reduced_masked_logps(self, logps, labels, average_log_probs=False):
return (logps * loss_mask).sum(-1)

def loss_func(self, pi_logprobs, ref_logprobs, labels, average_log_probs=False):
loss, chosen_rewards, reject_rewards = self.preference_loss(pi_logprobs, ref_logprobs, labels, average_log_probs)
loss, chosen_rewards, reject_rewards = self.preference_loss(
pi_logprobs, ref_logprobs, labels, average_log_probs
)

with torch.no_grad():
comp = chosen_rewards > reject_rewards
acc_chosen = comp.float().mean()

return loss, acc_chosen

def dpo_loss_func(self, pi_logprobs, ref_logprobs, labels, average_log_probs=False):
rewards = self.get_reduced_masked_logps(
pi_logprobs - ref_logprobs, labels, average_log_probs=average_log_probs
Expand Down Expand Up @@ -245,9 +248,7 @@ def kto_loss_func(self, pi_logprobs, ref_logprobs, labels, average_log_probs=Fal

chosen_rewards, reject_rewards = self.split_output_tensor(rewards)

rewards_kl = self.get_reduced_masked_logps(
pi_logprobs - ref_logprobs, labels, average_log_probs=True
)
rewards_kl = self.get_reduced_masked_logps(pi_logprobs - ref_logprobs, labels, average_log_probs=True)

chosen_kl, reject_kl = self.split_output_tensor(rewards_kl)
loss = (
Expand Down

0 comments on commit f7dafe0

Please sign in to comment.