From 0a08d41961220887c97074dcd585e52bba9f6220 Mon Sep 17 00:00:00 2001 From: takuoko Date: Mon, 4 Dec 2023 20:43:34 +0900 Subject: [PATCH] [Feature] Support IP-Adapter Plus (#5915) * Support IP-Adapter Plus * fix format * restore before black format * restore before black format * generic * Refactor PerceiverAttention * format * fix test and refactor PerceiverAttention * generic encode_image * keep attention implementation * merge tests * encode_image backward compatible * code quality * fix controlnet inpaint pipeline * refactor FFN * refactor FFN --------- Co-authored-by: YiYi Xu --- src/diffusers/loaders/unet.py | 106 +++++++++++--- src/diffusers/models/__init__.py | 2 + src/diffusers/models/activations.py | 15 +- src/diffusers/models/attention.py | 12 +- src/diffusers/models/embeddings.py | 89 ++++++++++++ .../alt_diffusion/pipeline_alt_diffusion.py | 28 +++- .../pipeline_alt_diffusion_img2img.py | 28 +++- .../animatediff/pipeline_animatediff.py | 28 +++- .../controlnet/pipeline_controlnet.py | 28 +++- .../controlnet/pipeline_controlnet_inpaint.py | 28 +++- .../controlnet/pipeline_controlnet_sd_xl.py | 28 +++- .../pipeline_stable_diffusion.py | 28 +++- .../pipeline_stable_diffusion_img2img.py | 28 +++- .../pipeline_stable_diffusion_inpaint.py | 28 +++- .../pipeline_stable_diffusion_xl.py | 28 +++- .../pipeline_stable_diffusion_xl_img2img.py | 28 +++- .../pipeline_stable_diffusion_xl_inpaint.py | 28 +++- tests/models/test_models_unet_2d_condition.py | 132 +++++++++++++++++- .../test_ip_adapter_stable_diffusion.py | 114 ++++++++++++++- 19 files changed, 683 insertions(+), 123 deletions(-) diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 6c805672c9cd..9d559a4b4af8 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from collections import defaultdict +from collections import OrderedDict, defaultdict from contextlib import nullcontext from typing import Callable, Dict, List, Optional, Union @@ -21,7 +21,7 @@ import torch.nn.functional as F from torch import nn -from ..models.embeddings import ImageProjection +from ..models.embeddings import ImageProjection, Resampler from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta from ..utils import ( DIFFUSERS_CACHE, @@ -672,6 +672,17 @@ def _load_ip_adapter_weights(self, state_dict): IPAdapterAttnProcessor2_0, ) + if "proj.weight" in state_dict["image_proj"]: + # IP-Adapter + num_image_text_embeds = 4 + else: + # IP-Adapter Plus + num_image_text_embeds = state_dict["image_proj"]["latents"].shape[1] + + # Set encoder_hid_proj after loading ip_adapter weights, + # because `Resampler` also has `attn_processors`. + self.encoder_hid_proj = None + # set ip-adapter cross-attention processors & load state_dict attn_procs = {} key_id = 1 @@ -695,7 +706,10 @@ def _load_ip_adapter_weights(self, state_dict): IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor ) attn_procs[name] = attn_processor_class( - hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0 + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + scale=1.0, + num_tokens=num_image_text_embeds, ).to(dtype=self.dtype, device=self.device) value_dict = {} @@ -708,26 +722,76 @@ def _load_ip_adapter_weights(self, state_dict): self.set_attn_processor(attn_procs) # create image projection layers. - clip_embeddings_dim = state_dict["image_proj"]["proj.weight"].shape[-1] - cross_attention_dim = state_dict["image_proj"]["proj.weight"].shape[0] // 4 + if "proj.weight" in state_dict["image_proj"]: + # IP-Adapter + clip_embeddings_dim = state_dict["image_proj"]["proj.weight"].shape[-1] + cross_attention_dim = state_dict["image_proj"]["proj.weight"].shape[0] // 4 + + image_projection = ImageProjection( + cross_attention_dim=cross_attention_dim, + image_embed_dim=clip_embeddings_dim, + num_image_text_embeds=num_image_text_embeds, + ) + image_projection.to(dtype=self.dtype, device=self.device) - image_projection = ImageProjection( - cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim, num_image_text_embeds=4 - ) - image_projection.to(dtype=self.dtype, device=self.device) - - # load image projection layer weights - image_proj_state_dict = {} - image_proj_state_dict.update( - { - "image_embeds.weight": state_dict["image_proj"]["proj.weight"], - "image_embeds.bias": state_dict["image_proj"]["proj.bias"], - "norm.weight": state_dict["image_proj"]["norm.weight"], - "norm.bias": state_dict["image_proj"]["norm.bias"], - } - ) + # load image projection layer weights + image_proj_state_dict = {} + image_proj_state_dict.update( + { + "image_embeds.weight": state_dict["image_proj"]["proj.weight"], + "image_embeds.bias": state_dict["image_proj"]["proj.bias"], + "norm.weight": state_dict["image_proj"]["norm.weight"], + "norm.bias": state_dict["image_proj"]["norm.bias"], + } + ) + + image_projection.load_state_dict(image_proj_state_dict) + + else: + # IP-Adapter Plus + embed_dims = state_dict["image_proj"]["proj_in.weight"].shape[1] + output_dims = state_dict["image_proj"]["proj_out.weight"].shape[0] + hidden_dims = state_dict["image_proj"]["latents"].shape[2] + heads = state_dict["image_proj"]["layers.0.0.to_q.weight"].shape[0] // 64 + + image_projection = Resampler( + embed_dims=embed_dims, + output_dims=output_dims, + hidden_dims=hidden_dims, + heads=heads, + num_queries=num_image_text_embeds, + ) + + image_proj_state_dict = state_dict["image_proj"] + + new_sd = OrderedDict() + for k, v in image_proj_state_dict.items(): + if "0.to" in k: + k = k.replace("0.to", "2.to") + elif "1.0.weight" in k: + k = k.replace("1.0.weight", "3.0.weight") + elif "1.0.bias" in k: + k = k.replace("1.0.bias", "3.0.bias") + elif "1.1.weight" in k: + k = k.replace("1.1.weight", "3.1.net.0.proj.weight") + elif "1.3.weight" in k: + k = k.replace("1.3.weight", "3.1.net.2.weight") + + if "norm1" in k: + new_sd[k.replace("0.norm1", "0")] = v + elif "norm2" in k: + new_sd[k.replace("0.norm2", "1")] = v + elif "to_kv" in k: + v_chunk = v.chunk(2, dim=0) + new_sd[k.replace("to_kv", "to_k")] = v_chunk[0] + new_sd[k.replace("to_kv", "to_v")] = v_chunk[1] + elif "to_out" in k: + new_sd[k.replace("to_out", "to_out.0")] = v + else: + new_sd[k] = v - image_projection.load_state_dict(image_proj_state_dict) + image_projection.load_state_dict(new_sd) + del image_proj_state_dict self.encoder_hid_proj = image_projection.to(device=self.device, dtype=self.dtype) self.config.encoder_hid_dim_type = "ip_image_proj" diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 1b76b4e03341..49ee3ee6af6b 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -34,6 +34,7 @@ _import_structure["controlnet"] = ["ControlNetModel"] _import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"] _import_structure["modeling_utils"] = ["ModelMixin"] + _import_structure["embeddings"] = ["ImageProjection"] _import_structure["prior_transformer"] = ["PriorTransformer"] _import_structure["t5_film_transformer"] = ["T5FilmDecoder"] _import_structure["transformer_2d"] = ["Transformer2DModel"] @@ -63,6 +64,7 @@ from .consistency_decoder_vae import ConsistencyDecoderVAE from .controlnet import ControlNetModel from .dual_transformer_2d import DualTransformer2DModel + from .embeddings import ImageProjection from .modeling_utils import ModelMixin from .prior_transformer import PriorTransformer from .t5_film_transformer import T5FilmDecoder diff --git a/src/diffusers/models/activations.py b/src/diffusers/models/activations.py index 8b75162ba597..47570eca8443 100644 --- a/src/diffusers/models/activations.py +++ b/src/diffusers/models/activations.py @@ -55,11 +55,12 @@ class GELU(nn.Module): dim_in (`int`): The number of channels in the input. dim_out (`int`): The number of channels in the output. approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. """ - def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"): + def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True): super().__init__() - self.proj = nn.Linear(dim_in, dim_out) + self.proj = nn.Linear(dim_in, dim_out, bias=bias) self.approximate = approximate def gelu(self, gate: torch.Tensor) -> torch.Tensor: @@ -81,13 +82,14 @@ class GEGLU(nn.Module): Parameters: dim_in (`int`): The number of channels in the input. dim_out (`int`): The number of channels in the output. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. """ - def __init__(self, dim_in: int, dim_out: int): + def __init__(self, dim_in: int, dim_out: int, bias: bool = True): super().__init__() linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear - self.proj = linear_cls(dim_in, dim_out * 2) + self.proj = linear_cls(dim_in, dim_out * 2, bias=bias) def gelu(self, gate: torch.Tensor) -> torch.Tensor: if gate.device.type != "mps": @@ -109,11 +111,12 @@ class ApproximateGELU(nn.Module): Parameters: dim_in (`int`): The number of channels in the input. dim_out (`int`): The number of channels in the output. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. """ - def __init__(self, dim_in: int, dim_out: int): + def __init__(self, dim_in: int, dim_out: int, bias: bool = True): super().__init__() - self.proj = nn.Linear(dim_in, dim_out) + self.proj = nn.Linear(dim_in, dim_out, bias=bias) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.proj(x) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index f02b5e249eee..08faaaf3e5bf 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -501,6 +501,7 @@ class FeedForward(nn.Module): dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. """ def __init__( @@ -511,6 +512,7 @@ def __init__( dropout: float = 0.0, activation_fn: str = "geglu", final_dropout: bool = False, + bias: bool = True, ): super().__init__() inner_dim = int(dim * mult) @@ -518,13 +520,13 @@ def __init__( linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear if activation_fn == "gelu": - act_fn = GELU(dim, inner_dim) + act_fn = GELU(dim, inner_dim, bias=bias) if activation_fn == "gelu-approximate": - act_fn = GELU(dim, inner_dim, approximate="tanh") + act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias) elif activation_fn == "geglu": - act_fn = GEGLU(dim, inner_dim) + act_fn = GEGLU(dim, inner_dim, bias=bias) elif activation_fn == "geglu-approximate": - act_fn = ApproximateGELU(dim, inner_dim) + act_fn = ApproximateGELU(dim, inner_dim, bias=bias) self.net = nn.ModuleList([]) # project in @@ -532,7 +534,7 @@ def __init__( # project dropout self.net.append(nn.Dropout(dropout)) # project out - self.net.append(linear_cls(inner_dim, dim_out)) + self.net.append(linear_cls(inner_dim, dim_out, bias=bias)) # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout if final_dropout: self.net.append(nn.Dropout(dropout)) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index a377ae267411..bdd2930d20f9 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -20,6 +20,7 @@ from ..utils import USE_PEFT_BACKEND from .activations import get_activation +from .attention_processor import Attention from .lora import LoRACompatibleLinear @@ -790,3 +791,91 @@ def forward(self, caption, force_drop_ids=None): hidden_states = self.act_1(hidden_states) hidden_states = self.linear_2(hidden_states) return hidden_states + + +class Resampler(nn.Module): + """Resampler of IP-Adapter Plus. + + Args: + ---- + embed_dims (int): The feature dimension. Defaults to 768. + output_dims (int): The number of output channels, that is the same + number of the channels in the + `unet.config.cross_attention_dim`. Defaults to 1024. + hidden_dims (int): The number of hidden channels. Defaults to 1280. + depth (int): The number of blocks. Defaults to 8. + dim_head (int): The number of head channels. Defaults to 64. + heads (int): Parallel attention heads. Defaults to 16. + num_queries (int): The number of queries. Defaults to 8. + ffn_ratio (float): The expansion ratio of feedforward network hidden + layer channels. Defaults to 4. + """ + + def __init__( + self, + embed_dims: int = 768, + output_dims: int = 1024, + hidden_dims: int = 1280, + depth: int = 4, + dim_head: int = 64, + heads: int = 16, + num_queries: int = 8, + ffn_ratio: float = 4, + ) -> None: + super().__init__() + from .attention import FeedForward # Lazy import to avoid circular import + + self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dims) / hidden_dims**0.5) + + self.proj_in = nn.Linear(embed_dims, hidden_dims) + + self.proj_out = nn.Linear(hidden_dims, output_dims) + self.norm_out = nn.LayerNorm(output_dims) + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList( + [ + nn.LayerNorm(hidden_dims), + nn.LayerNorm(hidden_dims), + Attention( + query_dim=hidden_dims, + dim_head=dim_head, + heads=heads, + out_bias=False, + ), + nn.Sequential( + nn.LayerNorm(hidden_dims), + FeedForward(hidden_dims, hidden_dims, activation_fn="gelu", mult=ffn_ratio, bias=False), + ), + ] + ) + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + ---- + x (torch.Tensor): Input Tensor. + + Returns: + ------- + torch.Tensor: Output Tensor. + """ + latents = self.latents.repeat(x.size(0), 1, 1) + + x = self.proj_in(x) + + for ln0, ln1, attn, ff in self.layers: + residual = latents + + encoder_hidden_states = ln0(x) + latents = ln1(latents) + encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2) + latents = attn(latents, encoder_hidden_states) + residual + latents = ff(latents) + latents + + latents = self.proj_out(latents) + return self.norm_out(latents) diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index b5c7aee4b4de..2121e9b81509 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -22,7 +22,7 @@ from ...configuration_utils import FrozenDict from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, UNet2DConditionModel +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -494,18 +494,29 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds - def encode_image(self, image, device, num_images_per_prompt): + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) - uncond_image_embeds = torch.zeros_like(image_embeds) - return image_embeds, uncond_image_embeds + return image_embeds, uncond_image_embeds def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: @@ -875,7 +886,10 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py index 4272fa124755..401e6aef82b1 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -24,7 +24,7 @@ from ...configuration_utils import FrozenDict from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, UNet2DConditionModel +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -505,18 +505,29 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds - def encode_image(self, image, device, num_images_per_prompt): + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) - uncond_image_embeds = torch.zeros_like(image_embeds) - return image_embeds, uncond_image_embeds + return image_embeds, uncond_image_embeds def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: @@ -919,7 +930,10 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index 28dc220545dc..32a08a0264bc 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -22,7 +22,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, UNet2DConditionModel, UNetMotionModel +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel from ...models.lora import adjust_lora_scale_text_encoder from ...models.unet_motion_model import MotionAdapter from ...schedulers import ( @@ -320,18 +320,29 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image - def encode_image(self, image, device, num_images_per_prompt): + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) - uncond_image_embeds = torch.zeros_like(image_embeds) - return image_embeds, uncond_image_embeds + return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents def decode_latents(self, latents): @@ -651,7 +662,10 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_videos_per_prompt) + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_videos_per_prompt, output_hidden_state + ) if do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index 1e19678b221d..bf6ef2125446 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -24,7 +24,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel +from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -479,18 +479,29 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image - def encode_image(self, image, device, num_images_per_prompt): + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) - uncond_image_embeds = torch.zeros_like(image_embeds) - return image_embeds, uncond_image_embeds + return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): @@ -1067,7 +1078,10 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py index 72c2250dd5ac..71e237ce4e02 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py @@ -25,7 +25,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel +from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -597,18 +597,29 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image - def encode_image(self, image, device, num_images_per_prompt): + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) - uncond_image_embeds = torch.zeros_like(image_embeds) - return image_embeds, uncond_image_embeds + return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): @@ -1284,7 +1295,10 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 4696781dce0c..8c8399809228 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -37,7 +37,7 @@ StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ) -from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel +from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, LoRAAttnProcessor2_0, @@ -489,18 +489,29 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image - def encode_image(self, image, device, num_images_per_prompt): + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) - uncond_image_embeds = torch.zeros_like(image_embeds) - return image_embeds, uncond_image_embeds + return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): @@ -1169,7 +1180,10 @@ def __call__( # 3.2 Encode ip_adapter_image if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index bf43c043490b..f7f4a16f0aa4 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -22,7 +22,7 @@ from ...configuration_utils import FrozenDict from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, UNet2DConditionModel +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -489,18 +489,29 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds - def encode_image(self, image, device, num_images_per_prompt): + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) - uncond_image_embeds = torch.zeros_like(image_embeds) - return image_embeds, uncond_image_embeds + return image_embeds, uncond_image_embeds def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: @@ -871,7 +882,10 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index e3a1a0ed3660..c80178152a6e 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -24,7 +24,7 @@ from ...configuration_utils import FrozenDict from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, UNet2DConditionModel +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -503,18 +503,29 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image - def encode_image(self, image, device, num_images_per_prompt): + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) - uncond_image_embeds = torch.zeros_like(image_embeds) - return image_embeds, uncond_image_embeds + return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): @@ -923,7 +934,10 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 3570eaa6fd3d..375197cc9e4d 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -24,7 +24,7 @@ from ...configuration_utils import FrozenDict from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AsymmetricAutoencoderKL, AutoencoderKL, UNet2DConditionModel +from ...models import AsymmetricAutoencoderKL, AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers @@ -574,18 +574,29 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image - def encode_image(self, image, device, num_images_per_prompt): + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) - uncond_image_embeds = torch.zeros_like(image_embeds) - return image_embeds, uncond_image_embeds + return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): @@ -1103,7 +1114,10 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 40c981a46d48..12d52aa076d4 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -31,7 +31,7 @@ StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ) -from ...models import AutoencoderKL, UNet2DConditionModel +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, LoRAAttnProcessor2_0, @@ -524,18 +524,29 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image - def encode_image(self, image, device, num_images_per_prompt): + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) - uncond_image_embeds = torch.zeros_like(image_embeds) - return image_embeds, uncond_image_embeds + return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): @@ -1087,7 +1098,10 @@ def __call__( add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) image_embeds = image_embeds.to(device) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 436d816e5eb3..729924ec2e20 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -32,7 +32,7 @@ StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ) -from ...models import AutoencoderKL, UNet2DConditionModel +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, LoRAAttnProcessor2_0, @@ -741,18 +741,29 @@ def prepare_latents( return latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image - def encode_image(self, image, device, num_images_per_prompt): + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) - uncond_image_embeds = torch.zeros_like(image_embeds) - return image_embeds, uncond_image_embeds + return image_embeds, uncond_image_embeds def _get_add_time_ids( self, @@ -1259,7 +1270,10 @@ def denoising_value_valid(dnv): add_time_ids = add_time_ids.to(device) if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) image_embeds = image_embeds.to(device) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index f54b680dfd7c..7195b5f2521a 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -33,7 +33,7 @@ StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ) -from ...models import AutoencoderKL, UNet2DConditionModel +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, LoRAAttnProcessor2_0, @@ -462,18 +462,29 @@ def disable_vae_tiling(self): self.vae.disable_tiling() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image - def encode_image(self, image, device, num_images_per_prompt): + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) - uncond_image_embeds = torch.zeros_like(image_embeds) - return image_embeds, uncond_image_embeds + return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt def encode_prompt( @@ -1568,7 +1579,10 @@ def denoising_value_valid(dnv): add_time_ids = add_time_ids.to(device) if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) image_embeds = image_embeds.to(device) diff --git a/tests/models/test_models_unet_2d_condition.py b/tests/models/test_models_unet_2d_condition.py index 06bf2685560d..9ccd78f1fe47 100644 --- a/tests/models/test_models_unet_2d_condition.py +++ b/tests/models/test_models_unet_2d_condition.py @@ -18,6 +18,7 @@ import os import tempfile import unittest +from collections import OrderedDict import torch from parameterized import parameterized @@ -25,7 +26,7 @@ from diffusers import UNet2DConditionModel from diffusers.models.attention_processor import CustomDiffusionAttnProcessor, IPAdapterAttnProcessor -from diffusers.models.embeddings import ImageProjection +from diffusers.models.embeddings import ImageProjection, Resampler from diffusers.utils import logging from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.testing_utils import ( @@ -97,6 +98,85 @@ def create_ip_adapter_state_dict(model): return ip_state_dict +def create_ip_adapter_plus_state_dict(model): + # "ip_adapter" (cross-attention weights) + ip_cross_attn_state_dict = {} + key_id = 1 + + for name in model.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = model.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(model.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = model.config.block_out_channels[block_id] + if cross_attention_dim is not None: + sd = IPAdapterAttnProcessor( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0 + ).state_dict() + ip_cross_attn_state_dict.update( + { + f"{key_id}.to_k_ip.weight": sd["to_k_ip.weight"], + f"{key_id}.to_v_ip.weight": sd["to_v_ip.weight"], + } + ) + + key_id += 2 + + # "image_proj" (ImageProjection layer weights) + cross_attention_dim = model.config["cross_attention_dim"] + image_projection = Resampler( + embed_dims=cross_attention_dim, output_dims=cross_attention_dim, dim_head=32, heads=2, num_queries=4 + ) + + ip_image_projection_state_dict = OrderedDict() + for k, v in image_projection.state_dict().items(): + if "2.to" in k: + k = k.replace("2.to", "0.to") + elif "3.0.weight" in k: + k = k.replace("3.0.weight", "1.0.weight") + elif "3.0.bias" in k: + k = k.replace("3.0.bias", "1.0.bias") + elif "3.0.weight" in k: + k = k.replace("3.0.weight", "1.0.weight") + elif "3.1.net.0.proj.weight" in k: + k = k.replace("3.1.net.0.proj.weight", "1.1.weight") + elif "3.net.2.weight" in k: + k = k.replace("3.net.2.weight", "1.3.weight") + elif "layers.0.0" in k: + k = k.replace("layers.0.0", "layers.0.0.norm1") + elif "layers.0.1" in k: + k = k.replace("layers.0.1", "layers.0.0.norm2") + elif "layers.1.0" in k: + k = k.replace("layers.1.0", "layers.1.0.norm1") + elif "layers.1.1" in k: + k = k.replace("layers.1.1", "layers.1.0.norm2") + elif "layers.2.0" in k: + k = k.replace("layers.2.0", "layers.2.0.norm1") + elif "layers.2.1" in k: + k = k.replace("layers.2.1", "layers.2.0.norm2") + + if "norm_cross" in k: + ip_image_projection_state_dict[k.replace("norm_cross", "norm1")] = v + elif "layer_norm" in k: + ip_image_projection_state_dict[k.replace("layer_norm", "norm2")] = v + elif "to_k" in k: + ip_image_projection_state_dict[k.replace("to_k", "to_kv")] = torch.cat([v, v], dim=0) + elif "to_v" in k: + continue + elif "to_out.0" in k: + ip_image_projection_state_dict[k.replace("to_out.0", "to_out")] = v + else: + ip_image_projection_state_dict[k] = v + + ip_state_dict = {} + ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict}) + return ip_state_dict + + def create_custom_diffusion_layers(model, mock_weights: bool = True): train_kv = True train_q_out = True @@ -724,6 +804,56 @@ def test_ip_adapter(self): assert not sample2.allclose(sample3, atol=1e-4, rtol=1e-4) assert sample2.allclose(sample4, atol=1e-4, rtol=1e-4) + def test_ip_adapter_plus(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = (8, 16) + + model = self.model_class(**init_dict) + model.to(torch_device) + + # forward pass without ip-adapter + with torch.no_grad(): + sample1 = model(**inputs_dict).sample + + # update inputs_dict for ip-adapter + batch_size = inputs_dict["encoder_hidden_states"].shape[0] + image_embeds = floats_tensor((batch_size, 1, model.cross_attention_dim)).to(torch_device) + inputs_dict["added_cond_kwargs"] = {"image_embeds": image_embeds} + + # make ip_adapter_1 and ip_adapter_2 + ip_adapter_1 = create_ip_adapter_plus_state_dict(model) + + image_proj_state_dict_2 = {k: w + 1.0 for k, w in ip_adapter_1["image_proj"].items()} + cross_attn_state_dict_2 = {k: w + 1.0 for k, w in ip_adapter_1["ip_adapter"].items()} + ip_adapter_2 = {} + ip_adapter_2.update({"image_proj": image_proj_state_dict_2, "ip_adapter": cross_attn_state_dict_2}) + + # forward pass ip_adapter_1 + model._load_ip_adapter_weights(ip_adapter_1) + assert model.config.encoder_hid_dim_type == "ip_image_proj" + assert model.encoder_hid_proj is not None + assert model.down_blocks[0].attentions[0].transformer_blocks[0].attn2.processor.__class__.__name__ in ( + "IPAdapterAttnProcessor", + "IPAdapterAttnProcessor2_0", + ) + with torch.no_grad(): + sample2 = model(**inputs_dict).sample + + # forward pass with ip_adapter_2 + model._load_ip_adapter_weights(ip_adapter_2) + with torch.no_grad(): + sample3 = model(**inputs_dict).sample + + # forward pass with ip_adapter_1 again + model._load_ip_adapter_weights(ip_adapter_1) + with torch.no_grad(): + sample4 = model(**inputs_dict).sample + + assert not sample1.allclose(sample2, atol=1e-4, rtol=1e-4) + assert not sample2.allclose(sample3, atol=1e-4, rtol=1e-4) + assert sample2.allclose(sample4, atol=1e-4, rtol=1e-4) + @slow class UNet2DConditionModelIntegrationTests(unittest.TestCase): diff --git a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py index 57eb49013c1f..7c6349ce2600 100644 --- a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py +++ b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py @@ -116,7 +116,17 @@ def test_text_to_image(self): images = pipeline(**inputs).images image_slice = images[0, :3, :3, -1].flatten() - expected_slice = np.array([0.8047, 0.8774, 0.9248, 0.9155, 0.9814, 1.0, 0.9678, 1.0, 1.0]) + expected_slice = np.array([0.8110, 0.8843, 0.9326, 0.9224, 0.9878, 1.0, 0.9736, 1.0, 1.0]) + + assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) + + pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.bin") + + inputs = self.get_dummy_inputs() + images = pipeline(**inputs).images + image_slice = images[0, :3, :3, -1].flatten() + + expected_slice = np.array([0.3013, 0.2615, 0.2202, 0.2722, 0.2510, 0.2023, 0.2498, 0.2415, 0.2139]) assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) @@ -132,7 +142,17 @@ def test_image_to_image(self): images = pipeline(**inputs).images image_slice = images[0, :3, :3, -1].flatten() - expected_slice = np.array([0.2307, 0.2341, 0.2305, 0.24, 0.2268, 0.25, 0.2322, 0.2588, 0.2935]) + expected_slice = np.array([0.2253, 0.2251, 0.2219, 0.2312, 0.2236, 0.2434, 0.2275, 0.2575, 0.2805]) + + assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) + + pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.bin") + + inputs = self.get_dummy_inputs(for_image_to_image=True) + images = pipeline(**inputs).images + image_slice = images[0, :3, :3, -1].flatten() + + expected_slice = np.array([0.3550, 0.2600, 0.2520, 0.2412, 0.1870, 0.3831, 0.1453, 0.1880, 0.5371]) assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) @@ -148,7 +168,17 @@ def test_inpainting(self): images = pipeline(**inputs).images image_slice = images[0, :3, :3, -1].flatten() - expected_slice = np.array([0.2705, 0.2395, 0.2209, 0.2312, 0.2102, 0.2104, 0.2178, 0.2065, 0.1997]) + expected_slice = np.array([0.2700, 0.2388, 0.2202, 0.2304, 0.2095, 0.2097, 0.2173, 0.2058, 0.1987]) + + assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) + + pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.bin") + + inputs = self.get_dummy_inputs(for_inpainting=True) + images = pipeline(**inputs).images + image_slice = images[0, :3, :3, -1].flatten() + + expected_slice = np.array([0.2744, 0.2410, 0.2202, 0.2334, 0.2090, 0.2053, 0.2175, 0.2033, 0.1934]) assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) @@ -173,7 +203,30 @@ def test_text_to_image_sdxl(self): images = pipeline(**inputs).images image_slice = images[0, :3, :3, -1].flatten() - expected_slice = np.array([0.0968, 0.0959, 0.0852, 0.0912, 0.0948, 0.093, 0.0893, 0.0932, 0.0923]) + expected_slice = np.array([0.0965, 0.0956, 0.0849, 0.0908, 0.0944, 0.0927, 0.0888, 0.0929, 0.0920]) + + assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) + + image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder") + + pipeline = StableDiffusionXLPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + image_encoder=image_encoder, + feature_extractor=feature_extractor, + torch_dtype=self.dtype, + ) + pipeline.to(torch_device) + pipeline.load_ip_adapter( + "h94/IP-Adapter", + subfolder="sdxl_models", + weight_name="ip-adapter-plus_sdxl_vit-h.bin", + ) + + inputs = self.get_dummy_inputs() + images = pipeline(**inputs).images + image_slice = images[0, :3, :3, -1].flatten() + + expected_slice = np.array([0.0592, 0.0573, 0.0459, 0.0542, 0.0559, 0.0523, 0.0500, 0.0540, 0.0501]) assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) @@ -194,7 +247,31 @@ def test_image_to_image_sdxl(self): images = pipeline(**inputs).images image_slice = images[0, :3, :3, -1].flatten() - expected_slice = np.array([0.0653, 0.0704, 0.0725, 0.0741, 0.0702, 0.0647, 0.0782, 0.0799, 0.0752]) + expected_slice = np.array([0.0652, 0.0698, 0.0723, 0.0744, 0.0699, 0.0636, 0.0784, 0.0803, 0.0742]) + + assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) + + image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder") + feature_extractor = self.get_image_processor("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") + + pipeline = StableDiffusionXLImg2ImgPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + image_encoder=image_encoder, + feature_extractor=feature_extractor, + torch_dtype=self.dtype, + ) + pipeline.to(torch_device) + pipeline.load_ip_adapter( + "h94/IP-Adapter", + subfolder="sdxl_models", + weight_name="ip-adapter-plus_sdxl_vit-h.bin", + ) + + inputs = self.get_dummy_inputs(for_image_to_image=True) + images = pipeline(**inputs).images + image_slice = images[0, :3, :3, -1].flatten() + + expected_slice = np.array([0.0708, 0.0701, 0.0735, 0.0760, 0.0739, 0.0679, 0.0756, 0.0824, 0.0837]) assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) @@ -216,6 +293,31 @@ def test_inpainting_sdxl(self): image_slice = images[0, :3, :3, -1].flatten() image_slice.tolist() - expected_slice = np.array([0.1418, 0.1493, 0.1428, 0.146, 0.1491, 0.1501, 0.1473, 0.1501, 0.1516]) + expected_slice = np.array([0.1420, 0.1495, 0.1430, 0.1462, 0.1493, 0.1502, 0.1474, 0.1502, 0.1517]) + + assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) + + image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder") + feature_extractor = self.get_image_processor("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") + + pipeline = StableDiffusionXLInpaintPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + image_encoder=image_encoder, + feature_extractor=feature_extractor, + torch_dtype=self.dtype, + ) + pipeline.to(torch_device) + pipeline.load_ip_adapter( + "h94/IP-Adapter", + subfolder="sdxl_models", + weight_name="ip-adapter-plus_sdxl_vit-h.bin", + ) + + inputs = self.get_dummy_inputs(for_inpainting=True) + images = pipeline(**inputs).images + image_slice = images[0, :3, :3, -1].flatten() + image_slice.tolist() + + expected_slice = np.array([0.1398, 0.1476, 0.1407, 0.1442, 0.1470, 0.1480, 0.1449, 0.1481, 0.1494]) assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4)