Skip to content

Commit

Permalink
Add ability to mix usage of T2I-Adapter(s) and ControlNet(s). (huggin…
Browse files Browse the repository at this point in the history
…gface#5362)

* Add ability to mix usage of T2I-Adapter(s) and ControlNet(s).
Previously, UNet2DConditional implemnetation onloy allowed use of one or the other.
Adds new forward() arg down_intrablock_additional_residuals specifically for T2I-Adapters. If down_intrablock_addtional_residuals is not used, maintains backward compatibility with prior usage of only T2I-Adapter or ControlNet but not both

* Improving forward() arg docs in src/diffusers/models/unet_2d_condition.py

Co-authored-by: psychedelicious <[email protected]>

* Add deprecation warning if down_block_additional_residues is used for T2I-Adapter (intrablock residuals)

Co-authored-by: Patrick von Platen <[email protected]>

* Oops my bad, fixing last commit.

* Added import of diffusers utils.deprecate

* Conform to max line length

* Modifying T2I-Adapter pipelines to reflect change to UNet forward() arg for T2I-Adapter residuals.

---------

Co-authored-by: psychedelicious <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>
  • Loading branch information
3 people authored Oct 16, 2023
1 parent cc12f3e commit de12776
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 14 deletions.
40 changes: 30 additions & 10 deletions src/diffusers/models/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import UNet2DConditionLoadersMixin
from ..utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
from ..utils import USE_PEFT_BACKEND, BaseOutput, logging, deprecate, scale_lora_layers, unscale_lora_layers
from .activations import get_activation
from .attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
Expand Down Expand Up @@ -778,6 +778,7 @@ def forward(
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
mid_block_additional_residual: Optional[torch.Tensor] = None,
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[UNet2DConditionOutput, Tuple]:
Expand Down Expand Up @@ -822,6 +823,13 @@ def forward(
added_cond_kwargs: (`dict`, *optional*):
A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
are passed along to the UNet blocks.
down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
additional residuals to be added to UNet long skip connections from down blocks to up blocks
for example from ControlNet side model(s)
mid_block_additional_residual (`torch.Tensor`, *optional*):
additional residual to be added to UNet mid block output, for example from ControlNet side model
down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
Returns:
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
Expand Down Expand Up @@ -1000,15 +1008,28 @@ def forward(
scale_lora_layers(self, lora_scale)

is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
# using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
is_adapter = down_intrablock_additional_residuals is not None
# maintain backward compatibility for legacy usage, where
# T2I-Adapter and ControlNet both use down_block_additional_residuals arg
# but can only use one or the other
if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
deprecate("T2I should not use down_block_additional_residuals",
"1.3.0",
"Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
standard_warn=False)
down_intrablock_additional_residuals = down_block_additional_residuals
is_adapter = True

down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
# For t2i-adapter CrossAttnDownBlock2D
additional_residuals = {}
if is_adapter and len(down_block_additional_residuals) > 0:
additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
if is_adapter and len(down_intrablock_additional_residuals) > 0:
additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)

sample, res_samples = downsample_block(
hidden_states=sample,
Expand All @@ -1021,9 +1042,8 @@ def forward(
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale)

if is_adapter and len(down_block_additional_residuals) > 0:
sample += down_block_additional_residuals.pop(0)
if is_adapter and len(down_intrablock_additional_residuals) > 0:
sample += down_intrablock_additional_residuals.pop(0)

down_block_res_samples += res_samples

Expand Down Expand Up @@ -1051,10 +1071,10 @@ def forward(
# To support T2I-Adapter-XL
if (
is_adapter
and len(down_block_additional_residuals) > 0
and sample.shape == down_block_additional_residuals[0].shape
and len(down_intrablock_additional_residuals) > 0
and sample.shape == down_intrablock_additional_residuals[0].shape
):
sample += down_block_additional_residuals.pop(0)
sample += down_intrablock_additional_residuals.pop(0)

if is_controlnet:
sample = sample + mid_block_additional_residual
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -813,7 +813,7 @@ def __call__(
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
down_block_additional_residuals=[state.clone() for state in adapter_state],
down_intrablock_additional_residuals=[state.clone() for state in adapter_state],
).sample

# perform guidance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -975,9 +975,9 @@ def __call__(
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}

if i < int(num_inference_steps * adapter_conditioning_factor):
down_block_additional_residuals = [state.clone() for state in adapter_state]
down_intrablock_additional_residuals = [state.clone() for state in adapter_state]
else:
down_block_additional_residuals = None
down_intrablock_additional_residuals = None

noise_pred = self.unet(
latent_model_input,
Expand All @@ -986,7 +986,7 @@ def __call__(
cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
down_block_additional_residuals=down_block_additional_residuals,
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
)[0]

# perform guidance
Expand Down

0 comments on commit de12776

Please sign in to comment.