diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 4039fbfcc67a..06421305c301 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -20,7 +20,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..loaders import UNet2DConditionLoadersMixin -from ..utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers +from ..utils import USE_PEFT_BACKEND, BaseOutput, logging, deprecate, scale_lora_layers, unscale_lora_layers from .activations import get_activation from .attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, @@ -778,6 +778,7 @@ def forward( added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, mid_block_additional_residual: Optional[torch.Tensor] = None, + down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, encoder_attention_mask: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[UNet2DConditionOutput, Tuple]: @@ -822,6 +823,13 @@ def forward( added_cond_kwargs: (`dict`, *optional*): A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that are passed along to the UNet blocks. + down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*): + additional residuals to be added to UNet long skip connections from down blocks to up blocks + for example from ControlNet side model(s) + mid_block_additional_residual (`torch.Tensor`, *optional*): + additional residual to be added to UNet mid block output, for example from ControlNet side model + down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*): + additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s) Returns: [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: @@ -1000,15 +1008,28 @@ def forward( scale_lora_layers(self, lora_scale) is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None - is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None + # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets + is_adapter = down_intrablock_additional_residuals is not None + # maintain backward compatibility for legacy usage, where + # T2I-Adapter and ControlNet both use down_block_additional_residuals arg + # but can only use one or the other + if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None: + deprecate("T2I should not use down_block_additional_residuals", + "1.3.0", + "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \ + and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \ + for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ", + standard_warn=False) + down_intrablock_additional_residuals = down_block_additional_residuals + is_adapter = True down_block_res_samples = (sample,) for downsample_block in self.down_blocks: if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: # For t2i-adapter CrossAttnDownBlock2D additional_residuals = {} - if is_adapter and len(down_block_additional_residuals) > 0: - additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0) + if is_adapter and len(down_intrablock_additional_residuals) > 0: + additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) sample, res_samples = downsample_block( hidden_states=sample, @@ -1021,9 +1042,8 @@ def forward( ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale) - - if is_adapter and len(down_block_additional_residuals) > 0: - sample += down_block_additional_residuals.pop(0) + if is_adapter and len(down_intrablock_additional_residuals) > 0: + sample += down_intrablock_additional_residuals.pop(0) down_block_res_samples += res_samples @@ -1051,10 +1071,10 @@ def forward( # To support T2I-Adapter-XL if ( is_adapter - and len(down_block_additional_residuals) > 0 - and sample.shape == down_block_additional_residuals[0].shape + and len(down_intrablock_additional_residuals) > 0 + and sample.shape == down_intrablock_additional_residuals[0].shape ): - sample += down_block_additional_residuals.pop(0) + sample += down_intrablock_additional_residuals.pop(0) if is_controlnet: sample = sample + mid_block_additional_residual diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py index 0c7120c5b3ec..dca9e5fc3de2 100644 --- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py @@ -813,7 +813,7 @@ def __call__( t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, - down_block_additional_residuals=[state.clone() for state in adapter_state], + down_intrablock_additional_residuals=[state.clone() for state in adapter_state], ).sample # perform guidance diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py index b31d478a9d67..d4272696c23b 100644 --- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py @@ -975,9 +975,9 @@ def __call__( added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} if i < int(num_inference_steps * adapter_conditioning_factor): - down_block_additional_residuals = [state.clone() for state in adapter_state] + down_intrablock_additional_residuals = [state.clone() for state in adapter_state] else: - down_block_additional_residuals = None + down_intrablock_additional_residuals = None noise_pred = self.unet( latent_model_input, @@ -986,7 +986,7 @@ def __call__( cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, return_dict=False, - down_block_additional_residuals=down_block_additional_residuals, + down_intrablock_additional_residuals=down_intrablock_additional_residuals, )[0] # perform guidance