From e185084a5df12f6cf23ba455828808a3d8e3fda6 Mon Sep 17 00:00:00 2001 From: Levi McCallum Date: Mon, 4 Dec 2023 03:04:15 -0800 Subject: [PATCH 1/7] Add variant argument to dreambooth lora sdxl advanced (#6021) Co-authored-by: Sayak Paul --- .../train_dreambooth_lora_sdxl_advanced.py | 48 +++++++++++++++---- 1 file changed, 39 insertions(+), 9 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index abd169b8bc97..29fe2744ad7a 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -225,6 +225,12 @@ def parse_args(input_args=None): required=False, help="Revision of pretrained model identifier from huggingface.co/models.", ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) parser.add_argument( "--dataset_name", type=str, @@ -1064,6 +1070,7 @@ def main(args): args.pretrained_model_name_or_path, torch_dtype=torch_dtype, revision=args.revision, + variant=args.variant, ) pipeline.set_progress_bar_config(disable=True) @@ -1102,10 +1109,18 @@ def main(args): # Load the tokenizers tokenizer_one = AutoTokenizer.from_pretrained( - args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + variant=args.variant, + use_fast=False, ) tokenizer_two = AutoTokenizer.from_pretrained( - args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False + args.pretrained_model_name_or_path, + subfolder="tokenizer_2", + revision=args.revision, + variant=args.variant, + use_fast=False, ) # import correct text encoder classes @@ -1119,10 +1134,10 @@ def main(args): # Load scheduler and models noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") text_encoder_one = text_encoder_cls_one.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant ) text_encoder_two = text_encoder_cls_two.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision + args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant ) vae_path = ( args.pretrained_model_name_or_path @@ -1130,10 +1145,13 @@ def main(args): else args.pretrained_vae_model_name_or_path ) vae = AutoencoderKL.from_pretrained( - vae_path, subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, revision=args.revision + vae_path, + subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, + revision=args.revision, + variant=args.variant, ) unet = UNet2DConditionModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant ) if args.train_text_encoder_ti: @@ -1843,10 +1861,16 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # create pipeline if freeze_text_encoder: text_encoder_one = text_encoder_cls_one.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + args.pretrained_model_name_or_path, + subfolder="text_encoder", + revision=args.revision, + variant=args.variant, ) text_encoder_two = text_encoder_cls_two.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision + args.pretrained_model_name_or_path, + subfolder="text_encoder_2", + revision=args.revision, + variant=args.variant, ) pipeline = StableDiffusionXLPipeline.from_pretrained( args.pretrained_model_name_or_path, @@ -1855,6 +1879,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): text_encoder_2=accelerator.unwrap_model(text_encoder_two), unet=accelerator.unwrap_model(unet), revision=args.revision, + variant=args.variant, torch_dtype=weight_dtype, ) @@ -1932,10 +1957,15 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): vae_path, subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, revision=args.revision, + variant=args.variant, torch_dtype=weight_dtype, ) pipeline = StableDiffusionXLPipeline.from_pretrained( - args.pretrained_model_name_or_path, vae=vae, revision=args.revision, torch_dtype=weight_dtype + args.pretrained_model_name_or_path, + vae=vae, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, ) # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it From 0a08d41961220887c97074dcd585e52bba9f6220 Mon Sep 17 00:00:00 2001 From: takuoko Date: Mon, 4 Dec 2023 20:43:34 +0900 Subject: [PATCH 2/7] [Feature] Support IP-Adapter Plus (#5915) * Support IP-Adapter Plus * fix format * restore before black format * restore before black format * generic * Refactor PerceiverAttention * format * fix test and refactor PerceiverAttention * generic encode_image * keep attention implementation * merge tests * encode_image backward compatible * code quality * fix controlnet inpaint pipeline * refactor FFN * refactor FFN --------- Co-authored-by: YiYi Xu --- src/diffusers/loaders/unet.py | 106 +++++++++++--- src/diffusers/models/__init__.py | 2 + src/diffusers/models/activations.py | 15 +- src/diffusers/models/attention.py | 12 +- src/diffusers/models/embeddings.py | 89 ++++++++++++ .../alt_diffusion/pipeline_alt_diffusion.py | 28 +++- .../pipeline_alt_diffusion_img2img.py | 28 +++- .../animatediff/pipeline_animatediff.py | 28 +++- .../controlnet/pipeline_controlnet.py | 28 +++- .../controlnet/pipeline_controlnet_inpaint.py | 28 +++- .../controlnet/pipeline_controlnet_sd_xl.py | 28 +++- .../pipeline_stable_diffusion.py | 28 +++- .../pipeline_stable_diffusion_img2img.py | 28 +++- .../pipeline_stable_diffusion_inpaint.py | 28 +++- .../pipeline_stable_diffusion_xl.py | 28 +++- .../pipeline_stable_diffusion_xl_img2img.py | 28 +++- .../pipeline_stable_diffusion_xl_inpaint.py | 28 +++- tests/models/test_models_unet_2d_condition.py | 132 +++++++++++++++++- .../test_ip_adapter_stable_diffusion.py | 114 ++++++++++++++- 19 files changed, 683 insertions(+), 123 deletions(-) diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 6c805672c9cd..9d559a4b4af8 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from collections import defaultdict +from collections import OrderedDict, defaultdict from contextlib import nullcontext from typing import Callable, Dict, List, Optional, Union @@ -21,7 +21,7 @@ import torch.nn.functional as F from torch import nn -from ..models.embeddings import ImageProjection +from ..models.embeddings import ImageProjection, Resampler from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta from ..utils import ( DIFFUSERS_CACHE, @@ -672,6 +672,17 @@ def _load_ip_adapter_weights(self, state_dict): IPAdapterAttnProcessor2_0, ) + if "proj.weight" in state_dict["image_proj"]: + # IP-Adapter + num_image_text_embeds = 4 + else: + # IP-Adapter Plus + num_image_text_embeds = state_dict["image_proj"]["latents"].shape[1] + + # Set encoder_hid_proj after loading ip_adapter weights, + # because `Resampler` also has `attn_processors`. + self.encoder_hid_proj = None + # set ip-adapter cross-attention processors & load state_dict attn_procs = {} key_id = 1 @@ -695,7 +706,10 @@ def _load_ip_adapter_weights(self, state_dict): IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor ) attn_procs[name] = attn_processor_class( - hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0 + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + scale=1.0, + num_tokens=num_image_text_embeds, ).to(dtype=self.dtype, device=self.device) value_dict = {} @@ -708,26 +722,76 @@ def _load_ip_adapter_weights(self, state_dict): self.set_attn_processor(attn_procs) # create image projection layers. - clip_embeddings_dim = state_dict["image_proj"]["proj.weight"].shape[-1] - cross_attention_dim = state_dict["image_proj"]["proj.weight"].shape[0] // 4 + if "proj.weight" in state_dict["image_proj"]: + # IP-Adapter + clip_embeddings_dim = state_dict["image_proj"]["proj.weight"].shape[-1] + cross_attention_dim = state_dict["image_proj"]["proj.weight"].shape[0] // 4 + + image_projection = ImageProjection( + cross_attention_dim=cross_attention_dim, + image_embed_dim=clip_embeddings_dim, + num_image_text_embeds=num_image_text_embeds, + ) + image_projection.to(dtype=self.dtype, device=self.device) - image_projection = ImageProjection( - cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim, num_image_text_embeds=4 - ) - image_projection.to(dtype=self.dtype, device=self.device) - - # load image projection layer weights - image_proj_state_dict = {} - image_proj_state_dict.update( - { - "image_embeds.weight": state_dict["image_proj"]["proj.weight"], - "image_embeds.bias": state_dict["image_proj"]["proj.bias"], - "norm.weight": state_dict["image_proj"]["norm.weight"], - "norm.bias": state_dict["image_proj"]["norm.bias"], - } - ) + # load image projection layer weights + image_proj_state_dict = {} + image_proj_state_dict.update( + { + "image_embeds.weight": state_dict["image_proj"]["proj.weight"], + "image_embeds.bias": state_dict["image_proj"]["proj.bias"], + "norm.weight": state_dict["image_proj"]["norm.weight"], + "norm.bias": state_dict["image_proj"]["norm.bias"], + } + ) + + image_projection.load_state_dict(image_proj_state_dict) + + else: + # IP-Adapter Plus + embed_dims = state_dict["image_proj"]["proj_in.weight"].shape[1] + output_dims = state_dict["image_proj"]["proj_out.weight"].shape[0] + hidden_dims = state_dict["image_proj"]["latents"].shape[2] + heads = state_dict["image_proj"]["layers.0.0.to_q.weight"].shape[0] // 64 + + image_projection = Resampler( + embed_dims=embed_dims, + output_dims=output_dims, + hidden_dims=hidden_dims, + heads=heads, + num_queries=num_image_text_embeds, + ) + + image_proj_state_dict = state_dict["image_proj"] + + new_sd = OrderedDict() + for k, v in image_proj_state_dict.items(): + if "0.to" in k: + k = k.replace("0.to", "2.to") + elif "1.0.weight" in k: + k = k.replace("1.0.weight", "3.0.weight") + elif "1.0.bias" in k: + k = k.replace("1.0.bias", "3.0.bias") + elif "1.1.weight" in k: + k = k.replace("1.1.weight", "3.1.net.0.proj.weight") + elif "1.3.weight" in k: + k = k.replace("1.3.weight", "3.1.net.2.weight") + + if "norm1" in k: + new_sd[k.replace("0.norm1", "0")] = v + elif "norm2" in k: + new_sd[k.replace("0.norm2", "1")] = v + elif "to_kv" in k: + v_chunk = v.chunk(2, dim=0) + new_sd[k.replace("to_kv", "to_k")] = v_chunk[0] + new_sd[k.replace("to_kv", "to_v")] = v_chunk[1] + elif "to_out" in k: + new_sd[k.replace("to_out", "to_out.0")] = v + else: + new_sd[k] = v - image_projection.load_state_dict(image_proj_state_dict) + image_projection.load_state_dict(new_sd) + del image_proj_state_dict self.encoder_hid_proj = image_projection.to(device=self.device, dtype=self.dtype) self.config.encoder_hid_dim_type = "ip_image_proj" diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 1b76b4e03341..49ee3ee6af6b 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -34,6 +34,7 @@ _import_structure["controlnet"] = ["ControlNetModel"] _import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"] _import_structure["modeling_utils"] = ["ModelMixin"] + _import_structure["embeddings"] = ["ImageProjection"] _import_structure["prior_transformer"] = ["PriorTransformer"] _import_structure["t5_film_transformer"] = ["T5FilmDecoder"] _import_structure["transformer_2d"] = ["Transformer2DModel"] @@ -63,6 +64,7 @@ from .consistency_decoder_vae import ConsistencyDecoderVAE from .controlnet import ControlNetModel from .dual_transformer_2d import DualTransformer2DModel + from .embeddings import ImageProjection from .modeling_utils import ModelMixin from .prior_transformer import PriorTransformer from .t5_film_transformer import T5FilmDecoder diff --git a/src/diffusers/models/activations.py b/src/diffusers/models/activations.py index 8b75162ba597..47570eca8443 100644 --- a/src/diffusers/models/activations.py +++ b/src/diffusers/models/activations.py @@ -55,11 +55,12 @@ class GELU(nn.Module): dim_in (`int`): The number of channels in the input. dim_out (`int`): The number of channels in the output. approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. """ - def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"): + def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True): super().__init__() - self.proj = nn.Linear(dim_in, dim_out) + self.proj = nn.Linear(dim_in, dim_out, bias=bias) self.approximate = approximate def gelu(self, gate: torch.Tensor) -> torch.Tensor: @@ -81,13 +82,14 @@ class GEGLU(nn.Module): Parameters: dim_in (`int`): The number of channels in the input. dim_out (`int`): The number of channels in the output. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. """ - def __init__(self, dim_in: int, dim_out: int): + def __init__(self, dim_in: int, dim_out: int, bias: bool = True): super().__init__() linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear - self.proj = linear_cls(dim_in, dim_out * 2) + self.proj = linear_cls(dim_in, dim_out * 2, bias=bias) def gelu(self, gate: torch.Tensor) -> torch.Tensor: if gate.device.type != "mps": @@ -109,11 +111,12 @@ class ApproximateGELU(nn.Module): Parameters: dim_in (`int`): The number of channels in the input. dim_out (`int`): The number of channels in the output. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. """ - def __init__(self, dim_in: int, dim_out: int): + def __init__(self, dim_in: int, dim_out: int, bias: bool = True): super().__init__() - self.proj = nn.Linear(dim_in, dim_out) + self.proj = nn.Linear(dim_in, dim_out, bias=bias) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.proj(x) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index f02b5e249eee..08faaaf3e5bf 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -501,6 +501,7 @@ class FeedForward(nn.Module): dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. """ def __init__( @@ -511,6 +512,7 @@ def __init__( dropout: float = 0.0, activation_fn: str = "geglu", final_dropout: bool = False, + bias: bool = True, ): super().__init__() inner_dim = int(dim * mult) @@ -518,13 +520,13 @@ def __init__( linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear if activation_fn == "gelu": - act_fn = GELU(dim, inner_dim) + act_fn = GELU(dim, inner_dim, bias=bias) if activation_fn == "gelu-approximate": - act_fn = GELU(dim, inner_dim, approximate="tanh") + act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias) elif activation_fn == "geglu": - act_fn = GEGLU(dim, inner_dim) + act_fn = GEGLU(dim, inner_dim, bias=bias) elif activation_fn == "geglu-approximate": - act_fn = ApproximateGELU(dim, inner_dim) + act_fn = ApproximateGELU(dim, inner_dim, bias=bias) self.net = nn.ModuleList([]) # project in @@ -532,7 +534,7 @@ def __init__( # project dropout self.net.append(nn.Dropout(dropout)) # project out - self.net.append(linear_cls(inner_dim, dim_out)) + self.net.append(linear_cls(inner_dim, dim_out, bias=bias)) # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout if final_dropout: self.net.append(nn.Dropout(dropout)) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index a377ae267411..bdd2930d20f9 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -20,6 +20,7 @@ from ..utils import USE_PEFT_BACKEND from .activations import get_activation +from .attention_processor import Attention from .lora import LoRACompatibleLinear @@ -790,3 +791,91 @@ def forward(self, caption, force_drop_ids=None): hidden_states = self.act_1(hidden_states) hidden_states = self.linear_2(hidden_states) return hidden_states + + +class Resampler(nn.Module): + """Resampler of IP-Adapter Plus. + + Args: + ---- + embed_dims (int): The feature dimension. Defaults to 768. + output_dims (int): The number of output channels, that is the same + number of the channels in the + `unet.config.cross_attention_dim`. Defaults to 1024. + hidden_dims (int): The number of hidden channels. Defaults to 1280. + depth (int): The number of blocks. Defaults to 8. + dim_head (int): The number of head channels. Defaults to 64. + heads (int): Parallel attention heads. Defaults to 16. + num_queries (int): The number of queries. Defaults to 8. + ffn_ratio (float): The expansion ratio of feedforward network hidden + layer channels. Defaults to 4. + """ + + def __init__( + self, + embed_dims: int = 768, + output_dims: int = 1024, + hidden_dims: int = 1280, + depth: int = 4, + dim_head: int = 64, + heads: int = 16, + num_queries: int = 8, + ffn_ratio: float = 4, + ) -> None: + super().__init__() + from .attention import FeedForward # Lazy import to avoid circular import + + self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dims) / hidden_dims**0.5) + + self.proj_in = nn.Linear(embed_dims, hidden_dims) + + self.proj_out = nn.Linear(hidden_dims, output_dims) + self.norm_out = nn.LayerNorm(output_dims) + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList( + [ + nn.LayerNorm(hidden_dims), + nn.LayerNorm(hidden_dims), + Attention( + query_dim=hidden_dims, + dim_head=dim_head, + heads=heads, + out_bias=False, + ), + nn.Sequential( + nn.LayerNorm(hidden_dims), + FeedForward(hidden_dims, hidden_dims, activation_fn="gelu", mult=ffn_ratio, bias=False), + ), + ] + ) + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + ---- + x (torch.Tensor): Input Tensor. + + Returns: + ------- + torch.Tensor: Output Tensor. + """ + latents = self.latents.repeat(x.size(0), 1, 1) + + x = self.proj_in(x) + + for ln0, ln1, attn, ff in self.layers: + residual = latents + + encoder_hidden_states = ln0(x) + latents = ln1(latents) + encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2) + latents = attn(latents, encoder_hidden_states) + residual + latents = ff(latents) + latents + + latents = self.proj_out(latents) + return self.norm_out(latents) diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index b5c7aee4b4de..2121e9b81509 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -22,7 +22,7 @@ from ...configuration_utils import FrozenDict from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, UNet2DConditionModel +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -494,18 +494,29 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds - def encode_image(self, image, device, num_images_per_prompt): + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) - uncond_image_embeds = torch.zeros_like(image_embeds) - return image_embeds, uncond_image_embeds + return image_embeds, uncond_image_embeds def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: @@ -875,7 +886,10 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py index 4272fa124755..401e6aef82b1 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -24,7 +24,7 @@ from ...configuration_utils import FrozenDict from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, UNet2DConditionModel +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -505,18 +505,29 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds - def encode_image(self, image, device, num_images_per_prompt): + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) - uncond_image_embeds = torch.zeros_like(image_embeds) - return image_embeds, uncond_image_embeds + return image_embeds, uncond_image_embeds def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: @@ -919,7 +930,10 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index 28dc220545dc..32a08a0264bc 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -22,7 +22,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, UNet2DConditionModel, UNetMotionModel +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel from ...models.lora import adjust_lora_scale_text_encoder from ...models.unet_motion_model import MotionAdapter from ...schedulers import ( @@ -320,18 +320,29 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image - def encode_image(self, image, device, num_images_per_prompt): + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) - uncond_image_embeds = torch.zeros_like(image_embeds) - return image_embeds, uncond_image_embeds + return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents def decode_latents(self, latents): @@ -651,7 +662,10 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_videos_per_prompt) + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_videos_per_prompt, output_hidden_state + ) if do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index 1e19678b221d..bf6ef2125446 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -24,7 +24,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel +from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -479,18 +479,29 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image - def encode_image(self, image, device, num_images_per_prompt): + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) - uncond_image_embeds = torch.zeros_like(image_embeds) - return image_embeds, uncond_image_embeds + return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): @@ -1067,7 +1078,10 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py index 72c2250dd5ac..71e237ce4e02 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py @@ -25,7 +25,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel +from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -597,18 +597,29 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image - def encode_image(self, image, device, num_images_per_prompt): + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) - uncond_image_embeds = torch.zeros_like(image_embeds) - return image_embeds, uncond_image_embeds + return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): @@ -1284,7 +1295,10 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 4696781dce0c..8c8399809228 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -37,7 +37,7 @@ StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ) -from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel +from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, LoRAAttnProcessor2_0, @@ -489,18 +489,29 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image - def encode_image(self, image, device, num_images_per_prompt): + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) - uncond_image_embeds = torch.zeros_like(image_embeds) - return image_embeds, uncond_image_embeds + return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): @@ -1169,7 +1180,10 @@ def __call__( # 3.2 Encode ip_adapter_image if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index bf43c043490b..f7f4a16f0aa4 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -22,7 +22,7 @@ from ...configuration_utils import FrozenDict from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, UNet2DConditionModel +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -489,18 +489,29 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds - def encode_image(self, image, device, num_images_per_prompt): + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) - uncond_image_embeds = torch.zeros_like(image_embeds) - return image_embeds, uncond_image_embeds + return image_embeds, uncond_image_embeds def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: @@ -871,7 +882,10 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index e3a1a0ed3660..c80178152a6e 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -24,7 +24,7 @@ from ...configuration_utils import FrozenDict from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, UNet2DConditionModel +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -503,18 +503,29 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image - def encode_image(self, image, device, num_images_per_prompt): + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) - uncond_image_embeds = torch.zeros_like(image_embeds) - return image_embeds, uncond_image_embeds + return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): @@ -923,7 +934,10 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 3570eaa6fd3d..375197cc9e4d 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -24,7 +24,7 @@ from ...configuration_utils import FrozenDict from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AsymmetricAutoencoderKL, AutoencoderKL, UNet2DConditionModel +from ...models import AsymmetricAutoencoderKL, AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers @@ -574,18 +574,29 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image - def encode_image(self, image, device, num_images_per_prompt): + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) - uncond_image_embeds = torch.zeros_like(image_embeds) - return image_embeds, uncond_image_embeds + return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): @@ -1103,7 +1114,10 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 40c981a46d48..12d52aa076d4 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -31,7 +31,7 @@ StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ) -from ...models import AutoencoderKL, UNet2DConditionModel +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, LoRAAttnProcessor2_0, @@ -524,18 +524,29 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image - def encode_image(self, image, device, num_images_per_prompt): + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) - uncond_image_embeds = torch.zeros_like(image_embeds) - return image_embeds, uncond_image_embeds + return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): @@ -1087,7 +1098,10 @@ def __call__( add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) image_embeds = image_embeds.to(device) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 436d816e5eb3..729924ec2e20 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -32,7 +32,7 @@ StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ) -from ...models import AutoencoderKL, UNet2DConditionModel +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, LoRAAttnProcessor2_0, @@ -741,18 +741,29 @@ def prepare_latents( return latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image - def encode_image(self, image, device, num_images_per_prompt): + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) - uncond_image_embeds = torch.zeros_like(image_embeds) - return image_embeds, uncond_image_embeds + return image_embeds, uncond_image_embeds def _get_add_time_ids( self, @@ -1259,7 +1270,10 @@ def denoising_value_valid(dnv): add_time_ids = add_time_ids.to(device) if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) image_embeds = image_embeds.to(device) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index f54b680dfd7c..7195b5f2521a 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -33,7 +33,7 @@ StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ) -from ...models import AutoencoderKL, UNet2DConditionModel +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, LoRAAttnProcessor2_0, @@ -462,18 +462,29 @@ def disable_vae_tiling(self): self.vae.disable_tiling() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image - def encode_image(self, image, device, num_images_per_prompt): + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) - uncond_image_embeds = torch.zeros_like(image_embeds) - return image_embeds, uncond_image_embeds + return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt def encode_prompt( @@ -1568,7 +1579,10 @@ def denoising_value_valid(dnv): add_time_ids = add_time_ids.to(device) if ip_adapter_image is not None: - image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) image_embeds = image_embeds.to(device) diff --git a/tests/models/test_models_unet_2d_condition.py b/tests/models/test_models_unet_2d_condition.py index 06bf2685560d..9ccd78f1fe47 100644 --- a/tests/models/test_models_unet_2d_condition.py +++ b/tests/models/test_models_unet_2d_condition.py @@ -18,6 +18,7 @@ import os import tempfile import unittest +from collections import OrderedDict import torch from parameterized import parameterized @@ -25,7 +26,7 @@ from diffusers import UNet2DConditionModel from diffusers.models.attention_processor import CustomDiffusionAttnProcessor, IPAdapterAttnProcessor -from diffusers.models.embeddings import ImageProjection +from diffusers.models.embeddings import ImageProjection, Resampler from diffusers.utils import logging from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.testing_utils import ( @@ -97,6 +98,85 @@ def create_ip_adapter_state_dict(model): return ip_state_dict +def create_ip_adapter_plus_state_dict(model): + # "ip_adapter" (cross-attention weights) + ip_cross_attn_state_dict = {} + key_id = 1 + + for name in model.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = model.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(model.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = model.config.block_out_channels[block_id] + if cross_attention_dim is not None: + sd = IPAdapterAttnProcessor( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0 + ).state_dict() + ip_cross_attn_state_dict.update( + { + f"{key_id}.to_k_ip.weight": sd["to_k_ip.weight"], + f"{key_id}.to_v_ip.weight": sd["to_v_ip.weight"], + } + ) + + key_id += 2 + + # "image_proj" (ImageProjection layer weights) + cross_attention_dim = model.config["cross_attention_dim"] + image_projection = Resampler( + embed_dims=cross_attention_dim, output_dims=cross_attention_dim, dim_head=32, heads=2, num_queries=4 + ) + + ip_image_projection_state_dict = OrderedDict() + for k, v in image_projection.state_dict().items(): + if "2.to" in k: + k = k.replace("2.to", "0.to") + elif "3.0.weight" in k: + k = k.replace("3.0.weight", "1.0.weight") + elif "3.0.bias" in k: + k = k.replace("3.0.bias", "1.0.bias") + elif "3.0.weight" in k: + k = k.replace("3.0.weight", "1.0.weight") + elif "3.1.net.0.proj.weight" in k: + k = k.replace("3.1.net.0.proj.weight", "1.1.weight") + elif "3.net.2.weight" in k: + k = k.replace("3.net.2.weight", "1.3.weight") + elif "layers.0.0" in k: + k = k.replace("layers.0.0", "layers.0.0.norm1") + elif "layers.0.1" in k: + k = k.replace("layers.0.1", "layers.0.0.norm2") + elif "layers.1.0" in k: + k = k.replace("layers.1.0", "layers.1.0.norm1") + elif "layers.1.1" in k: + k = k.replace("layers.1.1", "layers.1.0.norm2") + elif "layers.2.0" in k: + k = k.replace("layers.2.0", "layers.2.0.norm1") + elif "layers.2.1" in k: + k = k.replace("layers.2.1", "layers.2.0.norm2") + + if "norm_cross" in k: + ip_image_projection_state_dict[k.replace("norm_cross", "norm1")] = v + elif "layer_norm" in k: + ip_image_projection_state_dict[k.replace("layer_norm", "norm2")] = v + elif "to_k" in k: + ip_image_projection_state_dict[k.replace("to_k", "to_kv")] = torch.cat([v, v], dim=0) + elif "to_v" in k: + continue + elif "to_out.0" in k: + ip_image_projection_state_dict[k.replace("to_out.0", "to_out")] = v + else: + ip_image_projection_state_dict[k] = v + + ip_state_dict = {} + ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict}) + return ip_state_dict + + def create_custom_diffusion_layers(model, mock_weights: bool = True): train_kv = True train_q_out = True @@ -724,6 +804,56 @@ def test_ip_adapter(self): assert not sample2.allclose(sample3, atol=1e-4, rtol=1e-4) assert sample2.allclose(sample4, atol=1e-4, rtol=1e-4) + def test_ip_adapter_plus(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = (8, 16) + + model = self.model_class(**init_dict) + model.to(torch_device) + + # forward pass without ip-adapter + with torch.no_grad(): + sample1 = model(**inputs_dict).sample + + # update inputs_dict for ip-adapter + batch_size = inputs_dict["encoder_hidden_states"].shape[0] + image_embeds = floats_tensor((batch_size, 1, model.cross_attention_dim)).to(torch_device) + inputs_dict["added_cond_kwargs"] = {"image_embeds": image_embeds} + + # make ip_adapter_1 and ip_adapter_2 + ip_adapter_1 = create_ip_adapter_plus_state_dict(model) + + image_proj_state_dict_2 = {k: w + 1.0 for k, w in ip_adapter_1["image_proj"].items()} + cross_attn_state_dict_2 = {k: w + 1.0 for k, w in ip_adapter_1["ip_adapter"].items()} + ip_adapter_2 = {} + ip_adapter_2.update({"image_proj": image_proj_state_dict_2, "ip_adapter": cross_attn_state_dict_2}) + + # forward pass ip_adapter_1 + model._load_ip_adapter_weights(ip_adapter_1) + assert model.config.encoder_hid_dim_type == "ip_image_proj" + assert model.encoder_hid_proj is not None + assert model.down_blocks[0].attentions[0].transformer_blocks[0].attn2.processor.__class__.__name__ in ( + "IPAdapterAttnProcessor", + "IPAdapterAttnProcessor2_0", + ) + with torch.no_grad(): + sample2 = model(**inputs_dict).sample + + # forward pass with ip_adapter_2 + model._load_ip_adapter_weights(ip_adapter_2) + with torch.no_grad(): + sample3 = model(**inputs_dict).sample + + # forward pass with ip_adapter_1 again + model._load_ip_adapter_weights(ip_adapter_1) + with torch.no_grad(): + sample4 = model(**inputs_dict).sample + + assert not sample1.allclose(sample2, atol=1e-4, rtol=1e-4) + assert not sample2.allclose(sample3, atol=1e-4, rtol=1e-4) + assert sample2.allclose(sample4, atol=1e-4, rtol=1e-4) + @slow class UNet2DConditionModelIntegrationTests(unittest.TestCase): diff --git a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py index 57eb49013c1f..7c6349ce2600 100644 --- a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py +++ b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py @@ -116,7 +116,17 @@ def test_text_to_image(self): images = pipeline(**inputs).images image_slice = images[0, :3, :3, -1].flatten() - expected_slice = np.array([0.8047, 0.8774, 0.9248, 0.9155, 0.9814, 1.0, 0.9678, 1.0, 1.0]) + expected_slice = np.array([0.8110, 0.8843, 0.9326, 0.9224, 0.9878, 1.0, 0.9736, 1.0, 1.0]) + + assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) + + pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.bin") + + inputs = self.get_dummy_inputs() + images = pipeline(**inputs).images + image_slice = images[0, :3, :3, -1].flatten() + + expected_slice = np.array([0.3013, 0.2615, 0.2202, 0.2722, 0.2510, 0.2023, 0.2498, 0.2415, 0.2139]) assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) @@ -132,7 +142,17 @@ def test_image_to_image(self): images = pipeline(**inputs).images image_slice = images[0, :3, :3, -1].flatten() - expected_slice = np.array([0.2307, 0.2341, 0.2305, 0.24, 0.2268, 0.25, 0.2322, 0.2588, 0.2935]) + expected_slice = np.array([0.2253, 0.2251, 0.2219, 0.2312, 0.2236, 0.2434, 0.2275, 0.2575, 0.2805]) + + assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) + + pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.bin") + + inputs = self.get_dummy_inputs(for_image_to_image=True) + images = pipeline(**inputs).images + image_slice = images[0, :3, :3, -1].flatten() + + expected_slice = np.array([0.3550, 0.2600, 0.2520, 0.2412, 0.1870, 0.3831, 0.1453, 0.1880, 0.5371]) assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) @@ -148,7 +168,17 @@ def test_inpainting(self): images = pipeline(**inputs).images image_slice = images[0, :3, :3, -1].flatten() - expected_slice = np.array([0.2705, 0.2395, 0.2209, 0.2312, 0.2102, 0.2104, 0.2178, 0.2065, 0.1997]) + expected_slice = np.array([0.2700, 0.2388, 0.2202, 0.2304, 0.2095, 0.2097, 0.2173, 0.2058, 0.1987]) + + assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) + + pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.bin") + + inputs = self.get_dummy_inputs(for_inpainting=True) + images = pipeline(**inputs).images + image_slice = images[0, :3, :3, -1].flatten() + + expected_slice = np.array([0.2744, 0.2410, 0.2202, 0.2334, 0.2090, 0.2053, 0.2175, 0.2033, 0.1934]) assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) @@ -173,7 +203,30 @@ def test_text_to_image_sdxl(self): images = pipeline(**inputs).images image_slice = images[0, :3, :3, -1].flatten() - expected_slice = np.array([0.0968, 0.0959, 0.0852, 0.0912, 0.0948, 0.093, 0.0893, 0.0932, 0.0923]) + expected_slice = np.array([0.0965, 0.0956, 0.0849, 0.0908, 0.0944, 0.0927, 0.0888, 0.0929, 0.0920]) + + assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) + + image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder") + + pipeline = StableDiffusionXLPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + image_encoder=image_encoder, + feature_extractor=feature_extractor, + torch_dtype=self.dtype, + ) + pipeline.to(torch_device) + pipeline.load_ip_adapter( + "h94/IP-Adapter", + subfolder="sdxl_models", + weight_name="ip-adapter-plus_sdxl_vit-h.bin", + ) + + inputs = self.get_dummy_inputs() + images = pipeline(**inputs).images + image_slice = images[0, :3, :3, -1].flatten() + + expected_slice = np.array([0.0592, 0.0573, 0.0459, 0.0542, 0.0559, 0.0523, 0.0500, 0.0540, 0.0501]) assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) @@ -194,7 +247,31 @@ def test_image_to_image_sdxl(self): images = pipeline(**inputs).images image_slice = images[0, :3, :3, -1].flatten() - expected_slice = np.array([0.0653, 0.0704, 0.0725, 0.0741, 0.0702, 0.0647, 0.0782, 0.0799, 0.0752]) + expected_slice = np.array([0.0652, 0.0698, 0.0723, 0.0744, 0.0699, 0.0636, 0.0784, 0.0803, 0.0742]) + + assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) + + image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder") + feature_extractor = self.get_image_processor("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") + + pipeline = StableDiffusionXLImg2ImgPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + image_encoder=image_encoder, + feature_extractor=feature_extractor, + torch_dtype=self.dtype, + ) + pipeline.to(torch_device) + pipeline.load_ip_adapter( + "h94/IP-Adapter", + subfolder="sdxl_models", + weight_name="ip-adapter-plus_sdxl_vit-h.bin", + ) + + inputs = self.get_dummy_inputs(for_image_to_image=True) + images = pipeline(**inputs).images + image_slice = images[0, :3, :3, -1].flatten() + + expected_slice = np.array([0.0708, 0.0701, 0.0735, 0.0760, 0.0739, 0.0679, 0.0756, 0.0824, 0.0837]) assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) @@ -216,6 +293,31 @@ def test_inpainting_sdxl(self): image_slice = images[0, :3, :3, -1].flatten() image_slice.tolist() - expected_slice = np.array([0.1418, 0.1493, 0.1428, 0.146, 0.1491, 0.1501, 0.1473, 0.1501, 0.1516]) + expected_slice = np.array([0.1420, 0.1495, 0.1430, 0.1462, 0.1493, 0.1502, 0.1474, 0.1502, 0.1517]) + + assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) + + image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder") + feature_extractor = self.get_image_processor("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") + + pipeline = StableDiffusionXLInpaintPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + image_encoder=image_encoder, + feature_extractor=feature_extractor, + torch_dtype=self.dtype, + ) + pipeline.to(torch_device) + pipeline.load_ip_adapter( + "h94/IP-Adapter", + subfolder="sdxl_models", + weight_name="ip-adapter-plus_sdxl_vit-h.bin", + ) + + inputs = self.get_dummy_inputs(for_inpainting=True) + images = pipeline(**inputs).images + image_slice = images[0, :3, :3, -1].flatten() + image_slice.tolist() + + expected_slice = np.array([0.1398, 0.1476, 0.1407, 0.1442, 0.1470, 0.1480, 0.1449, 0.1481, 0.1494]) assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) From c36f1c316018eceb183bc81f69d40e753813c383 Mon Sep 17 00:00:00 2001 From: RuoyiDu <61931443+RuoyiDu@users.noreply.github.com> Date: Mon, 4 Dec 2023 14:14:57 +0000 Subject: [PATCH 3/7] [Community Pipeline] DemoFusion: Democratising High-Resolution Image Generation With No $$$ (#6022) * Add files via upload * Update README.md * Update pipeline_demofusion_sdxl.py * Update pipeline_demofusion_sdxl.py * Update examples/community/README.md Co-authored-by: Sayak Paul --------- Co-authored-by: Sayak Paul --- examples/community/README.md | 81 +- .../community/pipeline_demofusion_sdxl.py | 1412 +++++++++++++++++ 2 files changed, 1492 insertions(+), 1 deletion(-) create mode 100644 examples/community/pipeline_demofusion_sdxl.py diff --git a/examples/community/README.md b/examples/community/README.md index 37b51c8c4139..9fad6ecbf690 100755 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -50,7 +50,7 @@ prompt-to-prompt | change parts of a prompt and retain image structure (see [pap | Latent Consistency Interpolation Pipeline | Interpolate the latent space of Latent Consistency Models with multiple prompts | [Latent Consistency Interpolation Pipeline](#latent-consistency-interpolation-pipeline) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1pK3NrLWJSiJsBynLns1K1-IDTW9zbPvl?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) | | Regional Prompting Pipeline | Assign multiple prompts for different regions | [Regional Prompting Pipeline](#regional-prompting-pipeline) | - | [hako-mikan](https://github.com/hako-mikan) | | LDM3D-sr (LDM3D upscaler) | Upscale low resolution RGB and depth inputs to high resolution | [StableDiffusionUpscaleLDM3D Pipeline](https://github.com/estelleafl/diffusers/tree/ldm3d_upscaler_community/examples/community#stablediffusionupscaleldm3d-pipeline) | - | [Estelle Aflalo](https://github.com/estelleafl) | -| +| DemoFusion Pipeline | Implementation of [DemoFusion: Democratising High-Resolution Image Generation With No $$$](https://arxiv.org/abs/2311.16973) | [DemoFusion Pipeline](#DemoFusion) | - | [Ruoyi Du](https://github.com/RuoyiDu) | To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly. ```py @@ -2842,3 +2842,82 @@ The Pipeline supports `compel` syntax. Input prompts using the `compel` structur * ![dps_mea](https://github.com/tongdaxu/Images/assets/22267548/ff6a33d6-26f0-42aa-88ce-f8a76ba45a13) * Reconstructed image: * ![dps_generated_image](https://github.com/tongdaxu/Images/assets/22267548/b74f084d-93f4-4845-83d8-44c0fa758a5f) + +### DemoFusion +This pipeline is the official implementation of [DemoFusion: Democratising High-Resolution Image Generation With No $$$](https://arxiv.org/abs/2311.16973). +The original repo can be found at [repo](https://github.com/PRIS-CV/DemoFusion). +- `view_batch_size` (`int`, defaults to 16): + The batch size for multiple denoising paths. Typically, a larger batch size can result in higher efficiency but comes with increased GPU memory requirements. + +- `stride` (`int`, defaults to 64): + The stride of moving local patches. A smaller stride is better for alleviating seam issues, but it also introduces additional computational overhead and inference time. + +- `cosine_scale_1` (`float`, defaults to 3): + Control the strength of skip-residual. For specific impacts, please refer to Appendix C in the DemoFusion paper. + +- `cosine_scale_2` (`float`, defaults to 1): + Control the strength of dilated sampling. For specific impacts, please refer to Appendix C in the DemoFusion paper. + +- `cosine_scale_3` (`float`, defaults to 1): + Control the strength of the Gaussian filter. For specific impacts, please refer to Appendix C in the DemoFusion paper. + +- `sigma` (`float`, defaults to 1): + The standard value of the Gaussian filter. Larger sigma promotes the global guidance of dilated sampling, but has the potential of over-smoothing. + +- `multi_decoder` (`bool`, defaults to True): + Determine whether to use a tiled decoder. Generally, when the resolution exceeds 3072x3072, a tiled decoder becomes necessary. + +- `show_image` (`bool`, defaults to False): + Determine whether to show intermediate results during generation. +``` +from pipeline_demofusion_sdxl import DemoFusionSDXLPipeline + +model_ckpt = "stabilityai/stable-diffusion-xl-base-1.0" +pipe = DemoFusionSDXLPipeline.from_pretrained(model_ckpt, torch_dtype=torch.float16) +pipe = pipe.to("cuda") + +prompt = "Envision a portrait of an elderly woman, her face a canvas of time, framed by a headscarf with muted tones of rust and cream. Her eyes, blue like faded denim. Her attire, simple yet dignified." +negative_prompt = "blurry, ugly, duplicate, poorly drawn, deformed, mosaic" + +images = pipe( + prompt, + negative_prompt=negative_prompt, + height=3072, + width=3072, + view_batch_size=16, + stride=64, + num_inference_steps=50, + guidance_scale=7.5, + cosine_scale_1=3, + cosine_scale_2=1, + cosine_scale_3=1, + sigma=0.8, + multi_decoder=True, + show_image=True +) +``` +You can display and save the generated images as: +``` +def image_grid(imgs, save_path=None): + + w = 0 + for i, img in enumerate(imgs): + h_, w_ = imgs[i].size + w += w_ + h = h_ + grid = Image.new('RGB', size=(w, h)) + grid_w, grid_h = grid.size + + w = 0 + for i, img in enumerate(imgs): + h_, w_ = imgs[i].size + grid.paste(img, box=(w, h - h_)) + if save_path != None: + img.save(save_path + "/img_{}.jpg".format((i + 1) * 1024)) + w += w_ + + return grid + +image_grid(images, save_path="./outputs/") +``` + ![output_example](https://github.com/PRIS-CV/DemoFusion/blob/main/output_example.png) diff --git a/examples/community/pipeline_demofusion_sdxl.py b/examples/community/pipeline_demofusion_sdxl.py new file mode 100644 index 000000000000..5a81320219a5 --- /dev/null +++ b/examples/community/pipeline_demofusion_sdxl.py @@ -0,0 +1,1412 @@ +import inspect +import os +import random +import warnings +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import matplotlib.pyplot as plt +import torch +import torch.nn.functional as F +from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer + +from diffusers.image_processor import VaeImageProcessor +from diffusers.loaders import ( + FromSingleFileMixin, + LoraLoaderMixin, + TextualInversionLoaderMixin, +) +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.models.attention_processor import ( + AttnProcessor2_0, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + XFormersAttnProcessor, +) +from diffusers.models.lora import adjust_lora_scale_text_encoder +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import ( + is_accelerate_available, + is_accelerate_version, + is_invisible_watermark_available, + logging, + replace_example_docstring, +) +from diffusers.utils.torch_utils import randn_tensor + + +if is_invisible_watermark_available(): + from .watermark import StableDiffusionXLWatermarker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusionXLPipeline + + >>> pipe = StableDiffusionXLPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt).images[0] + ``` +""" + + +def gaussian_kernel(kernel_size=3, sigma=1.0, channels=3): + x_coord = torch.arange(kernel_size) + gaussian_1d = torch.exp(-((x_coord - (kernel_size - 1) / 2) ** 2) / (2 * sigma**2)) + gaussian_1d = gaussian_1d / gaussian_1d.sum() + gaussian_2d = gaussian_1d[:, None] * gaussian_1d[None, :] + kernel = gaussian_2d[None, None, :, :].repeat(channels, 1, 1, 1) + + return kernel + + +def gaussian_filter(latents, kernel_size=3, sigma=1.0): + channels = latents.shape[1] + kernel = gaussian_kernel(kernel_size, sigma, channels).to(latents.device, latents.dtype) + blurred_latents = F.conv2d(latents, kernel, padding=kernel_size // 2, groups=channels) + + return blurred_latents + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class DemoFusionSDXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin): + r""" + Pipeline for text-to-image generation using Stable Diffusion XL. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + In addition the pipeline inherits the following loading methods: + - *LoRA*: [`StableDiffusionXLPipeline.load_lora_weights`] + - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] + + as well as the following saving methods: + - *LoRA*: [`loaders.StableDiffusionXLPipeline.save_lora_weights`] + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion XL uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([` CLIPTextModelWithProjection`]): + Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): + Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of + `stabilityai/stable-diffusion-xl-base-1-0`. + add_watermarker (`bool`, *optional*): + Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to + watermark output images. If not defined, it will default to True if the package is installed, otherwise no + watermarker will be used. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + force_zeros_for_empty_prompt: bool = True, + add_watermarker: Optional[bool] = None, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + scheduler=scheduler, + ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.default_sample_size = self.unet.config.sample_size + + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + + if add_watermarker: + self.watermark = StableDiffusionXLWatermarker() + else: + self.watermark = None + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def encode_prompt( + self, + prompt: str, + prompt_2: Optional[str] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + negative_prompt_2: Optional[str] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + # textual inversion: procecss multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder( + text_input_ids.to(device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt, negative_prompt_2] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + num_images_per_prompt=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + # DemoFusion specific checks + if max(height, width) % 1024 != 0: + raise ValueError( + f"the larger one of `height` and `width` has to be divisible by 1024 but are {height} and {width}." + ) + + if num_images_per_prompt != 1: + warnings.warn("num_images_per_prompt != 1 is not supported by DemoFusion and will be ignored.") + num_images_per_prompt = 1 + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + def get_views(self, height, width, window_size=128, stride=64, random_jitter=False): + height //= self.vae_scale_factor + width //= self.vae_scale_factor + num_blocks_height = int((height - window_size) / stride - 1e-6) + 2 if height > window_size else 1 + num_blocks_width = int((width - window_size) / stride - 1e-6) + 2 if width > window_size else 1 + total_num_blocks = int(num_blocks_height * num_blocks_width) + views = [] + for i in range(total_num_blocks): + h_start = int((i // num_blocks_width) * stride) + h_end = h_start + window_size + w_start = int((i % num_blocks_width) * stride) + w_end = w_start + window_size + + if h_end > height: + h_start = int(h_start + height - h_end) + h_end = int(height) + if w_end > width: + w_start = int(w_start + width - w_end) + w_end = int(width) + if h_start < 0: + h_end = int(h_end - h_start) + h_start = 0 + if w_start < 0: + w_end = int(w_end - w_start) + w_start = 0 + + if random_jitter: + jitter_range = (window_size - stride) // 4 + w_jitter = 0 + h_jitter = 0 + if (w_start != 0) and (w_end != width): + w_jitter = random.randint(-jitter_range, jitter_range) + elif (w_start == 0) and (w_end != width): + w_jitter = random.randint(-jitter_range, 0) + elif (w_start != 0) and (w_end == width): + w_jitter = random.randint(0, jitter_range) + if (h_start != 0) and (h_end != height): + h_jitter = random.randint(-jitter_range, jitter_range) + elif (h_start == 0) and (h_end != height): + h_jitter = random.randint(-jitter_range, 0) + elif (h_start != 0) and (h_end == height): + h_jitter = random.randint(0, jitter_range) + h_start += h_jitter + jitter_range + h_end += h_jitter + jitter_range + w_start += w_jitter + jitter_range + w_end += w_jitter + jitter_range + + views.append((h_start, h_end, w_start, w_end)) + return views + + def tiled_decode(self, latents, current_height, current_width): + core_size = self.unet.config.sample_size // 4 + core_stride = core_size + pad_size = self.unet.config.sample_size // 4 * 3 + decoder_view_batch_size = 1 + + views = self.get_views(current_height, current_width, stride=core_stride, window_size=core_size) + views_batch = [views[i : i + decoder_view_batch_size] for i in range(0, len(views), decoder_view_batch_size)] + latents_ = F.pad(latents, (pad_size, pad_size, pad_size, pad_size), "constant", 0) + image = torch.zeros(latents.size(0), 3, current_height, current_width).to(latents.device) + count = torch.zeros_like(image).to(latents.device) + # get the latents corresponding to the current view coordinates + with self.progress_bar(total=len(views_batch)) as progress_bar: + for j, batch_view in enumerate(views_batch): + len(batch_view) + latents_for_view = torch.cat( + [ + latents_[:, :, h_start : h_end + pad_size * 2, w_start : w_end + pad_size * 2] + for h_start, h_end, w_start, w_end in batch_view + ] + ) + image_patch = self.vae.decode(latents_for_view / self.vae.config.scaling_factor, return_dict=False)[0] + h_start, h_end, w_start, w_end = views[j] + h_start, h_end, w_start, w_end = ( + h_start * self.vae_scale_factor, + h_end * self.vae_scale_factor, + w_start * self.vae_scale_factor, + w_end * self.vae_scale_factor, + ) + p_h_start, p_h_end, p_w_start, p_w_end = ( + pad_size * self.vae_scale_factor, + image_patch.size(2) - pad_size * self.vae_scale_factor, + pad_size * self.vae_scale_factor, + image_patch.size(3) - pad_size * self.vae_scale_factor, + ) + image[:, :, h_start:h_end, w_start:w_end] += image_patch[:, :, p_h_start:p_h_end, p_w_start:p_w_end] + count[:, :, h_start:h_end, w_start:w_end] += 1 + progress_bar.update() + image = image / count + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + LoRAXFormersAttnProcessor, + LoRAAttnProcessor2_0, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + denoising_end: Optional[float] = None, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = False, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, + negative_original_size: Optional[Tuple[int, int]] = None, + negative_crops_coords_top_left: Tuple[int, int] = (0, 0), + negative_target_size: Optional[Tuple[int, int]] = None, + ################### DemoFusion specific parameters #################### + view_batch_size: int = 16, + multi_decoder: bool = True, + stride: Optional[int] = 64, + cosine_scale_1: Optional[float] = 3.0, + cosine_scale_2: Optional[float] = 1.0, + cosine_scale_3: Optional[float] = 1.0, + sigma: Optional[float] = 0.8, + show_image: bool = False, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.7): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + ################### DemoFusion specific parameters #################### + view_batch_size (`int`, defaults to 16): + The batch size for multiple denoising paths. Typically, a larger batch size can result in higher + efficiency but comes with increased GPU memory requirements. + multi_decoder (`bool`, defaults to True): + Determine whether to use a tiled decoder. Generally, when the resolution exceeds 3072x3072, + a tiled decoder becomes necessary. + stride (`int`, defaults to 64): + The stride of moving local patches. A smaller stride is better for alleviating seam issues, + but it also introduces additional computational overhead and inference time. + cosine_scale_1 (`float`, defaults to 3): + Control the strength of skip-residual. For specific impacts, please refer to Appendix C + in the DemoFusion paper. + cosine_scale_2 (`float`, defaults to 1): + Control the strength of dilated sampling. For specific impacts, please refer to Appendix C + in the DemoFusion paper. + cosine_scale_3 (`float`, defaults to 1): + Control the strength of the gaussion filter. For specific impacts, please refer to Appendix C + in the DemoFusion paper. + sigma (`float`, defaults to 1): + The standerd value of the gaussian filter. + show_image (`bool`, defaults to False): + Determine whether to show intermediate results during generation. + + Examples: + + Returns: + a `list` with the generated images at each phase. + """ + + # 0. Default height and width to unet + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + x1_size = self.default_sample_size * self.vae_scale_factor + + height_scale = height / x1_size + width_scale = width / x1_size + scale_num = int(max(height_scale, width_scale)) + aspect_ratio = min(height_scale, width_scale) / max(height_scale, width_scale) + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + num_images_per_prompt, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height // scale_num, + width // scale_num, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + add_time_ids = self._get_add_time_ids( + original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype + ) + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + ) + else: + negative_add_time_ids = add_time_ids + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + # 7.1 Apply denoising_end + if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1: + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + output_images = [] + + ############################################################### Phase 1 ################################################################# + + print("### Phase 1 Denoising ###") + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + latents_for_view = latents + + # expand the latents if we are doing classifier free guidance + latent_model_input = latents.repeat_interleave(2, dim=0) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred[::2], noise_pred[1::2] + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + anchor_mean = latents.mean() + anchor_std = latents.std() + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + print("### Phase 1 Decoding ###") + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + + image = self.image_processor.postprocess(image, output_type=output_type) + if show_image: + plt.figure(figsize=(10, 10)) + plt.imshow(image[0]) + plt.axis("off") # Turn off axis numbers and ticks + plt.show() + output_images.append(image[0]) + + ####################################################### Phase 2+ ##################################################### + + for current_scale_num in range(2, scale_num + 1): + print("### Phase {} Denoising ###".format(current_scale_num)) + current_height = self.unet.config.sample_size * self.vae_scale_factor * current_scale_num + current_width = self.unet.config.sample_size * self.vae_scale_factor * current_scale_num + if height > width: + current_width = int(current_width * aspect_ratio) + else: + current_height = int(current_height * aspect_ratio) + + latents = F.interpolate( + latents, + size=(int(current_height / self.vae_scale_factor), int(current_width / self.vae_scale_factor)), + mode="bicubic", + ) + + noise_latents = [] + noise = torch.randn_like(latents) + for timestep in timesteps: + noise_latent = self.scheduler.add_noise(latents, noise, timestep.unsqueeze(0)) + noise_latents.append(noise_latent) + latents = noise_latents[0] + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + count = torch.zeros_like(latents) + value = torch.zeros_like(latents) + cosine_factor = ( + 0.5 + * ( + 1 + + torch.cos( + torch.pi + * (self.scheduler.config.num_train_timesteps - t) + / self.scheduler.config.num_train_timesteps + ) + ).cpu() + ) + + c1 = cosine_factor**cosine_scale_1 + latents = latents * (1 - c1) + noise_latents[i] * c1 + + ############################################# MultiDiffusion ############################################# + + views = self.get_views( + current_height, + current_width, + stride=stride, + window_size=self.unet.config.sample_size, + random_jitter=True, + ) + views_batch = [views[i : i + view_batch_size] for i in range(0, len(views), view_batch_size)] + + jitter_range = (self.unet.config.sample_size - stride) // 4 + latents_ = F.pad(latents, (jitter_range, jitter_range, jitter_range, jitter_range), "constant", 0) + + count_local = torch.zeros_like(latents_) + value_local = torch.zeros_like(latents_) + + for j, batch_view in enumerate(views_batch): + vb_size = len(batch_view) + + # get the latents corresponding to the current view coordinates + latents_for_view = torch.cat( + [ + latents_[:, :, h_start:h_end, w_start:w_end] + for h_start, h_end, w_start, w_end in batch_view + ] + ) + + # expand the latents if we are doing classifier free guidance + latent_model_input = latents_for_view + latent_model_input = ( + latent_model_input.repeat_interleave(2, dim=0) + if do_classifier_free_guidance + else latent_model_input + ) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + prompt_embeds_input = torch.cat([prompt_embeds] * vb_size) + add_text_embeds_input = torch.cat([add_text_embeds] * vb_size) + add_time_ids_input = [] + for h_start, h_end, w_start, w_end in batch_view: + add_time_ids_ = add_time_ids.clone() + add_time_ids_[:, 2] = h_start * self.vae_scale_factor + add_time_ids_[:, 3] = w_start * self.vae_scale_factor + add_time_ids_input.append(add_time_ids_) + add_time_ids_input = torch.cat(add_time_ids_input) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds_input, "time_ids": add_time_ids_input} + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds_input, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred[::2], noise_pred[1::2] + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg( + noise_pred, noise_pred_text, guidance_rescale=guidance_rescale + ) + + # compute the previous noisy sample x_t -> x_t-1 + self.scheduler._init_step_index(t) + latents_denoised_batch = self.scheduler.step( + noise_pred, t, latents_for_view, **extra_step_kwargs, return_dict=False + )[0] + + # extract value from batch + for latents_view_denoised, (h_start, h_end, w_start, w_end) in zip( + latents_denoised_batch.chunk(vb_size), batch_view + ): + value_local[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised + count_local[:, :, h_start:h_end, w_start:w_end] += 1 + + value_local = value_local[ + :, + :, + jitter_range : jitter_range + current_height // self.vae_scale_factor, + jitter_range : jitter_range + current_width // self.vae_scale_factor, + ] + count_local = count_local[ + :, + :, + jitter_range : jitter_range + current_height // self.vae_scale_factor, + jitter_range : jitter_range + current_width // self.vae_scale_factor, + ] + + c2 = cosine_factor**cosine_scale_2 + + value += value_local / count_local * (1 - c2) + count += torch.ones_like(value_local) * (1 - c2) + + ############################################# Dilated Sampling ############################################# + + views = [[h, w] for h in range(current_scale_num) for w in range(current_scale_num)] + views_batch = [views[i : i + view_batch_size] for i in range(0, len(views), view_batch_size)] + + h_pad = (current_scale_num - (latents.size(2) % current_scale_num)) % current_scale_num + w_pad = (current_scale_num - (latents.size(3) % current_scale_num)) % current_scale_num + latents_ = F.pad(latents, (w_pad, 0, h_pad, 0), "constant", 0) + + count_global = torch.zeros_like(latents_) + value_global = torch.zeros_like(latents_) + + c3 = 0.99 * cosine_factor**cosine_scale_3 + 1e-2 + std_, mean_ = latents_.std(), latents_.mean() + latents_gaussian = gaussian_filter( + latents_, kernel_size=(2 * current_scale_num - 1), sigma=sigma * c3 + ) + latents_gaussian = ( + latents_gaussian - latents_gaussian.mean() + ) / latents_gaussian.std() * std_ + mean_ + + for j, batch_view in enumerate(views_batch): + latents_for_view = torch.cat( + [latents_[:, :, h::current_scale_num, w::current_scale_num] for h, w in batch_view] + ) + latents_for_view_gaussian = torch.cat( + [latents_gaussian[:, :, h::current_scale_num, w::current_scale_num] for h, w in batch_view] + ) + + vb_size = latents_for_view.size(0) + + # expand the latents if we are doing classifier free guidance + latent_model_input = latents_for_view_gaussian + latent_model_input = ( + latent_model_input.repeat_interleave(2, dim=0) + if do_classifier_free_guidance + else latent_model_input + ) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + prompt_embeds_input = torch.cat([prompt_embeds] * vb_size) + add_text_embeds_input = torch.cat([add_text_embeds] * vb_size) + add_time_ids_input = torch.cat([add_time_ids] * vb_size) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds_input, "time_ids": add_time_ids_input} + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds_input, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred[::2], noise_pred[1::2] + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg( + noise_pred, noise_pred_text, guidance_rescale=guidance_rescale + ) + + # compute the previous noisy sample x_t -> x_t-1 + self.scheduler._init_step_index(t) + latents_denoised_batch = self.scheduler.step( + noise_pred, t, latents_for_view, **extra_step_kwargs, return_dict=False + )[0] + + # extract value from batch + for latents_view_denoised, (h, w) in zip(latents_denoised_batch.chunk(vb_size), batch_view): + value_global[:, :, h::current_scale_num, w::current_scale_num] += latents_view_denoised + count_global[:, :, h::current_scale_num, w::current_scale_num] += 1 + + c2 = cosine_factor**cosine_scale_2 + + value_global = value_global[:, :, h_pad:, w_pad:] + + value += value_global * c2 + count += torch.ones_like(value_global) * c2 + + ########################################################### + + latents = torch.where(count > 0, value / count, value) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + ######################################################################################################################################### + + latents = (latents - latents.mean()) / latents.std() * anchor_std + anchor_mean + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + print("### Phase {} Decoding ###".format(current_scale_num)) + if multi_decoder: + image = self.tiled_decode(latents, current_height, current_width) + else: + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + + if not output_type == "latent": + image = self.image_processor.postprocess(image, output_type=output_type) + if show_image: + plt.figure(figsize=(10, 10)) + plt.imshow(image[0]) + plt.axis("off") # Turn off axis numbers and ticks + plt.show() + output_images.append(image[0]) + + # Offload all models + self.maybe_free_model_hooks() + + return output_images + + # Overrride to properly handle the loading and unloading of the additional text encoder. + def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): + # We could have accessed the unet config from `lora_state_dict()` too. We pass + # it here explicitly to be able to tell that it's coming from an SDXL + # pipeline. + + # Remove any existing hooks. + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module + else: + raise ImportError("Offloading requires `accelerate v0.17.0` or higher.") + + is_model_cpu_offload = False + is_sequential_cpu_offload = False + recursive = False + for _, component in self.components.items(): + if isinstance(component, torch.nn.Module): + if hasattr(component, "_hf_hook"): + is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload) + is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook) + logger.info( + "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." + ) + recursive = is_sequential_cpu_offload + remove_hook_from_module(component, recurse=recursive) + state_dict, network_alphas = self.lora_state_dict( + pretrained_model_name_or_path_or_dict, + unet_config=self.unet.config, + **kwargs, + ) + self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet) + + text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} + if len(text_encoder_state_dict) > 0: + self.load_lora_into_text_encoder( + text_encoder_state_dict, + network_alphas=network_alphas, + text_encoder=self.text_encoder, + prefix="text_encoder", + lora_scale=self.lora_scale, + ) + + text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k} + if len(text_encoder_2_state_dict) > 0: + self.load_lora_into_text_encoder( + text_encoder_2_state_dict, + network_alphas=network_alphas, + text_encoder=self.text_encoder_2, + prefix="text_encoder_2", + lora_scale=self.lora_scale, + ) + + # Offload back. + if is_model_cpu_offload: + self.enable_model_cpu_offload() + elif is_sequential_cpu_offload: + self.enable_sequential_cpu_offload() + + @classmethod + def save_lora_weights( + self, + save_directory: Union[str, os.PathLike], + unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + ): + state_dict = {} + + def pack_weights(layers, prefix): + layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers + layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()} + return layers_state_dict + + if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers): + raise ValueError( + "You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`." + ) + + if unet_lora_layers: + state_dict.update(pack_weights(unet_lora_layers, "unet")) + + if text_encoder_lora_layers and text_encoder_2_lora_layers: + state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder")) + state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) + + self.write_lora_layers( + state_dict=state_dict, + save_directory=save_directory, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) + + def _remove_text_encoder_monkey_patch(self): + self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder) + self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2) From 880c0fdd365f3c91c2d65cbbf97df7d2ab98bd92 Mon Sep 17 00:00:00 2001 From: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> Date: Mon, 4 Dec 2023 19:38:44 +0200 Subject: [PATCH 4/7] [advanced dreambooth lora training script][bug_fix] change token_abstraction type to str (#6040) * improve help tags * style fix * changes token_abstraction type to string. support multiple concepts for pivotal using a comma separated string. * style fixup * changed logger to warning (not yet available) * moved the token_abstraction parsing to be in the same block as where we create the mapping of identifier to token --------- Co-authored-by: Linoy --- .../train_dreambooth_lora_sdxl_advanced.py | 24 ++++++++----------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index 29fe2744ad7a..2bf3cc8f7c9c 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -300,16 +300,18 @@ def parse_args(input_args=None): ) parser.add_argument( "--token_abstraction", + type=str, default="TOK", help="identifier specifying the instance(or instances) as used in instance_prompt, validation prompt, " - "captions - e.g. TOK", + "captions - e.g. TOK. To use multiple identifiers, please specify them in a comma seperated string - e.g. " + "'TOK,TOK2,TOK3' etc.", ) parser.add_argument( "--num_new_tokens_per_abstraction", type=int, default=2, - help="number of new tokens inserted to the tokenizers per token_abstraction value when " + help="number of new tokens inserted to the tokenizers per token_abstraction identifier when " "--train_text_encoder_ti = True. By default, each --token_abstraction (e.g. TOK) is mapped to 2 new " "tokens - ", ) @@ -660,17 +662,6 @@ def parse_args(input_args=None): "inversion training check `--train_text_encoder_ti`" ) - if args.train_text_encoder_ti: - if isinstance(args.token_abstraction, str): - args.token_abstraction = [args.token_abstraction] - elif isinstance(args.token_abstraction, List): - args.token_abstraction = args.token_abstraction - else: - raise ValueError( - f"Unsupported type for --args.token_abstraction: {type(args.token_abstraction)}. " - f"Supported types are: str (for a single instance identifier) or List[str] (for multiple concepts)" - ) - env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) if env_local_rank != -1 and env_local_rank != args.local_rank: args.local_rank = env_local_rank @@ -1155,9 +1146,14 @@ def main(args): ) if args.train_text_encoder_ti: + # we parse the provided token identifier (or identifiers) into a list. s.t. - "TOK" -> ["TOK"], "TOK, + # TOK2" -> ["TOK", "TOK2"] etc. + token_abstraction_list = "".join(args.token_abstraction.split()).split(",") + logger.info(f"list of token identifiers: {token_abstraction_list}") + token_abstraction_dict = {} token_idx = 0 - for i, token in enumerate(args.token_abstraction): + for i, token in enumerate(token_abstraction_list): token_abstraction_dict[token] = [ f"" for j in range(args.num_new_tokens_per_abstraction) ] From b64f835ea73b2a25bf81e31a91fdd1925669c290 Mon Sep 17 00:00:00 2001 From: Steven Liu <59462357+stevhliu@users.noreply.github.com> Date: Mon, 4 Dec 2023 10:11:15 -0800 Subject: [PATCH 5/7] [docs] Add Kandinsky 3 (#5988) * add * fix api docs * edits --- docs/source/en/using-diffusers/kandinsky.md | 45 +++++++++++++++++++ .../kandinsky3/pipeline_kandinsky3.py | 4 +- 2 files changed, 47 insertions(+), 2 deletions(-) diff --git a/docs/source/en/using-diffusers/kandinsky.md b/docs/source/en/using-diffusers/kandinsky.md index 05be2e1ee289..0fbec32a5296 100644 --- a/docs/source/en/using-diffusers/kandinsky.md +++ b/docs/source/en/using-diffusers/kandinsky.md @@ -20,6 +20,8 @@ The Kandinsky models are a series of multilingual text-to-image generation model [Kandinsky 2.2](../api/pipelines/kandinsky_v22) improves on the previous model by replacing the image encoder of the image prior model with a larger CLIP-ViT-G model to improve quality. The image prior model was also retrained on images with different resolutions and aspect ratios to generate higher-resolution images and different image sizes. +[Kandinsky 3](../api/pipelines/kandinsky3) simplifies the architecture and shifts away from the two-stage generation process involving the prior model and diffusion model. Instead, Kandinsky 3 uses [Flan-UL2](https://huggingface.co/google/flan-ul2) to encode text, a UNet with [BigGan-deep](https://hf.co/papers/1809.11096) blocks, and [Sber-MoVQGAN](https://github.com/ai-forever/MoVQGAN) to decode the latents into images. Text understanding and generated image quality are primarily achieved by using a larger text encoder and UNet. + This guide will show you how to use the Kandinsky models for text-to-image, image-to-image, inpainting, interpolation, and more. Before you begin, make sure you have the following libraries installed: @@ -33,6 +35,10 @@ Before you begin, make sure you have the following libraries installed: Kandinsky 2.1 and 2.2 usage is very similar! The only difference is Kandinsky 2.2 doesn't accept `prompt` as an input when decoding the latents. Instead, Kandinsky 2.2 only accepts `image_embeds` during decoding. +
+ +Kandinsky 3 has a more concise architecture and it doesn't require a prior model. This means it's usage is identical to other diffusion models like [Stable Diffusion XL](sdxl). + ## Text-to-image @@ -91,6 +97,23 @@ image + + + +Kandinsky 3 doesn't require a prior model so you can directly load the [`Kandinsky3Pipeline`] and pass a prompt to generate an image: + +```py +from diffusers import Kandinsky3Pipeline +import torch + +pipeline = Kandinsky3Pipeline.from_pretrained("kandinsky-community/kandinsky-3", variant="fp16", torch_dtype=torch.float16) +pipeline.enable_model_cpu_offload() + +prompt = "A alien cheeseburger creature eating itself, claymation, cinematic, moody lighting" +image = pipeline(prompt).images[0] +image +``` + @@ -161,6 +184,20 @@ prior_pipeline = KandinskyPriorPipeline.from_pretrained("kandinsky-community/kan pipeline = KandinskyV22Img2ImgPipeline.from_pretrained("kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16, use_safetensors=True).to("cuda") ``` + + + +Kandinsky 3 doesn't require a prior model so you can directly load the image-to-image pipeline: + +```py +from diffusers import Kandinsky3Img2ImgPipeline +from diffusers.utils import load_image +import torch + +pipeline = Kandinsky3Img2ImgPipeline.from_pretrained("kandinsky-community/kandinsky-3", variant="fp16", torch_dtype=torch.float16) +pipeline.enable_model_cpu_offload() +``` + @@ -218,6 +255,14 @@ make_image_grid([original_image.resize((512, 512)), image.resize((512, 512))], r + + + +```py +image = pipeline(prompt, negative_prompt=negative_prompt, image=image, strength=0.75, num_inference_steps=25).images[0] +image +``` + diff --git a/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py b/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py index 4d14fc637b05..fcf7ddcb9966 100644 --- a/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +++ b/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py @@ -110,7 +110,7 @@ def encode_prompt( Encodes the prompt into text encoder hidden states. Args: - prompt (`str` or `List[str]`, *optional*): + prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`, *optional*): torch device to place the resulting embeddings on @@ -365,7 +365,7 @@ def __call__( prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. - num_inference_steps (`int`, *optional*, defaults to 50): + num_inference_steps (`int`, *optional*, defaults to 25): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): From 4684ea2fe8d568f44c491068c3eb94aac27045f3 Mon Sep 17 00:00:00 2001 From: Steven Liu <59462357+stevhliu@users.noreply.github.com> Date: Mon, 4 Dec 2023 10:12:52 -0800 Subject: [PATCH 6/7] [docs] `#Copied from` mechanism (#6007) * copied from section * feedback --- docs/source/en/conceptual/contribution.md | 34 ++++++++++++++++++----- 1 file changed, 27 insertions(+), 7 deletions(-) diff --git a/docs/source/en/conceptual/contribution.md b/docs/source/en/conceptual/contribution.md index dc942a24c42e..d2b45cac7362 100644 --- a/docs/source/en/conceptual/contribution.md +++ b/docs/source/en/conceptual/contribution.md @@ -297,17 +297,37 @@ if you don't know yet what specific component you would like to add: - [Model or pipeline](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+pipeline%2Fmodel%22) - [Scheduler](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+scheduler%22) -Before adding any of the three components, it is strongly recommended that you give the [Philosophy guide](philosophy) a read to better understand the design of any of the three components. Please be aware that -we cannot merge model, scheduler, or pipeline additions that strongly diverge from our design philosophy -as it will lead to API inconsistencies. If you fundamentally disagree with a design choice, please -open a [Feedback issue](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feedback.md&title=) instead so that it can be discussed whether a certain design -pattern/design choice shall be changed everywhere in the library and whether we shall update our design philosophy. Consistency across the library is very important for us. +Before adding any of the three components, it is strongly recommended that you give the [Philosophy guide](philosophy) a read to better understand the design of any of the three components. Please be aware that we cannot merge model, scheduler, or pipeline additions that strongly diverge from our design philosophy +as it will lead to API inconsistencies. If you fundamentally disagree with a design choice, please open a [Feedback issue](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feedback.md&title=) instead so that it can be discussed whether a certain design pattern/design choice shall be changed everywhere in the library and whether we shall update our design philosophy. Consistency across the library is very important for us. -Please make sure to add links to the original codebase/paper to the PR and ideally also ping the -original author directly on the PR so that they can follow the progress and potentially help with questions. +Please make sure to add links to the original codebase/paper to the PR and ideally also ping the original author directly on the PR so that they can follow the progress and potentially help with questions. If you are unsure or stuck in the PR, don't hesitate to leave a message to ask for a first review or help. +#### Copied from mechanism + +A unique and important feature to understand when adding any pipeline, model or scheduler code is the `# Copied from` mechanism. You'll see this all over the Diffusers codebase, and the reason we use it is to keep the codebase easy to understand and maintain. Marking code with the `# Copied from` mechanism forces the marked code to be identical to the code it was copied from. This makes it easy to update and propagate changes across many files whenever you run `make fix-copies`. + +For example, in the code example below, [`~diffusers.pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is the original code and `AltDiffusionPipelineOutput` uses the `# Copied from` mechanism to copy it. The only difference is changing the class prefix from `Stable` to `Alt`. + +```py +# Copied from diffusers.pipelines.stable_diffusion.pipeline_output.StableDiffusionPipelineOutput with Stable->Alt +class AltDiffusionPipelineOutput(BaseOutput): + """ + Output class for Alt Diffusion pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width, + num_channels)`. + nsfw_content_detected (`List[bool]`) + List indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content or + `None` if safety checking could not be performed. + """ +``` + +To learn more, read this section of the [~Don't~ Repeat Yourself*](https://huggingface.co/blog/transformers-design-philosophy#4-machine-learning-models-are-static) blog post. + ## How to write a good issue **The better your issue is written, the higher the chances that it will be quickly resolved.** From f9487783228cd500a21555da3346db40e8f05992 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Tue, 5 Dec 2023 15:12:37 +0530 Subject: [PATCH 7/7] Move kandinsky convert script (#6047) move kandinsky convert script --- .../diffusers/pipelines/kandinsky3}/convert_kandinsky3_unet.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename {tests => src/diffusers/pipelines/kandinsky3}/convert_kandinsky3_unet.py (100%) diff --git a/tests/convert_kandinsky3_unet.py b/src/diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py similarity index 100% rename from tests/convert_kandinsky3_unet.py rename to src/diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py