Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix gradient checkpointing + fp16 autocast for most models #24247

Merged
merged 15 commits into from
Jun 21, 2023

Conversation

younesbelkada
Copy link
Contributor

@younesbelkada younesbelkada commented Jun 13, 2023

What does this PR do?

This PR fixes a bug users can encounter when using gradient checkpointing under fp16 autocast context manager. Currently if a user trains a model using autocast and GC the last layer's weights will never get updated.

Handy reproducible snippet
import torch
from transformers import AutoModelForCausalLM

model_id = "facebook/opt-350m"

model = AutoModelForCausalLM.from_pretrained(model_id).to(0)
model.gradient_checkpointing_enable()
model.train()

assert model.training and model.is_gradient_checkpointing

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
with torch.cuda.amp.autocast(True, dtype=torch.float16):
    dummy_input = torch.LongTensor([[0, 1, 0, 1]]).to(0)
    model.train()
    logits = model(dummy_input).logits
    loss = logits.sum()

    loss.backward()
    optimizer.step()

    for n, param in model.named_parameters():
        if param.grad is None:
            print(n)

As discussed internally, the fix seems to be to force-set use_reentrant=False when calling the gradient checkpointing. Putting that boolean to False lifts the restriction that the input tensors initially need to have if use_reentrant=True - according to PT team use_reentrant=True led to some silent bugs and they are planning to remove that boolean in the next releases and use False by deafault.

This might be problematic for users that train adapters (using PEFT for example) where they will see some training performance downside. I propose a PoC to fix this for most common architectures until PyTorch remove that support for the next releases.

For more context, users that train models using PEFT end up using autocast inside the trainer as they use 4bit / 8bit base models

Related: #23990

@younesbelkada younesbelkada changed the title [PoC] Fix gradient opt [PoC] Fix gradient checkpointing OPT + autocast Jun 13, 2023
@younesbelkada younesbelkada requested a review from sgugger June 13, 2023 13:49
Comment on lines 705 to 721
if "use_reentrant" in list(inspect.signature(torch.utils.checkpoint.checkpoint).parameters):
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
causal_attention_mask,
head_mask[idx] if head_mask is not None else None,
None,
use_reentrant=False,
)
else:
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
causal_attention_mask,
head_mask[idx] if head_mask is not None else None,
None,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe just define kwargs that is either an empty dict or a dict with the use_reentrant key. Also this will need copying in all models and then potentially be removed in all models. I think we should only touch all model once, so maybe define a util fonction we will reuse (and update when needed) in pytorch_utils.py?

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jun 13, 2023

The documentation is not available anymore as the PR was closed or merged.

Comment on lines 704 to 712
gradient_checkpointing_kwargs = get_checkpointing_kwargs()

layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
causal_attention_mask,
head_mask[idx] if head_mask is not None else None,
None,
**gradient_checkpointing_kwargs,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know at all if it better or not, but we could wrap the import of and maybe
torch.utils.checkpoint.checkpoint = custom_checkpoint

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I was thinking of having a util function wrapping torch.utils.checkpoint.checkpoint not just the kwargs creation, soory if I was unclear.

@@ -619,6 +619,41 @@ def test_training_gradient_checkpointing(self):
loss = model(**inputs).loss
loss.backward()

@slow
@pytest.mark.gradient_checkpointing_autocast_test
def test_training_gradient_checkpointing_autocast(self):
Copy link
Contributor Author

@younesbelkada younesbelkada Jun 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small note on the test. I have ran the tests on all transformers and it fails for ~40 models:

