-
Notifications
You must be signed in to change notification settings - Fork 27.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Revert "Fix gradient checkpointing + fp16 autocast for most models" #24420
Conversation
The documentation is not available anymore as the PR was closed or merged. |
I'd be pro reverting until we manage to resolve or find another solution. cc @ydshieh too as the testing master :) Looking at our daily CI, it seems this hasn't affected our "normal" models - is this right? Are there any tests we should be running to verify this? |
That PR #24247 is merged yesterday. The daily CI is triggered this morning and not finished yet. So we don't know what that PR brings. |
Ok I just did some benchmarks by observing the peak memory usage of different training setups and it seems to affect most of the models regardless of the modality:
Note that before #24420 the last PEFT layer had always None grad, therefore got never updated. But the surprising thing is that the last layer shouldn't cause 2x memory increase, it should cause in the worst case x(1 + 1/num_layers) increase I will investigate further and keep updates here |
@younesbelkada Thanks for investigating and sharing! Could you also add a model with no quantization for reference in the table? |
Sure yes! Will update the table soon |
From the updated observations above 1- it seems to affect the quantized models only |
we can merge this PR and revert the change as it is leading to huge increase in VRAM usage for quantized models. The below minimal example doesn't lead to final layer having Please note the way Accelerate does the Mixed Precision handling which is now used in Trainer too. Don't know why this works and why using autocast as a context manager fails (results in import torch
from transformers import AutoModelForCausalLM
from types import MethodType
from accelerate.utils import convert_outputs_to_fp32
model_id = "facebook/opt-350m"
model = AutoModelForCausalLM.from_pretrained(model_id).to(0)
model.gradient_checkpointing_enable()
model.train()
+ model.forward = MethodType(torch.cuda.amp.autocast(dtype=torch.bfloat16)(model.forward.__func__), model)
+ model.forward = MethodType(convert_outputs_to_fp32(model.forward.__func__), model)
assert model.training and model.is_gradient_checkpointing
optimizer = torch.optim.Adam(model.parameters(), lr=1e-7)
- with torch.cuda.amp.autocast(True, dtype=torch.float16):
dummy_input = torch.LongTensor([[0, 1, 0, 1]]).to(0)
model.train()
logits = model(dummy_input).logits
loss = logits.mean()
loss.backward()
optimizer.step()
for n, param in model.named_parameters():
if param.grad is None:
print(n) |
Perfect, let's revert the PR then cc @amyeroberts @sgugger this is ready for review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Revert LGTM. Thanks for acting so quickly on this and the detailed investigation!
Thanks very much for the support and quick feedback! @amyeroberts and big kudos to @pacman100 as well ! |
Reverts #24247
This PR reverts #24247
The investigation initially started with the failing test in https://github.com/huggingface/peft/actions/runs/5340918925/jobs/9686171926 - a training setup that was taking 7GB now takes 15GB and OOM. I looked back at each commit and can confirm this commit caused it.
Instead of patching the initial issue on our side, I propose for now to revert the PR and just wait for the fix in PT side as doubling down the memory requirements is a lot for PEFT users.
Can confirm the training doesn't OOM before the commit 285a480 hence the PR that reverts the commit
cc @sgugger @pacman100 @amyeroberts
Putting it as draft as I need to deep dive a bit before making sure this is the right thing to do