diff --git a/docs/source/reward_trainer.mdx b/docs/source/reward_trainer.mdx
index e3ea9e1985..249ec9bde1 100644
--- a/docs/source/reward_trainer.mdx
+++ b/docs/source/reward_trainer.mdx
@@ -6,29 +6,27 @@ Check out a complete flexible example inside [`examples/scripts`](https://github
## Expected dataset format
-The reward trainer expects a very specific format for the dataset. Since the model will be trained to predict which sentence is the most relevant, given two sentences. We provide an example from the [`Anthropic/hh-rlhf`](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset below:
+The [`RewardTrainer`] expects a very specific format for the dataset since the model will be trained on pairs of examples to predict which of the two is preferred. We provide an example from the [`Anthropic/hh-rlhf`](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset below:
-Therefore the final dataset object should contain two 4 entries at least if you use the default `RewardDataCollatorWithPadding` data collator. The entries should be named:
+Therefore the final dataset object should contain two 4 entries at least if you use the default [`RewardDataCollatorWithPadding`] data collator. The entries should be named:
- `input_ids_chosen`
- `attention_mask_chosen`
- `input_ids_rejected`
- `attention_mask_rejected`
-The `j` and `k` suffixes are used to denote the two sentences in the paired dataset.
-
## Using the `RewardTrainer`
-After standardizing your dataset, you can use the `RewardTrainer` as a classic Hugging Face Trainer.
-You should pass an `AutoModelForSequenceClassification` model to the `RewardTrainer`.
+After preparing your dataset, you can use the [`RewardTrainer`] in the same way as the `Trainer` class from 🤗 Transformers.
+You should pass an `AutoModelForSequenceClassification` model to the [`RewardTrainer`], along with a [`RewardConfig`] which configures the hyperparameters of the training.
-### Leveraging the `peft` library to train a reward model
+### Leveraging 🤗 PEFT to train a reward model
-Just pass a `peft_config` in the key word arguments of `RewardTrainer`, and the trainer should automatically take care of converting the model into a PEFT model!
+Just pass a `peft_config` in the keyword arguments of [`RewardTrainer`], and the trainer should automatically take care of converting the model into a PEFT model!
```python
from peft import LoraConfig, task_type
diff --git a/examples/scripts/reward_trainer.py b/examples/scripts/reward_trainer.py
index 4382d21ac1..0e6007c28e 100644
--- a/examples/scripts/reward_trainer.py
+++ b/examples/scripts/reward_trainer.py
@@ -15,6 +15,7 @@
from dataclasses import dataclass, field
from typing import Optional
+from accelerate import Accelerator
from datasets import load_dataset
from peft import LoraConfig
from tqdm import tqdm
@@ -26,15 +27,14 @@
tqdm.pandas()
-# Define and parse arguments.
@dataclass
class ScriptArguments:
"""
- The name of the Casual LM model we wish to fine with RewardTrainer
+ Hyperparameters to fine-tune a reward model on a given dataset with the `RewardTrainer`.
"""
model_name: Optional[str] = field(default="facebook/opt-350m", metadata={"help": "the model name"})
- dataset_name: Optional[str] = field(default="Anthropic/hh-rlhf", metadata={"help": "the model name"})
+ dataset_name: Optional[str] = field(default="Anthropic/hh-rlhf", metadata={"help": "the dataset name"})
dataset_text_field: Optional[str] = field(default="text", metadata={"help": "the text field of the dataset"})
log_with: Optional[str] = field(default=None, metadata={"help": "use 'wandb' to log with wandb"})
logging_steps: Optional[int] = field(default=500, metadata={"help": "the number of update steps between two logs"})
@@ -48,6 +48,7 @@ class ScriptArguments:
gradient_accumulation_steps: Optional[int] = field(
default=16, metadata={"help": "the number of gradient accumulation steps"}
)
+ gradient_checkpointing: Optional[bool] = field(default=True, metadata={"help": "Enable gradient checkpointing"})
load_in_8bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 8 bits precision"})
load_in_4bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 4 bits precision"})
use_peft: Optional[bool] = field(default=False, metadata={"help": "Wether to use PEFT or not to train adapters"})
@@ -65,8 +66,8 @@ class ScriptArguments:
quantization_config = BitsAndBytesConfig(
load_in_8bit=script_args.load_in_8bit, load_in_4bit=script_args.load_in_4bit
)
- # This means: fit the entire model on the GPU:0
- device_map = {"": 0}
+ # Copy the model to each device
+ device_map = {"": Accelerator().local_process_index}
else:
device_map = None
quantization_config = None
@@ -84,11 +85,8 @@ class ScriptArguments:
train_dataset = load_dataset(script_args.dataset_name, split="train")
-# Turn the dataset into pairs of post + summaries, where text_j is the preferred question + answer and text_k is the other.
-# Then tokenize the dataset.
+# Tokenize chosen/rejected pairs of inputs
# Adapt this section to your needs for custom datasets
-
-
def preprocess_function(examples):
new_examples = {
"input_ids_chosen": [],
@@ -97,18 +95,18 @@ def preprocess_function(examples):
"attention_mask_rejected": [],
}
for chosen, rejected in zip(examples["chosen"], examples["rejected"]):
- tokenized_j = tokenizer(chosen, truncation=True)
- tokenized_k = tokenizer(rejected, truncation=True)
+ tokenized_chosen = tokenizer(chosen, truncation=True)
+ tokenized_rejected = tokenizer(rejected, truncation=True)
- new_examples["input_ids_chosen"].append(tokenized_j["input_ids"])
- new_examples["attention_mask_chosen"].append(tokenized_j["attention_mask"])
- new_examples["input_ids_rejected"].append(tokenized_k["input_ids"])
- new_examples["attention_mask_rejected"].append(tokenized_k["attention_mask"])
+ new_examples["input_ids_chosen"].append(tokenized_chosen["input_ids"])
+ new_examples["attention_mask_chosen"].append(tokenized_chosen["attention_mask"])
+ new_examples["input_ids_rejected"].append(tokenized_rejected["input_ids"])
+ new_examples["attention_mask_rejected"].append(tokenized_rejected["attention_mask"])
return new_examples
-# preprocess the dataset and filter out QAs that are longer than script_args.max_length
+# Preprocess the dataset and filter out examples that are longer than script_args.max_length
train_dataset = train_dataset.map(
preprocess_function,
batched=True,
@@ -141,6 +139,7 @@ def preprocess_function(examples):
per_device_train_batch_size=script_args.batch_size,
num_train_epochs=script_args.num_train_epochs,
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
+ gradient_checkpointing=script_args.gradient_checkpointing,
learning_rate=script_args.learning_rate,
report_to="wandb" if script_args.log_with == "wandb" else "tensorboard",
remove_unused_columns=False,
diff --git a/examples/scripts/sentiment_tuning.py b/examples/scripts/sentiment_tuning.py
index 96a70b324d..25851c47f5 100644
--- a/examples/scripts/sentiment_tuning.py
+++ b/examples/scripts/sentiment_tuning.py
@@ -16,6 +16,7 @@
from typing import Optional
import torch
+from accelerate import Accelerator
from datasets import load_dataset
from peft import LoraConfig
from tqdm import tqdm
@@ -149,7 +150,8 @@ def collator(data):
task_type="CAUSAL_LM",
)
ref_model = None
- device_map = {"": 0}
+ # Copy the model to each device
+ device_map = {"": Accelerator().local_process_index}
model = trl_model_class.from_pretrained(
config.model_name,
diff --git a/examples/scripts/sft_trainer.py b/examples/scripts/sft_trainer.py
index 79a59b6005..e2fce2575a 100644
--- a/examples/scripts/sft_trainer.py
+++ b/examples/scripts/sft_trainer.py
@@ -16,6 +16,7 @@
from typing import Optional
import torch
+from accelerate import Accelerator
from datasets import load_dataset
from peft import LoraConfig
from tqdm import tqdm
@@ -75,8 +76,8 @@ class ScriptArguments:
quantization_config = BitsAndBytesConfig(
load_in_8bit=script_args.load_in_8bit, load_in_4bit=script_args.load_in_4bit
)
- # This means: fit the entire model on the GPU:0
- device_map = {"": 0}
+ # Copy the model to each device
+ device_map = {"": Accelerator().local_process_index}
torch_dtype = torch.bfloat16
else:
device_map = None
diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py
index 3ba4912afd..b32466beac 100644
--- a/trl/trainer/reward_trainer.py
+++ b/trl/trainer/reward_trainer.py
@@ -113,7 +113,7 @@ def __init__(
)
elif is_peft_available() and peft_config is not None:
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_quantized", False):
- model = prepare_model_for_int8_training(model)
+ model = prepare_model_for_int8_training(model, use_gradient_checkpointing=args.gradient_checkpointing)
model = get_peft_model(model, peft_config)
diff --git a/trl/trainer/training_configs.py b/trl/trainer/training_configs.py
index a8622b1904..cde659a927 100644
--- a/trl/trainer/training_configs.py
+++ b/trl/trainer/training_configs.py
@@ -31,6 +31,8 @@ class RewardConfig(TrainingArguments):
Parameters:
max_length (`int`, *optional*, defaults to `None`):
The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator.
+ gradient_checkpointing (`bool`, *optional*, defaults to `True`):
+ If True, use gradient checkpointing to save memory at the expense of slower backward pass.
"""
max_length: Optional[int] = field(
@@ -39,3 +41,9 @@ class RewardConfig(TrainingArguments):
"help": "The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator."
},
)
+ gradient_checkpointing: Optional[bool] = field(
+ default=True,
+ metadata={
+ "help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
+ },
+ )