Skip to content

Commit

Permalink
make style
Browse files Browse the repository at this point in the history
  • Loading branch information
patrickvonplaten committed Oct 16, 2023
1 parent de12776 commit 57239da
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 16 deletions.
16 changes: 9 additions & 7 deletions src/diffusers/models/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import UNet2DConditionLoadersMixin
from ..utils import USE_PEFT_BACKEND, BaseOutput, logging, deprecate, scale_lora_layers, unscale_lora_layers
from ..utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
from .activations import get_activation
from .attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
Expand Down Expand Up @@ -824,8 +824,8 @@ def forward(
A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
are passed along to the UNet blocks.
down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
additional residuals to be added to UNet long skip connections from down blocks to up blocks
for example from ControlNet side model(s)
additional residuals to be added to UNet long skip connections from down blocks to up blocks for
example from ControlNet side model(s)
mid_block_additional_residual (`torch.Tensor`, *optional*):
additional residual to be added to UNet mid block output, for example from ControlNet side model
down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
Expand Down Expand Up @@ -1014,12 +1014,14 @@ def forward(
# T2I-Adapter and ControlNet both use down_block_additional_residuals arg
# but can only use one or the other
if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
deprecate("T2I should not use down_block_additional_residuals",
"1.3.0",
"Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
deprecate(
"T2I should not use down_block_additional_residuals",
"1.3.0",
"Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
standard_warn=False)
standard_warn=False,
)
down_intrablock_additional_residuals = down_block_additional_residuals
is_adapter = True

Expand Down
41 changes: 32 additions & 9 deletions src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -987,6 +987,7 @@ def forward(
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
mid_block_additional_residual: Optional[torch.Tensor] = None,
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[UNet2DConditionOutput, Tuple]:
Expand Down Expand Up @@ -1031,6 +1032,13 @@ def forward(
added_cond_kwargs: (`dict`, *optional*):
A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
are passed along to the UNet blocks.
down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
additional residuals to be added to UNet long skip connections from down blocks to up blocks for
example from ControlNet side model(s)
mid_block_additional_residual (`torch.Tensor`, *optional*):
additional residual to be added to UNet mid block output, for example from ControlNet side model
down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
Returns:
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
Expand Down Expand Up @@ -1216,15 +1224,31 @@ def forward(
scale_lora_layers(self, lora_scale)

is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
# using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
is_adapter = down_intrablock_additional_residuals is not None
# maintain backward compatibility for legacy usage, where
# T2I-Adapter and ControlNet both use down_block_additional_residuals arg
# but can only use one or the other
if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
deprecate(
"T2I should not use down_block_additional_residuals",
"1.3.0",
"Passing intrablock residual connections with `down_block_additional_residuals` is deprecated "
" and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only"
" be used for ControlNet. Please make sure use"
" `down_intrablock_additional_residuals` instead. ",
standard_warn=False,
)
down_intrablock_additional_residuals = down_block_additional_residuals
is_adapter = True

down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
# For t2i-adapter CrossAttnDownBlockFlat
additional_residuals = {}
if is_adapter and len(down_block_additional_residuals) > 0:
additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
if is_adapter and len(down_intrablock_additional_residuals) > 0:
additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)

sample, res_samples = downsample_block(
hidden_states=sample,
Expand All @@ -1237,9 +1261,8 @@ def forward(
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale)

if is_adapter and len(down_block_additional_residuals) > 0:
sample += down_block_additional_residuals.pop(0)
if is_adapter and len(down_intrablock_additional_residuals) > 0:
sample += down_intrablock_additional_residuals.pop(0)

down_block_res_samples += res_samples

Expand Down Expand Up @@ -1267,10 +1290,10 @@ def forward(
# To support T2I-Adapter-XL
if (
is_adapter
and len(down_block_additional_residuals) > 0
and sample.shape == down_block_additional_residuals[0].shape
and len(down_intrablock_additional_residuals) > 0
and sample.shape == down_intrablock_additional_residuals[0].shape
):
sample += down_block_additional_residuals.pop(0)
sample += down_intrablock_additional_residuals.pop(0)

if is_controlnet:
sample = sample + mid_block_additional_residual
Expand Down

0 comments on commit 57239da

Please sign in to comment.