FAILED tests/models/align/test_modeling_align.py::AlignTextModelTest::test_training_gradient_checkpointing_autocast - AssertionError: False is not true : None gradient in param pooler.dense.weight
FAILED tests/models/altclip/test_modeling_altclip.py::AltCLIPVisionModelTest::test_training_gradient_checkpointing_autocast - AssertionError: False is not true : None gradient in param vision_model.post_layernorm.weight
FAILED tests/models/autoformer/test_modeling_autoformer.py::AutoformerModelTest::test_training_gradient_checkpointing_autocast - RuntimeError: cuFFT only supports dimensions whose sizes are powers of two when computing in half precision, but got a signal size of[14]
FAILED tests/models/beit/test_modeling_beit.py::BeitModelTest::test_training_gradient_checkpointing_autocast - AssertionError: False is not true : None gradient in param beit.encoder.layer.2.lambda_1
FAILED tests/models/big_bird/test_modeling_big_bird.py::BigBirdModelTest::test_training_gradient_checkpointing_autocast - AssertionError: False is not true : None gradient in param bert.pooler.weight
FAILED tests/models/blip/test_modeling_blip.py::BlipTextRetrievalModelTest::test_training_gradient_checkpointing_autocast - AssertionError: False is not true : None gradient in param vision_proj.weight
FAILED tests/models/canine/test_modeling_canine.py::CanineModelTest::test_training_gradient_checkpointing_autocast - AssertionError: False is not true : None gradient in param canine.projection.conv.weight
FAILED tests/models/chinese_clip/test_modeling_chinese_clip.py::ChineseCLIPTextModelTest::test_training_gradient_checkpointing_autocast - AssertionError: False is not true : None gradient in param pooler.dense.weight
FAILED tests/models/chinese_clip/test_modeling_chinese_clip.py::ChineseCLIPVisionModelTest::test_training_gradient_checkpointing_autocast - AssertionError: False is not true : None gradient in param vision_model.post_layernorm.weight
FAILED tests/models/clip/test_modeling_clip.py::CLIPVisionModelTest::test_training_gradient_checkpointing_autocast - AssertionError: False is not true : None gradient in param vision_model.post_layernorm.weight
FAILED tests/models/clipseg/test_modeling_clipseg.py::CLIPSegVisionModelTest::test_training_gradient_checkpointing_autocast - AssertionError: False is not true : None gradient in param vision_model.post_layernorm.weight
FAILED tests/models/clipseg/test_modeling_clipseg.py::CLIPSegModelTest::test_training_gradient_checkpointing_autocast - AssertionError: False is not true : None gradient in param clip.logit_scale
FAILED tests/models/data2vec/test_modeling_data2vec_vision.py::Data2VecVisionModelTest::test_training_gradient_checkpointing_autocast - AssertionError: False is not true : None gradient in param data2vec_vision.encoder.layer.2.lambda_1
FAILED tests/models/dpt/test_modeling_dpt.py::DPTModelTest::test_training_gradient_checkpointing_autocast - AssertionError: False is not true : None gradient in param dpt.layernorm.weight
FAILED tests/models/dpt/test_modeling_dpt_hybrid.py::DPTModelTest::test_training_gradient_checkpointing_autocast - AssertionError: False is not true : None gradient in param dpt.layernorm.weight
FAILED tests/models/flava/test_modeling_flava.py::FlavaImageModelTest::test_training_gradient_checkpointing_autocast - AttributeError: 'FlavaImageModelTester' object has no attribute 'is_training'
FAILED tests/models/flava/test_modeling_flava.py::FlavaTextModelTest::test_training_gradient_checkpointing_autocast - AssertionError: False is not true : None gradient in param pooler.dense.weight
FAILED tests/models/flava/test_modeling_flava.py::FlavaMultimodalModelTest::test_training_gradient_checkpointing_autocast - AttributeError: 'FlavaMultimodalModelTester' object has no attribute 'is_training'
FAILED tests/models/flava/test_modeling_flava.py::FlavaImageCodebookTest::test_training_gradient_checkpointing_autocast - AttributeError: 'FlavaImageCodebookTester' object has no attribute 'is_training'
FAILED tests/models/flava/test_modeling_flava.py::FlavaForPreTrainingTest::test_training_gradient_checkpointing_autocast - AssertionError: False is not true : None gradient in param flava.text_model.pooler.dense.weight
FAILED tests/models/fnet/test_modeling_fnet.py::FNetModelTest::test_training_gradient_checkpointing_autocast - RuntimeError: cuFFT only supports dimensions whose sizes are powers of two when computing in half precision, but got a signal size of[7, 32]
FAILED tests/models/gpt2/test_modeling_gpt2.py::GPT2ModelTest::test_training_gradient_checkpointing_autocast - AssertionError: False is not true : None gradient in param multiple_choice_head.summary.weight
FAILED tests/models/graphormer/test_modeling_graphormer.py::GraphormerModelTest::test_training_gradient_checkpointing_autocast - AssertionError: False is not true : None gradient in param encoder.graph_encoder.graph_attn_bias.edge_encoder.weight
FAILED tests/models/imagegpt/test_modeling_imagegpt.py::ImageGPTModelTest::test_training_gradient_checkpointing_autocast - AssertionError: False is not true : None gradient in param transformer.h.0.ln_1.weight
FAILED tests/models/informer/test_modeling_informer.py::InformerModelTest::test_training_gradient_checkpointing_autocast - RuntimeError: Index put requires the source and destination dtypes match, got Float for the destination and Half for the source.
FAILED tests/models/layoutlm/test_modeling_layoutlm.py::LayoutLMModelTest::test_training_gradient_checkpointing_autocast - AssertionError: False is not true : None gradient in param layoutlm.pooler.dense.weight
FAILED tests/models/lilt/test_modeling_lilt.py::LiltModelTest::test_training_gradient_checkpointing_autocast - AssertionError: False is not true : None gradient in param lilt.encoder.layer.1.attention.self.layout_value.weight
FAILED tests/models/luke/test_modeling_luke.py::LukeModelTest::test_training_gradient_checkpointing_autocast - AssertionError: False is not true : None gradient in param luke.pooler.dense.weight
FAILED tests/models/marian/test_modeling_marian.py::MarianModelTest::test_training_gradient_checkpointing_autocast - AssertionError: False is not true : None gradient in param model.encoder.embed_positions.weight
FAILED tests/models/owlvit/test_modeling_owlvit.py::OwlViTVisionModelTest::test_training_gradient_checkpointing_autocast - AssertionError: False is not true : None gradient in param vision_model.post_layernorm.weight
FAILED tests/models/owlvit/test_modeling_owlvit.py::OwlViTForObjectDetectionTest::test_training_gradient_checkpointing_autocast - RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [12, 4, 64]], which is output 0 of AsStridedBackward0, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
FAILED tests/models/pegasus/test_modeling_pegasus.py::PegasusModelTest::test_training_gradient_checkpointing_autocast - AssertionError: False is not true : None gradient in param model.encoder.embed_positions.weight
FAILED tests/models/pix2struct/test_modeling_pix2struct.py::Pix2StructTextModelTest::test_training_gradient_checkpointing_autocast - AssertionError: False is not true : None gradient in param layer.0.encoder_decoder_attention.attention.query.weight
FAILED tests/models/regnet/test_modeling_regnet.py::RegNetModelTest::test_training_gradient_checkpointing_autocast - RuntimeError: GET was unable to find an engine to execute this computation
FAILED tests/models/roformer/test_modeling_roformer.py::RoFormerModelTest::test_training_gradient_checkpointing_autocast - AssertionError: False is not true : None gradient in param roformer.encoder.embed_positions.weight
FAILED tests/models/sam/test_modeling_sam.py::SamModelTest::test_training_gradient_checkpointing_autocast - AttributeError: 'SamModelTester' object has no attribute 'is_training'
FAILED tests/models/speech_to_text/test_modeling_speech_to_text.py::Speech2TextModelTest::test_training_gradient_checkpointing_autocast - AssertionError: False is not true : None gradient in param model.encoder.embed_positions.weights
FAILED tests/models/switch_transformers/test_modeling_switch_transformers.py::SwitchTransformersModelTest::test_training_gradient_checkpointing_autocast - RuntimeError: Index put requires the source and destination dtypes match, got Float for the destination and Half for the source.
FAILED tests/models/time_series_transformer/test_modeling_time_series_transformer.py::TimeSeriesTransformerModelTest::test_training_gradient_checkpointing_autocast - AssertionError: False is not true : None gradient in param model.encoder.embed_positions.weight
FAILED tests/models/van/test_modeling_van.py::VanModelTest::test_training_gradient_checkpointing_autocast - RuntimeError: GET was unable to find an engine to execute this computation
FAILED tests/models/vilt/test_modeling_vilt.py::ViltModelTest::test_training_gradient_checkpointing_autocast - AssertionError: False is not true : None gradient in param vilt.pooler.dense.weight
FAILED tests/models/visual_bert/test_modeling_visual_bert.py::VisualBertModelTest::test_training_gradient_checkpointing_autocast - AssertionError: False is not true : None gradient in param visual_bert.pooler.dense.weight
FAILED tests/models/vit_mae/test_modeling_vit_mae.py::ViTMAEModelTest::test_training_gradient_checkpointing_autocast - AssertionError: False is not true : None gradient in param vit.embeddings.position_embeddings
FAILED tests/models/x_clip/test_modeling_x_clip.py::XCLIPVisionModelTest::test_training_gradient_checkpointing_autocast - AssertionError: False is not true : None gradient in param vision_model.post_layernorm.weight

