Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable gradient checkpointing to be disabled for reward modelling #725

Merged
merged 6 commits into from
Sep 6, 2023

Conversation

lewtun
Copy link
Member

@lewtun lewtun commented Sep 1, 2023

Edit: to be merged after #726

This PR allows users to disable gradient checkpointing for PEFT models in the RewardTrainer. The motivation for this is to bypass the RuntimeError: Expected to mark a variable ready only once. issue noted in #480 at the expense of using more vRAM. More generally, I think it's good when a library allows users to configure settings like these, since there are tradeoffs involved (i.e. computation time for memory).

With this change, it is now possible to train PEFT reward models in multi-GPU contexts as follows:

# The example script didn't work when num_gpus > 1
accelerate launch --multi_gpu --num_processes 4 examples/scripts/reward_trainer.py --use_peft --load_in_8bit --gradient_checkpointing false

When gradient checkpointing is activated the script still fails due to #480 (which requires a deeper-level fix):

# Current behaviour is the same
accelerate launch --multi_gpu --num_processes 4 examples/scripts/reward_trainer.py --use_peft --load_in_8bit --gradient_checkpointing true
Traceback (most recent call last):
  File "/fsx/lewis/git/trl/examples/scripts/reward_trainer.py", line 177, in <module>
    trainer.train()
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/transformers/trainer.py", line 1539, in train
    return inner_training_loop(
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/transformers/trainer.py", line 1809, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/transformers/trainer.py", line 2665, in training_step
    self.accelerator.backward(loss)
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/accelerate/accelerator.py", line 1923, in backward
    loss.backward(**kwargs)
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/_tensor.py", line 487, in backward
    torch.autograd.backward(
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/autograd/__init__.py", line 200, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/autograd/function.py", line 274, in apply
    return user_fn(self, *args)
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 157, in backward
    torch.autograd.backward(outputs_with_grad, args_with_grad)
  File "/fsx/lewis/miniconda/envs/trl/lib/python3.10/site-packages/torch/autograd/__init__.py", line 200, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Expected to mark a variable ready only once. This error is caused by one of the following reasons: 1) Use of a module parameter outside the `forward` function. Please make sure model parameters are not shared across multiple concurrent forward-backward passes. or try to use _set_static_graph() as a workaround if this module graph does not change during training loop.2) Reused parameters in multiple reentrant backward passes. For example, if you use multiple `checkpoint` functions to wrap the same part of your model, it would result in the same set of parameters been used by different reentrant backward passes multiple times, and hence marking a variable ready multiple times. DDP does not support such use cases in default. You can try to use _set_static_graph() as a workaround if your module graph does not change over iterations.
Parameter at index 93 has been marked as ready twice. This means that multiple autograd engine  hooks have fired for this particular parameter during this iteration. You can set the environment variable TORCH_DISTRIBUTED_DEBUG to either INFO or DETAIL to print parameter names for further debugging.

Question: this change overrides the default value of the RewardTrainer from true to false which may lead to unexpected behaviour with user code. Should I add a warning in the logger?

Update: I think I can solve the question above with my proposal to have dedicated training arguments per trainer. I'll open a separate PR for that.

@@ -149,7 +150,8 @@ def collator(data):
task_type="CAUSAL_LM",
)
ref_model = None
device_map = {"": 0}
# Copy the model to each device
device_map = {"": Accelerator().local_process_index}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed none of the example scripts actually work in multi-GPU settings unless we add this trick

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense!

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Sep 1, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Member

@lvwerra lvwerra left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this. Just a few small nits, looks generally good to me.

train_dataset = train_dataset.map(
preprocess_function,
batched=True,
num_proc=4,
)
train_dataset = train_dataset.filter(
lambda x: len(x["input_ids_chosen"]) <= script_args.seq_length
and len(x["input_ids_rejected"]) <= script_args.seq_length
and len(x["input_ids_rejected"]) <= script_args.seq_length,
num_proc=4,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do you change to four?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because the HHH dataset is quite large and I was bored of waiting :) Will revert it back to 1

examples/scripts/reward_trainer.py Outdated Show resolved Hide resolved
Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for workaround and for adding this! LGTM!

@@ -149,7 +150,8 @@ def collator(data):
task_type="CAUSAL_LM",
)
ref_model = None
device_map = {"": 0}
# Copy the model to each device
device_map = {"": Accelerator().local_process_index}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense!

examples/scripts/reward_trainer.py Outdated Show resolved Hide resolved
@@ -6,29 +6,27 @@ Check out a complete flexible example inside [`examples/scripts`](https://github

## Expected dataset format

The reward trainer expects a very specific format for the dataset. Since the model will be trained to predict which sentence is the most relevant, given two sentences. We provide an example from the [`Anthropic/hh-rlhf`](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset below:
The [`RewardTrainer`] expects a very specific format for the dataset since the model will be trained on pairs of examples to predict which of the two is preferred. We provide an example from the [`Anthropic/hh-rlhf`](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset below:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lvwerra @younesbelkada FYI I made a few minor tweaks to the docs since your last review

@@ -39,3 +41,9 @@ class RewardConfig(TrainingArguments):
"help": "The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator."
},
)
gradient_checkpointing: Optional[bool] = field(
default=True,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we override the default value of transformers.TrainingArguments to be True for backwards compatibility with how PEFT prepares k-bit models

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot @lewtun !!

@lewtun lewtun merged commit 453c4ec into main Sep 6, 2023
@lewtun lewtun deleted the disable-grad-ckpt branch September 6, 2023 12:08
kushal-tri pushed a commit to kushalarora/trl that referenced this pull request Sep 19, 2023
…ggingface#725)

* Enable gradient checkpointing to be disabled for reward modelling

* Update examples/scripts/reward_trainer.py

Co-authored-by: Leandro von Werra <[email protected]>

* Apply suggestions from code review

Co-authored-by: Younes Belkada <[email protected]>

* Tidy docs

* Remove commas

---------

Co-authored-by: Leandro von Werra <[email protected]>
Co-authored-by: Younes Belkada <[email protected]>
lapp0 pushed a commit to lapp0/trl that referenced this pull request May 10, 2024
…ggingface#725)

* Enable gradient checkpointing to be disabled for reward modelling

* Update examples/scripts/reward_trainer.py

Co-authored-by: Leandro von Werra <[email protected]>

* Apply suggestions from code review

Co-authored-by: Younes Belkada <[email protected]>

* Tidy docs

* Remove commas

---------

Co-authored-by: Leandro von Werra <[email protected]>
Co-authored-by: Younes Belkada <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants