diff --git a/src/transformers/models/albert/modeling_albert.py b/src/transformers/models/albert/modeling_albert.py index 9eadaa219834ee..6aaf187b040b45 100755 --- a/src/transformers/models/albert/modeling_albert.py +++ b/src/transformers/models/albert/modeling_albert.py @@ -34,7 +34,11 @@ TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, +) from ...utils import ( ModelOutput, add_code_sample_docstrings, diff --git a/src/transformers/models/align/modeling_align.py b/src/transformers/models/align/modeling_align.py index 09ee6eca62650e..0a7238ef606976 100644 --- a/src/transformers/models/align/modeling_align.py +++ b/src/transformers/models/align/modeling_align.py @@ -30,7 +30,12 @@ BaseModelOutputWithPoolingAndNoAttention, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import ( ModelOutput, add_start_docstrings, @@ -1100,7 +1105,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/altclip/modeling_altclip.py b/src/transformers/models/altclip/modeling_altclip.py index 26b3f59280810b..68a3a28a480128 100755 --- a/src/transformers/models/altclip/modeling_altclip.py +++ b/src/transformers/models/altclip/modeling_altclip.py @@ -30,7 +30,12 @@ BaseModelOutputWithPoolingAndProjection, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import ModelOutput, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from .configuration_altclip import AltCLIPConfig, AltCLIPTextConfig, AltCLIPVisionConfig @@ -651,7 +656,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -965,7 +970,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py b/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py index 0f8c045121c749..15b8c379352d8a 100644 --- a/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +++ b/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py @@ -25,7 +25,7 @@ from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, SequenceClassifierOutput from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer, torch_custom_checkpointing from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from .configuration_audio_spectrogram_transformer import ASTConfig @@ -343,7 +343,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, layer_head_mask, diff --git a/src/transformers/models/autoformer/modeling_autoformer.py b/src/transformers/models/autoformer/modeling_autoformer.py index 70587add17e721..981df3ab845c3b 100644 --- a/src/transformers/models/autoformer/modeling_autoformer.py +++ b/src/transformers/models/autoformer/modeling_autoformer.py @@ -34,6 +34,7 @@ Seq2SeqTSPredictionOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from .configuration_autoformer import AutoformerConfig @@ -1210,7 +1211,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -1428,7 +1429,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 50452449021c32..9000ad3d060266 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -35,6 +35,7 @@ Seq2SeqSequenceClassifierOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( add_code_sample_docstrings, add_end_docstrings, @@ -849,7 +850,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -1105,7 +1106,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/beit/modeling_beit.py b/src/transformers/models/beit/modeling_beit.py index b17721fb2bcd32..b546f14001911c 100755 --- a/src/transformers/models/beit/modeling_beit.py +++ b/src/transformers/models/beit/modeling_beit.py @@ -34,7 +34,7 @@ SemanticSegmenterOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer +from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer, torch_custom_checkpointing from ...utils import ( add_code_sample_docstrings, add_start_docstrings, @@ -517,7 +517,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, layer_head_mask, diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index fb92a0e84cc49e..37f236d4a60291 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -40,7 +40,12 @@ TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import ( ModelOutput, add_code_sample_docstrings, @@ -598,7 +603,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/bert_generation/modeling_bert_generation.py b/src/transformers/models/bert_generation/modeling_bert_generation.py index f92b7a0633e8cb..f20503c594dff1 100755 --- a/src/transformers/models/bert_generation/modeling_bert_generation.py +++ b/src/transformers/models/bert_generation/modeling_bert_generation.py @@ -25,7 +25,12 @@ from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import ( add_code_sample_docstrings, add_start_docstrings, @@ -408,7 +413,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py index e1346a23c9db5b..5e80d0423f7443 100755 --- a/src/transformers/models/big_bird/modeling_big_bird.py +++ b/src/transformers/models/big_bird/modeling_big_bird.py @@ -37,7 +37,7 @@ TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward +from ...pytorch_utils import apply_chunking_to_forward, torch_custom_checkpointing from ...utils import ( ModelOutput, add_code_sample_docstrings, @@ -1622,7 +1622,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index 8d7906631d54f2..1ab72f0b49121c 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -36,6 +36,7 @@ Seq2SeqSequenceClassifierOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( add_code_sample_docstrings, add_end_docstrings, @@ -1945,7 +1946,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -2291,7 +2292,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index a9ecb11a61f1c2..c29c13547eb3eb 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -32,6 +32,7 @@ TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( add_code_sample_docstrings, add_start_docstrings, @@ -594,7 +595,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index 8f2780772cbd39..f96531f51f7684 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -36,6 +36,7 @@ Seq2SeqModelOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( add_end_docstrings, add_start_docstrings, @@ -779,7 +780,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -1034,7 +1035,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index ef8d51a2b0e7ba..b09dce88e02e72 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -34,6 +34,7 @@ Seq2SeqModelOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( add_end_docstrings, add_start_docstrings, @@ -777,7 +778,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -1031,7 +1032,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/blip/modeling_blip.py b/src/transformers/models/blip/modeling_blip.py index f16b89b7a316e7..93bb26c5b969f6 100644 --- a/src/transformers/models/blip/modeling_blip.py +++ b/src/transformers/models/blip/modeling_blip.py @@ -25,6 +25,7 @@ from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( ModelOutput, add_start_docstrings, @@ -620,7 +621,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/blip/modeling_blip_text.py b/src/transformers/models/blip/modeling_blip_text.py index 1f269cf852ee0d..38866578b6b021 100644 --- a/src/transformers/models/blip/modeling_blip_text.py +++ b/src/transformers/models/blip/modeling_blip_text.py @@ -34,6 +34,7 @@ find_pruneable_heads_and_indices, prune_linear_layer, ) +from ...pytorch_utils import torch_custom_checkpointing from ...utils import logging from .configuration_blip import BlipTextConfig @@ -427,7 +428,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index 82a879771b786f..b326ff36c7ef3d 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -31,7 +31,12 @@ BaseModelOutputWithPoolingAndCrossAttentions, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import ( ModelOutput, add_start_docstrings, @@ -492,7 +497,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -963,7 +968,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 4f6de49a144711..2144c43687ae2b 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -33,6 +33,7 @@ TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import logging from .configuration_bloom import BloomConfig @@ -775,7 +776,7 @@ def custom_forward(*inputs): return custom_forward - outputs = torch.utils.checkpoint.checkpoint( + outputs = torch_custom_checkpointing( create_custom_forward(block), hidden_states, alibi, diff --git a/src/transformers/models/bridgetower/modeling_bridgetower.py b/src/transformers/models/bridgetower/modeling_bridgetower.py index 4290241fbc097d..37424e03545a93 100644 --- a/src/transformers/models/bridgetower/modeling_bridgetower.py +++ b/src/transformers/models/bridgetower/modeling_bridgetower.py @@ -32,8 +32,13 @@ ModelOutput, SequenceClassifierOutput, ) -from ...modeling_utils import PreTrainedModel, apply_chunking_to_forward -from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from .configuration_bridgetower import BridgeTowerConfig, BridgeTowerTextConfig, BridgeTowerVisionConfig @@ -810,7 +815,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/camembert/modeling_camembert.py b/src/transformers/models/camembert/modeling_camembert.py index e98840fbc6d2a6..25d11d24e14cfb 100644 --- a/src/transformers/models/camembert/modeling_camembert.py +++ b/src/transformers/models/camembert/modeling_camembert.py @@ -35,7 +35,12 @@ TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import ( add_code_sample_docstrings, add_start_docstrings, @@ -529,7 +534,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/canine/modeling_canine.py b/src/transformers/models/canine/modeling_canine.py index a91d42f0395ee8..8406a9d1d42fa9 100644 --- a/src/transformers/models/canine/modeling_canine.py +++ b/src/transformers/models/canine/modeling_canine.py @@ -36,7 +36,12 @@ TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import ( add_code_sample_docstrings, add_start_docstrings, @@ -800,7 +805,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/chinese_clip/modeling_chinese_clip.py b/src/transformers/models/chinese_clip/modeling_chinese_clip.py index 0adf5cfdcb1857..975857024e337f 100644 --- a/src/transformers/models/chinese_clip/modeling_chinese_clip.py +++ b/src/transformers/models/chinese_clip/modeling_chinese_clip.py @@ -31,7 +31,12 @@ BaseModelOutputWithPoolingAndCrossAttentions, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import ( ModelOutput, add_code_sample_docstrings, @@ -914,7 +919,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -1023,7 +1028,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, ) diff --git a/src/transformers/models/clap/modeling_clap.py b/src/transformers/models/clap/modeling_clap.py index c4dbcb03f34df7..fa836700066b8b 100644 --- a/src/transformers/models/clap/modeling_clap.py +++ b/src/transformers/models/clap/modeling_clap.py @@ -30,7 +30,13 @@ BaseModelOutputWithPoolingAndCrossAttentions, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, meshgrid, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + meshgrid, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import ( ModelOutput, add_start_docstrings, @@ -947,7 +953,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, input_dimensions, layer_head_mask ) else: @@ -1601,7 +1607,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/clip/modeling_clip.py b/src/transformers/models/clip/modeling_clip.py index ee9d660ef71347..6a96715a276bcb 100644 --- a/src/transformers/models/clip/modeling_clip.py +++ b/src/transformers/models/clip/modeling_clip.py @@ -25,6 +25,7 @@ from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( ModelOutput, add_start_docstrings, @@ -644,7 +645,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/clipseg/modeling_clipseg.py b/src/transformers/models/clipseg/modeling_clipseg.py index 85b11965306861..fc37277c34d5bf 100644 --- a/src/transformers/models/clipseg/modeling_clipseg.py +++ b/src/transformers/models/clipseg/modeling_clipseg.py @@ -26,6 +26,7 @@ from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( ModelOutput, add_start_docstrings, @@ -654,7 +655,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py index 8b1d34f59e7bf6..7cee097b2b1aaf 100644 --- a/src/transformers/models/codegen/modeling_codegen.py +++ b/src/transformers/models/codegen/modeling_codegen.py @@ -24,6 +24,7 @@ from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from .configuration_codegen import CodeGenConfig @@ -549,7 +550,7 @@ def custom_forward(*inputs): return custom_forward - outputs = torch.utils.checkpoint.checkpoint( + outputs = torch_custom_checkpointing( create_custom_forward(block), hidden_states, None, diff --git a/src/transformers/models/conditional_detr/modeling_conditional_detr.py b/src/transformers/models/conditional_detr/modeling_conditional_detr.py index 023cb278484193..44d9cc9bb5994a 100644 --- a/src/transformers/models/conditional_detr/modeling_conditional_detr.py +++ b/src/transformers/models/conditional_detr/modeling_conditional_detr.py @@ -26,6 +26,7 @@ from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( ModelOutput, add_start_docstrings, @@ -1395,7 +1396,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, combined_attention_mask, diff --git a/src/transformers/models/convbert/modeling_convbert.py b/src/transformers/models/convbert/modeling_convbert.py index bbdba210c23330..49923ba1234ecd 100755 --- a/src/transformers/models/convbert/modeling_convbert.py +++ b/src/transformers/models/convbert/modeling_convbert.py @@ -35,7 +35,12 @@ TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel, SequenceSummary -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from .configuration_convbert import ConvBertConfig @@ -639,7 +644,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/cvt/modeling_cvt.py b/src/transformers/models/cvt/modeling_cvt.py index 99e3a02febf4d2..8784fff414fb52 100644 --- a/src/transformers/models/cvt/modeling_cvt.py +++ b/src/transformers/models/cvt/modeling_cvt.py @@ -26,7 +26,8 @@ from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward from ...modeling_outputs import ImageClassifierOutputWithNoAttention, ModelOutput -from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import logging from .configuration_cvt import CvtConfig diff --git a/src/transformers/models/data2vec/modeling_data2vec_audio.py b/src/transformers/models/data2vec/modeling_data2vec_audio.py index 168f342acd3200..72a53c292cee82 100755 --- a/src/transformers/models/data2vec/modeling_data2vec_audio.py +++ b/src/transformers/models/data2vec/modeling_data2vec_audio.py @@ -35,6 +35,7 @@ XVectorOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from .configuration_data2vec_audio import Data2VecAudioConfig @@ -300,7 +301,7 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint( + hidden_states = torch_custom_checkpointing( create_custom_forward(conv_layer), hidden_states, ) @@ -600,7 +601,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer), hidden_states, attention_mask, diff --git a/src/transformers/models/data2vec/modeling_data2vec_text.py b/src/transformers/models/data2vec/modeling_data2vec_text.py index 206fe1603b0045..45c182a95c3d73 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_text.py +++ b/src/transformers/models/data2vec/modeling_data2vec_text.py @@ -34,7 +34,12 @@ TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import ( add_code_sample_docstrings, add_start_docstrings, @@ -515,7 +520,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/data2vec/modeling_data2vec_vision.py b/src/transformers/models/data2vec/modeling_data2vec_vision.py index 77b424354892b9..cbef81d2a81bb8 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_vision.py +++ b/src/transformers/models/data2vec/modeling_data2vec_vision.py @@ -33,7 +33,7 @@ SemanticSegmenterOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer +from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer, torch_custom_checkpointing from ...utils import ( add_code_sample_docstrings, add_start_docstrings, @@ -529,7 +529,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, layer_head_mask, diff --git a/src/transformers/models/deberta/modeling_deberta.py b/src/transformers/models/deberta/modeling_deberta.py index 9a0d43db3a0aec..260e713d5b9e78 100644 --- a/src/transformers/models/deberta/modeling_deberta.py +++ b/src/transformers/models/deberta/modeling_deberta.py @@ -31,7 +31,7 @@ TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import softmax_backward_data +from ...pytorch_utils import softmax_backward_data, torch_custom_checkpointing from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from .configuration_deberta import DebertaConfig @@ -464,7 +464,7 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint( + hidden_states = torch_custom_checkpointing( create_custom_forward(layer_module), next_kv, attention_mask, diff --git a/src/transformers/models/deberta_v2/modeling_deberta_v2.py b/src/transformers/models/deberta_v2/modeling_deberta_v2.py index 1596ad4ffad42e..22aef359240422 100644 --- a/src/transformers/models/deberta_v2/modeling_deberta_v2.py +++ b/src/transformers/models/deberta_v2/modeling_deberta_v2.py @@ -32,7 +32,7 @@ TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import softmax_backward_data +from ...pytorch_utils import softmax_backward_data, torch_custom_checkpointing from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from .configuration_deberta_v2 import DebertaV2Config @@ -508,7 +508,7 @@ def custom_forward(*inputs): return custom_forward - output_states = torch.utils.checkpoint.checkpoint( + output_states = torch_custom_checkpointing( create_custom_forward(layer_module), next_kv, attention_mask, diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py index 926947b1617de8..64d64191c484c0 100755 --- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py +++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py @@ -27,7 +27,7 @@ from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer +from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer, torch_custom_checkpointing from ...utils import ( ModelOutput, add_start_docstrings, @@ -643,7 +643,7 @@ def custom_forward(*inputs): return custom_forward - outputs = torch.utils.checkpoint.checkpoint( + outputs = torch_custom_checkpointing( create_custom_forward(block), hidden_states, None, diff --git a/src/transformers/models/deformable_detr/modeling_deformable_detr.py b/src/transformers/models/deformable_detr/modeling_deformable_detr.py index 6469cf7a65df9e..fc195622c2ddae 100755 --- a/src/transformers/models/deformable_detr/modeling_deformable_detr.py +++ b/src/transformers/models/deformable_detr/modeling_deformable_detr.py @@ -41,7 +41,7 @@ ) from ...modeling_outputs import BaseModelOutput from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import meshgrid +from ...pytorch_utils import meshgrid, torch_custom_checkpointing from ...utils import is_ninja_available, logging from ..auto import AutoBackbone from .configuration_deformable_detr import DeformableDetrConfig @@ -1380,7 +1380,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, encoder_hidden_states, diff --git a/src/transformers/models/deit/modeling_deit.py b/src/transformers/models/deit/modeling_deit.py index 8b03835812fcdf..4c5491935cad5a 100644 --- a/src/transformers/models/deit/modeling_deit.py +++ b/src/transformers/models/deit/modeling_deit.py @@ -33,7 +33,7 @@ MaskedImageModelingOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer, torch_custom_checkpointing from ...utils import ( ModelOutput, add_code_sample_docstrings, @@ -364,7 +364,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, layer_head_mask, diff --git a/src/transformers/models/deta/modeling_deta.py b/src/transformers/models/deta/modeling_deta.py index af218829d6f9ab..67427b4f4137d8 100644 --- a/src/transformers/models/deta/modeling_deta.py +++ b/src/transformers/models/deta/modeling_deta.py @@ -36,7 +36,7 @@ ) from ...modeling_outputs import BaseModelOutput from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import meshgrid +from ...pytorch_utils import meshgrid, torch_custom_checkpointing from ...utils import is_torchvision_available, logging, requires_backends from ..auto import AutoBackbone from .configuration_deta import DetaConfig @@ -1272,7 +1272,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, encoder_hidden_states, diff --git a/src/transformers/models/detr/modeling_detr.py b/src/transformers/models/detr/modeling_detr.py index c92c43e46d18e9..684129663fa846 100644 --- a/src/transformers/models/detr/modeling_detr.py +++ b/src/transformers/models/detr/modeling_detr.py @@ -26,6 +26,7 @@ from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( ModelOutput, add_start_docstrings, @@ -1130,7 +1131,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, combined_attention_mask, diff --git a/src/transformers/models/donut/modeling_donut_swin.py b/src/transformers/models/donut/modeling_donut_swin.py index 65c48eb81f8368..07f9fee14ed656 100644 --- a/src/transformers/models/donut/modeling_donut_swin.py +++ b/src/transformers/models/donut/modeling_donut_swin.py @@ -28,7 +28,7 @@ from ...activations import ACT2FN from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer +from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer, torch_custom_checkpointing from ...utils import ( ModelOutput, add_code_sample_docstrings, @@ -756,7 +756,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, input_dimensions, layer_head_mask ) else: diff --git a/src/transformers/models/dpt/modeling_dpt.py b/src/transformers/models/dpt/modeling_dpt.py index 187a6c36656a8e..0630a3c48be941 100755 --- a/src/transformers/models/dpt/modeling_dpt.py +++ b/src/transformers/models/dpt/modeling_dpt.py @@ -39,7 +39,7 @@ ) from ...modeling_outputs import BaseModelOutput, DepthEstimatorOutput, SemanticSegmenterOutput from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer, torch_custom_checkpointing from ...utils import ModelOutput, logging from ..auto import AutoBackbone from .configuration_dpt import DPTConfig @@ -535,7 +535,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, layer_head_mask, diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index a7ee4ec9320204..3197e060bd1a7f 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -36,7 +36,12 @@ TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel, SequenceSummary -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import ( ModelOutput, add_code_sample_docstrings, @@ -576,7 +581,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/ernie/modeling_ernie.py b/src/transformers/models/ernie/modeling_ernie.py index b8df1b2d5035c3..a5f16a3a867f7f 100644 --- a/src/transformers/models/ernie/modeling_ernie.py +++ b/src/transformers/models/ernie/modeling_ernie.py @@ -38,7 +38,12 @@ TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import ( ModelOutput, add_code_sample_docstrings, @@ -511,7 +516,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/esm/modeling_esm.py b/src/transformers/models/esm/modeling_esm.py index e0b26e0f7812b7..27b7bb2d917cb2 100755 --- a/src/transformers/models/esm/modeling_esm.py +++ b/src/transformers/models/esm/modeling_esm.py @@ -30,7 +30,8 @@ SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer, torch_custom_checkpointing from ...utils import logging from .configuration_esm import EsmConfig @@ -610,7 +611,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/flava/modeling_flava.py b/src/transformers/models/flava/modeling_flava.py index 5d49197f8ca50e..0f85ff06f58561 100644 --- a/src/transformers/models/flava/modeling_flava.py +++ b/src/transformers/models/flava/modeling_flava.py @@ -26,7 +26,8 @@ from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling -from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer, torch_custom_checkpointing from ...utils import ( ModelOutput, add_code_sample_docstrings, @@ -668,7 +669,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/fnet/modeling_fnet.py b/src/transformers/models/fnet/modeling_fnet.py index 6bc526eeebcb91..8d8de88c8f1d6b 100755 --- a/src/transformers/models/fnet/modeling_fnet.py +++ b/src/transformers/models/fnet/modeling_fnet.py @@ -43,7 +43,7 @@ TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward +from ...pytorch_utils import apply_chunking_to_forward, torch_custom_checkpointing from ...utils import ( add_code_sample_docstrings, add_start_docstrings, @@ -297,7 +297,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint(create_custom_forward(layer_module), hidden_states) + layer_outputs = torch_custom_checkpointing(create_custom_forward(layer_module), hidden_states) else: layer_outputs = layer_module(hidden_states) diff --git a/src/transformers/models/focalnet/modeling_focalnet.py b/src/transformers/models/focalnet/modeling_focalnet.py index fc327ad0b39f8c..9e8efed44388db 100644 --- a/src/transformers/models/focalnet/modeling_focalnet.py +++ b/src/transformers/models/focalnet/modeling_focalnet.py @@ -28,6 +28,7 @@ from ...activations import ACT2FN from ...modeling_outputs import BackboneOutput from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( ModelOutput, add_code_sample_docstrings, @@ -593,7 +594,7 @@ def custom_forward(*inputs): return custom_forward - stage_outputs = torch.utils.checkpoint.checkpoint( + stage_outputs = torch_custom_checkpointing( create_custom_forward(stage_module), hidden_states, input_dimensions, diff --git a/src/transformers/models/git/modeling_git.py b/src/transformers/models/git/modeling_git.py index 23ae6d64962fe7..83bf591fdb9885 100644 --- a/src/transformers/models/git/modeling_git.py +++ b/src/transformers/models/git/modeling_git.py @@ -34,7 +34,12 @@ CausalLMOutputWithPast, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from .configuration_git import GitConfig, GitVisionConfig @@ -457,7 +462,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -883,7 +888,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index b9a8568f00e7fd..dab2c613532bde 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -35,8 +35,12 @@ SequenceClassifierOutputWithPast, TokenClassifierOutput, ) -from ...modeling_utils import PreTrainedModel, SequenceSummary -from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer +from ...modeling_utils import Conv1D, PreTrainedModel, SequenceSummary +from ...pytorch_utils import ( + find_pruneable_heads_and_indices, + prune_conv1d_layer, + torch_custom_checkpointing, +) from ...utils import ( ModelOutput, add_code_sample_docstrings, @@ -890,7 +894,7 @@ def custom_forward(*inputs): return custom_forward - outputs = torch.utils.checkpoint.checkpoint( + outputs = torch_custom_checkpointing( create_custom_forward(block), hidden_states, None, diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 705d07b1da257f..cf23e1ba08a512 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -28,6 +28,7 @@ TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( add_code_sample_docstrings, add_start_docstrings, @@ -661,7 +662,7 @@ def custom_forward(*inputs): return custom_forward - outputs = torch.utils.checkpoint.checkpoint( + outputs = torch_custom_checkpointing( create_custom_forward(block), hidden_states, None, diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index b67f4ddbfacac3..768893cb447148 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -34,6 +34,7 @@ TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from .configuration_gpt_neo import GPTNeoConfig @@ -613,7 +614,7 @@ def custom_forward(*inputs): return custom_forward - outputs = torch.utils.checkpoint.checkpoint( + outputs = torch_custom_checkpointing( create_custom_forward(block), hidden_states, None, diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 7c3bfd1035f904..3f7bbcdf64e601 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -36,6 +36,7 @@ TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import logging from .configuration_gpt_neox import GPTNeoXConfig @@ -557,7 +558,7 @@ def custom_forward(*inputs): return custom_forward - outputs = torch.utils.checkpoint.checkpoint( + outputs = torch_custom_checkpointing( create_custom_forward(layer), hidden_states, attention_mask, diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index de120167989d84..4969bd7fd1bb1c 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -31,6 +31,7 @@ SequenceClassifierOutputWithPast, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( add_code_sample_docstrings, add_start_docstrings, @@ -677,7 +678,7 @@ def custom_forward(*inputs): return custom_forward - outputs = torch.utils.checkpoint.checkpoint( + outputs = torch_custom_checkpointing( create_custom_forward(block), hidden_states, None, diff --git a/src/transformers/models/groupvit/modeling_groupvit.py b/src/transformers/models/groupvit/modeling_groupvit.py index c19ebd13b91d6f..e5ee94adbd2699 100644 --- a/src/transformers/models/groupvit/modeling_groupvit.py +++ b/src/transformers/models/groupvit/modeling_groupvit.py @@ -28,6 +28,7 @@ from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( ModelOutput, add_start_docstrings, @@ -1037,7 +1038,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index 70a8c079409b51..774c4826c9bb94 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -27,6 +27,7 @@ from ...deepspeed import is_deepspeed_zero3_enabled from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( add_code_sample_docstrings, add_start_docstrings, @@ -353,7 +354,7 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint( + hidden_states = torch_custom_checkpointing( create_custom_forward(conv_layer), hidden_states, ) @@ -738,7 +739,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer), hidden_states, attention_mask, @@ -828,7 +829,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer), hidden_states, attention_mask, diff --git a/src/transformers/models/imagegpt/modeling_imagegpt.py b/src/transformers/models/imagegpt/modeling_imagegpt.py index 539119fabf281d..31b911431f92ca 100755 --- a/src/transformers/models/imagegpt/modeling_imagegpt.py +++ b/src/transformers/models/imagegpt/modeling_imagegpt.py @@ -32,7 +32,7 @@ SequenceClassifierOutputWithPast, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer +from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer, torch_custom_checkpointing from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from .configuration_imagegpt import ImageGPTConfig @@ -826,7 +826,7 @@ def custom_forward(*inputs): return custom_forward - outputs = torch.utils.checkpoint.checkpoint( + outputs = torch_custom_checkpointing( create_custom_forward(block), hidden_states, None, diff --git a/src/transformers/models/informer/modeling_informer.py b/src/transformers/models/informer/modeling_informer.py index 2bf3f208a903fd..4774f1d91d6655 100644 --- a/src/transformers/models/informer/modeling_informer.py +++ b/src/transformers/models/informer/modeling_informer.py @@ -30,6 +30,7 @@ Seq2SeqTSPredictionOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from .configuration_informer import InformerConfig @@ -1217,14 +1218,14 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), ) if conv_layer is not None: - output = torch.utils.checkpoint.checkpoint(conv_layer, layer_outputs[0]) + output = torch_custom_checkpointing(conv_layer, layer_outputs[0]) layer_outputs = (output,) + layer_outputs[1:] else: layer_outputs = encoder_layer( @@ -1440,7 +1441,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/layoutlm/modeling_layoutlm.py b/src/transformers/models/layoutlm/modeling_layoutlm.py index 410f76509422f3..614bebe121961c 100644 --- a/src/transformers/models/layoutlm/modeling_layoutlm.py +++ b/src/transformers/models/layoutlm/modeling_layoutlm.py @@ -33,7 +33,12 @@ TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from .configuration_layoutlm import LayoutLMConfig @@ -492,7 +497,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py b/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py index 5a6f39ce31a6e1..0e0f2c1bd82361 100755 --- a/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py +++ b/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py @@ -31,7 +31,7 @@ TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward +from ...pytorch_utils import apply_chunking_to_forward, torch_custom_checkpointing from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, @@ -455,7 +455,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py index db6618caaeaf30..31fb1f6fb5728f 100644 --- a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py +++ b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py @@ -32,7 +32,7 @@ TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward +from ...pytorch_utils import apply_chunking_to_forward, torch_custom_checkpointing from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from .configuration_layoutlmv3 import LayoutLMv3Config @@ -671,7 +671,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/led/modeling_led.py b/src/transformers/models/led/modeling_led.py index a11659e3893389..8fa8c00aadf736 100755 --- a/src/transformers/models/led/modeling_led.py +++ b/src/transformers/models/led/modeling_led.py @@ -35,6 +35,7 @@ Seq2SeqSequenceClassifierOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( ModelOutput, add_code_sample_docstrings, @@ -1884,7 +1885,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -2150,7 +2151,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, combined_attention_mask, diff --git a/src/transformers/models/lilt/modeling_lilt.py b/src/transformers/models/lilt/modeling_lilt.py index 74454d244e8d31..1953992d058fb3 100644 --- a/src/transformers/models/lilt/modeling_lilt.py +++ b/src/transformers/models/lilt/modeling_lilt.py @@ -31,7 +31,12 @@ TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from .configuration_lilt import LiltConfig @@ -519,7 +524,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, layout_inputs, diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index c9debdd252dc7a..2468f7088ba455 100755 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -29,6 +29,7 @@ from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from .configuration_llama import LlamaConfig @@ -568,7 +569,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/longformer/modeling_longformer.py b/src/transformers/models/longformer/modeling_longformer.py index 665e2cb56421b6..809d889eed47b4 100755 --- a/src/transformers/models/longformer/modeling_longformer.py +++ b/src/transformers/models/longformer/modeling_longformer.py @@ -25,7 +25,12 @@ from ...activations import ACT2FN, gelu from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import ( ModelOutput, add_code_sample_docstrings, @@ -1311,7 +1316,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index 1a49444e8a509c..d1358a78d8f536 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -23,7 +23,6 @@ import torch from torch import nn from torch.nn import CrossEntropyLoss -from torch.utils.checkpoint import checkpoint from ...activations import ACT2FN from ...modeling_outputs import ( @@ -33,7 +32,12 @@ Seq2SeqModelOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import ALL_LAYERNORM_LAYERS, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + ALL_LAYERNORM_LAYERS, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import ( DUMMY_INPUTS, DUMMY_MASK, @@ -1517,7 +1521,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, extended_attention_mask, diff --git a/src/transformers/models/luke/modeling_luke.py b/src/transformers/models/luke/modeling_luke.py index ba21d3deb32e8d..0f217909b0ca31 100644 --- a/src/transformers/models/luke/modeling_luke.py +++ b/src/transformers/models/luke/modeling_luke.py @@ -26,7 +26,7 @@ from ...activations import ACT2FN, gelu from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward +from ...pytorch_utils import apply_chunking_to_forward, torch_custom_checkpointing from ...utils import ( ModelOutput, add_code_sample_docstrings, @@ -795,7 +795,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), word_hidden_states, entity_hidden_states, diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index f8f9e1d3a8ee3d..db8e017d17f31f 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -32,6 +32,7 @@ Seq2SeqModelOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( add_code_sample_docstrings, add_end_docstrings, @@ -827,7 +828,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -1074,7 +1075,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, combined_attention_mask, diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index a75f833fb5cb87..15d58baeda6ac7 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -35,6 +35,7 @@ Seq2SeqModelOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( add_end_docstrings, add_start_docstrings, @@ -790,7 +791,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -1039,7 +1040,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/markuplm/modeling_markuplm.py b/src/transformers/models/markuplm/modeling_markuplm.py index 0c6847b47815ce..0792ff1b723eec 100755 --- a/src/transformers/models/markuplm/modeling_markuplm.py +++ b/src/transformers/models/markuplm/modeling_markuplm.py @@ -43,6 +43,7 @@ find_pruneable_heads_and_indices, prune_linear_layer, ) +from ...pytorch_utils import torch_custom_checkpointing from ...utils import logging from .configuration_markuplm import MarkupLMConfig @@ -653,7 +654,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/mask2former/modeling_mask2former.py b/src/transformers/models/mask2former/modeling_mask2former.py index 4cb2493e58c8bb..2fd61a179b1bc4 100644 --- a/src/transformers/models/mask2former/modeling_mask2former.py +++ b/src/transformers/models/mask2former/modeling_mask2former.py @@ -36,6 +36,7 @@ ) from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import logging from .configuration_mask2former import Mask2FormerConfig @@ -1875,7 +1876,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/maskformer/modeling_maskformer.py b/src/transformers/models/maskformer/modeling_maskformer.py index 830f8b23c81602..2b91e975ce1ece 100644 --- a/src/transformers/models/maskformer/modeling_maskformer.py +++ b/src/transformers/models/maskformer/modeling_maskformer.py @@ -28,6 +28,7 @@ from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutputWithCrossAttentions from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( ModelOutput, add_start_docstrings, @@ -776,7 +777,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, combined_attention_mask, diff --git a/src/transformers/models/maskformer/modeling_maskformer_swin.py b/src/transformers/models/maskformer/modeling_maskformer_swin.py index 7016b598e8535b..e22f466edce82e 100644 --- a/src/transformers/models/maskformer/modeling_maskformer_swin.py +++ b/src/transformers/models/maskformer/modeling_maskformer_swin.py @@ -28,7 +28,7 @@ from ...file_utils import ModelOutput from ...modeling_outputs import BackboneOutput from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer +from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer, torch_custom_checkpointing from ...utils.backbone_utils import BackboneMixin from .configuration_maskformer_swin import MaskFormerSwinConfig @@ -695,7 +695,7 @@ def custom_forward(*inputs): return custom_forward - layer_hidden_states, output_dimensions, layer_all_hidden_states = torch.utils.checkpoint.checkpoint( + layer_hidden_states, output_dimensions, layer_all_hidden_states = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, layer_head_mask ) else: diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 67750ab42f7118..660177708a835e 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -34,6 +34,7 @@ Seq2SeqSequenceClassifierOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( add_code_sample_docstrings, add_end_docstrings, @@ -831,7 +832,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -1089,7 +1090,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/mctct/modeling_mctct.py b/src/transformers/models/mctct/modeling_mctct.py index 08e280b3ccf9b2..22838d4e28d0e1 100755 --- a/src/transformers/models/mctct/modeling_mctct.py +++ b/src/transformers/models/mctct/modeling_mctct.py @@ -33,6 +33,7 @@ find_pruneable_heads_and_indices, prune_linear_layer, ) +from ...pytorch_utils import torch_custom_checkpointing from ...utils import logging from .configuration_mctct import MCTCTConfig @@ -623,7 +624,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/megatron_bert/modeling_megatron_bert.py b/src/transformers/models/megatron_bert/modeling_megatron_bert.py index bba7e7369cb8a1..9a24f41e70b0dc 100755 --- a/src/transformers/models/megatron_bert/modeling_megatron_bert.py +++ b/src/transformers/models/megatron_bert/modeling_megatron_bert.py @@ -40,7 +40,12 @@ TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import ( ModelOutput, add_code_sample_docstrings, @@ -556,7 +561,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/mobilevit/modeling_mobilevit.py b/src/transformers/models/mobilevit/modeling_mobilevit.py index 3503e86c9c75c2..e68357e6d37dc0 100755 --- a/src/transformers/models/mobilevit/modeling_mobilevit.py +++ b/src/transformers/models/mobilevit/modeling_mobilevit.py @@ -33,7 +33,7 @@ SemanticSegmenterOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer, torch_custom_checkpointing from ...utils import ( add_code_sample_docstrings, add_start_docstrings, @@ -633,7 +633,7 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint( + hidden_states = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, ) diff --git a/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py b/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py index b8c071a74f4b1e..bd2f2bd9cf6ff2 100644 --- a/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py +++ b/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py @@ -32,6 +32,7 @@ SemanticSegmenterOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( add_code_sample_docstrings, add_start_docstrings, @@ -589,7 +590,7 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint( + hidden_states = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, ) diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index a3cfce8ffc4a3f..ce5c81f63425b5 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -23,7 +23,6 @@ import torch from torch import nn from torch.nn import CrossEntropyLoss -from torch.utils.checkpoint import checkpoint from ...activations import ACT2FN from ...modeling_outputs import ( @@ -33,7 +32,7 @@ Seq2SeqModelOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer, torch_custom_checkpointing from ...utils import ( DUMMY_INPUTS, DUMMY_MASK, @@ -1046,7 +1045,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, extended_attention_mask, diff --git a/src/transformers/models/mvp/modeling_mvp.py b/src/transformers/models/mvp/modeling_mvp.py index 6a44768d8eec86..4f905a7b51ed60 100644 --- a/src/transformers/models/mvp/modeling_mvp.py +++ b/src/transformers/models/mvp/modeling_mvp.py @@ -34,6 +34,7 @@ Seq2SeqSequenceClassifierOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( add_code_sample_docstrings, add_end_docstrings, @@ -953,7 +954,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -1231,7 +1232,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/nezha/modeling_nezha.py b/src/transformers/models/nezha/modeling_nezha.py index 97c5b5a90ec3b5..68a78a64faeb14 100644 --- a/src/transformers/models/nezha/modeling_nezha.py +++ b/src/transformers/models/nezha/modeling_nezha.py @@ -38,7 +38,12 @@ TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import ( ModelOutput, add_code_sample_docstrings, @@ -584,7 +589,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index 06b61c7497dbe3..d67032d119bcb0 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -22,7 +22,6 @@ import torch import torch.nn as nn from torch.nn import CrossEntropyLoss -from torch.utils.checkpoint import checkpoint from ...activations import ACT2FN from ...deepspeed import is_deepspeed_zero3_enabled @@ -33,6 +32,7 @@ Seq2SeqMoEOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( add_end_docstrings, add_start_docstrings, @@ -1155,7 +1155,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -1428,7 +1428,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, combined_attention_mask, diff --git a/src/transformers/models/nystromformer/modeling_nystromformer.py b/src/transformers/models/nystromformer/modeling_nystromformer.py index b859b0db1d4f4f..6bbd95e7091da6 100755 --- a/src/transformers/models/nystromformer/modeling_nystromformer.py +++ b/src/transformers/models/nystromformer/modeling_nystromformer.py @@ -33,7 +33,12 @@ TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from .configuration_nystromformer import NystromformerConfig @@ -375,7 +380,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/oneformer/modeling_oneformer.py b/src/transformers/models/oneformer/modeling_oneformer.py index a874611acde892..1e2c59a717d5fd 100644 --- a/src/transformers/models/oneformer/modeling_oneformer.py +++ b/src/transformers/models/oneformer/modeling_oneformer.py @@ -28,6 +28,7 @@ from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( ModelOutput, add_start_docstrings, @@ -2619,7 +2620,7 @@ def __init__( def forward(self, hidden_states: torch.Tensor): for layer in self.layers: if self.use_checkpoint: - hidden_states = torch.utils.checkpoint.checkpoint(layer, hidden_states) + hidden_states = torch_custom_checkpointing(layer, hidden_states) else: hidden_states = layer(hidden_states) return hidden_states diff --git a/src/transformers/models/open_llama/modeling_open_llama.py b/src/transformers/models/open_llama/modeling_open_llama.py index 16ad554dc31344..07b19a808de316 100644 --- a/src/transformers/models/open_llama/modeling_open_llama.py +++ b/src/transformers/models/open_llama/modeling_open_llama.py @@ -29,6 +29,7 @@ from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from .configuration_open_llama import OpenLlamaConfig @@ -603,7 +604,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index bd64630c6200f5..79b555a5e3e0a2 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -29,6 +29,7 @@ SequenceClassifierOutputWithPast, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( add_code_sample_docstrings, add_start_docstrings, @@ -700,7 +701,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, causal_attention_mask, diff --git a/src/transformers/models/owlvit/modeling_owlvit.py b/src/transformers/models/owlvit/modeling_owlvit.py index f65a0688578e2d..6f43ea603e3f35 100644 --- a/src/transformers/models/owlvit/modeling_owlvit.py +++ b/src/transformers/models/owlvit/modeling_owlvit.py @@ -27,6 +27,7 @@ from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( ModelOutput, add_start_docstrings, @@ -754,7 +755,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index a2bd3f3812e550..3af479971e1893 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -34,6 +34,7 @@ Seq2SeqModelOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( add_end_docstrings, add_start_docstrings, @@ -805,7 +806,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -1089,7 +1090,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index 8e380a4de5f0a0..94fc2d25ddcba8 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -33,6 +33,7 @@ Seq2SeqModelOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( add_end_docstrings, add_start_docstrings, @@ -1072,7 +1073,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, global_hidden_states, @@ -1330,7 +1331,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index 2db104a5a112af..0834bbeaaf7bdd 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -20,7 +20,6 @@ import torch import torch.utils.checkpoint from torch import nn -from torch.utils.checkpoint import checkpoint from ...activations import ACT2FN from ...modeling_outputs import ( @@ -31,7 +30,7 @@ Seq2SeqModelOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import ALL_LAYERNORM_LAYERS +from ...pytorch_utils import ALL_LAYERNORM_LAYERS, torch_custom_checkpointing from ...utils import ( DUMMY_INPUTS, DUMMY_MASK, @@ -350,7 +349,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -1502,7 +1501,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, extended_attention_mask, diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index 365429360af508..23a9f928d193fe 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -33,6 +33,7 @@ Seq2SeqSequenceClassifierOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( add_code_sample_docstrings, add_end_docstrings, @@ -810,7 +811,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -1067,7 +1068,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index 9160d5e1eb462d..007d9aadf268ba 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -28,6 +28,7 @@ from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( ModelOutput, add_start_docstrings, @@ -1336,7 +1337,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, extended_attention_mask, @@ -1577,7 +1578,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, extended_attention_mask, diff --git a/src/transformers/models/qdqbert/modeling_qdqbert.py b/src/transformers/models/qdqbert/modeling_qdqbert.py index 47a34e959072fa..d4371c0efbb250 100755 --- a/src/transformers/models/qdqbert/modeling_qdqbert.py +++ b/src/transformers/models/qdqbert/modeling_qdqbert.py @@ -39,7 +39,7 @@ TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer, torch_custom_checkpointing from ...utils import ( add_code_sample_docstrings, add_start_docstrings, @@ -586,7 +586,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/realm/modeling_realm.py b/src/transformers/models/realm/modeling_realm.py index f68fc04105de6a..09d6bb7325b4ec 100644 --- a/src/transformers/models/realm/modeling_realm.py +++ b/src/transformers/models/realm/modeling_realm.py @@ -31,7 +31,12 @@ ModelOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from .configuration_realm import RealmConfig @@ -591,7 +596,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/rembert/modeling_rembert.py b/src/transformers/models/rembert/modeling_rembert.py index da4ad9608514c7..06da821d4dd37a 100755 --- a/src/transformers/models/rembert/modeling_rembert.py +++ b/src/transformers/models/rembert/modeling_rembert.py @@ -36,7 +36,12 @@ TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import ( add_code_sample_docstrings, add_start_docstrings, @@ -548,7 +553,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/retribert/modeling_retribert.py b/src/transformers/models/retribert/modeling_retribert.py index 240d9476e70b01..e1397d39ceae2c 100644 --- a/src/transformers/models/retribert/modeling_retribert.py +++ b/src/transformers/models/retribert/modeling_retribert.py @@ -21,10 +21,10 @@ from typing import Optional import torch -import torch.utils.checkpoint as checkpoint from torch import nn from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import add_start_docstrings, logging from ..bert.modeling_bert import BertModel from .configuration_retribert import RetriBertConfig @@ -141,7 +141,7 @@ def partial_encode(*inputs): for b in range(math.ceil(input_ids.shape[0] / checkpoint_batch_size)): b_embedding_output = embedding_output[b * checkpoint_batch_size : (b + 1) * checkpoint_batch_size] b_attention_mask = extended_attention_mask[b * checkpoint_batch_size : (b + 1) * checkpoint_batch_size] - pooled_output = checkpoint.checkpoint(partial_encode, b_embedding_output, b_attention_mask) + pooled_output = torch_custom_checkpointing(partial_encode, b_embedding_output, b_attention_mask) pooled_output_list.append(pooled_output) return torch.cat(pooled_output_list, dim=0) diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index b0f13692460166..f86fa4aa80820c 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -35,7 +35,12 @@ TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import ( add_code_sample_docstrings, add_start_docstrings, @@ -515,7 +520,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py index b1e02e27f13890..01276cd07119a6 100644 --- a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +++ b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py @@ -35,7 +35,12 @@ TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import ( add_code_sample_docstrings, add_start_docstrings, @@ -517,7 +522,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/roc_bert/modeling_roc_bert.py b/src/transformers/models/roc_bert/modeling_roc_bert.py index 7647c14a9ea3d4..63abc9d4aa1876 100644 --- a/src/transformers/models/roc_bert/modeling_roc_bert.py +++ b/src/transformers/models/roc_bert/modeling_roc_bert.py @@ -35,7 +35,12 @@ TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import ( add_code_sample_docstrings, add_start_docstrings, @@ -649,7 +654,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/roformer/modeling_roformer.py b/src/transformers/models/roformer/modeling_roformer.py index b966bf4490a9fd..586ecbd2dad690 100644 --- a/src/transformers/models/roformer/modeling_roformer.py +++ b/src/transformers/models/roformer/modeling_roformer.py @@ -36,7 +36,12 @@ TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel, SequenceSummary -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import ( add_code_sample_docstrings, add_start_docstrings, @@ -585,7 +590,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/sam/modeling_sam.py b/src/transformers/models/sam/modeling_sam.py index c3cbaa9176f0bf..0e4177d90b0a1a 100644 --- a/src/transformers/models/sam/modeling_sam.py +++ b/src/transformers/models/sam/modeling_sam.py @@ -28,6 +28,7 @@ from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging from .configuration_sam import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig @@ -1049,7 +1050,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, ) diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py index dd854c49f5c9d2..75ad1f97dffe8c 100644 --- a/src/transformers/models/sew/modeling_sew.py +++ b/src/transformers/models/sew/modeling_sew.py @@ -28,6 +28,7 @@ from ...deepspeed import is_deepspeed_zero3_enabled from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from .configuration_sew import SEWConfig @@ -367,7 +368,7 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint( + hidden_states = torch_custom_checkpointing( create_custom_forward(conv_layer), hidden_states, ) @@ -680,7 +681,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer), hidden_states, attention_mask, diff --git a/src/transformers/models/sew_d/modeling_sew_d.py b/src/transformers/models/sew_d/modeling_sew_d.py index 7f7c1977d69248..b7acb306bb91c5 100644 --- a/src/transformers/models/sew_d/modeling_sew_d.py +++ b/src/transformers/models/sew_d/modeling_sew_d.py @@ -29,7 +29,7 @@ from ...deepspeed import is_deepspeed_zero3_enabled from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import softmax_backward_data +from ...pytorch_utils import softmax_backward_data, torch_custom_checkpointing from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from .configuration_sew_d import SEWDConfig @@ -460,7 +460,7 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint( + hidden_states = torch_custom_checkpointing( create_custom_forward(conv_layer), hidden_states, ) @@ -1141,7 +1141,7 @@ def custom_forward(*inputs): return custom_forward - output_states = torch.utils.checkpoint.checkpoint( + output_states = torch_custom_checkpointing( create_custom_forward(layer_module), next_kv, attention_mask, diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index d8a19084eb3847..3e2024dc69ecf1 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -31,6 +31,7 @@ Seq2SeqModelOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from .configuration_speech_to_text import Speech2TextConfig @@ -820,7 +821,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -1068,7 +1069,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py b/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py index c13b04642d9d54..12e8d4592adb65 100755 --- a/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py +++ b/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py @@ -27,6 +27,7 @@ from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import add_start_docstrings, logging, replace_return_docstrings from .configuration_speech_to_text_2 import Speech2Text2Config @@ -677,7 +678,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/speecht5/modeling_speecht5.py b/src/transformers/models/speecht5/modeling_speecht5.py index 3e8ce5a23b7e6b..5988607f1cb4f2 100644 --- a/src/transformers/models/speecht5/modeling_speecht5.py +++ b/src/transformers/models/speecht5/modeling_speecht5.py @@ -35,6 +35,7 @@ Seq2SeqSpectrogramOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from .configuration_speecht5 import SpeechT5Config, SpeechT5HifiGanConfig @@ -528,7 +529,7 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint( + hidden_states = torch_custom_checkpointing( create_custom_forward(conv_layer), hidden_states, ) @@ -1394,7 +1395,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -1723,7 +1724,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/splinter/modeling_splinter.py b/src/transformers/models/splinter/modeling_splinter.py index 6e636fb695daef..88d6a480b70557 100755 --- a/src/transformers/models/splinter/modeling_splinter.py +++ b/src/transformers/models/splinter/modeling_splinter.py @@ -27,7 +27,12 @@ from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, ModelOutput, QuestionAnsweringModelOutput from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from .configuration_splinter import SplinterConfig @@ -464,7 +469,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index b324cfdcd9354c..93144c66a9134a 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -29,7 +29,7 @@ from ...activations import ACT2FN from ...modeling_outputs import BackboneOutput from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer +from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer, torch_custom_checkpointing from ...utils import ( ModelOutput, add_code_sample_docstrings, @@ -832,7 +832,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, input_dimensions, layer_head_mask ) else: diff --git a/src/transformers/models/swin2sr/modeling_swin2sr.py b/src/transformers/models/swin2sr/modeling_swin2sr.py index cd58b706505865..6b1b803345557e 100644 --- a/src/transformers/models/swin2sr/modeling_swin2sr.py +++ b/src/transformers/models/swin2sr/modeling_swin2sr.py @@ -27,7 +27,7 @@ from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, ImageSuperResolutionOutput from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer +from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer, torch_custom_checkpointing from ...utils import ( ModelOutput, add_code_sample_docstrings, @@ -753,7 +753,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(stage_module), hidden_states, input_dimensions, layer_head_mask ) else: diff --git a/src/transformers/models/swinv2/modeling_swinv2.py b/src/transformers/models/swinv2/modeling_swinv2.py index 97b460479d6d5d..07dd0a79b7ae5a 100644 --- a/src/transformers/models/swinv2/modeling_swinv2.py +++ b/src/transformers/models/swinv2/modeling_swinv2.py @@ -28,7 +28,7 @@ from ...activations import ACT2FN from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer +from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer, torch_custom_checkpointing from ...utils import ( ModelOutput, add_code_sample_docstrings, @@ -908,7 +908,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, input_dimensions, layer_head_mask ) else: diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 008e23531ac1a9..1378ec9a98d5ac 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -23,7 +23,6 @@ import torch import torch.nn as nn from torch.nn import CrossEntropyLoss -from torch.utils.checkpoint import checkpoint from ...activations import ACT2FN from ...modeling_outputs import ( @@ -33,7 +32,12 @@ Seq2SeqMoEOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import ALL_LAYERNORM_LAYERS, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + ALL_LAYERNORM_LAYERS, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import ( DUMMY_INPUTS, DUMMY_MASK, @@ -1075,7 +1079,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, extended_attention_mask, diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 050309fa9a3367..4531214e19cd16 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -24,7 +24,6 @@ import torch from torch import nn from torch.nn import CrossEntropyLoss -from torch.utils.checkpoint import checkpoint from ...activations import ACT2FN from ...modeling_outputs import ( @@ -34,7 +33,12 @@ Seq2SeqModelOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import ALL_LAYERNORM_LAYERS, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + ALL_LAYERNORM_LAYERS, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import ( DUMMY_INPUTS, DUMMY_MASK, @@ -1074,7 +1078,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, extended_attention_mask, diff --git a/src/transformers/models/table_transformer/modeling_table_transformer.py b/src/transformers/models/table_transformer/modeling_table_transformer.py index 733ff7b9b453df..998f21a286109b 100644 --- a/src/transformers/models/table_transformer/modeling_table_transformer.py +++ b/src/transformers/models/table_transformer/modeling_table_transformer.py @@ -26,6 +26,7 @@ from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( ModelOutput, add_start_docstrings, @@ -1074,7 +1075,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, combined_attention_mask, diff --git a/src/transformers/models/tapas/modeling_tapas.py b/src/transformers/models/tapas/modeling_tapas.py index 1621653f3ee08b..4f736a367e3022 100644 --- a/src/transformers/models/tapas/modeling_tapas.py +++ b/src/transformers/models/tapas/modeling_tapas.py @@ -34,6 +34,7 @@ find_pruneable_heads_and_indices, is_torch_greater_or_equal_than_1_12, prune_linear_layer, + torch_custom_checkpointing, ) from ...utils import ( ModelOutput, @@ -653,7 +654,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py index 8986ef6729caaf..e3e0b3055d8b78 100644 --- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -31,6 +31,7 @@ Seq2SeqTSPredictionOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from .configuration_time_series_transformer import TimeSeriesTransformerConfig @@ -949,7 +950,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -1166,7 +1167,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/timesformer/modeling_timesformer.py b/src/transformers/models/timesformer/modeling_timesformer.py index 9f886b6ece5371..5ff5bd7fd19d89 100644 --- a/src/transformers/models/timesformer/modeling_timesformer.py +++ b/src/transformers/models/timesformer/modeling_timesformer.py @@ -27,6 +27,7 @@ from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from .configuration_timesformer import TimesformerConfig @@ -446,7 +447,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, ) diff --git a/src/transformers/models/trajectory_transformer/modeling_trajectory_transformer.py b/src/transformers/models/trajectory_transformer/modeling_trajectory_transformer.py index e8ecedccb5ea50..1027bd73f3fe96 100644 --- a/src/transformers/models/trajectory_transformer/modeling_trajectory_transformer.py +++ b/src/transformers/models/trajectory_transformer/modeling_trajectory_transformer.py @@ -26,6 +26,7 @@ from torch.nn import functional as F from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( ModelOutput, add_start_docstrings, @@ -556,7 +557,7 @@ def custom_forward(*inputs): return custom_forward - outputs = torch.utils.checkpoint.checkpoint( + outputs = torch_custom_checkpointing( create_custom_forward(block), hidden_states, layer_past, diff --git a/src/transformers/models/trocr/modeling_trocr.py b/src/transformers/models/trocr/modeling_trocr.py index 6276c68a425d10..e8ee10f7defdb3 100644 --- a/src/transformers/models/trocr/modeling_trocr.py +++ b/src/transformers/models/trocr/modeling_trocr.py @@ -27,6 +27,7 @@ from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import add_start_docstrings, logging, replace_return_docstrings from .configuration_trocr import TrOCRConfig @@ -709,7 +710,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/tvlt/modeling_tvlt.py b/src/transformers/models/tvlt/modeling_tvlt.py index 3725c5e7728be9..4b990cdb03ebee 100644 --- a/src/transformers/models/tvlt/modeling_tvlt.py +++ b/src/transformers/models/tvlt/modeling_tvlt.py @@ -29,7 +29,7 @@ from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, SequenceClassifierOutput from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer, torch_custom_checkpointing from ...utils import ( ModelOutput, add_start_docstrings, @@ -567,7 +567,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -884,7 +884,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, None, diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index e068fa59e5792e..5bd1af95c75abe 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -29,6 +29,7 @@ from ...deepspeed import is_deepspeed_zero3_enabled from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput, Wav2Vec2BaseModelOutput from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( ModelOutput, add_code_sample_docstrings, @@ -391,7 +392,7 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint( + hidden_states = torch_custom_checkpointing( create_custom_forward(conv_layer), hidden_states, ) @@ -774,7 +775,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer), hidden_states, attention_mask, @@ -864,7 +865,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer), hidden_states, attention_mask, diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py index 2ed8a5d57204e7..f603d2712f6d25 100755 --- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -36,6 +36,7 @@ XVectorOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( ModelOutput, add_code_sample_docstrings, @@ -405,7 +406,7 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint( + hidden_states = torch_custom_checkpointing( create_custom_forward(conv_layer), hidden_states, ) @@ -788,7 +789,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer), hidden_states, attention_mask, @@ -878,7 +879,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer), hidden_states, attention_mask, diff --git a/src/transformers/models/videomae/modeling_videomae.py b/src/transformers/models/videomae/modeling_videomae.py index c62d0c4632cb68..5f44a5e4b3aade 100644 --- a/src/transformers/models/videomae/modeling_videomae.py +++ b/src/transformers/models/videomae/modeling_videomae.py @@ -30,7 +30,7 @@ from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer, torch_custom_checkpointing from ...utils import ( ModelOutput, add_start_docstrings, @@ -441,7 +441,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, layer_head_mask, @@ -724,7 +724,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, None, diff --git a/src/transformers/models/vilt/modeling_vilt.py b/src/transformers/models/vilt/modeling_vilt.py index 6ee1e396a625e3..5499a26cc748ab 100755 --- a/src/transformers/models/vilt/modeling_vilt.py +++ b/src/transformers/models/vilt/modeling_vilt.py @@ -38,6 +38,7 @@ find_pruneable_heads_and_indices, meshgrid, prune_linear_layer, + torch_custom_checkpointing, ) from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from .configuration_vilt import ViltConfig @@ -536,7 +537,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/visual_bert/modeling_visual_bert.py b/src/transformers/models/visual_bert/modeling_visual_bert.py index 0bef6e4af9d918..a73d6ac720725b 100755 --- a/src/transformers/models/visual_bert/modeling_visual_bert.py +++ b/src/transformers/models/visual_bert/modeling_visual_bert.py @@ -32,7 +32,12 @@ SequenceClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import ( ModelOutput, add_start_docstrings, @@ -423,7 +428,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/vit/modeling_vit.py b/src/transformers/models/vit/modeling_vit.py index bfd440caae2b01..28ea5740ca59f6 100644 --- a/src/transformers/models/vit/modeling_vit.py +++ b/src/transformers/models/vit/modeling_vit.py @@ -32,7 +32,7 @@ MaskedImageModelingOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer, torch_custom_checkpointing from ...utils import ( add_code_sample_docstrings, add_start_docstrings, @@ -404,7 +404,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, layer_head_mask, diff --git a/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py b/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py index 051d431946a852..ba3bbddf563b0d 100644 --- a/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py +++ b/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py @@ -27,7 +27,7 @@ from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer, torch_custom_checkpointing from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from ..auto import AutoBackbone from .configuration_vit_hybrid import ViTHybridConfig @@ -422,7 +422,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, layer_head_mask, diff --git a/src/transformers/models/vit_mae/modeling_vit_mae.py b/src/transformers/models/vit_mae/modeling_vit_mae.py index ef0c7c9f36869e..5a9c539fbc0b47 100755 --- a/src/transformers/models/vit_mae/modeling_vit_mae.py +++ b/src/transformers/models/vit_mae/modeling_vit_mae.py @@ -29,7 +29,7 @@ from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer, torch_custom_checkpointing from ...utils import ( ModelOutput, add_start_docstrings, @@ -543,7 +543,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, layer_head_mask, @@ -800,7 +800,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, None, diff --git a/src/transformers/models/vit_msn/modeling_vit_msn.py b/src/transformers/models/vit_msn/modeling_vit_msn.py index 46639e7d622cb7..4f7b412fecb83b 100644 --- a/src/transformers/models/vit_msn/modeling_vit_msn.py +++ b/src/transformers/models/vit_msn/modeling_vit_msn.py @@ -27,7 +27,7 @@ from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer, torch_custom_checkpointing from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from .configuration_vit_msn import ViTMSNConfig @@ -394,7 +394,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, layer_head_mask, diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index 43ab2408bb2309..9705a51e488a09 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -37,6 +37,7 @@ XVectorOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( ModelOutput, add_code_sample_docstrings, @@ -458,7 +459,7 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint( + hidden_states = torch_custom_checkpointing( create_custom_forward(conv_layer), hidden_states, ) @@ -810,7 +811,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer), hidden_states, attention_mask, @@ -899,7 +900,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer), hidden_states, attention_mask, diff --git a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py index 3e37a4a504b0b4..86c0cbe5e2d6cb 100644 --- a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +++ b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py @@ -35,6 +35,7 @@ XVectorOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( ModelOutput, add_code_sample_docstrings, @@ -523,7 +524,7 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint( + hidden_states = torch_custom_checkpointing( create_custom_forward(conv_layer), hidden_states, ) @@ -916,7 +917,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer), hidden_states, attention_mask, diff --git a/src/transformers/models/wavlm/modeling_wavlm.py b/src/transformers/models/wavlm/modeling_wavlm.py index e4072d93724fd8..35dc46bac1f942 100755 --- a/src/transformers/models/wavlm/modeling_wavlm.py +++ b/src/transformers/models/wavlm/modeling_wavlm.py @@ -36,6 +36,7 @@ XVectorOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from .configuration_wavlm import WavLMConfig @@ -361,7 +362,7 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint( + hidden_states = torch_custom_checkpointing( create_custom_forward(conv_layer), hidden_states, ) @@ -720,7 +721,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer), hidden_states, attention_mask, @@ -811,7 +812,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer), hidden_states, attention_mask, diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 515b886f98db62..703607ad4a69ce 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -34,6 +34,7 @@ SequenceClassifierOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, @@ -853,7 +854,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, None, @@ -1085,7 +1086,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/x_clip/modeling_x_clip.py b/src/transformers/models/x_clip/modeling_x_clip.py index 8db4ee0fd19480..be6b2818900330 100644 --- a/src/transformers/models/x_clip/modeling_x_clip.py +++ b/src/transformers/models/x_clip/modeling_x_clip.py @@ -26,6 +26,7 @@ from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( ModelOutput, add_start_docstrings, @@ -708,7 +709,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -955,7 +956,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/xglm/modeling_xglm.py b/src/transformers/models/xglm/modeling_xglm.py index 4a72b785a02412..61b51d51fcbcfb 100755 --- a/src/transformers/models/xglm/modeling_xglm.py +++ b/src/transformers/models/xglm/modeling_xglm.py @@ -27,6 +27,7 @@ from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from .configuration_xglm import XGLMConfig @@ -683,7 +684,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py b/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py index 2d14bfb6a7b548..fd90086672cf40 100644 --- a/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py +++ b/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py @@ -29,6 +29,7 @@ from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import ( ModelOutput, add_start_docstrings, @@ -1356,7 +1357,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, extended_attention_mask, @@ -1600,7 +1601,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, extended_attention_mask, diff --git a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py index ae8d51a3f8eb63..e00574239c9296 100644 --- a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py @@ -35,7 +35,12 @@ TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import ( add_code_sample_docstrings, add_start_docstrings, @@ -516,7 +521,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py index fb86717e1d7fa4..71f8de5a7277b9 100644 --- a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py @@ -34,7 +34,12 @@ TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import ( add_code_sample_docstrings, add_start_docstrings, @@ -504,7 +509,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/models/xmod/modeling_xmod.py b/src/transformers/models/xmod/modeling_xmod.py index d99b77fedda38f..44e50bed3b21ca 100644 --- a/src/transformers/models/xmod/modeling_xmod.py +++ b/src/transformers/models/xmod/modeling_xmod.py @@ -34,7 +34,12 @@ TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging from .configuration_xmod import XmodConfig @@ -578,7 +583,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, lang_ids, diff --git a/src/transformers/models/yolos/modeling_yolos.py b/src/transformers/models/yolos/modeling_yolos.py index e3cb02ceae6ec0..4b4aa012416780 100755 --- a/src/transformers/models/yolos/modeling_yolos.py +++ b/src/transformers/models/yolos/modeling_yolos.py @@ -27,7 +27,7 @@ from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer, torch_custom_checkpointing from ...utils import ( ModelOutput, add_code_sample_docstrings, @@ -499,7 +499,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, layer_head_mask, diff --git a/src/transformers/models/yoso/modeling_yoso.py b/src/transformers/models/yoso/modeling_yoso.py index 8c2ff9fa4e0753..1b1e6b13add8a3 100644 --- a/src/transformers/models/yoso/modeling_yoso.py +++ b/src/transformers/models/yoso/modeling_yoso.py @@ -34,7 +34,12 @@ TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, + torch_custom_checkpointing, +) from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from .configuration_yoso import YosoConfig @@ -566,7 +571,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, diff --git a/src/transformers/pytorch_utils.py b/src/transformers/pytorch_utils.py index 4723c43035e67c..520eeb89393aea 100644 --- a/src/transformers/pytorch_utils.py +++ b/src/transformers/pytorch_utils.py @@ -285,3 +285,18 @@ def id_tensor_storage(tensor: torch.Tensor) -> Tuple[torch.device, int, int]: non-overlapping lifetimes may have the same id. """ return tensor.device, storage_ptr(tensor), storage_size(tensor) + + +def torch_custom_checkpointing(*args): + r""" + A correct usage of `torch.utils.checkpoint.checkpoint` as the default call leads to silent bugs that leads to the + gradients of the last layers not being updated. For more in depth detail of the issue, please have a look at: + https://github.com/huggingface/transformers/pull/24247 + """ + kwargs = {} + if "use_reentrant" in list(inspect.signature(torch.utils.checkpoint.checkpoint).parameters): + kwargs["use_reentrant"] = False + return torch.utils.checkpoint.checkpoint( + *args, + **kwargs, + ) diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py index 4899e195986fd2..c5d141b1f4b839 100755 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py @@ -43,6 +43,7 @@ TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel, SequenceSummary +from ...pytorch_utils import torch_custom_checkpointing from ...pytorch_utils import ( apply_chunking_to_forward, find_pruneable_heads_and_indices, @@ -550,7 +551,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(layer_module), hidden_states, attention_mask, @@ -1585,6 +1586,7 @@ def forward( CausalLMOutputWithCrossAttentions ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import torch_custom_checkpointing from ...utils import logging from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config @@ -2318,7 +2320,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(encoder_layer), hidden_states, attention_mask, @@ -2557,7 +2559,7 @@ def custom_forward(*inputs): return custom_forward - layer_outputs = torch.utils.checkpoint.checkpoint( + layer_outputs = torch_custom_checkpointing( create_custom_forward(decoder_layer), hidden_states, attention_mask, diff --git a/tests/models/align/test_modeling_align.py b/tests/models/align/test_modeling_align.py index 2357c20e213a8d..c8ac69840f77f7 100644 --- a/tests/models/align/test_modeling_align.py +++ b/tests/models/align/test_modeling_align.py @@ -352,6 +352,12 @@ def test_training(self): def test_training_gradient_checkpointing(self): pass + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + @unittest.skip(reason="ALIGN does not use inputs_embeds") def test_inputs_embeds(self): pass diff --git a/tests/models/altclip/test_modeling_altclip.py b/tests/models/altclip/test_modeling_altclip.py index 28213de84df63c..266e0c47b6bba6 100755 --- a/tests/models/altclip/test_modeling_altclip.py +++ b/tests/models/altclip/test_modeling_altclip.py @@ -186,6 +186,12 @@ def test_training(self): def test_training_gradient_checkpointing(self): pass + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + @unittest.skip(reason="AltCLIPVisionModel has no base class and is not available in MODEL_MAPPING") def test_save_load_fast_init_from_base(self): pass diff --git a/tests/models/autoformer/test_modeling_autoformer.py b/tests/models/autoformer/test_modeling_autoformer.py index 9f0434689c4be8..ad006d9d0794ce 100644 --- a/tests/models/autoformer/test_modeling_autoformer.py +++ b/tests/models/autoformer/test_modeling_autoformer.py @@ -238,6 +238,12 @@ def test_encoder_decoder_model_standalone(self): def test_resize_tokens_embeddings(self): pass + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + # # Input is 'static_categorical_features' not 'input_ids' def test_model_main_input_name(self): model_signature = inspect.signature(getattr(AutoformerModel, "forward")) diff --git a/tests/models/beit/test_modeling_beit.py b/tests/models/beit/test_modeling_beit.py index f9aa7339f7e0c2..149820023a6940 100644 --- a/tests/models/beit/test_modeling_beit.py +++ b/tests/models/beit/test_modeling_beit.py @@ -227,6 +227,12 @@ def test_inputs_embeds(self): def test_multi_gpu_data_parallel_forward(self): pass + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + def test_model_common_attributes(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/big_bird/test_modeling_big_bird.py b/tests/models/big_bird/test_modeling_big_bird.py index f86c6d0ac70ab8..45bff430bfdb8e 100644 --- a/tests/models/big_bird/test_modeling_big_bird.py +++ b/tests/models/big_bird/test_modeling_big_bird.py @@ -609,6 +609,12 @@ def test_for_change_to_full_attn(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_change_to_full_attn(*config_and_inputs) + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + # overwrite from common in order to skip the check on `attentions` def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=1e-5, name="outputs", attributes=None): # `bigbird_block_sparse_attention` in `FlaxBigBird` returns `attention_probs = None`, while in PyTorch version, diff --git a/tests/models/blip/test_modeling_blip.py b/tests/models/blip/test_modeling_blip.py index 7d9c6b5ba58b05..a34efc026474d9 100644 --- a/tests/models/blip/test_modeling_blip.py +++ b/tests/models/blip/test_modeling_blip.py @@ -789,6 +789,12 @@ def test_retain_grad_hidden_states_attentions(self): def test_model_common_attributes(self): pass + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + def test_forward_signature(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/canine/test_modeling_canine.py b/tests/models/canine/test_modeling_canine.py index d612a02bf47c67..6e6d7ce3836a28 100644 --- a/tests/models/canine/test_modeling_canine.py +++ b/tests/models/canine/test_modeling_canine.py @@ -499,6 +499,12 @@ def test_inputs_embeds(self): # ViT does not use inputs_embeds pass + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + @unittest.skip("CANINE does not have a get_input_embeddings() method.") def test_model_common_attributes(self): pass diff --git a/tests/models/chinese_clip/test_modeling_chinese_clip.py b/tests/models/chinese_clip/test_modeling_chinese_clip.py index 57f532da863515..cf2668f4d8b5ca 100644 --- a/tests/models/chinese_clip/test_modeling_chinese_clip.py +++ b/tests/models/chinese_clip/test_modeling_chinese_clip.py @@ -395,6 +395,12 @@ def test_training(self): def test_training_gradient_checkpointing(self): pass + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + @unittest.skip(reason="ChineseCLIPTextModel has no base class and is not available in MODEL_MAPPING") def test_save_load_fast_init_from_base(self): pass @@ -469,6 +475,12 @@ def test_save_load_fast_init_from_base(self): def test_save_load_fast_init_to_base(self): pass + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + @slow def test_model_from_pretrained(self): for model_name in CHINESE_CLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: diff --git a/tests/models/clip/test_modeling_clip.py b/tests/models/clip/test_modeling_clip.py index d16241ab2f22a0..82592d8452f5f8 100644 --- a/tests/models/clip/test_modeling_clip.py +++ b/tests/models/clip/test_modeling_clip.py @@ -227,6 +227,12 @@ def test_training(self): def test_training_gradient_checkpointing(self): pass + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + @unittest.skip(reason="CLIPVisionModel has no base class and is not available in MODEL_MAPPING") def test_save_load_fast_init_from_base(self): pass diff --git a/tests/models/clipseg/test_modeling_clipseg.py b/tests/models/clipseg/test_modeling_clipseg.py index b54861d8d8d045..387a2e1c8f3454 100644 --- a/tests/models/clipseg/test_modeling_clipseg.py +++ b/tests/models/clipseg/test_modeling_clipseg.py @@ -202,6 +202,12 @@ def test_training(self): def test_training_gradient_checkpointing(self): pass + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + @unittest.skip(reason="CLIPSegVisionModel has no base class and is not available in MODEL_MAPPING") def test_save_load_fast_init_from_base(self): pass @@ -448,6 +454,12 @@ def test_model_for_image_segmentation(self): def test_hidden_states_output(self): pass + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + @unittest.skip(reason="Inputs_embeds is tested in individual model tests") def test_inputs_embeds(self): pass diff --git a/tests/models/data2vec/test_modeling_data2vec_vision.py b/tests/models/data2vec/test_modeling_data2vec_vision.py index b4c391fea17e64..90786de24978f2 100644 --- a/tests/models/data2vec/test_modeling_data2vec_vision.py +++ b/tests/models/data2vec/test_modeling_data2vec_vision.py @@ -310,6 +310,12 @@ def test_for_image_classification(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_image_classification(*config_and_inputs) + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + @slow def test_model_from_pretrained(self): for model_name in DATA2VEC_VISION_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: diff --git a/tests/models/dpt/test_modeling_dpt.py b/tests/models/dpt/test_modeling_dpt.py index 76790ee795026e..5889653991cac2 100644 --- a/tests/models/dpt/test_modeling_dpt.py +++ b/tests/models/dpt/test_modeling_dpt.py @@ -182,6 +182,12 @@ def test_config(self): def test_inputs_embeds(self): pass + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + def test_model_common_attributes(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/dpt/test_modeling_dpt_hybrid.py b/tests/models/dpt/test_modeling_dpt_hybrid.py index 6d4a75c80da120..04ba8c0289bed0 100644 --- a/tests/models/dpt/test_modeling_dpt_hybrid.py +++ b/tests/models/dpt/test_modeling_dpt_hybrid.py @@ -196,6 +196,12 @@ def test_config(self): def test_inputs_embeds(self): pass + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + def test_model_common_attributes(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/flava/test_modeling_flava.py b/tests/models/flava/test_modeling_flava.py index 2544b7ee93f6ca..b6d71f33a684fd 100644 --- a/tests/models/flava/test_modeling_flava.py +++ b/tests/models/flava/test_modeling_flava.py @@ -185,6 +185,12 @@ def test_inputs_embeds(self): # FLAVA does not use inputs_embeds pass + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + def test_model_common_attributes(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() @@ -462,6 +468,12 @@ def test_inputs_embeds(self): # FLAVA does not use inputs_embeds pass + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + # skip this test as FlavaTextModel has no base class and is # not available in MODEL_MAPPING def test_save_load_fast_init_from_base(self): @@ -624,6 +636,12 @@ def test_save_load_fast_init_from_base(self): def test_save_load_fast_init_to_base(self): pass + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + @slow def test_model_from_pretrained(self): for model_name in FLAVA_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: @@ -731,6 +749,12 @@ def test_save_load_fast_init_from_base(self): def test_save_load_fast_init_to_base(self): pass + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + @slow def test_model_from_pretrained(self): for model_name in FLAVA_CODEBOOK_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: @@ -1156,6 +1180,12 @@ class FlavaForPreTrainingTest(FlavaModelTest): class_for_tester = FlavaForPreTrainingTester test_torchscript = False + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + # We will verify our results on an image of cute cats def prepare_img(): diff --git a/tests/models/fnet/test_modeling_fnet.py b/tests/models/fnet/test_modeling_fnet.py index e7e592d5b62ff5..96821842736522 100644 --- a/tests/models/fnet/test_modeling_fnet.py +++ b/tests/models/fnet/test_modeling_fnet.py @@ -444,6 +444,12 @@ def test_for_token_classification(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_token_classification(*config_and_inputs) + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + @slow def test_model_from_pretrained(self): for model_name in FNET_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: diff --git a/tests/models/gpt2/test_modeling_gpt2.py b/tests/models/gpt2/test_modeling_gpt2.py index 65542b49549742..620bb30b265713 100644 --- a/tests/models/gpt2/test_modeling_gpt2.py +++ b/tests/models/gpt2/test_modeling_gpt2.py @@ -562,6 +562,12 @@ def test_gpt2_weight_initialization(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_gpt2_weight_initialization(*config_and_inputs) + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + @slow def test_batch_generation(self): model = GPT2LMHeadModel.from_pretrained("gpt2") diff --git a/tests/models/graphormer/test_modeling_graphormer.py b/tests/models/graphormer/test_modeling_graphormer.py index e874ebf0f44a2b..f1c63729e00063 100644 --- a/tests/models/graphormer/test_modeling_graphormer.py +++ b/tests/models/graphormer/test_modeling_graphormer.py @@ -356,6 +356,12 @@ def test_inputs_embeds(self): def test_feed_forward_chunking(self): pass + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + @unittest.skip(reason="Graphormer does not share input and output embeddings") def test_model_common_attributes(self): pass diff --git a/tests/models/imagegpt/test_modeling_imagegpt.py b/tests/models/imagegpt/test_modeling_imagegpt.py index 27d83f3eb8c1e9..1f4ea02f8d2002 100644 --- a/tests/models/imagegpt/test_modeling_imagegpt.py +++ b/tests/models/imagegpt/test_modeling_imagegpt.py @@ -304,6 +304,12 @@ def setUp(self): def test_config(self): self.config_tester.run_common_tests() + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + def test_imagegpt_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_imagegpt_model(*config_and_inputs) diff --git a/tests/models/informer/test_modeling_informer.py b/tests/models/informer/test_modeling_informer.py index f3c8539d845049..2202d62242cab9 100644 --- a/tests/models/informer/test_modeling_informer.py +++ b/tests/models/informer/test_modeling_informer.py @@ -216,6 +216,12 @@ def test_encoder_decoder_model_standalone(self): config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common() self.model_tester.check_encoder_decoder_model_standalone(*config_and_inputs) + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + def test_hidden_states_output(self): def check_hidden_states_output(inputs_dict, config, model_class): model = model_class(config) diff --git a/tests/models/layoutlm/test_modeling_layoutlm.py b/tests/models/layoutlm/test_modeling_layoutlm.py index 0535fbf4e1f4c8..b88d0c4b50d87a 100644 --- a/tests/models/layoutlm/test_modeling_layoutlm.py +++ b/tests/models/layoutlm/test_modeling_layoutlm.py @@ -279,6 +279,12 @@ def test_for_question_answering(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_question_answering(*config_and_inputs) + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + def prepare_layoutlm_batch_inputs(): # Here we prepare a batch of 2 sequences to test a LayoutLM forward pass on: diff --git a/tests/models/lilt/test_modeling_lilt.py b/tests/models/lilt/test_modeling_lilt.py index 1bb92300c3db91..4032504b8b2587 100644 --- a/tests/models/lilt/test_modeling_lilt.py +++ b/tests/models/lilt/test_modeling_lilt.py @@ -275,6 +275,12 @@ def test_for_question_answering(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_question_answering(*config_and_inputs) + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + @slow def test_model_from_pretrained(self): for model_name in LILT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: diff --git a/tests/models/luke/test_modeling_luke.py b/tests/models/luke/test_modeling_luke.py index 35bdb6b6d5fa6a..4e1ef3d173b47c 100644 --- a/tests/models/luke/test_modeling_luke.py +++ b/tests/models/luke/test_modeling_luke.py @@ -697,6 +697,12 @@ def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + @slow def test_model_from_pretrained(self): for model_name in LUKE_PRETRAINED_MODEL_ARCHIVE_LIST: diff --git a/tests/models/marian/test_modeling_marian.py b/tests/models/marian/test_modeling_marian.py index 6cbcd55d3f7687..933383d2929ad9 100644 --- a/tests/models/marian/test_modeling_marian.py +++ b/tests/models/marian/test_modeling_marian.py @@ -263,6 +263,12 @@ def setUp(self): def test_config(self): self.config_tester.run_common_tests() + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + def test_save_load_strict(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs() for model_class in self.all_model_classes: diff --git a/tests/models/owlvit/test_modeling_owlvit.py b/tests/models/owlvit/test_modeling_owlvit.py index acf078ffe80075..83fb86ba0e9319 100644 --- a/tests/models/owlvit/test_modeling_owlvit.py +++ b/tests/models/owlvit/test_modeling_owlvit.py @@ -155,6 +155,12 @@ def test_config(self): def test_inputs_embeds(self): pass + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + def test_model_common_attributes(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() @@ -633,6 +639,12 @@ def test_training(self): def test_training_gradient_checkpointing(self): pass + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + def _create_and_check_torchscript(self, config, inputs_dict): if not self.test_torchscript: return diff --git a/tests/models/pegasus/test_modeling_pegasus.py b/tests/models/pegasus/test_modeling_pegasus.py index bde7477f945040..1f409d1b004bfe 100644 --- a/tests/models/pegasus/test_modeling_pegasus.py +++ b/tests/models/pegasus/test_modeling_pegasus.py @@ -280,6 +280,12 @@ def test_encoder_decoder_model_standalone(self): config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common() self.model_tester.check_encoder_decoder_model_standalone(*config_and_inputs) + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + def test_generate_fp16(self): config, input_dict = self.model_tester.prepare_config_and_inputs() input_ids = input_dict["input_ids"] diff --git a/tests/models/pix2struct/test_modeling_pix2struct.py b/tests/models/pix2struct/test_modeling_pix2struct.py index 8ec023676d6327..1eba4cb10c287f 100644 --- a/tests/models/pix2struct/test_modeling_pix2struct.py +++ b/tests/models/pix2struct/test_modeling_pix2struct.py @@ -332,6 +332,12 @@ def test_model(self): def test_training(self): pass + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + @unittest.skip(reason="Training is tested directly on `Pix2StructTextImageModelTest`") def test_training_gradient_checkpointing(self): pass diff --git a/tests/models/regnet/test_modeling_regnet.py b/tests/models/regnet/test_modeling_regnet.py index e7c33699fda7db..9b260845287b10 100644 --- a/tests/models/regnet/test_modeling_regnet.py +++ b/tests/models/regnet/test_modeling_regnet.py @@ -161,6 +161,12 @@ def test_inputs_embeds(self): def test_model_common_attributes(self): pass + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + def test_forward_signature(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/roformer/test_modeling_roformer.py b/tests/models/roformer/test_modeling_roformer.py index 357e126a047a4d..6d54b7c1286e50 100644 --- a/tests/models/roformer/test_modeling_roformer.py +++ b/tests/models/roformer/test_modeling_roformer.py @@ -452,6 +452,12 @@ def test_model_as_decoder(self): config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() self.model_tester.create_and_check_model_as_decoder(*config_and_inputs) + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + def test_model_as_decoder_with_default_input_mask(self): # This regression test was failing with PyTorch < 1.3 ( diff --git a/tests/models/sam/test_modeling_sam.py b/tests/models/sam/test_modeling_sam.py index a0f39a40135577..8a4772138647c8 100644 --- a/tests/models/sam/test_modeling_sam.py +++ b/tests/models/sam/test_modeling_sam.py @@ -421,6 +421,12 @@ def test_training(self): def test_training_gradient_checkpointing(self): pass + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + @unittest.skip(reason="SamModel has no base class and is not available in MODEL_MAPPING") def test_save_load_fast_init_from_base(self): pass diff --git a/tests/models/speech_to_text/test_modeling_speech_to_text.py b/tests/models/speech_to_text/test_modeling_speech_to_text.py index 16ad704fd51043..1524ce24d26273 100644 --- a/tests/models/speech_to_text/test_modeling_speech_to_text.py +++ b/tests/models/speech_to_text/test_modeling_speech_to_text.py @@ -324,6 +324,12 @@ def test_training(self): def test_training_gradient_checkpointing(self): pass + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + def test_generate_fp16(self): config, input_dict = self.model_tester.prepare_config_and_inputs() input_features = input_dict["input_features"] diff --git a/tests/models/switch_transformers/test_modeling_switch_transformers.py b/tests/models/switch_transformers/test_modeling_switch_transformers.py index f8730d899329ff..4ff4554fbc3506 100644 --- a/tests/models/switch_transformers/test_modeling_switch_transformers.py +++ b/tests/models/switch_transformers/test_modeling_switch_transformers.py @@ -613,6 +613,12 @@ def test_decoder_model_past_with_attn_mask(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs) + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + @slow def test_beam_sample_generate_dict_output(self): r""" diff --git a/tests/models/time_series_transformer/test_modeling_time_series_transformer.py b/tests/models/time_series_transformer/test_modeling_time_series_transformer.py index 42319a1dd0a242..44962267feea64 100644 --- a/tests/models/time_series_transformer/test_modeling_time_series_transformer.py +++ b/tests/models/time_series_transformer/test_modeling_time_series_transformer.py @@ -200,6 +200,12 @@ def setUp(self): def test_config(self): self.config_tester.run_common_tests() + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + def test_save_load_strict(self): config, _ = self.model_tester.prepare_config_and_inputs() for model_class in self.all_model_classes: diff --git a/tests/models/van/test_modeling_van.py b/tests/models/van/test_modeling_van.py index 49df30a828a61e..7ec941dbc8851f 100644 --- a/tests/models/van/test_modeling_van.py +++ b/tests/models/van/test_modeling_van.py @@ -243,6 +243,12 @@ def test_model_from_pretrained(self): model = VanModel.from_pretrained(model_name) self.assertIsNotNone(model) + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + # We will verify our results on an image of cute cats def prepare_img(): diff --git a/tests/models/vilt/test_modeling_vilt.py b/tests/models/vilt/test_modeling_vilt.py index 772091d5b976d5..17447acf680d52 100644 --- a/tests/models/vilt/test_modeling_vilt.py +++ b/tests/models/vilt/test_modeling_vilt.py @@ -340,6 +340,12 @@ def test_determinism(self): def test_model_outputs_equivalence(self): pass + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + def test_attention_outputs(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config.return_dict = True diff --git a/tests/models/visual_bert/test_modeling_visual_bert.py b/tests/models/visual_bert/test_modeling_visual_bert.py index cf48fd7ffbec31..5dae4ebe1f9439 100644 --- a/tests/models/visual_bert/test_modeling_visual_bert.py +++ b/tests/models/visual_bert/test_modeling_visual_bert.py @@ -549,6 +549,12 @@ def test_model_for_flickr(self): config_and_inputs = self.model_tester.prepare_config_and_inputs_for_flickr() self.model_tester.create_and_check_for_flickr(*config_and_inputs) + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + @slow def test_model_from_pretrained(self): for model_name in VISUAL_BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: diff --git a/tests/models/vit_mae/test_modeling_vit_mae.py b/tests/models/vit_mae/test_modeling_vit_mae.py index c58e2e94802e6b..77c36bef8babe2 100644 --- a/tests/models/vit_mae/test_modeling_vit_mae.py +++ b/tests/models/vit_mae/test_modeling_vit_mae.py @@ -208,6 +208,12 @@ def test_for_pretraining(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_pretraining(*config_and_inputs) + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + # overwrite from common since ViTMAEForPretraining has random masking, we need to fix the noise # to generate masks during test def check_pt_tf_models(self, tf_model, pt_model, pt_inputs_dict): diff --git a/tests/models/x_clip/test_modeling_x_clip.py b/tests/models/x_clip/test_modeling_x_clip.py index 2efece44caebeb..7fd65d871dbfc3 100644 --- a/tests/models/x_clip/test_modeling_x_clip.py +++ b/tests/models/x_clip/test_modeling_x_clip.py @@ -202,6 +202,12 @@ def test_save_load_fast_init_from_base(self): def test_save_load_fast_init_to_base(self): pass + @unittest.skip( + reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247" + ) + def test_training_gradient_checkpointing_autocast(self): + pass + @slow def test_model_from_pretrained(self): for model_name in XCLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 07a8b16bfef758..7c02141f057e94 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import collections import copy import gc @@ -549,6 +548,41 @@ def test_training_gradient_checkpointing(self): loss = model(**inputs).loss loss.backward() + @slow + @require_torch_gpu + def test_training_gradient_checkpointing_autocast(self): + if not self.model_tester.is_training: + return + + for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.use_cache = False + config.return_dict = True + + if ( + model_class.__name__ + in [*get_values(MODEL_MAPPING_NAMES), *get_values(MODEL_FOR_BACKBONE_MAPPING_NAMES)] + or not model_class.supports_gradient_checkpointing + ): + continue + model = model_class(config) + model.to(torch_device) + + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + + model.gradient_checkpointing_enable() + model.train() + inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + with torch.cuda.amp.autocast(True, dtype=torch.float16): + output = model(**inputs)[0] + loss = output.mean() + + loss.backward() + optimizer.step() + + for n, param in model.named_parameters(): + self.assertTrue(param.grad is not None, f"None gradient in param {n}") + def test_attention_outputs(self): if not self.has_attentions: self.skipTest(reason="Model does not output attentions")