Skip to content

Commit

Permalink
Enabling gradient checkpointing in eval() mode (#9878)
Browse files Browse the repository at this point in the history
* refactored
  • Loading branch information
MikeTkachuk authored Nov 8, 2024
1 parent 0be52c0 commit 5b972fb
Show file tree
Hide file tree
Showing 34 changed files with 84 additions and 84 deletions.
8 changes: 4 additions & 4 deletions examples/community/matryoshka.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,7 +868,7 @@ def forward(
blocks = list(zip(self.resnets, self.attentions))

for i, (resnet, attn) in enumerate(blocks):
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
Expand Down Expand Up @@ -1029,7 +1029,7 @@ def forward(

hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
Expand Down Expand Up @@ -1191,7 +1191,7 @@ def forward(

hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)

if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
Expand Down Expand Up @@ -1364,7 +1364,7 @@ def forward(

# Blocks
for block in self.transformer_blocks:
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def forward(

# 2. Blocks
for block_index, block in enumerate(self.transformer.transformer_blocks):
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:
# rc todo: for training and gradient checkpointing
print("Gradient checkpointing is not supported for the controlnet transformer model, yet.")
exit(1)
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/models/autoencoders/autoencoder_kl_allegro.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ def forward(self, sample: torch.Tensor) -> torch.Tensor:
sample = self.temp_conv_in(sample)
sample = sample + residual

if self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module):
def custom_forward(*inputs):
Expand Down Expand Up @@ -646,7 +646,7 @@ def forward(self, sample: torch.Tensor) -> torch.Tensor:

upscale_dtype = next(iter(self.up_blocks.parameters())).dtype

if self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module):
def custom_forward(*inputs):
Expand Down
10 changes: 5 additions & 5 deletions src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ def forward(
for i, resnet in enumerate(self.resnets):
conv_cache_key = f"resnet_{i}"

if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module):
def create_forward(*inputs):
Expand Down Expand Up @@ -522,7 +522,7 @@ def forward(
for i, resnet in enumerate(self.resnets):
conv_cache_key = f"resnet_{i}"

if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module):
def create_forward(*inputs):
Expand Down Expand Up @@ -636,7 +636,7 @@ def forward(
for i, resnet in enumerate(self.resnets):
conv_cache_key = f"resnet_{i}"

if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module):
def create_forward(*inputs):
Expand Down Expand Up @@ -773,7 +773,7 @@ def forward(

hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))

if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module):
def custom_forward(*inputs):
Expand Down Expand Up @@ -939,7 +939,7 @@ def forward(

hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))

if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module):
def custom_forward(*inputs):
Expand Down
10 changes: 5 additions & 5 deletions src/diffusers/models/autoencoders/autoencoder_kl_mochi.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def forward(
for i, (resnet, norm, attn) in enumerate(zip(self.resnets, self.norms, self.attentions)):
conv_cache_key = f"resnet_{i}"

if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module):
def create_forward(*inputs):
Expand Down Expand Up @@ -311,7 +311,7 @@ def forward(
for i, (resnet, norm, attn) in enumerate(zip(self.resnets, self.norms, self.attentions)):
conv_cache_key = f"resnet_{i}"

if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module):
def create_forward(*inputs):
Expand Down Expand Up @@ -392,7 +392,7 @@ def forward(
for i, resnet in enumerate(self.resnets):
conv_cache_key = f"resnet_{i}"

if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module):
def create_forward(*inputs):
Expand Down Expand Up @@ -529,7 +529,7 @@ def forward(
hidden_states = self.proj_in(hidden_states)
hidden_states = hidden_states.permute(0, 4, 1, 2, 3)

if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module):
def create_forward(*inputs):
Expand Down Expand Up @@ -646,7 +646,7 @@ def forward(
hidden_states = self.conv_in(hidden_states)

# 1. Mid
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module):
def create_forward(*inputs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def forward(
sample = self.conv_in(sample)

upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module):
def custom_forward(*inputs):
Expand Down
10 changes: 5 additions & 5 deletions src/diffusers/models/autoencoders/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def forward(self, sample: torch.Tensor) -> torch.Tensor:

sample = self.conv_in(sample)

if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module):
def custom_forward(*inputs):
Expand Down Expand Up @@ -291,7 +291,7 @@ def forward(
sample = self.conv_in(sample)

upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module):
def custom_forward(*inputs):
Expand Down Expand Up @@ -544,7 +544,7 @@ def forward(
sample = self.conv_in(sample)

upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module):
def custom_forward(*inputs):
Expand Down Expand Up @@ -876,7 +876,7 @@ def __init__(

def forward(self, x: torch.Tensor) -> torch.Tensor:
r"""The forward method of the `EncoderTiny` class."""
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module):
def custom_forward(*inputs):
Expand Down Expand Up @@ -962,7 +962,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
# Clamp.
x = torch.tanh(x / 3) * 3

if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module):
def custom_forward(*inputs):
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/models/controlnets/controlnet_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def forward(

block_samples = ()
for index_block, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
Expand Down Expand Up @@ -363,7 +363,7 @@ def custom_forward(*inputs):

single_block_samples = ()
for index_block, block in enumerate(self.single_transformer_blocks):
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/controlnets/controlnet_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def forward(
block_res_samples = ()

for block in self.transformer_blocks:
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
Expand Down
6 changes: 3 additions & 3 deletions src/diffusers/models/controlnets/controlnet_xs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1466,7 +1466,7 @@ def custom_forward(*inputs):
h_ctrl = torch.cat([h_ctrl, b2c(h_base)], dim=1)

# apply base subblock
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
h_base = torch.utils.checkpoint.checkpoint(
create_custom_forward(b_res),
Expand All @@ -1489,7 +1489,7 @@ def custom_forward(*inputs):

# apply ctrl subblock
if apply_control:
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
h_ctrl = torch.utils.checkpoint.checkpoint(
create_custom_forward(c_res),
Expand Down Expand Up @@ -1898,7 +1898,7 @@ def maybe_apply_freeu_to_subblock(hidden_states, res_h_base):
hidden_states, res_h_base = maybe_apply_freeu_to_subblock(hidden_states, res_h_base)
hidden_states = torch.cat([hidden_states, res_h_base], dim=1)

if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/models/transformers/auraflow_transformer_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ def forward(

# MMDiT blocks.
for index_block, block in enumerate(self.joint_transformer_blocks):
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
Expand Down Expand Up @@ -497,7 +497,7 @@ def custom_forward(*inputs):
combined_hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)

for index_block, block in enumerate(self.single_transformer_blocks):
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ def forward(

# 3. Transformer blocks
for i, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module):
def custom_forward(*inputs):
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/transformers/dit_transformer_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def forward(

# 2. Blocks
for block in self.transformer_blocks:
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/models/transformers/latte_transformer_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def forward(
for i, (spatial_block, temp_block) in enumerate(
zip(self.transformer_blocks, self.temporal_transformer_blocks)
):
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = torch.utils.checkpoint.checkpoint(
spatial_block,
hidden_states,
Expand Down Expand Up @@ -271,7 +271,7 @@ def forward(
if i == 0 and num_frame > 1:
hidden_states = hidden_states + self.temp_pos_embed

if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = torch.utils.checkpoint.checkpoint(
temp_block,
hidden_states,
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/transformers/pixart_transformer_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def forward(

# 2. Blocks
for block in self.transformer_blocks:
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ def forward(
attention_mask = torch.cat([prepend_mask, attention_mask], dim=-1)

for block in self.transformer_blocks:
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/transformers/transformer_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ def forward(

# 2. Blocks
for block in self.transformer_blocks:
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/transformers/transformer_allegro.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ def forward(
# 3. Transformer blocks
for i, block in enumerate(self.transformer_blocks):
# TODO(aryan): Implement gradient checkpointing
if self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module):
def custom_forward(*inputs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def forward(
hidden_states = hidden_states[:, text_seq_length:]

for index_block, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module):
def custom_forward(*inputs):
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/models/transformers/transformer_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ def forward(
image_rotary_emb = self.pos_embed(ids)

for index_block, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
Expand Down Expand Up @@ -525,7 +525,7 @@ def custom_forward(*inputs):
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)

for index_block, block in enumerate(self.single_transformer_blocks):
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/transformers/transformer_mochi.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ def forward(
)

for i, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module):
def custom_forward(*inputs):
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/transformers/transformer_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def forward(
encoder_hidden_states = self.context_embedder(encoder_hidden_states)

for index_block, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
Expand Down
Loading

0 comments on commit 5b972fb

Please sign in to comment.