From 72741105a687c67137eb5d7a38840b8373d82e61 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 21 Nov 2023 17:18:49 -0500 Subject: [PATCH] Remove useless code. --- .../modules/diffusionmodules/openaimodel.py | 31 +++++++------------ 1 file changed, 11 insertions(+), 20 deletions(-) diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 10eb68d73b5..e8f35a540fa 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -28,25 +28,6 @@ def forward(self, x, emb): Apply the module to `x` given `emb` timestep embeddings. """ - -class TimestepEmbedSequential(nn.Sequential, TimestepBlock): - """ - A sequential module that passes timestep embeddings to the children that - support it as an extra input. - """ - - def forward(self, x, emb, context=None, transformer_options={}, output_shape=None): - for layer in self: - if isinstance(layer, TimestepBlock): - x = layer(x, emb) - elif isinstance(layer, SpatialTransformer): - x = layer(x, context, transformer_options) - elif isinstance(layer, Upsample): - x = layer(x, output_shape=output_shape) - else: - x = layer(x) - return x - #This is needed because accelerate makes a copy of transformer_options which breaks "current_index" def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None): for layer in ts: @@ -54,13 +35,23 @@ def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, out x = layer(x, emb) elif isinstance(layer, SpatialTransformer): x = layer(x, context, transformer_options) - transformer_options["current_index"] += 1 + if "current_index" in transformer_options: + transformer_options["current_index"] += 1 elif isinstance(layer, Upsample): x = layer(x, output_shape=output_shape) else: x = layer(x) return x +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, *args, **kwargs): + return forward_timestep_embed(self, *args, **kwargs) + class Upsample(nn.Module): """ An upsampling layer with an optional convolution.