It fails for most of vision models, the fix is intended for a quite unique and niche use case (autocast float16 + gradient checkpointing + LLM) and it works for the targeted most usecases (large LLMs). I am not sure about the approach we should follow here (fix everything in this PR or no?), so would love to hear your thoughts here. Also I am not sure if GC is supported out of the box for vision models (a lot of GC tests are skipped in the CIs for vision models)?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did the test pass before?

@younesbelkada younesbelkada changed the title [PoC] Fix gradient checkpointing OPT + autocast [PoC] Fix gradient checkpointing + fp16 autocast Jun 14, 2023
@younesbelkada younesbelkada requested a review from sgugger June 14, 2023 15:59
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the models are tested with gradient checkpointing so I'm not surprised a lot of them actually fail. Let's maybe mark them with supports_gradient_checkpointing=False to avoid adding failing tests to the CI?

@@ -549,6 +553,41 @@ def test_training_gradient_checkpointing(self):
loss = model(**inputs).loss
loss.backward()

@slow
@pytest.mark.gradient_checkpointing_autocast_test
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the marker, you can execute a single test with the -k flag.

model.gradient_checkpointing_enable()
model.train()
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
with torch.cuda.amp.autocast(True, dtype=torch.float16):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This requires CUDA, the test should be marked as such.

