From 5b972fbd6a6c50cf1afdf1ba34c34d84fc67861c Mon Sep 17 00:00:00 2001 From: Michael Tkachuk <61463055+MikeTkachuk@users.noreply.github.com> Date: Fri, 8 Nov 2024 14:03:26 -0500 Subject: [PATCH] Enabling gradient checkpointing in eval() mode (#9878) * refactored --- examples/community/matryoshka.py | 8 +++--- .../pixart/controlnet_pixart_alpha.py | 2 +- .../autoencoders/autoencoder_kl_allegro.py | 4 +-- .../autoencoders/autoencoder_kl_cogvideox.py | 10 +++---- .../autoencoders/autoencoder_kl_mochi.py | 10 +++---- .../autoencoder_kl_temporal_decoder.py | 2 +- src/diffusers/models/autoencoders/vae.py | 10 +++---- .../models/controlnets/controlnet_flux.py | 4 +-- .../models/controlnets/controlnet_sd3.py | 2 +- .../models/controlnets/controlnet_xs.py | 6 ++--- .../transformers/auraflow_transformer_2d.py | 4 +-- .../transformers/cogvideox_transformer_3d.py | 2 +- .../models/transformers/dit_transformer_2d.py | 2 +- .../transformers/latte_transformer_3d.py | 4 +-- .../transformers/pixart_transformer_2d.py | 2 +- .../transformers/stable_audio_transformer.py | 2 +- .../models/transformers/transformer_2d.py | 2 +- .../transformers/transformer_allegro.py | 2 +- .../transformers/transformer_cogview3plus.py | 2 +- .../models/transformers/transformer_flux.py | 4 +-- .../models/transformers/transformer_mochi.py | 2 +- .../models/transformers/transformer_sd3.py | 2 +- .../transformers/transformer_temporal.py | 2 +- src/diffusers/models/unets/unet_2d_blocks.py | 26 +++++++++---------- src/diffusers/models/unets/unet_3d_blocks.py | 10 +++---- .../models/unets/unet_motion_model.py | 10 +++---- .../models/unets/unet_stable_cascade.py | 4 +-- src/diffusers/models/unets/uvit_2d.py | 2 +- .../pipelines/audioldm2/modeling_audioldm2.py | 6 ++--- .../blip_diffusion/modeling_blip2.py | 2 +- .../versatile_diffusion/modeling_text_unet.py | 10 +++---- .../pipelines/kolors/text_encoder.py | 4 +-- .../pipeline_latent_diffusion.py | 2 +- .../wuerstchen/modeling_wuerstchen_prior.py | 2 +- 34 files changed, 84 insertions(+), 84 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index 7ac0ab542910..0c85ad118752 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -868,7 +868,7 @@ def forward( blocks = list(zip(self.resnets, self.attentions)) for i, (resnet, attn) in enumerate(blocks): - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): @@ -1029,7 +1029,7 @@ def forward( hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): @@ -1191,7 +1191,7 @@ def forward( hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): @@ -1364,7 +1364,7 @@ def forward( # Blocks for block in self.transformer_blocks: - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): diff --git a/examples/research_projects/pixart/controlnet_pixart_alpha.py b/examples/research_projects/pixart/controlnet_pixart_alpha.py index b7f5a427e52e..f825719a1364 100644 --- a/examples/research_projects/pixart/controlnet_pixart_alpha.py +++ b/examples/research_projects/pixart/controlnet_pixart_alpha.py @@ -215,7 +215,7 @@ def forward( # 2. Blocks for block_index, block in enumerate(self.transformer.transformer_blocks): - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: # rc todo: for training and gradient checkpointing print("Gradient checkpointing is not supported for the controlnet transformer model, yet.") exit(1) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py b/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py index 922fd15c08fb..b62ed67ade29 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py @@ -506,7 +506,7 @@ def forward(self, sample: torch.Tensor) -> torch.Tensor: sample = self.temp_conv_in(sample) sample = sample + residual - if self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): @@ -646,7 +646,7 @@ def forward(self, sample: torch.Tensor) -> torch.Tensor: upscale_dtype = next(iter(self.up_blocks.parameters())).dtype - if self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py index 8575c7658605..d9ee15062daf 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py @@ -420,7 +420,7 @@ def forward( for i, resnet in enumerate(self.resnets): conv_cache_key = f"resnet_{i}" - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def create_forward(*inputs): @@ -522,7 +522,7 @@ def forward( for i, resnet in enumerate(self.resnets): conv_cache_key = f"resnet_{i}" - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def create_forward(*inputs): @@ -636,7 +636,7 @@ def forward( for i, resnet in enumerate(self.resnets): conv_cache_key = f"resnet_{i}" - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def create_forward(*inputs): @@ -773,7 +773,7 @@ def forward( hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in")) - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): @@ -939,7 +939,7 @@ def forward( hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in")) - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py b/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py index 57e8b8f647ba..0eabf3a26d7c 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py @@ -206,7 +206,7 @@ def forward( for i, (resnet, norm, attn) in enumerate(zip(self.resnets, self.norms, self.attentions)): conv_cache_key = f"resnet_{i}" - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def create_forward(*inputs): @@ -311,7 +311,7 @@ def forward( for i, (resnet, norm, attn) in enumerate(zip(self.resnets, self.norms, self.attentions)): conv_cache_key = f"resnet_{i}" - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def create_forward(*inputs): @@ -392,7 +392,7 @@ def forward( for i, resnet in enumerate(self.resnets): conv_cache_key = f"resnet_{i}" - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def create_forward(*inputs): @@ -529,7 +529,7 @@ def forward( hidden_states = self.proj_in(hidden_states) hidden_states = hidden_states.permute(0, 4, 1, 2, 3) - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def create_forward(*inputs): @@ -646,7 +646,7 @@ def forward( hidden_states = self.conv_in(hidden_states) # 1. Mid - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def create_forward(*inputs): diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py b/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py index 55449644ed03..4e3902ae6dbe 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py @@ -95,7 +95,7 @@ def forward( sample = self.conv_in(sample) upscale_dtype = next(iter(self.up_blocks.parameters())).dtype - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): diff --git a/src/diffusers/models/autoencoders/vae.py b/src/diffusers/models/autoencoders/vae.py index bb80ce8605ba..2f3f4f2fc35c 100644 --- a/src/diffusers/models/autoencoders/vae.py +++ b/src/diffusers/models/autoencoders/vae.py @@ -142,7 +142,7 @@ def forward(self, sample: torch.Tensor) -> torch.Tensor: sample = self.conv_in(sample) - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): @@ -291,7 +291,7 @@ def forward( sample = self.conv_in(sample) upscale_dtype = next(iter(self.up_blocks.parameters())).dtype - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): @@ -544,7 +544,7 @@ def forward( sample = self.conv_in(sample) upscale_dtype = next(iter(self.up_blocks.parameters())).dtype - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): @@ -876,7 +876,7 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: r"""The forward method of the `EncoderTiny` class.""" - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): @@ -962,7 +962,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # Clamp. x = torch.tanh(x / 3) * 3 - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): diff --git a/src/diffusers/models/controlnets/controlnet_flux.py b/src/diffusers/models/controlnets/controlnet_flux.py index e6a3eceed9b4..76a97847ef9a 100644 --- a/src/diffusers/models/controlnets/controlnet_flux.py +++ b/src/diffusers/models/controlnets/controlnet_flux.py @@ -329,7 +329,7 @@ def forward( block_samples = () for index_block, block in enumerate(self.transformer_blocks): - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): @@ -363,7 +363,7 @@ def custom_forward(*inputs): single_block_samples = () for index_block, block in enumerate(self.single_transformer_blocks): - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): diff --git a/src/diffusers/models/controlnets/controlnet_sd3.py b/src/diffusers/models/controlnets/controlnet_sd3.py index 911d65e03d88..209aad93244e 100644 --- a/src/diffusers/models/controlnets/controlnet_sd3.py +++ b/src/diffusers/models/controlnets/controlnet_sd3.py @@ -324,7 +324,7 @@ def forward( block_res_samples = () for block in self.transformer_blocks: - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): diff --git a/src/diffusers/models/controlnets/controlnet_xs.py b/src/diffusers/models/controlnets/controlnet_xs.py index 06e0eda3c3b0..11ad676ec92b 100644 --- a/src/diffusers/models/controlnets/controlnet_xs.py +++ b/src/diffusers/models/controlnets/controlnet_xs.py @@ -1466,7 +1466,7 @@ def custom_forward(*inputs): h_ctrl = torch.cat([h_ctrl, b2c(h_base)], dim=1) # apply base subblock - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} h_base = torch.utils.checkpoint.checkpoint( create_custom_forward(b_res), @@ -1489,7 +1489,7 @@ def custom_forward(*inputs): # apply ctrl subblock if apply_control: - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} h_ctrl = torch.utils.checkpoint.checkpoint( create_custom_forward(c_res), @@ -1898,7 +1898,7 @@ def maybe_apply_freeu_to_subblock(hidden_states, res_h_base): hidden_states, res_h_base = maybe_apply_freeu_to_subblock(hidden_states, res_h_base) hidden_states = torch.cat([hidden_states, res_h_base], dim=1) - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), diff --git a/src/diffusers/models/transformers/auraflow_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py index ad64df0c0790..b3f29e6b6224 100644 --- a/src/diffusers/models/transformers/auraflow_transformer_2d.py +++ b/src/diffusers/models/transformers/auraflow_transformer_2d.py @@ -466,7 +466,7 @@ def forward( # MMDiT blocks. for index_block, block in enumerate(self.joint_transformer_blocks): - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): @@ -497,7 +497,7 @@ def custom_forward(*inputs): combined_hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) for index_block, block in enumerate(self.single_transformer_blocks): - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index 821da6d032d5..01c54ef090bd 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -452,7 +452,7 @@ def forward( # 3. Transformer blocks for i, block in enumerate(self.transformer_blocks): - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): diff --git a/src/diffusers/models/transformers/dit_transformer_2d.py b/src/diffusers/models/transformers/dit_transformer_2d.py index 9f8957737dbc..f787c5279499 100644 --- a/src/diffusers/models/transformers/dit_transformer_2d.py +++ b/src/diffusers/models/transformers/dit_transformer_2d.py @@ -184,7 +184,7 @@ def forward( # 2. Blocks for block in self.transformer_blocks: - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): diff --git a/src/diffusers/models/transformers/latte_transformer_3d.py b/src/diffusers/models/transformers/latte_transformer_3d.py index 71d19216e5ff..7e2b1273687d 100644 --- a/src/diffusers/models/transformers/latte_transformer_3d.py +++ b/src/diffusers/models/transformers/latte_transformer_3d.py @@ -238,7 +238,7 @@ def forward( for i, (spatial_block, temp_block) in enumerate( zip(self.transformer_blocks, self.temporal_transformer_blocks) ): - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: hidden_states = torch.utils.checkpoint.checkpoint( spatial_block, hidden_states, @@ -271,7 +271,7 @@ def forward( if i == 0 and num_frame > 1: hidden_states = hidden_states + self.temp_pos_embed - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: hidden_states = torch.utils.checkpoint.checkpoint( temp_block, hidden_states, diff --git a/src/diffusers/models/transformers/pixart_transformer_2d.py b/src/diffusers/models/transformers/pixart_transformer_2d.py index 1e5cd5794517..7f145edf16fb 100644 --- a/src/diffusers/models/transformers/pixart_transformer_2d.py +++ b/src/diffusers/models/transformers/pixart_transformer_2d.py @@ -386,7 +386,7 @@ def forward( # 2. Blocks for block in self.transformer_blocks: - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): diff --git a/src/diffusers/models/transformers/stable_audio_transformer.py b/src/diffusers/models/transformers/stable_audio_transformer.py index e3462b51a412..d687dbabf317 100644 --- a/src/diffusers/models/transformers/stable_audio_transformer.py +++ b/src/diffusers/models/transformers/stable_audio_transformer.py @@ -414,7 +414,7 @@ def forward( attention_mask = torch.cat([prepend_mask, attention_mask], dim=-1) for block in self.transformer_blocks: - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): diff --git a/src/diffusers/models/transformers/transformer_2d.py b/src/diffusers/models/transformers/transformer_2d.py index c7c19e4582c6..e208a1c10ed4 100644 --- a/src/diffusers/models/transformers/transformer_2d.py +++ b/src/diffusers/models/transformers/transformer_2d.py @@ -415,7 +415,7 @@ def forward( # 2. Blocks for block in self.transformer_blocks: - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): diff --git a/src/diffusers/models/transformers/transformer_allegro.py b/src/diffusers/models/transformers/transformer_allegro.py index f756399a378a..fe9c7290b063 100644 --- a/src/diffusers/models/transformers/transformer_allegro.py +++ b/src/diffusers/models/transformers/transformer_allegro.py @@ -371,7 +371,7 @@ def forward( # 3. Transformer blocks for i, block in enumerate(self.transformer_blocks): # TODO(aryan): Implement gradient checkpointing - if self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): diff --git a/src/diffusers/models/transformers/transformer_cogview3plus.py b/src/diffusers/models/transformers/transformer_cogview3plus.py index 962cbbff7c1b..94d852f6df4b 100644 --- a/src/diffusers/models/transformers/transformer_cogview3plus.py +++ b/src/diffusers/models/transformers/transformer_cogview3plus.py @@ -341,7 +341,7 @@ def forward( hidden_states = hidden_states[:, text_seq_length:] for index_block, block in enumerate(self.transformer_blocks): - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index f078cace0f3e..0ad3be866019 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -480,7 +480,7 @@ def forward( image_rotary_emb = self.pos_embed(ids) for index_block, block in enumerate(self.transformer_blocks): - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): @@ -525,7 +525,7 @@ def custom_forward(*inputs): hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) for index_block, block in enumerate(self.single_transformer_blocks): - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index 7f4ad2b328fa..8ac8b5dababa 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -350,7 +350,7 @@ def forward( ) for i, block in enumerate(self.transformer_blocks): - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index b28350b8ed9c..f39a102c7256 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -317,7 +317,7 @@ def forward( encoder_hidden_states = self.context_embedder(encoder_hidden_states) for index_block, block in enumerate(self.transformer_blocks): - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): diff --git a/src/diffusers/models/transformers/transformer_temporal.py b/src/diffusers/models/transformers/transformer_temporal.py index c0c5467050dd..6ca42b9745fd 100644 --- a/src/diffusers/models/transformers/transformer_temporal.py +++ b/src/diffusers/models/transformers/transformer_temporal.py @@ -340,7 +340,7 @@ def forward( # 2. Blocks for block, temporal_block in zip(self.transformer_blocks, self.temporal_transformer_blocks): - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: hidden_states = torch.utils.checkpoint.checkpoint( block, hidden_states, diff --git a/src/diffusers/models/unets/unet_2d_blocks.py b/src/diffusers/models/unets/unet_2d_blocks.py index 93a0a82cdcff..b9d186ac1aa6 100644 --- a/src/diffusers/models/unets/unet_2d_blocks.py +++ b/src/diffusers/models/unets/unet_2d_blocks.py @@ -859,7 +859,7 @@ def forward( hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): @@ -1257,7 +1257,7 @@ def forward( blocks = list(zip(self.resnets, self.attentions)) for i, (resnet, attn) in enumerate(blocks): - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): @@ -1371,7 +1371,7 @@ def forward( output_states = () for resnet in self.resnets: - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): @@ -1859,7 +1859,7 @@ def forward( output_states = () for resnet in self.resnets: - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): @@ -2011,7 +2011,7 @@ def forward( mask = attention_mask for resnet, attn in zip(self.resnets, self.attentions): - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): @@ -2106,7 +2106,7 @@ def forward( output_states = () for resnet in self.resnets: - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): @@ -2215,7 +2215,7 @@ def forward( output_states = () for resnet, attn in zip(self.resnets, self.attentions): - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): @@ -2520,7 +2520,7 @@ def forward( hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): @@ -2653,7 +2653,7 @@ def forward( hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): @@ -3183,7 +3183,7 @@ def forward( res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): @@ -3341,7 +3341,7 @@ def forward( res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): @@ -3444,7 +3444,7 @@ def forward( hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1) for resnet in self.resnets: - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): @@ -3572,7 +3572,7 @@ def forward( hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1) for resnet, attn in zip(self.resnets, self.attentions): - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): diff --git a/src/diffusers/models/unets/unet_3d_blocks.py b/src/diffusers/models/unets/unet_3d_blocks.py index 8b472a89e13d..9c9fd7555899 100644 --- a/src/diffusers/models/unets/unet_3d_blocks.py +++ b/src/diffusers/models/unets/unet_3d_blocks.py @@ -1078,7 +1078,7 @@ def forward( ) for attn, resnet in zip(self.attentions, self.resnets[1:]): - if self.training and self.gradient_checkpointing: # TODO + if torch.is_grad_enabled() and self.gradient_checkpointing: # TODO def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): @@ -1168,7 +1168,7 @@ def forward( ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]: output_states = () for resnet in self.resnets: - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): @@ -1281,7 +1281,7 @@ def forward( blocks = list(zip(self.resnets, self.attentions)) for resnet, attn in blocks: - if self.training and self.gradient_checkpointing: # TODO + if torch.is_grad_enabled() and self.gradient_checkpointing: # TODO def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): @@ -1383,7 +1383,7 @@ def forward( hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): @@ -1493,7 +1493,7 @@ def forward( hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - if self.training and self.gradient_checkpointing: # TODO + if torch.is_grad_enabled() and self.gradient_checkpointing: # TODO def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index 6125feba5899..ddc3e41c340d 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -323,7 +323,7 @@ def forward( blocks = zip(self.resnets, self.motion_modules) for resnet, motion_module in blocks: - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): @@ -513,7 +513,7 @@ def forward( blocks = list(zip(self.resnets, self.attentions, self.motion_modules)) for i, (resnet, attn, motion_module) in enumerate(blocks): - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): @@ -732,7 +732,7 @@ def forward( hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): @@ -895,7 +895,7 @@ def forward( hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): @@ -1079,7 +1079,7 @@ def forward( return_dict=False, )[0] - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): diff --git a/src/diffusers/models/unets/unet_stable_cascade.py b/src/diffusers/models/unets/unet_stable_cascade.py index 7deea9a714d4..238e6b411356 100644 --- a/src/diffusers/models/unets/unet_stable_cascade.py +++ b/src/diffusers/models/unets/unet_stable_cascade.py @@ -455,7 +455,7 @@ def _down_encode(self, x, r_embed, clip): level_outputs = [] block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers) - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): @@ -504,7 +504,7 @@ def _up_decode(self, level_outputs, r_embed, clip): x = level_outputs[0] block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers) - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): diff --git a/src/diffusers/models/unets/uvit_2d.py b/src/diffusers/models/unets/uvit_2d.py index 8a379bf5f9c3..2f0b3eb19508 100644 --- a/src/diffusers/models/unets/uvit_2d.py +++ b/src/diffusers/models/unets/uvit_2d.py @@ -181,7 +181,7 @@ def forward(self, input_ids, encoder_hidden_states, pooled_text_emb, micro_conds hidden_states = self.project_to_hidden(hidden_states) for layer in self.transformer_layers: - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def layer_(*args): return checkpoint(layer, *args) diff --git a/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py b/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py index 2af3078f7412..63d3957ae17d 100644 --- a/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py +++ b/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py @@ -1112,7 +1112,7 @@ def forward( ) for i in range(num_layers): - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): @@ -1290,7 +1290,7 @@ def forward( ) for i in range(len(self.resnets[1:])): - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): @@ -1464,7 +1464,7 @@ def forward( res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): diff --git a/src/diffusers/pipelines/blip_diffusion/modeling_blip2.py b/src/diffusers/pipelines/blip_diffusion/modeling_blip2.py index 1be4761a9987..0d78b987ce77 100644 --- a/src/diffusers/pipelines/blip_diffusion/modeling_blip2.py +++ b/src/diffusers/pipelines/blip_diffusion/modeling_blip2.py @@ -167,7 +167,7 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if getattr(self.config, "gradient_checkpointing", False) and torch.is_grad_enabled(): if use_cache: logger.warning( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py index 3937e87f63c9..107a5a45bfa2 100644 --- a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py @@ -1595,7 +1595,7 @@ def forward( output_states = () for resnet in self.resnets: - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): @@ -1732,7 +1732,7 @@ def forward( blocks = list(zip(self.resnets, self.attentions)) for i, (resnet, attn) in enumerate(blocks): - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): @@ -1874,7 +1874,7 @@ def forward( hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): @@ -2033,7 +2033,7 @@ def forward( hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): @@ -2352,7 +2352,7 @@ def forward( hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): diff --git a/src/diffusers/pipelines/kolors/text_encoder.py b/src/diffusers/pipelines/kolors/text_encoder.py index 6fb6f18a907a..5eb8d4c43d02 100644 --- a/src/diffusers/pipelines/kolors/text_encoder.py +++ b/src/diffusers/pipelines/kolors/text_encoder.py @@ -590,7 +590,7 @@ def forward( if not kv_caches: kv_caches = [None for _ in range(self.num_layers)] presents = () if use_cache else None - if self.gradient_checkpointing and self.training: + if torch.is_grad_enabled() and self.gradient_checkpointing: if use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." @@ -604,7 +604,7 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) layer = self._get_layer(index) - if self.gradient_checkpointing and self.training: + if torch.is_grad_enabled() and self.gradient_checkpointing: layer_ret = torch.utils.checkpoint.checkpoint( layer, hidden_states, attention_mask, rotary_pos_emb, kv_caches[index], use_cache ) diff --git a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py index f6f3531a8835..cd63637b6c2f 100644 --- a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +++ b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py @@ -675,7 +675,7 @@ def forward( for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if self.gradient_checkpointing and self.training: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): diff --git a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py index edb0c1ec45de..f90fc82a98ad 100644 --- a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py @@ -158,7 +158,7 @@ def forward(self, x, r, c): c_embed = self.cond_mapper(c) r_embed = self.gen_r_embedding(r) - if self.training and self.gradient_checkpointing: + if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs):