diff --git a/docs/source/reward_trainer.mdx b/docs/source/reward_trainer.mdx index e3ea9e1985..249ec9bde1 100644 --- a/docs/source/reward_trainer.mdx +++ b/docs/source/reward_trainer.mdx @@ -6,29 +6,27 @@ Check out a complete flexible example inside [`examples/scripts`](https://github ## Expected dataset format -The reward trainer expects a very specific format for the dataset. Since the model will be trained to predict which sentence is the most relevant, given two sentences. We provide an example from the [`Anthropic/hh-rlhf`](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset below: +The [`RewardTrainer`] expects a very specific format for the dataset since the model will be trained on pairs of examples to predict which of the two is preferred. We provide an example from the [`Anthropic/hh-rlhf`](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset below:
-Therefore the final dataset object should contain two 4 entries at least if you use the default `RewardDataCollatorWithPadding` data collator. The entries should be named: +Therefore the final dataset object should contain two 4 entries at least if you use the default [`RewardDataCollatorWithPadding`] data collator. The entries should be named: - `input_ids_chosen` - `attention_mask_chosen` - `input_ids_rejected` - `attention_mask_rejected` -The `j` and `k` suffixes are used to denote the two sentences in the paired dataset. - ## Using the `RewardTrainer` -After standardizing your dataset, you can use the `RewardTrainer` as a classic Hugging Face Trainer. -You should pass an `AutoModelForSequenceClassification` model to the `RewardTrainer`. +After preparing your dataset, you can use the [`RewardTrainer`] in the same way as the `Trainer` class from 🤗 Transformers. +You should pass an `AutoModelForSequenceClassification` model to the [`RewardTrainer`], along with a [`RewardConfig`] which configures the hyperparameters of the training. -### Leveraging the `peft` library to train a reward model +### Leveraging 🤗 PEFT to train a reward model -Just pass a `peft_config` in the key word arguments of `RewardTrainer`, and the trainer should automatically take care of converting the model into a PEFT model! +Just pass a `peft_config` in the keyword arguments of [`RewardTrainer`], and the trainer should automatically take care of converting the model into a PEFT model! ```python from peft import LoraConfig, task_type diff --git a/examples/scripts/reward_trainer.py b/examples/scripts/reward_trainer.py index 4382d21ac1..0e6007c28e 100644 --- a/examples/scripts/reward_trainer.py +++ b/examples/scripts/reward_trainer.py @@ -15,6 +15,7 @@ from dataclasses import dataclass, field from typing import Optional +from accelerate import Accelerator from datasets import load_dataset from peft import LoraConfig from tqdm import tqdm @@ -26,15 +27,14 @@ tqdm.pandas() -# Define and parse arguments. @dataclass class ScriptArguments: """ - The name of the Casual LM model we wish to fine with RewardTrainer + Hyperparameters to fine-tune a reward model on a given dataset with the `RewardTrainer`. """ model_name: Optional[str] = field(default="facebook/opt-350m", metadata={"help": "the model name"}) - dataset_name: Optional[str] = field(default="Anthropic/hh-rlhf", metadata={"help": "the model name"}) + dataset_name: Optional[str] = field(default="Anthropic/hh-rlhf", metadata={"help": "the dataset name"}) dataset_text_field: Optional[str] = field(default="text", metadata={"help": "the text field of the dataset"}) log_with: Optional[str] = field(default=None, metadata={"help": "use 'wandb' to log with wandb"}) logging_steps: Optional[int] = field(default=500, metadata={"help": "the number of update steps between two logs"}) @@ -48,6 +48,7 @@ class ScriptArguments: gradient_accumulation_steps: Optional[int] = field( default=16, metadata={"help": "the number of gradient accumulation steps"} ) + gradient_checkpointing: Optional[bool] = field(default=True, metadata={"help": "Enable gradient checkpointing"}) load_in_8bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 8 bits precision"}) load_in_4bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 4 bits precision"}) use_peft: Optional[bool] = field(default=False, metadata={"help": "Wether to use PEFT or not to train adapters"}) @@ -65,8 +66,8 @@ class ScriptArguments: quantization_config = BitsAndBytesConfig( load_in_8bit=script_args.load_in_8bit, load_in_4bit=script_args.load_in_4bit ) - # This means: fit the entire model on the GPU:0 - device_map = {"": 0} + # Copy the model to each device + device_map = {"": Accelerator().local_process_index} else: device_map = None quantization_config = None @@ -84,11 +85,8 @@ class ScriptArguments: train_dataset = load_dataset(script_args.dataset_name, split="train") -# Turn the dataset into pairs of post + summaries, where text_j is the preferred question + answer and text_k is the other. -# Then tokenize the dataset. +# Tokenize chosen/rejected pairs of inputs # Adapt this section to your needs for custom datasets - - def preprocess_function(examples): new_examples = { "input_ids_chosen": [], @@ -97,18 +95,18 @@ def preprocess_function(examples): "attention_mask_rejected": [], } for chosen, rejected in zip(examples["chosen"], examples["rejected"]): - tokenized_j = tokenizer(chosen, truncation=True) - tokenized_k = tokenizer(rejected, truncation=True) + tokenized_chosen = tokenizer(chosen, truncation=True) + tokenized_rejected = tokenizer(rejected, truncation=True) - new_examples["input_ids_chosen"].append(tokenized_j["input_ids"]) - new_examples["attention_mask_chosen"].append(tokenized_j["attention_mask"]) - new_examples["input_ids_rejected"].append(tokenized_k["input_ids"]) - new_examples["attention_mask_rejected"].append(tokenized_k["attention_mask"]) + new_examples["input_ids_chosen"].append(tokenized_chosen["input_ids"]) + new_examples["attention_mask_chosen"].append(tokenized_chosen["attention_mask"]) + new_examples["input_ids_rejected"].append(tokenized_rejected["input_ids"]) + new_examples["attention_mask_rejected"].append(tokenized_rejected["attention_mask"]) return new_examples -# preprocess the dataset and filter out QAs that are longer than script_args.max_length +# Preprocess the dataset and filter out examples that are longer than script_args.max_length train_dataset = train_dataset.map( preprocess_function, batched=True, @@ -141,6 +139,7 @@ def preprocess_function(examples): per_device_train_batch_size=script_args.batch_size, num_train_epochs=script_args.num_train_epochs, gradient_accumulation_steps=script_args.gradient_accumulation_steps, + gradient_checkpointing=script_args.gradient_checkpointing, learning_rate=script_args.learning_rate, report_to="wandb" if script_args.log_with == "wandb" else "tensorboard", remove_unused_columns=False, diff --git a/examples/scripts/sentiment_tuning.py b/examples/scripts/sentiment_tuning.py index 96a70b324d..25851c47f5 100644 --- a/examples/scripts/sentiment_tuning.py +++ b/examples/scripts/sentiment_tuning.py @@ -16,6 +16,7 @@ from typing import Optional import torch +from accelerate import Accelerator from datasets import load_dataset from peft import LoraConfig from tqdm import tqdm @@ -149,7 +150,8 @@ def collator(data): task_type="CAUSAL_LM", ) ref_model = None - device_map = {"": 0} + # Copy the model to each device + device_map = {"": Accelerator().local_process_index} model = trl_model_class.from_pretrained( config.model_name, diff --git a/examples/scripts/sft_trainer.py b/examples/scripts/sft_trainer.py index 79a59b6005..e2fce2575a 100644 --- a/examples/scripts/sft_trainer.py +++ b/examples/scripts/sft_trainer.py @@ -16,6 +16,7 @@ from typing import Optional import torch +from accelerate import Accelerator from datasets import load_dataset from peft import LoraConfig from tqdm import tqdm @@ -75,8 +76,8 @@ class ScriptArguments: quantization_config = BitsAndBytesConfig( load_in_8bit=script_args.load_in_8bit, load_in_4bit=script_args.load_in_4bit ) - # This means: fit the entire model on the GPU:0 - device_map = {"": 0} + # Copy the model to each device + device_map = {"": Accelerator().local_process_index} torch_dtype = torch.bfloat16 else: device_map = None diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py index 3ba4912afd..b32466beac 100644 --- a/trl/trainer/reward_trainer.py +++ b/trl/trainer/reward_trainer.py @@ -113,7 +113,7 @@ def __init__( ) elif is_peft_available() and peft_config is not None: if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_quantized", False): - model = prepare_model_for_int8_training(model) + model = prepare_model_for_int8_training(model, use_gradient_checkpointing=args.gradient_checkpointing) model = get_peft_model(model, peft_config) diff --git a/trl/trainer/training_configs.py b/trl/trainer/training_configs.py index a8622b1904..cde659a927 100644 --- a/trl/trainer/training_configs.py +++ b/trl/trainer/training_configs.py @@ -31,6 +31,8 @@ class RewardConfig(TrainingArguments): Parameters: max_length (`int`, *optional*, defaults to `None`): The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator. + gradient_checkpointing (`bool`, *optional*, defaults to `True`): + If True, use gradient checkpointing to save memory at the expense of slower backward pass. """ max_length: Optional[int] = field( @@ -39,3 +41,9 @@ class RewardConfig(TrainingArguments): "help": "The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator." }, ) + gradient_checkpointing: Optional[bool] = field( + default=True, + metadata={ + "help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass." + }, + )