Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "Fix gradient checkpointing + fp16 autocast for most models" #24420

Merged
merged 1 commit into from
Jun 22, 2023

Conversation

younesbelkada
Copy link
Contributor

@younesbelkada younesbelkada commented Jun 22, 2023

Reverts #24247

This PR reverts #24247

The investigation initially started with the failing test in https://github.com/huggingface/peft/actions/runs/5340918925/jobs/9686171926 - a training setup that was taking 7GB now takes 15GB and OOM. I looked back at each commit and can confirm this commit caused it.

Instead of patching the initial issue on our side, I propose for now to revert the PR and just wait for the fix in PT side as doubling down the memory requirements is a lot for PEFT users.

Can confirm the training doesn't OOM before the commit 285a480 hence the PR that reverts the commit

cc @sgugger @pacman100 @amyeroberts

Putting it as draft as I need to deep dive a bit before making sure this is the right thing to do

@younesbelkada younesbelkada marked this pull request as draft June 22, 2023 08:56
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jun 22, 2023

The documentation is not available anymore as the PR was closed or merged.

@amyeroberts
Copy link
Collaborator

I'd be pro reverting until we manage to resolve or find another solution.

cc @ydshieh too as the testing master :) Looking at our daily CI, it seems this hasn't affected our "normal" models - is this right? Are there any tests we should be running to verify this?

@ydshieh
Copy link
Collaborator

ydshieh commented Jun 22, 2023

@amyeroberts

That PR #24247 is merged yesterday. The daily CI is triggered this morning and not finished yet. So we don't know what that PR brings.

@ydshieh
Copy link
Collaborator

ydshieh commented Jun 22, 2023

There is also a push CI (only non-slow tests and not a complete CI).

From the screenshot, it does look good though.

Screenshot 2023-06-22 114431

@younesbelkada
Copy link
Contributor Author

younesbelkada commented Jun 22, 2023

Ok I just did some benchmarks by observing the peak memory usage of different training setups and it seems to affect most of the models regardless of the modality:

Model Quantization method Use Rentrant == False (i.e. #24247 included) Peak memory usage
openai/whisper-large 8bit Yes OOM
openai/whisper-large 8bit No 7.5GB
openai/whisper-large 4bit No 5.1GB
openai/whisper-large 4bit Yes 14.5GB
facebook/opt-6.7b 8bit Yes 14.1GB
facebook/opt-6.7b 8bit no 9.8GB
facebook/opt-1.3b 16bit Yes 12.1GB
facebook/opt-1.3b 16bit no 12.1GB
google/flan-t5-large 16bit Yes 12.7GB
google/flan-t5-large 16bit no 12.7GB
facebook/opt-1.3b 8bit Yes 5.1GB
facebook/opt-1.3b 8bit no 4.1GB

Note that before #24420 the last PEFT layer had always None grad, therefore got never updated. But the surprising thing is that the last layer shouldn't cause 2x memory increase, it should cause in the worst case x(1 + 1/num_layers) increase

I will investigate further and keep updates here

@amyeroberts
Copy link
Collaborator

@younesbelkada Thanks for investigating and sharing! Could you also add a model with no quantization for reference in the table?

@younesbelkada
Copy link
Contributor Author

Sure yes! Will update the table soon

@younesbelkada
Copy link
Contributor Author

From the updated observations above

1- it seems to affect the quantized models only
2- Larger models gets more affected

@pacman100
Copy link
Contributor

we can merge this PR and revert the change as it is leading to huge increase in VRAM usage for quantized models. The below minimal example doesn't lead to final layer having None grads.

Please note the way Accelerate does the Mixed Precision handling which is now used in Trainer too. Don't know why this works and why using autocast as a context manager fails (results in None grads for final layer).

import torch
from transformers import AutoModelForCausalLM
from types import MethodType
from accelerate.utils import convert_outputs_to_fp32
model_id = "facebook/opt-350m"

model = AutoModelForCausalLM.from_pretrained(model_id).to(0)

model.gradient_checkpointing_enable()
model.train()

+ model.forward = MethodType(torch.cuda.amp.autocast(dtype=torch.bfloat16)(model.forward.__func__), model)
+ model.forward = MethodType(convert_outputs_to_fp32(model.forward.__func__), model)

assert model.training and model.is_gradient_checkpointing

optimizer = torch.optim.Adam(model.parameters(), lr=1e-7)
- with torch.cuda.amp.autocast(True, dtype=torch.float16):
dummy_input = torch.LongTensor([[0, 1, 0, 1]]).to(0)
model.train()
logits = model(dummy_input).logits
loss = logits.mean()

loss.backward()
optimizer.step()

for n, param in model.named_parameters():
    if param.grad is None:
        print(n)

@younesbelkada younesbelkada marked this pull request as ready for review June 22, 2023 13:56
@younesbelkada
Copy link
Contributor Author

younesbelkada commented Jun 22, 2023

Perfect, let's revert the PR then
I can also confirm I don't have any None-grad for lora layers using llama (as posted in original issue), I believe the recent accelerate integration silently fixed the bug and the user was using a former version of transfomers

cc @amyeroberts @sgugger this is ready for review

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Revert LGTM. Thanks for acting so quickly on this and the detailed investigation!

@younesbelkada
Copy link
Contributor Author

younesbelkada commented Jun 22, 2023

Thanks very much for the support and quick feedback! @amyeroberts and big kudos to @pacman100 as well !

@younesbelkada younesbelkada merged commit 3ce3385 into main Jun 22, 2023
@younesbelkada younesbelkada deleted the revert-24247-fix-gc-poc branch June 22, 2023 14:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants