diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index f3d39ba445..4f956865b3 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -52,6 +52,8 @@ title: Iterative SFT - local: xpo_trainer title: XPO + - local: vas_trainer + title: VAS title: Trainers - local: models title: Model Classes diff --git a/docs/source/vas_trainer.mdx b/docs/source/vas_trainer.mdx new file mode 100644 index 0000000000..5722bdd884 --- /dev/null +++ b/docs/source/vas_trainer.mdx @@ -0,0 +1,110 @@ +# VAS Trainer + +[![](https://img.shields.io/badge/All_models-VAS-blue)](https://huggingface.co/models?other=vas,trl) + +## Overview +value Augmented Sampling (VAS) is an inference-time algorithm, introduced in the paper [Value Augmented Sampling for Language Model Alignment and Personalization](https://arxiv.org/abs/2405.06639) by [Seungwook Han](https://huggingface.co/hanseungwook), Idan Shenfeld, Akash Srivastava, Yoon Kim, and Pulkit Agrawal. At a high level, VAS is a method to improve the quality of language model generations by leveraging a value model to guide the sampling process. The value model is trained to predict the quality of the generated text, and the sampling process is modified to prefer samples that are predicted to be of higher quality. + +This approach allows VAS to optimize for arbitrary reward functions without requiring modifications to the language model's weights. This is in contrast to traditional reinforcement learning methods, which require training the language model with a reward function as part of the training process. + +## Quick start + +This example demonstrates how to train a model using the online DPO method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model and as a [reward](https://huggingface.co/trl-lib/Qwen2-0.5B-Reward). We use the prompts from the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback). You can view the prompts in the dataset here: + + + +Below is the script to train the model: + +```python +# train_vas.py +from datasets import load_dataset +from trl import CPOConfig, CPOTrainer +from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSequenceClassification + +ref_policy = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") +value_model = AutoModelForSequenceClassification.from_pretrained("Qwen/Qwen2-0.5B-Instruct", num_labels=1) +reward_model = AutoModelForSequenceClassification.from_pretrained("trl-lib/Qwen2-0.5B-Reward", num_labels=1) +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + +train_dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train") +train_dataset = train_dataset.map( + lambda x: {"input_ids": tokenizer(x["prompt"][0]["content"]).input_ids}, + remove_columns=train_dataset.column_names, +) + +training_args = VASConfig(output_dir="Qwen2-0.5B-VAS", logging_steps=10) +trainer = VASTrainer(ref_policy=ref_policy, value_model=value_model, reward_model=reward_model, config=training_args, processing_class=tokenizer, train_dataset=train_dataset) +trainer.train() +``` + +Execute the script using the following command: + +```bash +accelerate launch train_vas.py +``` + +## Expected dataset type + +VAS requires a [prompt-only dataset](dataset_formats#prompt-only). The [`NashMDTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. + +## Usage tips + + + +Make sure that the reference policy, value model and reward model use the _same_ chat template and the same tokenizer. Otherwise, you may find the model completions are scored incorrectly during training. + + + +### Encourage EOS token generation + +We may want the model to generate completions within a given length. During training, the model will generate completions up to the maximum length specified in the `max_new_tokens` argument of [`VASConfig`]. If you want to penalize the model for not generating an EOS token before reaching the maximum length, you can use the `missing_eos_penalty` argument of [`VASConfig`]: + +```python +training_args = VASConfig(..., max_new_tokens=128, missing_eos_penalty=1.0) +``` + +### Logging Completions + +To better understand your model’s behavior during training, you can log sample completions periodically using the [`num_sample_generations`] argument of [`VASConfig`]. Notice that generations at the beginning of the training will appear much worse than those of the reference policy, as the value function is initialized randomly. + +## Example script + +We provide an example script to train a model using the VAS method. The script is available in [`examples/scripts/vas.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/vas.py). + +To test the VAS script with the [Pythia 1B model](https://huggingface.co/EleutherAI/pythia-1b-deduped) on the [TL;DR dataset](https://huggingface.co/datasets/trl-lib/ultrafeedback_binarized), run the following command: + +```bash +accelerate launch examples/scripts/vas.py \ + --dataset_name trl-internal-testing/tldr-preference-sft-trl-style \ + --output_dir models/minimal/vas_tldr \ + --total_episodes 1000 \ + --model_name_or_path EleutherAI/pythia-1b-deduped \ + --sft_model_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \ + --reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \ + --response_length 256 \ + --stop_token eos +``` + +## Logged metrics + +While training and evaluating we record the following reward metrics: + +* `objective/scores`: The mean scores returned by the reward model. +* `loss/value_loss`: The average loss, indicating the difference between the predicted value and the actual reward. +* `train/num_eos_tokens`: The number of end-of-sequence (EOS) tokens generated, which can indicate the number of complete responses. +* `train/lr`: The current learning rate used by the optimizer. +* `train/episode`: The current global step or episode count in the training process. +* `train/eps`: Tracks the number of episodes per second. + +## VASTrainer + +[[autodoc]] VASTrainer + +## VASConfig + +[[autodoc]] VASConfig \ No newline at end of file diff --git a/examples/scripts/vas.py b/examples/scripts/vas.py new file mode 100644 index 0000000000..9ea110975d --- /dev/null +++ b/examples/scripts/vas.py @@ -0,0 +1,108 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import shutil + +from accelerate import PartialState +from datasets import load_dataset +from transformers import ( + AutoModelForCausalLM, + AutoModelForSequenceClassification, + AutoTokenizer, + HfArgumentParser, +) + +from trl import ModelConfig, VASConfig, VASTrainer, ScriptArguments +from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE + + +if __name__ == "__main__": + parser = HfArgumentParser((ScriptArguments, VASConfig, ModelConfig)) + script_args, training_args, model_config = parser.parse_args_into_dataclasses() + # remove output_dir if exists + shutil.rmtree(training_args.output_dir, ignore_errors=True) + + ################ + # Model & Tokenizer + ################ + tokenizer = AutoTokenizer.from_pretrained( + model_config.model_name_or_path, + padding_side="left", + trust_remote_code=model_config.trust_remote_code, + ) + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + if tokenizer.chat_template is None: + tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE + value_model = AutoModelForSequenceClassification.from_pretrained( + training_args.sft_model_path, trust_remote_code=model_config.trust_remote_code, num_labels=1, attn_implementation="flash_attention_2", + ) + value_model.config.pad_token_id = tokenizer.pad_token_id + reward_model = AutoModelForSequenceClassification.from_pretrained( + training_args.reward_model_path, trust_remote_code=model_config.trust_remote_code, num_labels=1, torch_dtype="bfloat16", attn_implementation="flash_attention_2", + ) + ref_policy = AutoModelForCausalLM.from_pretrained( + training_args.sft_model_path, trust_remote_code=model_config.trust_remote_code, torch_dtype="float16", attn_implementation="flash_attention_2", + ) + ################ + # Dataset + ################ + dataset = load_dataset(script_args.dataset_name, split=script_args.dataset_train_split) + eval_samples = 20 + train_dataset = dataset.select(range(len(dataset) - eval_samples)) + eval_dataset = dataset.select(range(len(dataset) - eval_samples, len(dataset))) + dataset_text_field = "prompt" + + def prepare_dataset(dataset, tokenizer): + """pre-tokenize the dataset before training; only collate during training""" + + def tokenize(element): + outputs = tokenizer( + element[dataset_text_field], + padding=False, + ) + return {"input_ids": outputs["input_ids"]} + + return dataset.map( + tokenize, + batched=True, + remove_columns=dataset.column_names, + num_proc=training_args.dataset_num_proc, + ) + + # Compute that only on the main process for faster data processing. + # see: https://github.com/huggingface/trl/pull/1255 + with PartialState().local_main_process_first(): + train_dataset = prepare_dataset(train_dataset, tokenizer) + eval_dataset = prepare_dataset(eval_dataset, tokenizer) + + ################ + # Training + ################ + trainer = VASTrainer( + config=training_args, + processing_class=tokenizer, + ref_policy=ref_policy, + reward_model=reward_model, + value_model=value_model, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + ) + trainer.train() + + # Save and push to hub + trainer.save_model(training_args.output_dir) + if training_args.push_to_hub: + trainer.push_to_hub(dataset_name=script_args.dataset_name) + + trainer.generate_completions() diff --git a/tests/test_vas_trainer.py b/tests/test_vas_trainer.py new file mode 100644 index 0000000000..32b988d1c8 --- /dev/null +++ b/tests/test_vas_trainer.py @@ -0,0 +1,81 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import tempfile +import unittest + +from datasets import load_dataset +from parameterized import parameterized +from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer + +from trl import VASTrainer, VASConfig +from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE + + +class TestVASTrainer(unittest.TestCase): + def setUp(self): + self.model_id = "trl-internal-testing/dummy-GPT2-correct-vocab" + self.value_model = AutoModelForSequenceClassification.from_pretrained(self.model_id, num_labels=1) + self.ref_model = AutoModelForCausalLM.from_pretrained(self.model_id) + self.reward_model = AutoModelForSequenceClassification.from_pretrained("EleutherAI/pythia-14m", num_labels=1) + self.reward_tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-14m") + self.reward_tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE + self.reward_tokenizer.pad_token = self.reward_tokenizer.eos_token + self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) + self.tokenizer.pad_token = self.tokenizer.eos_token + + @parameterized.expand([("standard_prompt_only",)]) + def test_training(self, config_name): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = VASConfig( + output_dir=tmp_dir, + lam=0.95, + batch_size=1, + total_episodes=10, + learning_rate=5.0e-7, + per_device_eval_batch_size=2, + report_to="none", + ) + dummy_dataset = load_dataset("trl-internal-testing/zen", config_name) + def prepare_dataset(dataset, tokenizer): + """pre-tokenize the dataset before training; only collate during training""" + + def tokenize(element): + outputs = tokenizer( + element["prompt"], + padding=False, + ) + return {"input_ids": outputs["input_ids"]} + + return dataset.map( + tokenize, + batched=True, + remove_columns=dataset.column_names, + num_proc=training_args.dataset_num_proc, + ) + + + dummy_dataset['train'] = prepare_dataset(dummy_dataset['train'], self.tokenizer) + dummy_dataset['test'] = prepare_dataset(dummy_dataset['test'], self.tokenizer) + trainer = VASTrainer( + ref_policy=self.ref_model, + reward_model=self.reward_model, + value_model=self.value_model, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + processing_class=self.tokenizer, + config=training_args, + ) + trainer.train() + + self.assertIn("loss/value_loss", trainer.state.log_history[-1]) diff --git a/trl/__init__.py b/trl/__init__.py index 1c12c2ade9..32a4b34107 100644 --- a/trl/__init__.py +++ b/trl/__init__.py @@ -87,6 +87,8 @@ "RLOOTrainer", "SFTConfig", "SFTTrainer", + "VASConfig", + "VASTrainer", "WinRateCallback", "XPOConfig", "XPOTrainer", @@ -176,6 +178,8 @@ RLOOTrainer, SFTConfig, SFTTrainer, + VASConfig, + VASTrainer, WinRateCallback, XPOConfig, XPOTrainer, diff --git a/trl/extras/vas_sampler.py b/trl/extras/vas_sampler.py new file mode 100644 index 0000000000..f07dc17a5d --- /dev/null +++ b/trl/extras/vas_sampler.py @@ -0,0 +1,212 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Union + +import torch +from transformers import GenerationConfig, LogitsProcessor + +from ..core import set_seed +from ..models import ( + SUPPORTED_ARCHITECTURES, + AutoModelForCausalLMWithValueHead, + AutoModelForSeq2SeqLMWithValueHead, + PreTrainedModelWrapper, +) + + +class VASSampler: + def __init__( + self, + model: PreTrainedModelWrapper, + value_model: Union[AutoModelForSeq2SeqLMWithValueHead, AutoModelForCausalLMWithValueHead], + beta: float = 1.0, + top_k: int = 10, + value_model_batch_size: int = 1, + seed: Optional[int] = None, + generation_config: Optional[GenerationConfig] = None, + ) -> None: + """ + VASSampler is used to generate responses from a model trained with the VAS framework (see /trainers/VASTrainer.py). + Args: + model (`PreTrainedModelWrapper`): + The pretrained model to use for generation + value_model (`AutoModelForSeq2SeqLMWithValueHead` or `AutoModelForCausalLMWithValueHead`): + The Pretrained Value model use to augment the sampling process + beta (`float`): + The value to use for weighting the Value outputs versus the logits of the LLM + top_k (`int`): + The number of top-k tokens that will be evaluated using the Value model + value_model_batch_size (`int`): + Batch size for the Value model, can be different from the batch size of the LLM + seed (`int`, *optional*): + Random seed used to control generation + generation_config (`GenerationConfig`, *optional*): + Generation config passed to the underlying model's `generate` method. + See `GenerationConfig` (https://huggingface.co/docs/transformers/v4.29.1/en/main_classes/text_generation#transformers.GenerationConfig) for more details + + """ + if seed is not None: + set_seed(seed) + + self.model = model + self.value_model = value_model + self.beta = beta + self.top_k = top_k + self.value_model_batch_size = value_model_batch_size + self.gen_config = generation_config + + # Create a VAS logits processor + self.logits_processor = VASLogitsProcessor( + self.value_model, beta=self.beta, top_k=self.top_k, value_model_batch_size=self.value_model_batch_size + ) + + def generate( + self, + tokenized_query: Union[List[int], torch.Tensor, List[torch.Tensor], List[List[int]]], + attention_mask: Optional[Union[List[int], torch.Tensor, List[torch.Tensor], List[List[int]]]] = None, + device: Optional[Union[str, torch.device]] = None, + **generation_kwargs, + ) -> List[List[str]]: + """ + Generate a response using the VAS framework. + + Args: + tokenized_query (`List[int]` or `torch.Tensor` or `List[torch.Tensor]` or `List[int]`): + represents either a single tokenized query (a single tensor or a list of integers) or a batch of tokenized queries (a list of tensors or a list of lists of integers) + device (`str` or `torch.device`, *optional*): + The device on which the model will be loaded + **generation_kwargs (`dict`, *optional*): + Additional keyword arguments passed along to the underlying model's `generate` method. + This is used to override generation config + + Returns: + List[List[str]]: A list of lists of generated texts + """ + if device is None: + device = tokenized_query[0].device if isinstance(tokenized_query, torch.Tensor) else "cpu" + + # generate the response + outputs = self.model.generate( + tokenized_query.to(device), + attention_mask=attention_mask.to(device) if attention_mask is not None else None, + logits_processor=[self.logits_processor,], + generation_config=self.gen_config, + **generation_kwargs, + ) + + return outputs + + +class VASLogitsProcessor(LogitsProcessor, torch.nn.Module): + """ + + value_model: AutoModelForCausalLMWithValueHead, the Value model to use + beta: float, the beta value to use for weighting the q model + topk: int, the number of topk to use for the Value model + value_model_batch_size: int, the batch suze of tokens to evaluate at once + """ + + def __init__( + self, + value_model: Union[AutoModelForSeq2SeqLMWithValueHead, AutoModelForCausalLMWithValueHead], + beta: float = 1.0, + top_k: int = 10, + value_model_batch_size: int = 1, + ): + """ + A logit processor that augment the output logits with Value model as per the VAS decoding scheme. + + Args: + value_model (`AutoModelForSeq2SeqLMWithValueHead` or `AutoModelForCausalLMWithValueHead`): + The Pretrained Value model use to augment the sampling process + beta (`float`): + The value to use for weighting the Value outputs versus the logits of the LLM + top_k (`int`): + The number of top-k tokens that will be evaluated using the Value model + value_model_batch_size (`int`): + Batch size for the Value model, can be different from the batch size of the LLM + """ + super().__init__() + self.value_model = value_model + self.beta = beta + self.top_k = top_k + self.value_model_batch_size = value_model_batch_size + self.pad_token_id = self.value_model.value_model.config.pad_token_id + + assert self.top_k > 0, "topk must be larger than zero" + + self.last_input_ids = None + self.past_key_values = None + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + augmented_outputs = torch.clone(scores) + batch_size = input_ids.shape[0] + + orig_input_ids = input_ids + attention_mask = input_ids != 0 + position_ids = attention_mask.cumsum(1) - attention_mask.long() + + if ( + self.last_input_ids is not None + and (input_ids[0, :-1].shape == self.last_input_ids.shape) + and torch.all(input_ids[0, :-1] == self.last_input_ids) + ): + # if the last input ids are the same as the current input ids, we can reuse the past key values + _, past_key_values = self.value_model(input_ids[:, -1:], + attention_mask=attention_mask[:, -1:], + position_ids=position_ids[:, -1:], + past_key_values=self.past_key_values, + return_past_key_values=True, + return_dict=True, + output_hidden_states=True,) + else: + _, past_key_values = self.value_model(input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + return_past_key_values=True, + return_dict=True, + output_hidden_states=True,) + self.past_key_values = past_key_values + self.last_input_ids = input_ids[0, :] + + values = torch.zeros_like(scores, device=scores.device) + topk_ids = torch.topk(scores, self.top_k, dim=-1).indices + + for i in range(0, topk_ids.shape[1], self.value_model_batch_size): + curr_topk_ids = topk_ids[:, i : i + self.value_model_batch_size] + curr_input_ids = curr_topk_ids + curr_input_ids = curr_input_ids.reshape((batch_size * self.value_model_batch_size, -1)) + curr_attention_mask = curr_input_ids != 0 + curr_position_ids = curr_attention_mask.cumsum(1) - curr_attention_mask.long() + + value, _ = self.value_model(curr_input_ids, + attention_mask=curr_attention_mask, + position_ids=curr_position_ids, + past_key_values=tuple( + (t1.repeat(curr_topk_ids.shape[1], 1, 1, 1), + t2.repeat(curr_topk_ids.shape[1], 1, 1, 1)) + for t1, t2 in self.past_key_values), + return_past_key_values=True, + return_dict=True, + output_hidden_states=True,) + value = value.reshape((batch_size, self.value_model_batch_size, -1))[:, :, -1].to(values.dtype) + values = values.scatter_(1, curr_topk_ids, value) + + values = values.scatter_( + 1, topk_ids, values.gather(1, topk_ids) - values.gather(1, topk_ids).mean(-1, keepdim=True) + ) + augmented_outputs += self.beta * values + + return augmented_outputs diff --git a/trl/trainer/__init__.py b/trl/trainer/__init__.py index f0eba412c6..2bc8c309d8 100644 --- a/trl/trainer/__init__.py +++ b/trl/trainer/__init__.py @@ -54,6 +54,8 @@ "orpo_trainer": ["ORPOTrainer"], "ppo_config": ["PPOConfig"], "ppo_trainer": ["PPOTrainer"], + "vas_config": ["VASConfig"], + "vas_trainer": ["VASTrainer"], "ppov2_config": ["PPOv2Config"], "ppov2_trainer": ["PPOv2Trainer"], "reward_config": ["RewardConfig"], @@ -136,6 +138,8 @@ empty_cache, peft_module_casting_to_bf16, ) + from .vas_config import VASConfig + from .vas_trainer import VASTrainer from .xpo_config import XPOConfig from .xpo_trainer import XPOTrainer diff --git a/trl/trainer/vas_config.py b/trl/trainer/vas_config.py new file mode 100644 index 0000000000..372dd87479 --- /dev/null +++ b/trl/trainer/vas_config.py @@ -0,0 +1,61 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from dataclasses import dataclass + +from ..trainer.utils import OnPolicyConfig + + +@dataclass +class VASConfig(OnPolicyConfig): + r""" + Configuration class for the [`VASTrainer`]. + + Using [`~transformers.HfArgumentParser`] we can turn this class into + [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the + command line. + + Parameters: + exp_name (`str`, *optional*, defaults to `os.path.basename(__file__)[:-3]`): + Name of this experiment. + reward_model_path (`str`, *optional*, defaults to `"EleutherAI/pythia-160m"`): + Path to the reward model. + num_ppo_epochs (`int`, *optional*, defaults to `4`): + Number of epochs to train. + whiten_rewards (`bool`, *optional*, defaults to `False`): + Whether to whiten the rewards. + kl_coef (`float`, *optional*, defaults to `0.05`): + KL coefficient. + cliprange (`float`, *optional*, defaults to `0.2`): + Clip range. + vf_coef (`float`, *optional*, defaults to `0.1`): + Value function coefficient. + cliprange_value (`float`, *optional*, defaults to `0.2`): + Clip range for the value function. + gamma (`float`, *optional*, defaults to `1.0`): + Discount factor. + lam (`float`, *optional*, defaults to `0.95`): + Lambda value for GAE. + """ + + exp_name: str = os.path.basename(__file__)[: -len(".py")] + reward_model_path: str = "EleutherAI/pythia-160m" + num_vas_epochs: int = 2 + whiten_rewards: bool = False + gamma: float = 1.0 + lam: float = 0.95 + generation_beta: float = 2.0 + save_safetensors: bool = False + num_sample_generations: int = 0 diff --git a/trl/trainer/vas_trainer.py b/trl/trainer/vas_trainer.py new file mode 100644 index 0000000000..b20c390d39 --- /dev/null +++ b/trl/trainer/vas_trainer.py @@ -0,0 +1,609 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import math +import os +import textwrap +import time +from collections import defaultdict +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +import torch.nn.functional as F +from accelerate import Accelerator +from accelerate.utils import broadcast, gather_object +from datasets import Dataset +from torch.utils.data import DataLoader +from transformers import ( + BaseImageProcessor, + DataCollatorWithPadding, + FeatureExtractionMixin, + GenerationConfig, + PreTrainedTokenizerBase, + ProcessorMixin, + Trainer, + TrainerCallback, + TrainerControl, + is_wandb_available, +) +from transformers.integrations import get_reporting_integration_callbacks +from transformers.trainer import DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK +from transformers.trainer_callback import CallbackHandler, ExportableState, PrinterCallback + +from ..core import masked_mean, masked_whiten +from ..models.utils import unwrap_model_for_generation +from ..trainer.utils import ( + OnlineTrainerState, + batch_generation, + disable_dropout_in_model, + exact_div, + first_true_indices, + forward, + get_reward, + prepare_deepspeed, + print_rich_table, + truncate_response, +) +from .vas_config import VASConfig +from ..extras.vas_sampler import VASSampler +from .utils import generate_model_card + + +if is_wandb_available(): + import wandb + + +INVALID_LOGPROB = 1.0 + + +# taken from https://github.com/OpenLMLab/MOSS-RLHF/blob/40b91eb2f2b71b16919addede0341d2bef70825d/ppo/ppo_trainer.py#L29 +# we did this we can do a single `model = accelerator.prepare(model)` +class ValueWrapper(nn.Module): + def __init__(self, value_model) -> None: + super().__init__() + self.value_model = value_model + self.critic_backbone = getattr(value_model, value_model.base_model_prefix) + + def forward(self, input_ids, return_past_key_values=False, **kwargs): + output = self.critic_backbone( + input_ids, + **kwargs, + ) + logits = self.value_model.score(output.hidden_states[-1]) + + if return_past_key_values: + return logits, output.past_key_values + else: + return logits, None + + +class VASTrainer(Trainer): + _tag_names = ["trl", "vas"] + + def __init__( + self, + config: VASConfig, + processing_class: Optional[ + Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] + ], + ref_policy: nn.Module, + reward_model: nn.Module, + value_model: nn.Module, + train_dataset: Dataset, + data_collator: Optional[DataCollatorWithPadding] = None, + eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, + # less commonly used + optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), + callbacks: Optional[List[TrainerCallback]] = None, + ) -> None: + + self.args = config + args = config + self.processing_class = processing_class + + self.ref_policy = ref_policy + self.ref_policy.generation_config.eos_token_id = ( + None # disable `pad_token_id` and `eos_token_id` because we just want to + ) + self.ref_policy.generation_config.pad_token_id = None # generate tokens without truncation / padding + + self.reward_model = reward_model + self.train_dataset = train_dataset + self.train_dataset_len = len(train_dataset) + self.value_model = value_model + self.data_collator = data_collator + self.eval_dataset = eval_dataset + self.optimizer, self.lr_scheduler = optimizers + self.optimizer_cls_and_kwargs = None # needed for transformers >= 4.47 + + ######### + # calculate various batch sizes + ######### + if args.total_episodes is None: # allow the users to define episodes in terms of epochs. + args.total_episodes = int(args.num_train_epochs * self.train_dataset_len) + accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision="fp16") + self.accelerator = accelerator + args.world_size = accelerator.num_processes + args.local_batch_size = ( + args.per_device_train_batch_size * args.gradient_accumulation_steps * args.num_mini_batches + ) + args.micro_batch_size = int(args.per_device_train_batch_size * args.world_size) + args.batch_size = int(args.local_batch_size * args.world_size) + args.mini_batch_size = exact_div( + args.batch_size, args.num_mini_batches, "`batch_size` must be a multiple of `num_mini_batches`" + ) + args.local_mini_batch_size = exact_div( + args.local_batch_size, args.num_mini_batches, "`local_batch_size` must be a multiple of `num_mini_batches`" + ) + if args.whiten_rewards: + assert ( + args.local_mini_batch_size >= 8 + ), f"Per-rank minibatch size {args.local_mini_batch_size} is insufficient for whitening" + # `per_rank_rollout_batch_size` is our `args.local_batch_size` + # `per_rank_minibatch_size` is our `args.local_mini_batch_size` + args.num_total_batches = math.ceil( + args.total_episodes / args.batch_size + ) # we may train for more than `total_episodes` + time_tensor = torch.tensor(int(time.time()), device=accelerator.device) + time_int = broadcast(time_tensor, 0).item() # avoid different timestamps across processes + args.run_name = f"{args.exp_name}__{args.seed}__{time_int}" + self.local_seed = args.seed + accelerator.process_index * 100003 # Prime + if args.num_sample_generations > 0: + self.sample_generations_freq = max(1, args.num_total_batches // args.num_sample_generations) + self.local_dataloader_batch_size = args.local_batch_size + + ######### + # setup model, optimizer, and others + ######### + for module in [ref_policy, value_model, reward_model]: + disable_dropout_in_model(module) + if args.stop_token and args.stop_token == "eos": + args.stop_token_id = processing_class.eos_token_id + self.model = ValueWrapper(value_model) + self.create_optimizer_and_scheduler( + num_training_steps=args.num_total_batches + ) # note that we are calling `self.lr_scheduler.step()` manually only at the batch level + + ######### + ### trainer specifics + ######### + default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to) + self.callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks + self.callback_handler = CallbackHandler( + self.callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler + ) + self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK) + self.control = TrainerControl() + self.state = OnlineTrainerState( + is_local_process_zero=self.is_local_process_zero(), + is_world_process_zero=self.is_world_process_zero(), + stateful_callbacks=[ + cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState) + ], + ) + self.current_flos = 0 + self.hp_search_backend = None + self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None + self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None + # Create distant repo and output directory if needed + self.hub_model_id = None + if self.args.push_to_hub: + self.init_hf_repo() + if self.args.should_save: + os.makedirs(self.args.output_dir, exist_ok=True) + + # Add tags for models that have been loaded with the correct transformers version + if hasattr(self.model, "add_model_tags"): + self.model.add_model_tags(self._tag_names) + + ######### + ### setup dataloader + ######### + self.dataloader = DataLoader( + self.train_dataset, + batch_size=self.local_dataloader_batch_size, + shuffle=True, + collate_fn=DataCollatorWithPadding(self.processing_class), + drop_last=True, # needed; otherwise the last batch will be of ragged shape + ) + # sync random states for DataLoader(shuffle=True) before `accelerator.prepare` + # see https://gist.github.com/vwxyzjn/2581bff1e48e185e0b85b6dfe1def79c + torch.manual_seed(args.seed) + self.model, self.optimizer, self.dataloader = accelerator.prepare(self.model, self.optimizer, self.dataloader) + torch.manual_seed(self.local_seed) # reset the local seed again + + self.eval_dataloader = DataLoader( + self.eval_dataset, + batch_size=args.per_device_eval_batch_size, + collate_fn=DataCollatorWithPadding(self.processing_class), + drop_last=True, + ) # no need to shuffle eval dataset + self.eval_dataloader = accelerator.prepare(self.eval_dataloader) + + if self.is_deepspeed_enabled: + self.reward_model = prepare_deepspeed( + self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16 + ) + self.ref_policy = prepare_deepspeed( + self.ref_policy, args.per_device_train_batch_size, args.fp16, args.bf16 + ) + else: + self.ref_policy = self.ref_policy.to(self.accelerator.device) + self.reward_model = self.reward_model.to(self.accelerator.device) + + def get_train_dataloader(self) -> DataLoader: + return self.dataloader + + def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False): + backup_model = self.model + self.model = self.model.value_model # save the value model + + if self.is_deepspeed_enabled: + backup_deepspeed = self.deepspeed + self.deepspeed = self.model + + super().save_model(output_dir, _internal_call) + + self.model = backup_model + + if self.is_deepspeed_enabled: + self.deepspeed = backup_deepspeed + + def train(self): + args = self.args + accelerator = self.accelerator + optimizer = self.optimizer + model = self.model + ref_policy = self.ref_policy + reward_model = self.reward_model + processing_class = self.processing_class + dataloader = self.dataloader + device = accelerator.device + + def repeat_generator(): + while True: + yield from dataloader + + iter_dataloader = iter(repeat_generator()) + generation_config = GenerationConfig( + max_new_tokens=args.response_length, + temperature=(args.temperature + 1e-7), + top_k=0.0, + top_p=1.0, + do_sample=True, + ) + + accelerator.print("===training policy===") + start_time = time.time() + stats_shape = (args.num_vas_epochs, args.num_mini_batches, args.gradient_accumulation_steps) + model.train() + + # trainer state initialization + self.state.global_step = 0 + self.state.episode = 0 + self.state.max_steps = args.num_total_batches * args.num_mini_batches + self.state.num_train_epochs = args.total_episodes / self.train_dataset_len + # Compute absolute values for logging and save if given as ratio + if args.logging_steps is not None: + if args.logging_steps < 1: + self.state.logging_steps = math.ceil(self.state.max_steps * args.logging_steps) + else: + self.state.logging_steps = args.logging_steps + if args.save_steps is not None: + if args.save_steps < 1: + self.state.save_steps = math.ceil(self.state.max_steps * args.save_steps) + else: + self.state.save_steps = args.save_steps + self.control = self.callback_handler.on_train_begin(args, self.state, self.control) + + # backward compatibility + if self.is_deepspeed_enabled: + self.deepspeed = self.model + self.model_wrapped = self.model + + for update in range(1, args.num_total_batches + 1): + self.state.episode += 1 * args.batch_size + data = next(iter_dataloader) + with torch.no_grad(): + queries = data["input_ids"].to(device) + context_length = queries.shape[1] + responses = [] + postprocessed_responses = [] + scores = [] + sequence_lengths = [] + values = [] + with unwrap_model_for_generation(self.ref_policy, self.accelerator) as unwrapped_model: + query_responses, logitss = batch_generation( + unwrapped_model, + queries, + args.local_rollout_forward_batch_size, + processing_class.pad_token_id, + generation_config, + ) + + for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size): + query = queries[i : i + args.local_rollout_forward_batch_size] + query_response = query_responses[i : i + args.local_rollout_forward_batch_size] + response = query_response[:, context_length:] + + # Response Processing 1. truncate response after the first occurrence of `stop_token_id` + postprocessed_response = response + if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 + postprocessed_response = truncate_response( + args.stop_token_id, processing_class.pad_token_id, response + ) + + # Response Processing 2. run reward model on the truncated responses + postprocessed_query_response = torch.cat((query, postprocessed_response), 1) + sequence_length = first_true_indices(postprocessed_response == processing_class.pad_token_id) - 1 + + full_value, _ = forward(model, query_response, processing_class.pad_token_id) + value = full_value[:, context_length - 1 : -1].squeeze(-1) + _, score, _ = get_reward( + reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length + ) + + responses.append(response) + postprocessed_responses.append(postprocessed_response) + sequence_lengths.append(sequence_length) + scores.append(score) + values.append(value) + responses = torch.cat(responses, 0) + postprocessed_responses = torch.cat(postprocessed_responses, 0) + sequence_lengths = torch.cat(sequence_lengths, 0) + scores = torch.cat(scores, 0) + values = torch.cat(values, 0) + del (full_value, value, score, unwrapped_model) + torch.cuda.empty_cache() + gc.collect() + + # Response Processing 3. Filter completion. Ensure that the sample contains stop_token_id + # Completions not passing that filter will receive a lower score. + contain_eos_token = torch.any(postprocessed_responses == self.processing_class.eos_token_id, dim=-1) + if self.args.missing_eos_penalty is not None: + scores[~contain_eos_token] -= self.args.missing_eos_penalty + # accelerator.print(f"{scores=}, {(contain_eos_token.sum() / len(contain_eos_token))=}") + + # be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw + response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1) + padding_mask = response_idxs > sequence_lengths.unsqueeze(1) + sequence_lengths_p1 = sequence_lengths + 1 + padding_mask_p1 = response_idxs > (sequence_lengths_p1.unsqueeze(1)) + values = torch.masked_fill(values, padding_mask_p1, 0) + + # 4. compute rewards + rewards = torch.zeros_like(values) + actual_start = torch.arange(rewards.size(0), device=rewards.device) + actual_end = torch.where(sequence_lengths_p1 < rewards.size(1), sequence_lengths_p1, sequence_lengths) + rewards[[actual_start, actual_end]] += scores + + # 5. whiten rewards + if args.whiten_rewards: + rewards = masked_whiten(rewards, mask=~padding_mask_p1, shift_mean=False) + rewards = torch.masked_fill(rewards, padding_mask_p1, 0) + + # 6. compute advantages and returns + lastgaelam = 0 + advantages_reversed = [] + gen_length = responses.shape[1] + for t in reversed(range(gen_length)): + nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0 + delta = rewards[:, t] + args.gamma * nextvalues - values[:, t] + lastgaelam = delta + args.gamma * args.lam * lastgaelam + advantages_reversed.append(lastgaelam) + advantages = torch.stack(advantages_reversed[::-1], axis=1) + returns = advantages + values + advantages = masked_whiten(advantages, ~padding_mask) + advantages = torch.masked_fill(advantages, padding_mask, 0) + torch.cuda.empty_cache() + + # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch + for vas_epoch_idx in range(args.num_vas_epochs): + b_inds = np.random.permutation(args.local_batch_size) + minibatch_idx = 0 + for mini_batch_start in range(0, args.local_batch_size, args.local_mini_batch_size): + mini_batch_end = mini_batch_start + args.local_mini_batch_size + mini_batch_inds = b_inds[mini_batch_start:mini_batch_end] + gradient_accumulation_idx = 0 + for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size): + with accelerator.accumulate(model): + micro_batch_end = micro_batch_start + args.per_device_train_batch_size + micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end] + mb_advantage = advantages[micro_batch_inds] + mb_responses = responses[micro_batch_inds] + mb_query_responses = query_responses[micro_batch_inds] + mb_return = returns[micro_batch_inds] + mb_values = values[micro_batch_inds] + + vpred, _ = forward(model, mb_query_responses, processing_class.pad_token_id) + vpred = vpred[:, context_length - 1 : -1].squeeze(-1) + vpred = torch.masked_fill(vpred, padding_mask_p1[micro_batch_inds], 0) + vf_loss = torch.square(vpred - mb_return) + vf_loss = 0.5 * masked_mean(vf_loss, ~padding_mask_p1[micro_batch_inds]) + loss = vf_loss + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + gradient_accumulation_idx += 1 + minibatch_idx += 1 + # del everything and empty cache + # fmt: off + del (vpred, vf_loss, mb_return, mb_advantage, mb_values, mb_responses, mb_query_responses, ) + # fmt: on + torch.cuda.empty_cache() + with torch.no_grad(): + eps = int(self.state.episode / (time.time() - start_time)) + metrics = {} + metrics["eps"] = eps + metrics["objective/scores"] = self.accelerator.gather(scores.mean()).mean().item() + metrics["loss/value_loss"] = self.accelerator.gather(loss).mean().item() + metrics["num_eos_tokens"] = (responses == processing_class.eos_token_id).sum().item() + metrics["lr"] = self.lr_scheduler.get_last_lr()[0] + metrics["episode"] = self.state.episode + self.state.epoch = self.state.episode / self.train_dataset_len # used by self.log + self.state.global_step += 1 + self.log(metrics) + + self.lr_scheduler.step() + self.control = self.callback_handler.on_step_end(args, self.state, self.control) + if self.control.should_save: + self._save_checkpoint(model, trial=None) + self.control = self.callback_handler.on_save(self.args, self.state, self.control) + del scores, metrics + torch.cuda.empty_cache() + gc.collect() + + if args.num_sample_generations > 0 and (update - 1) % self.sample_generations_freq == 0: + self.generate_completions(sampling=True) + torch.cuda.empty_cache() + del ( + query_responses, + responses, + postprocessed_responses, + values, + sequence_lengths, + contain_eos_token, + sequence_lengths_p1, + response_idxs, + padding_mask, + padding_mask_p1, + rewards, + actual_start, + actual_end, + advantages, + returns, + ) + torch.cuda.empty_cache() + + # HF trainer specifics + self.control = self.callback_handler.on_train_end(args, self.state, self.control) + if self.control.should_save: + self._save_checkpoint(model, trial=None, metrics=None) + self.control = self.callback_handler.on_save(self.args, self.state, self.control) + + def generate_completions(self, sampling: bool = False): + args = self.args + processing_class = self.processing_class + generation_config = GenerationConfig( + max_new_tokens=self.args.response_length, + temperature=(0.01 + 1e-7), + top_k=0.0, + top_p=1.0, + do_sample=True, + ) + + table = defaultdict(list) + with unwrap_model_for_generation(self.ref_policy, self.accelerator) as unwrapped_ref_policy: + for batch in self.eval_dataloader: + query = batch["input_ids"] + with torch.no_grad(): + sampler = VASSampler( + model=unwrapped_ref_policy, + value_model=self.model, + generation_config=generation_config, + beta=args.generation_beta, + ) + context_length = query.shape[1] + attention_mask = query != processing_class.pad_token_id + input_ids = torch.masked_fill(query, ~attention_mask, 0) + query_response = sampler.generate(input_ids, attention_mask) + response = query_response[:, context_length:] + postprocessed_response = response + if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 + postprocessed_response = truncate_response( + args.stop_token_id, processing_class.pad_token_id, response + ) + table["query"].extend( + gather_object(processing_class.batch_decode(query, skip_special_tokens=True)) + ) + table["model response"].extend( + gather_object(processing_class.batch_decode(postprocessed_response)) + ) + + postprocessed_query_response = torch.cat((query, postprocessed_response), 1) + _, score, _ = get_reward( + self.reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length + ) + table["score"].extend(self.accelerator.gather(score).float().cpu().numpy()) + + if sampling: + break + df = pd.DataFrame(table) + + if self.accelerator.is_main_process: + print_rich_table(df.iloc[0 : 0 + 5]) + if "wandb" in args.report_to: + import wandb + + if wandb.run is not None: + wandb.log({"completions": wandb.Table(dataframe=df)}) + + def create_model_card( + self, + model_name: Optional[str] = None, + dataset_name: Optional[str] = None, + tags: Union[str, List[str], None] = None, + ): + """ + Creates a draft of a model card using the information available to the `Trainer`. + + Args: + model_name (`str`, *optional*, defaults to `None`): + The name of the model. + dataset_name (`str`, *optional*, defaults to `None`): + The name of the dataset used for training. + tags (`str`, `List[str]` or `None`, *optional*, defaults to `None`): + Tags to be associated with the model card. + """ + if not self.is_world_process_zero(): + return + + if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path): + base_model = self.model.config._name_or_path + else: + base_model = None + + tags = tags or [] + if isinstance(tags, str): + tags = [tags] + + if hasattr(self.model.config, "unsloth_version"): + tags.append("unsloth") + + citation = textwrap.dedent("""\ + @article{mziegler2019fine-tuning, + title = {{Fine-Tuning Language Models from Human Preferences}}, + author = {Daniel M. Ziegler and Nisan Stiennon and Jeffrey Wu and Tom B. Brown and Alec Radford and Dario Amodei and Paul F. Christiano and Geoffrey Irving}, + year = 2019, + eprint = {arXiv:1909.08593} + }""") + + model_card = generate_model_card( + base_model=base_model, + model_name=model_name, + hub_model_id=self.hub_model_id, + dataset_name=dataset_name, + tags=tags, + wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, + trainer_name="PPO", + trainer_citation=citation, + paper_title="Fine-Tuning Language Models from Human Preferences", + paper_id="1909.08593", + ) + + model_card.save(os.path.join(self.args.output_dir, "README.md"))