diff --git a/examples/summarize_rlhf/reward_model/gptj_reward_test.py b/examples/summarize_rlhf/reward_model/gptj_reward_test.py index 09fd2e5e7..b3dd0bee9 100644 --- a/examples/summarize_rlhf/reward_model/gptj_reward_test.py +++ b/examples/summarize_rlhf/reward_model/gptj_reward_test.py @@ -59,10 +59,11 @@ def __init__(self, pairs, tokenizer, max_length): padding="max_length", return_tensors="pt", ) - self.chosen_input_ids.append(chosen_encodings_dict["input_ids"]) - self.chosen_attn_masks.append(chosen_encodings_dict["attention_mask"]) - self.rejected_input_ids.append(rejected_encodings_dict["input_ids"]) - self.rejected_attn_masks.append(rejected_encodings_dict["attention_mask"]) + if not torch.all(torch.eq(chosen_encodings_dict["input_ids"], rejected_encodings_dict["input_ids"])).item(): + self.chosen_input_ids.append(chosen_encodings_dict["input_ids"]) + self.chosen_attn_masks.append(chosen_encodings_dict["attention_mask"]) + self.rejected_input_ids.append(rejected_encodings_dict["input_ids"]) + self.rejected_attn_masks.append(rejected_encodings_dict["attention_mask"]) def __len__(self): return len(self.chosen_input_ids) diff --git a/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py b/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py index ba15d7da2..474d514e3 100644 --- a/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py +++ b/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py @@ -48,10 +48,11 @@ def __init__(self, pairs, tokenizer, max_length): padding="max_length", return_tensors="pt", ) - self.chosen_input_ids.append(chosen_encodings_dict["input_ids"]) - self.chosen_attn_masks.append(chosen_encodings_dict["attention_mask"]) - self.rejected_input_ids.append(rejected_encodings_dict["input_ids"]) - self.rejected_attn_masks.append(rejected_encodings_dict["attention_mask"]) + if not torch.all(torch.eq(chosen_encodings_dict["input_ids"], rejected_encodings_dict["input_ids"])).item(): + self.chosen_input_ids.append(chosen_encodings_dict["input_ids"]) + self.chosen_attn_masks.append(chosen_encodings_dict["attention_mask"]) + self.rejected_input_ids.append(rejected_encodings_dict["input_ids"]) + self.rejected_attn_masks.append(rejected_encodings_dict["attention_mask"]) def __len__(self): return len(self.chosen_input_ids)