Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error with Multi-GPU peft Reward Training #480

Closed
mnoukhov opened this issue Jun 29, 2023 · 15 comments
Closed

Error with Multi-GPU peft Reward Training #480

mnoukhov opened this issue Jun 29, 2023 · 15 comments
Assignees

Comments

@mnoukhov
Copy link
Contributor

mnoukhov commented Jun 29, 2023

There is an issue when you combine all four:

  • peft quantization
  • gradient checkpointing
  • multi-gpu ddp
  • two gradients on the same parameters (as you have in the loss function for Reward Trainer)

This is reproducible if you correctly enable gradient checkpointing in examples/multi-adapter-rl as shown in PR #479 and then run in a multi-gpu setup

accelerate launch --multi_gpu reward_modeling.py --gradient_checkpointing True

you will receive the error

RuntimeError: Expected to mark a variable ready only once. This error is caused by one of the following reasons: 1) Use of a module parameter outside the forward function. Please make sure model parameters are not shared across multiple concurrent forward-backward passes. or try to use _set_static_graph() as a workaround if this module graph does not change during training loop.2) Reused parameters in multiple reentrant backward passes. For example, if you use multiple checkpoint functions to wrap the same part of your model, it would result in the same set of parameters been used by different reentrant backward passes multiple times, and hence marking a variable ready multiple times. DDP does not support such use cases in default. You can try to use _set_static_graph() as a workaround if your module
graph does not change over iterations. Parameter at index 127 has been marked as ready twice. This means that multiple autograd engine hooks have fired for this particular parameter during this iteration. You can set the environment variable
TORCH_DISTRIBUTED_DEBUG to either INFO or DETAIL to print parameter names for further debugging.

With TORCH_DISTRIBUTED_DEBUG=DETAIL, we find the affected parameter is a LoRA parameter. It is not related to pytorch/pytorch#60844 because find_unused_parameters is set to False.

This is likely a problem between peft and accelerate/ddp but I'm putting the issue here because it affects RewardTrainer and quantization + multi gpu + gradient checkpointing are a common combination

@lvwerra lvwerra assigned lvwerra and younesbelkada and unassigned lvwerra Jul 3, 2023
@Receiling
Copy link

We have the same problem.
With the training scripts from peft, multi-gpu + gradient checkpointing is working normally.
But with RewardTrainer from trl, we have the aforementioned error, i.e., "RuntimeError: Expected to mark a variable ready only once. "

@younesbelkada
Copy link
Contributor

Hi @mnoukhov @Receiling
After discussing with @pacman100 , the fix should go on transformers, and should be similar to huggingface/transformers#24247
we should add a flag in training arguments and make the default behavior use use_reentrant=True for backward compatiblity

@lewtun
Copy link
Member

lewtun commented Sep 5, 2023

Gently pinging @younesbelkada if there's any update on the transformers side to support gradient checkpointing with PEFT & DDP?

For context, the Llama 2 paper shows that training large reward models is an important ingredient in RLHF and enabling the RewardTrainer to scale to 70B params would be awesome 🚀 !

@mnoukhov
Copy link
Contributor Author

No positive results yet, but new negative result that this isn't related to #728. Using that fix, I still encounter the problem. I also checked that the fix from huggingface/peft#899 doesn't solve it, although that issue also prevents multi gpu reward training since we need to use modules_to_save=["score"] but isn't related to gradient checkpointing

@mollynatsu
Copy link

I'm getting the same error with accelerate launch --multi_gpu reward_modeling.py --load_in_8bit --gradient_checkpointing True with llama2-7b, however if disabling either gradient_checkpointing or load_in_8bit, I got OOM.

@ChrisChros123
Copy link

ChrisChros123 commented Sep 19, 2023

Hi all,
I also encounter this issue.
I am deriving my attemps from this example:

https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama/scripts/reward_modeling.py
and I am fine-tuning meta-llama/Llama-2-7b-hf

But I am giving the TrainingArguments a path to ds_config.json where I specify

