-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable gradient checkpointing to be disabled for reward modelling #725
Changes from all commits
b6a033a
0f2cbc2
e9f05db
de062d3
24f882d
3f106af
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,7 @@ | |
from typing import Optional | ||
|
||
import torch | ||
from accelerate import Accelerator | ||
from datasets import load_dataset | ||
from peft import LoraConfig | ||
from tqdm import tqdm | ||
|
@@ -149,7 +150,8 @@ def collator(data): | |
task_type="CAUSAL_LM", | ||
) | ||
ref_model = None | ||
device_map = {"": 0} | ||
# Copy the model to each device | ||
device_map = {"": Accelerator().local_process_index} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I noticed none of the example scripts actually work in multi-GPU settings unless we add this trick There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense! |
||
|
||
model = trl_model_class.from_pretrained( | ||
config.model_name, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,6 +31,8 @@ class RewardConfig(TrainingArguments): | |
Parameters: | ||
max_length (`int`, *optional*, defaults to `None`): | ||
The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator. | ||
gradient_checkpointing (`bool`, *optional*, defaults to `True`): | ||
If True, use gradient checkpointing to save memory at the expense of slower backward pass. | ||
""" | ||
|
||
max_length: Optional[int] = field( | ||
|
@@ -39,3 +41,9 @@ class RewardConfig(TrainingArguments): | |
"help": "The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator." | ||
}, | ||
) | ||
gradient_checkpointing: Optional[bool] = field( | ||
default=True, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here we override the default value of |
||
metadata={ | ||
"help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass." | ||
}, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@lvwerra @younesbelkada FYI I made a few minor tweaks to the docs since your last review