-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable gradient checkpointing to be disabled for reward modelling #725
Conversation
@@ -149,7 +150,8 @@ def collator(data): | |||
task_type="CAUSAL_LM", | |||
) | |||
ref_model = None | |||
device_map = {"": 0} | |||
# Copy the model to each device | |||
device_map = {"": Accelerator().local_process_index} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I noticed none of the example scripts actually work in multi-GPU settings unless we add this trick
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense!
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this. Just a few small nits, looks generally good to me.
examples/scripts/reward_trainer.py
Outdated
train_dataset = train_dataset.map( | ||
preprocess_function, | ||
batched=True, | ||
num_proc=4, | ||
) | ||
train_dataset = train_dataset.filter( | ||
lambda x: len(x["input_ids_chosen"]) <= script_args.seq_length | ||
and len(x["input_ids_rejected"]) <= script_args.seq_length | ||
and len(x["input_ids_rejected"]) <= script_args.seq_length, | ||
num_proc=4, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do you change to four?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because the HHH dataset is quite large and I was bored of waiting :) Will revert it back to 1
Co-authored-by: Leandro von Werra <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for workaround and for adding this! LGTM!
@@ -149,7 +150,8 @@ def collator(data): | |||
task_type="CAUSAL_LM", | |||
) | |||
ref_model = None | |||
device_map = {"": 0} | |||
# Copy the model to each device | |||
device_map = {"": Accelerator().local_process_index} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense!
Co-authored-by: Younes Belkada <[email protected]>
@@ -6,29 +6,27 @@ Check out a complete flexible example inside [`examples/scripts`](https://github | |||
|
|||
## Expected dataset format | |||
|
|||
The reward trainer expects a very specific format for the dataset. Since the model will be trained to predict which sentence is the most relevant, given two sentences. We provide an example from the [`Anthropic/hh-rlhf`](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset below: | |||
The [`RewardTrainer`] expects a very specific format for the dataset since the model will be trained on pairs of examples to predict which of the two is preferred. We provide an example from the [`Anthropic/hh-rlhf`](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset below: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@lvwerra @younesbelkada FYI I made a few minor tweaks to the docs since your last review
@@ -39,3 +41,9 @@ class RewardConfig(TrainingArguments): | |||
"help": "The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator." | |||
}, | |||
) | |||
gradient_checkpointing: Optional[bool] = field( | |||
default=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here we override the default value of transformers.TrainingArguments
to be True
for backwards compatibility with how PEFT prepares k-bit models
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot @lewtun !!
…ggingface#725) * Enable gradient checkpointing to be disabled for reward modelling * Update examples/scripts/reward_trainer.py Co-authored-by: Leandro von Werra <[email protected]> * Apply suggestions from code review Co-authored-by: Younes Belkada <[email protected]> * Tidy docs * Remove commas --------- Co-authored-by: Leandro von Werra <[email protected]> Co-authored-by: Younes Belkada <[email protected]>
…ggingface#725) * Enable gradient checkpointing to be disabled for reward modelling * Update examples/scripts/reward_trainer.py Co-authored-by: Leandro von Werra <[email protected]> * Apply suggestions from code review Co-authored-by: Younes Belkada <[email protected]> * Tidy docs * Remove commas --------- Co-authored-by: Leandro von Werra <[email protected]> Co-authored-by: Younes Belkada <[email protected]>
Edit: to be merged after #726
This PR allows users to disable gradient checkpointing for PEFT models in the
RewardTrainer
. The motivation for this is to bypass theRuntimeError: Expected to mark a variable ready only once.
issue noted in #480 at the expense of using more vRAM. More generally, I think it's good when a library allows users to configure settings like these, since there are tradeoffs involved (i.e. computation time for memory).With this change, it is now possible to train PEFT reward models in multi-GPU contexts as follows:
When gradient checkpointing is activated the script still fails due to #480 (which requires a deeper-level fix):
Question: this change overrides the default value of the
RewardTrainer
fromtrue
tofalse
which may lead to unexpected behaviour with user code. Should I add a warning in the logger?Update: I think I can solve the question above with my proposal to have dedicated training arguments per trainer. I'll open a separate PR for that.