{
    "zero_optimization": {
        "stage": 3,
        ....

Since I encountered incompatible dtype issues, used the workaround mentioned here:
huggingface/transformers#24445 (comment)
like this:
model = prepare_model_for_kbit_training(model) # I added this to the original reward_modeling.py to avoid OOM
model = get_peft_model(model, peft_config)
...
deepspeed.zero.Init()

I run the whole thing with
torchrun --nproc_per_node=8 reward_modelling_gpu.py

I enabled gradient checkpointing to prevent OOM error:

gradient_checkpointing: Optional[bool] = field(
        default=True)

I also had to set use_reentrant to True in /usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py

def checkpoint(
    function,
    *args,
    use_reentrant: Optional[bool] = True,
...

To avoid this warning:
/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.

Here is my logging now:

Traceback (most recent call last):
  File "/home/jupyter-data_nlp/reward_modelling_gpu.py", line 453, in <module>
    trainer.train(script_args.resume_from_checkpoint)
  File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 1553, in train
    return inner_training_loop(
  File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 1835, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 2690, in training_step
    self.accelerator.backward(loss)
  File "/usr/local/lib/python3.10/dist-packages/accelerate/accelerator.py", line 1923, in backward
    loss.backward(**kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_tensor.py", line 492, in backward
    torch.autograd.backward(
  File "/usr/local/lib/python3.10/dist-packages/torch/autograd/__init__.py", line 251, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 289, in apply
    return user_fn(self, *args)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py", line 288, in backward
    torch.autograd.backward(outputs_with_grad, args_with_grad)
  File "/usr/local/lib/python3.10/dist-packages/torch/autograd/__init__.py", line 251, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Expected to mark a variable ready only once. This error is caused by one of the following reasons: 1) Use of a module parameter outside the `forward` function. Please make sure model parameters are not shared across multiple concurrent forward-backward passes. or try to use _set_static_graph() as a workaround if this module graph does not change during training loop.2) Reused parameters in multiple reentrant backward passes. For example, if you use multiple `checkpoint` functions to wrap the same part of your model, it would result in the same set of parameters been used by different reentrant backward passes multiple times, and hence marking a variable ready multiple times. DDP does not support such use cases in default. You can try to use _set_static_graph() as a workaround if your module graph does not change over iterations.
Parameter at index 127 with name base_model.model.model.layers.31.self_attn.v_proj.lora_B.default.weight has been marked as ready twice. This means that multiple autograd engine  hooks have fired for this particular parameter during this iteration.

@mnoukhov
Copy link
Contributor Author

I have also found that use_reentrant=True does not fix the issue. It is likely an issue with how gradient checkpointing is implemented and how that interacts with quantized models in peft. I'm currently training reward models either on a single GPU or multi-GPU with bf16 + gradient checkpointing but without quantization.

@lvwerra
Copy link
Member

lvwerra commented Oct 2, 2023

Which models/setup are you using?

@mnoukhov
Copy link
Contributor Author

mnoukhov commented Oct 5, 2023

I am using huggyllama/llama-7b

EDIT: I misunderstood "unused parameters". As long as the frozen parameters are used in the backward computation, they are fine. So if the model has two peft adapters, then there could be unused parameters, but not if there's just one.

After looking into this, I am mostly convinced that DDP does not work with gradient checkpointing when there are unused parameters in the forward computation or two forward passes. This means you should not use gradient checkpointing with peft and DDP, regardless of the gradient passes but as it will explicitly fail with two forward passes on the same parameters.

This is a note from the Pytorch docs

DistributedDataParallel currently offers limited support for gradient checkpointing with torch.utils.checkpoint(). DDP will work as expected when there are no unused parameters in the model and each layer is checkpointed at most once (make sure you are not passing find_unused_parameters=True to DDP). We currently do not support the case where a layer is checkpointed multiple times, or when there unused parameters in the checkpointed model.

two peft adapters can create unused parameters and two forward passes (as in the reward trainer) will cause a layer to be checkpointed twice.

I think a next step is to check whether DDP training with peft is worse with gradient checkpointing even without two forward passes. If so, peft should probably add a warning or error when a user tries to combine these things until DDP supports the use case. Or maybe switch to DataParallel although it feels like it will be deprecated.

Copy link

github-actions bot commented Nov 7, 2023

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@lewtun
Copy link
Member

lewtun commented Nov 8, 2023

I think this issue is resolved by #912 right @younesbelkada ?

@younesbelkada
Copy link
Contributor

Yes indeed! if you use latest releases from transformers, trl and peft, simply pass gradient_checkpointing_kwargs={"use_reentrant":False} and it should be resolved

@mnoukhov
Copy link
Contributor Author

I tested this and it is resolved by #912, thank you @younesbelkada !

Sorry I didn't do this myself earlier. I tried doing the use_rentrant trick you mentioned by it didn't work. I must have missed something in my tests. Thanks for doing it!

@younesbelkada
Copy link
Contributor

Awesome that it worked @mnoukhov , thanks!

@Abyss-J
Copy link

Abyss-J commented Jan 31, 2024

Same error here, set gradient_checkpointing False works, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

8 participants