From adb3be224769cf7d85bc3e28b4068ab68bc4f101 Mon Sep 17 00:00:00 2001 From: Chen9154 Date: Wed, 12 Apr 2023 03:20:09 +0800 Subject: [PATCH] [fix] update pairwise dataloader. (#395) * Update train_reward_model_gptj.py In forward() of reward_model.py (Line 62), if "chosen" and "rejected" are exactly the same, "inference" would turn to True, which should not happen during the training procedure. However in class PairwiseDataset, "chosen" and "rejected" could be the same after truncation. So we filter out those cases from training data. * Update gptj_reward_test.py In forward() of reward_model.py (Line 62), if "chosen" and "rejected" are exactly the same, "inference" would turn to True, which should not happen during the training procedure. However in class PairwiseDataset, "chosen" and "rejected" could be the same after truncation. So we filter out those cases from training data. --- examples/summarize_rlhf/reward_model/gptj_reward_test.py | 9 +++++---- .../reward_model/train_reward_model_gptj.py | 9 +++++---- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/examples/summarize_rlhf/reward_model/gptj_reward_test.py b/examples/summarize_rlhf/reward_model/gptj_reward_test.py index 09fd2e5e7..b3dd0bee9 100644 --- a/examples/summarize_rlhf/reward_model/gptj_reward_test.py +++ b/examples/summarize_rlhf/reward_model/gptj_reward_test.py @@ -59,10 +59,11 @@ def __init__(self, pairs, tokenizer, max_length): padding="max_length", return_tensors="pt", ) - self.chosen_input_ids.append(chosen_encodings_dict["input_ids"]) - self.chosen_attn_masks.append(chosen_encodings_dict["attention_mask"]) - self.rejected_input_ids.append(rejected_encodings_dict["input_ids"]) - self.rejected_attn_masks.append(rejected_encodings_dict["attention_mask"]) + if not torch.all(torch.eq(chosen_encodings_dict["input_ids"], rejected_encodings_dict["input_ids"])).item(): + self.chosen_input_ids.append(chosen_encodings_dict["input_ids"]) + self.chosen_attn_masks.append(chosen_encodings_dict["attention_mask"]) + self.rejected_input_ids.append(rejected_encodings_dict["input_ids"]) + self.rejected_attn_masks.append(rejected_encodings_dict["attention_mask"]) def __len__(self): return len(self.chosen_input_ids) diff --git a/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py b/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py index ba15d7da2..474d514e3 100644 --- a/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py +++ b/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py @@ -48,10 +48,11 @@ def __init__(self, pairs, tokenizer, max_length): padding="max_length", return_tensors="pt", ) - self.chosen_input_ids.append(chosen_encodings_dict["input_ids"]) - self.chosen_attn_masks.append(chosen_encodings_dict["attention_mask"]) - self.rejected_input_ids.append(rejected_encodings_dict["input_ids"]) - self.rejected_attn_masks.append(rejected_encodings_dict["attention_mask"]) + if not torch.all(torch.eq(chosen_encodings_dict["input_ids"], rejected_encodings_dict["input_ids"])).item(): + self.chosen_input_ids.append(chosen_encodings_dict["input_ids"]) + self.chosen_attn_masks.append(chosen_encodings_dict["attention_mask"]) + self.rejected_input_ids.append(rejected_encodings_dict["input_ids"]) + self.rejected_attn_masks.append(rejected_encodings_dict["attention_mask"]) def __len__(self): return len(self.chosen_input_ids)