@@ -1064,6 +1064,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix

is_parallelizable = False
supports_gradient_checkpointing = False
supports_cuda_fp16_gradient_checkpointing = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No we are not adding an attribute like this for just one test in the common tester 😅

Copy link
Contributor Author

@younesbelkada younesbelkada Jun 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I don't think it's for the common tester only :D I think we should properly throw an exception or a warning to users if they use a model with gradient checkpointing under autocast + fp16 (that I have attempted here - I've commented it but plan to use something in that logic), the list of models that doesn't support that scenario are listed here: #24247 (comment)
Although we can add supports_gradient_checkpointing=False to those models I don't think this is a correct behavior since technically it is supported in some cases (autocast + bf16, no autocast...)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even less for the PreTrainedModel class, I misready where you put this.

Copy link
Contributor Author

@younesbelkada younesbelkada Jun 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I see, what do you think could be the approach here, since this behaviour is very specific to some models and not observed for other architectures, there is probably a much simpler approach 🤔

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Honestly this is a bug in PyTorch, so we should not test it for the models where it doesn't work (by manually adding the test in each of those test files and skipping it) instead of adding a new model attribute.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK sounds great!

@@ -43,6 +43,7 @@
TokenClassifierOutput,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not really familiar with that file so I just wanted to double check if the changes are all good here or no? @sgugger

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thnaks for all your work on this!

@younesbelkada younesbelkada changed the title [PoC] Fix gradient checkpointing + fp16 autocast Fix gradient checkpointing + fp16 autocast for most models Jun 20, 2023
@younesbelkada younesbelkada marked this pull request as ready for review June 20, 2023 16:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants