Skip to content

Commit

Permalink
Skip layer guidance now works on hydit model.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Nov 24, 2024
1 parent 3d80271 commit b4526d3
Showing 1 changed file with 17 additions and 5 deletions.
22 changes: 17 additions & 5 deletions comfy/ldm/hydit/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def forward(self,
style=None,
return_dict=False,
control=None,
transformer_options=None,
transformer_options={},
):
"""
Forward pass of the encoder.
Expand Down Expand Up @@ -315,8 +315,7 @@ def forward(self,
return_dict: bool
Whether to return a dictionary.
"""
#import pdb
#pdb.set_trace()
patches_replace = transformer_options.get("patches_replace", {})
encoder_hidden_states = context
text_states = encoder_hidden_states # 2,77,1024
text_states_t5 = encoder_hidden_states_t5 # 2,256,2048
Expand Down Expand Up @@ -364,6 +363,8 @@ def forward(self,
# Concatenate all extra vectors
c = t + self.extra_embedder(extra_vec) # [B, D]

blocks_replace = patches_replace.get("dit", {})

controls = None
if control:
controls = control.get("output", None)
Expand All @@ -375,9 +376,20 @@ def forward(self,
skip = skips.pop() + controls.pop().to(dtype=x.dtype)
else:
skip = skips.pop()
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)
else:
x = block(x, c, text_states, freqs_cis_img) # (N, L, D)
skip = None

if ("double_block", layer) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], args["vec"], args["txt"], args["pe"], args["skip"])
return out

out = blocks_replace[("double_block", layer)]({"img": x, "txt": text_states, "vec": c, "pe": freqs_cis_img, "skip": skip}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)


if layer < (self.depth // 2 - 1):
skips.append(x)
Expand Down

0 comments on commit b4526d3

Please sign in to comment.