-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RewardTrainer] Enable gradient checkpointing for all multi-GPU training modes #835
Comments
Thanks for the deep investigation and report ! Will have a look in the next weeks |
@younesbelkada related, are we correctly doing gradient checkpointing with
e.g. as part of SFT Trainer https://github.com/huggingface/trl/blob/main/trl/trainer/sft_trainer.py#L150 In huggingface/transformers#25841 (comment), it seems like you're saying that we should do things the other way around
|
I've met the exact same issue for full training with ZeRO-3. And I fixed it by adding P.S. I do not quiet understand the reason. |
I hit this issue and fixed by using |
Hi everyone, |
What is different from setting gradient_checkpointing=False? |
I encountered |
We currently have a few issues like #831 and #480 where gradient checkpointing + DDP does not work with the
RewardTrainer
.Let's use this issue to collect the various training modes we'd like to support and track the status of their fixes. Ideally these would all be treated as integration tests so we know if there's regressions in future.
More details on each mode are provided below.
1. LoRA without quantization
Currently, this mode gives a warning that gradients are
None
on the inputs (i.e. model doesn't learn):Command to reproduce:
2. LoRA with 4-bit or 8-bit quantization
Commands to reproduce:
3. Full training with DDP
Similar to 4-bit / 8-bit and DDP:
Command to reproduce:
4. Full training with ZeRO-1
ZeRO-1 is fine and can be tested with (add
bf16=True
to theTrainingArguments
):5. Full training with ZeRO-2
ZeRO-2 throws the following error:
Command to reproduce (add
bf16=True
to theTrainingArguments
):6. Full training with ZeRO-3
Error:
Command to reproduce (add
bf16=True
to theTrainingArguments
):The text was updated successfully, but these errors were encountered: