From db6ad243f2f919cb25ce2f38b817fdfc93f1e787 Mon Sep 17 00:00:00 2001 From: takuoko Date: Mon, 4 Dec 2023 20:43:34 +0900 Subject: [PATCH] [Feature] Support IP-Adapter Plus (#5915) * Support IP-Adapter Plus * fix format * restore before black format * restore before black format * generic * Refactor PerceiverAttention * format * fix test and refactor PerceiverAttention * generic encode_image * keep attention implementation * merge tests * encode_image backward compatible * code quality * fix controlnet inpaint pipeline * refactor FFN * refactor FFN --------- Co-authored-by: YiYi Xu --- loaders/unet.py | 106 ++++++++++++++---- models/__init__.py | 2 + models/activations.py | 15 ++- models/attention.py | 12 +- models/embeddings.py | 89 +++++++++++++++ .../alt_diffusion/pipeline_alt_diffusion.py | 28 +++-- .../pipeline_alt_diffusion_img2img.py | 28 +++-- pipelines/animatediff/pipeline_animatediff.py | 28 +++-- pipelines/controlnet/pipeline_controlnet.py | 28 +++-- .../controlnet/pipeline_controlnet_inpaint.py | 28 +++-- .../controlnet/pipeline_controlnet_sd_xl.py | 28 +++-- .../pipeline_stable_diffusion.py | 28 +++-- .../pipeline_stable_diffusion_img2img.py | 28 +++-- .../pipeline_stable_diffusion_inpaint.py | 28 +++-- .../pipeline_stable_diffusion_xl.py | 28 +++-- .../pipeline_stable_diffusion_xl_img2img.py | 28 +++-- .../pipeline_stable_diffusion_xl_inpaint.py | 28 +++-- 17 files changed, 444 insertions(+), 116 deletions(-) diff --git a/loaders/unet.py b/loaders/unet.py index 6c805672c9cd..9d559a4b4af8 100644 --- a/loaders/unet.py +++ b/loaders/unet.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from collections import defaultdict +from collections import OrderedDict, defaultdict from contextlib import nullcontext from typing import Callable, Dict, List, Optional, Union @@ -21,7 +21,7 @@ import torch.nn.functional as F from torch import nn -from ..models.embeddings import ImageProjection +from ..models.embeddings import ImageProjection, Resampler from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta from ..utils import ( DIFFUSERS_CACHE, @@ -672,6 +672,17 @@ def _load_ip_adapter_weights(self, state_dict): IPAdapterAttnProcessor2_0, ) + if "proj.weight" in state_dict["image_proj"]: + # IP-Adapter + num_image_text_embeds = 4 + else: + # IP-Adapter Plus + num_image_text_embeds = state_dict["image_proj"]["latents"].shape[1] + + # Set encoder_hid_proj after loading ip_adapter weights, + # because `Resampler` also has `attn_processors`. + self.encoder_hid_proj = None + # set ip-adapter cross-attention processors & load state_dict attn_procs = {} key_id = 1 @@ -695,7 +706,10 @@ def _load_ip_adapter_weights(self, state_dict): IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor ) attn_procs[name] = attn_processor_class( - hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0 + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + scale=1.0, + num_tokens=num_image_text_embeds, ).to(dtype=self.dtype, device=self.device) value_dict = {} @@ -708,26 +722,76 @@ def _load_ip_adapter_weights(self, state_dict): self.set_attn_processor(attn_procs) # create image projection layers. - clip_embeddings_dim = state_dict["image_proj"]["proj.weight"].shape[-1] - cross_attention_dim = state_dict["image_proj"]["proj.weight"].shape[0] // 4 + if "proj.weight" in state_dict["image_proj"]: + # IP-Adapter + clip_embeddings_dim = state_dict["image_proj"]["proj.weight"].shape[-1] + cross_attention_dim = state_dict["image_proj"]["proj.weight"].shape[0] // 4 + + image_projection = ImageProjection( + cross_attention_dim=cross_attention_dim, + image_embed_dim=clip_embeddings_dim, + num_image_text_embeds=num_image_text_embeds, + ) + image_projection.to(dtype=self.dtype, device=self.device) - image_projection = ImageProjection( - cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim, num_image_text_embeds=4 - ) - image_projection.to(dtype=self.dtype, device=self.device) - - # load image projection layer weights - image_proj_state_dict = {} - image_proj_state_dict.update( - { - "image_embeds.weight": state_dict["image_proj"]["proj.weight"], - "image_embeds.bias": state_dict["image_proj"]["proj.bias"], - "norm.weight": state_dict["image_proj"]["norm.weight"], - "norm.bias": state_dict["image_proj"]["norm.bias"], - } - ) + # load image projection layer weights + image_proj_state_dict = {} + image_proj_state_dict.update( + { + "image_embeds.weight": state_dict["image_proj"]["proj.weight"], + "image_embeds.bias": state_dict["image_proj"]["proj.bias"], + "norm.weight": state_dict["image_proj"]["norm.weight"], + "norm.bias": state_dict["image_proj"]["norm.bias"], + } + ) + + image_projection.load_state_dict(image_proj_state_dict) + + else: + # IP-Adapter Plus + embed_dims = state_dict["image_proj"]["proj_in.weight"].shape[1] + output_dims = state_dict["image_proj"]["proj_out.weight"].shape[0] + hidden_dims = state_dict["image_proj"]["latents"].shape[2] + heads = state_dict["image_proj"]["layers.0.0.to_q.weight"].shape[0] // 64 + + image_projection = Resampler( + embed_dims=embed_dims, + output_dims=output_dims, + hidden_dims=hidden_dims, + heads=heads, + num_queries=num_image_text_embeds, + ) + + image_proj_state_dict = state_dict["image_proj"] + + new_sd = OrderedDict() + for k, v in image_proj_state_dict.items(): + if "0.to" in k: + k = k.replace("0.to", "2.to") + elif "1.0.weight" in k: + k = k.replace("1.0.weight", "3.0.weight") + elif "1.0.bias" in k: + k = k.replace("1.0.bias", "3.0.bias") + elif "1.1.weight" in k: + k = k.replace("1.1.weight", "3.1.net.0.proj.weight") + elif "1.3.weight" in k: + k = k.replace("1.3.weight", "3.1.net.2.weight") + + if "norm1" in k: + new_sd[k.replace("0.norm1", "0")] = v + elif "norm2" in k: + new_sd[k.replace("0.norm2", "1")] = v + elif "to_kv" in k: + v_chunk = v.chunk(2, dim=0) + new_sd[k.replace("to_kv", "to_k")] = v_chunk[0] + new_sd[k.replace("to_kv", "to_v")] = v_chunk[1] + elif "to_out" in k: + new_sd[k.replace("to_out", "to_out.0")] = v + else: + new_sd[k] = v - image_projection.load_state_dict(image_proj_state_dict) + image_projection.load_state_dict(new_sd) + del image_proj_state_dict self.encoder_hid_proj = image_projection.to(device=self.device, dtype=self.dtype) self.config.encoder_hid_dim_type = "ip_image_proj" diff --git a/models/__init__.py b/models/__init__.py index 1b76b4e03341..49ee3ee6af6b 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -34,6 +34,7 @@ _import_structure["controlnet"] = ["ControlNetModel"] _import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"] _import_structure["modeling_utils"] = ["ModelMixin"] + _import_structure["embeddings"] = ["ImageProjection"] _import_structure["prior_transformer"] = ["PriorTransformer"] _import_structure["t5_film_transformer"] = ["T5FilmDecoder"] _import_structure["transformer_2d"] = ["Transformer2DModel"] @@ -63,6 +64,7 @@ from .consistency_decoder_vae import ConsistencyDecoderVAE from .controlnet import ControlNetModel from .dual_transformer_2d import DualTransformer2DModel + from .embeddings import ImageProjection from .modeling_utils import ModelMixin from .prior_transformer import PriorTransformer from .t5_film_transformer import T5FilmDecoder diff --git a/models/activations.py b/models/activations.py index 8b75162ba597..47570eca8443 100644 --- a/models/activations.py +++ b/models/activations.py @@ -55,11 +55,12 @@ class GELU(nn.Module): dim_in (`int`): The number of channels in the input. dim_out (`int`): The number of channels in the output. approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. """ - def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"): + def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True): super().__init__() - self.proj = nn.Linear(dim_in, dim_out) + self.proj = nn.Linear(dim_in, dim_out, bias=bias) self.approximate = approximate def gelu(self, gate: torch.Tensor) -> torch.Tensor: @@ -81,13 +82,14 @@ class GEGLU(nn.Module): Parameters: dim_in (`int`): The number of channels in the input. dim_out (`int`): The number of channels in the output. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. """ - def __init__(self, dim_in: int, dim_out: int): + def __init__(self, dim_in: int, dim_out: int, bias: bool = True): super().__init__() linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear - self.proj = linear_cls(dim_in, dim_out * 2) + self.proj = linear_cls(dim_in, dim_out * 2, bias=bias) def gelu(self, gate: torch.Tensor) -> torch.Tensor: if gate.device.type != "mps": @@ -109,11 +111,12 @@ class ApproximateGELU(nn.Module): Parameters: dim_in (`int`): The number of channels in the input. dim_out (`int`): The number of channels in the output. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. """ - def __init__(self, dim_in: int, dim_out: int): + def __init__(self, dim_in: int, dim_out: int, bias: bool = True): super().__init__() - self.proj = nn.Linear(dim_in, dim_out) + self.proj = nn.Linear(dim_in, dim_out, bias=bias) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.proj(x) diff --git a/models/attention.py b/models/attention.py index f02b5e249eee..08faaaf3e5bf 100644 --- a/models/attention.py +++ b/models/attention.py @@ -501,6 +501,7 @@ class FeedForward(nn.Module): dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. """ def __init__( @@ -511,6 +512,7 @@ def __init__( dropout: float = 0.0, activation_fn: str = "geglu", final_dropout: bool = False, + bias: bool = True, ): super().__init__() inner_dim = int(dim * mult) @@ -518,13 +520,13 @@ def __init__( linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear if activation_fn == "gelu": - act_fn = GELU(dim, inner_dim) + act_fn = GELU(dim, inner_dim, bias=bias) if activation_fn == "gelu-approximate": - act_fn = GELU(dim, inner_dim, approximate="tanh") + act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias) elif activation_fn == "geglu": - act_fn = GEGLU(dim, inner_dim) + act_fn = GEGLU(dim, inner_dim, bias=bias) elif activation_fn == "geglu-approximate": - act_fn = ApproximateGELU(dim, inner_dim) + act_fn = ApproximateGELU(dim, inner_dim, bias=bias) self.net = nn.ModuleList([]) # project in @@ -532,7 +534,7 @@ def __init__( # project dropout self.net.append(nn.Dropout(dropout)) # project out - self.net.append(linear_cls(inner_dim, dim_out)) + self.net.append(linear_cls(inner_dim, dim_out, bias=bias)) # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout if final_dropout: self.net.append(nn.Dropout(dropout)) diff --git a/models/embeddings.py b/models/embeddings.py index a377ae267411..bdd2930d20f9 100644 --- a/models/embeddings.py +++ b/models/embeddings.py @@ -20,6 +20,7 @@ from ..utils import USE_PEFT_BACKEND from .activations import get_activation +from .attention_processor import Attention from .lora import LoRACompatibleLinear @@ -790,3 +791,91 @@ def forward(self, caption, force_drop_ids=None): hidden_states = self.act_1(hidden_states) hidden_states = self.linear_2(hidden_states) return hidden_states + + +class Resampler(nn.Module): + """Resampler of IP-Adapter Plus. + + Args: + ---- + embed_dims (int): The feature dimension. Defaults to 768. + output_dims (int): The number of output channels, that is the same + number of the channels in the + `unet.config.cross_attention_dim`. Defaults to 1024. + hidden_dims (int): The number of hidden channels. Defaults to 1280. + depth (int): The number of blocks. Defaults to 8. + dim_head (int): The number of head channels. Defaults to 64. + heads (int): Parallel attention heads. Defaults to 16. + num_queries (int): The number of queries. Defaults to 8. + ffn_ratio (float): The expansion ratio of feedforward network hidden + layer channels. Defaults to 4. + """ + + def __init__( + self, + embed_dims: int = 768, + output_dims: int = 1024, + hidden_dims: int = 1280, + depth: int = 4, + dim_head: int = 64, + heads: int = 16, + num_queries: int = 8, + ffn_ratio: float = 4, + ) -> None: + super().__init__() + from .attention import FeedForward # Lazy import to avoid circular import + + self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dims) / hidden_dims**0.5) + + self.proj_in = nn.Linear(embed_dims, hidden_dims) + + self.proj_out = nn.Linear(hidden_dims, output_dims) + self.norm_out = nn.LayerNorm(output_dims) + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList( + [ + nn.LayerNorm(hidden_dims), + nn.LayerNorm(hidden_dims), + Attention( + query_dim=hidden_dims, + dim_head=dim_head, + heads=heads, + out_bias=False, + ), + nn.Sequential( + nn.LayerNorm(hidden_dims), + FeedForward(hidden_dims, hidden_dims, activation_fn="gelu", mult=ffn_ratio, bias=False), + ), + ] + ) + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + ---- + x (torch.Tensor): Input Tensor. + + Returns: + ------- + torch.Tensor: Output Tensor. + """ + latents = self.latents.repeat(x.size(0), 1, 1) + + x = self.proj_in(x) + + for ln0, ln1, attn, ff in self.layers: + residual = latents + + encoder_hidden_states = ln0(x) + latents = ln1(latents) + encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2) + latents = attn(latents, encoder_hidden_states) + residual + latents = ff(latents) + latents + + latents = self.proj_out(latents) + return self.norm_out(latents) diff --git a/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/pipelines/alt_diffusion/pipeline_alt_diffusion.py index b5c7aee4b4de..2121e9b81509 100644 --- a/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -22,7 +22,7 @@ from ...configuration_utils import FrozenDict from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, UNet2DConditionModel +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -494,18 +494,29 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds - def encode_image(self, image, device, num_images_per_prompt): + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) - uncond_image_embeds = torch.zeros_like(image_embeds) - return image_embeds, uncond_image_embeds + return image_embeds, uncond_image_embeds def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: @@ -875,7 +886,10 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) diff --git a/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py index 4272fa124755..401e6aef82b1 100644 --- a/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -24,7 +24,7 @@ from ...configuration_utils import FrozenDict from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, UNet2DConditionModel +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -505,18 +505,29 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds - def encode_image(self, image, device, num_images_per_prompt): + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) - uncond_image_embeds = torch.zeros_like(image_embeds) - return image_embeds, uncond_image_embeds + return image_embeds, uncond_image_embeds def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: @@ -919,7 +930,10 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) diff --git a/pipelines/animatediff/pipeline_animatediff.py b/pipelines/animatediff/pipeline_animatediff.py index 28dc220545dc..32a08a0264bc 100644 --- a/pipelines/animatediff/pipeline_animatediff.py +++ b/pipelines/animatediff/pipeline_animatediff.py @@ -22,7 +22,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, UNet2DConditionModel, UNetMotionModel +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel from ...models.lora import adjust_lora_scale_text_encoder from ...models.unet_motion_model import MotionAdapter from ...schedulers import ( @@ -320,18 +320,29 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image - def encode_image(self, image, device, num_images_per_prompt): + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) - uncond_image_embeds = torch.zeros_like(image_embeds) - return image_embeds, uncond_image_embeds + return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents def decode_latents(self, latents): @@ -651,7 +662,10 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_videos_per_prompt) + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_videos_per_prompt, output_hidden_state + ) if do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) diff --git a/pipelines/controlnet/pipeline_controlnet.py b/pipelines/controlnet/pipeline_controlnet.py index 1e19678b221d..bf6ef2125446 100644 --- a/pipelines/controlnet/pipeline_controlnet.py +++ b/pipelines/controlnet/pipeline_controlnet.py @@ -24,7 +24,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel +from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -479,18 +479,29 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image - def encode_image(self, image, device, num_images_per_prompt): + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) - uncond_image_embeds = torch.zeros_like(image_embeds) - return image_embeds, uncond_image_embeds + return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): @@ -1067,7 +1078,10 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) diff --git a/pipelines/controlnet/pipeline_controlnet_inpaint.py b/pipelines/controlnet/pipeline_controlnet_inpaint.py index 72c2250dd5ac..71e237ce4e02 100644 --- a/pipelines/controlnet/pipeline_controlnet_inpaint.py +++ b/pipelines/controlnet/pipeline_controlnet_inpaint.py @@ -25,7 +25,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel +from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -597,18 +597,29 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image - def encode_image(self, image, device, num_images_per_prompt): + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) - uncond_image_embeds = torch.zeros_like(image_embeds) - return image_embeds, uncond_image_embeds + return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): @@ -1284,7 +1295,10 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) diff --git a/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 4696781dce0c..8c8399809228 100644 --- a/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -37,7 +37,7 @@ StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ) -from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel +from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, LoRAAttnProcessor2_0, @@ -489,18 +489,29 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image - def encode_image(self, image, device, num_images_per_prompt): + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) - uncond_image_embeds = torch.zeros_like(image_embeds) - return image_embeds, uncond_image_embeds + return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): @@ -1169,7 +1180,10 @@ def __call__( # 3.2 Encode ip_adapter_image if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) diff --git a/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/pipelines/stable_diffusion/pipeline_stable_diffusion.py index bf43c043490b..f7f4a16f0aa4 100644 --- a/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -22,7 +22,7 @@ from ...configuration_utils import FrozenDict from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, UNet2DConditionModel +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -489,18 +489,29 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds - def encode_image(self, image, device, num_images_per_prompt): + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) - uncond_image_embeds = torch.zeros_like(image_embeds) - return image_embeds, uncond_image_embeds + return image_embeds, uncond_image_embeds def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: @@ -871,7 +882,10 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) diff --git a/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index e3a1a0ed3660..c80178152a6e 100644 --- a/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -24,7 +24,7 @@ from ...configuration_utils import FrozenDict from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, UNet2DConditionModel +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -503,18 +503,29 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image - def encode_image(self, image, device, num_images_per_prompt): + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) - uncond_image_embeds = torch.zeros_like(image_embeds) - return image_embeds, uncond_image_embeds + return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): @@ -923,7 +934,10 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) diff --git a/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 3570eaa6fd3d..375197cc9e4d 100644 --- a/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -24,7 +24,7 @@ from ...configuration_utils import FrozenDict from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AsymmetricAutoencoderKL, AutoencoderKL, UNet2DConditionModel +from ...models import AsymmetricAutoencoderKL, AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers @@ -574,18 +574,29 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image - def encode_image(self, image, device, num_images_per_prompt): + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) - uncond_image_embeds = torch.zeros_like(image_embeds) - return image_embeds, uncond_image_embeds + return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): @@ -1103,7 +1114,10 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) diff --git a/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 40c981a46d48..12d52aa076d4 100644 --- a/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -31,7 +31,7 @@ StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ) -from ...models import AutoencoderKL, UNet2DConditionModel +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, LoRAAttnProcessor2_0, @@ -524,18 +524,29 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image - def encode_image(self, image, device, num_images_per_prompt): + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) - uncond_image_embeds = torch.zeros_like(image_embeds) - return image_embeds, uncond_image_embeds + return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): @@ -1087,7 +1098,10 @@ def __call__( add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) image_embeds = image_embeds.to(device) diff --git a/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 436d816e5eb3..729924ec2e20 100644 --- a/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -32,7 +32,7 @@ StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ) -from ...models import AutoencoderKL, UNet2DConditionModel +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, LoRAAttnProcessor2_0, @@ -741,18 +741,29 @@ def prepare_latents( return latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image - def encode_image(self, image, device, num_images_per_prompt): + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) - uncond_image_embeds = torch.zeros_like(image_embeds) - return image_embeds, uncond_image_embeds + return image_embeds, uncond_image_embeds def _get_add_time_ids( self, @@ -1259,7 +1270,10 @@ def denoising_value_valid(dnv): add_time_ids = add_time_ids.to(device) if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) image_embeds = image_embeds.to(device) diff --git a/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index f54b680dfd7c..7195b5f2521a 100644 --- a/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -33,7 +33,7 @@ StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ) -from ...models import AutoencoderKL, UNet2DConditionModel +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, LoRAAttnProcessor2_0, @@ -462,18 +462,29 @@ def disable_vae_tiling(self): self.vae.disable_tiling() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image - def encode_image(self, image, device, num_images_per_prompt): + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) - uncond_image_embeds = torch.zeros_like(image_embeds) - return image_embeds, uncond_image_embeds + return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt def encode_prompt( @@ -1568,7 +1579,10 @@ def denoising_value_valid(dnv): add_time_ids = add_time_ids.to(device) if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) image_embeds = image_embeds.to(device)