diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index 08263e0181..1afd490bd1 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -18,7 +18,7 @@ import torch from datasets import load_dataset from parameterized import parameterized -from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification +from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer from transformers.testing_utils import require_peft from transformers.utils import is_peft_available @@ -110,29 +110,30 @@ def test_training_peft(self): def test_training_different_reward_model(self): # Use a reward model different from the model: different chat template, tokenization, etc. dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_only", split="train") - reward_model = AutoModelForSequenceClassification.from_pretrained( - "trl-internal-testing/tiny-LlamaForSequenceClassification-3.2" - ) - # When training with the raw model, the score are too low to have an impact on the training - # within the few generations we are running here. We multiply the weights by 10000 to make sure - # the reward has an impact. - with torch.no_grad(): - reward_model.score.weight *= 10000 + reward_model_id = "trl-internal-testing/tiny-LlamaForSequenceClassification-3.2" + reward_model = AutoModelForSequenceClassification.from_pretrained(reward_model_id) + reward_tokenizer = AutoTokenizer.from_pretrained(reward_model_id) + # By default, the trainer uses the eos token as the padding token. However, for Llama models, the eos token + # appears in the chat template. Using it as a pad token disrupts the reward calculation, as the calculation + # considers the score of the last token before the first pad token. To ensure correct reward calculations, + # we use a separate pad token instead. + reward_tokenizer.pad_token = "<|finetune_right_pad_id|>" with tempfile.TemporaryDirectory() as tmp_dir: training_args = GRPOConfig( output_dir=tmp_dir, learning_rate=0.1, # increase the learning rate to speed up the test per_device_train_batch_size=2, # reduce the batch size to reduce memory usage - num_generations=5, # reduce the number of generations to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage max_completion_length=32, # reduce the completion length to reduce memory usage report_to="none", ) trainer = GRPOTrainer( model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", - reward_model=reward_model, # llama-based RM + reward_model=reward_model, args=training_args, train_dataset=dataset, + reward_processing_class=reward_tokenizer, ) previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}