Skip to content

Commit

Permalink
[Kolors] Add IP Adapter (#8901)
Browse files Browse the repository at this point in the history
* initial draft

* apply suggestions

* fix failing test

* added ipa to img2img

* add docs

* apply suggestions
  • Loading branch information
asomoza authored Jul 26, 2024
1 parent ca0747a commit 73acebb
Show file tree
Hide file tree
Showing 7 changed files with 362 additions and 16 deletions.
58 changes: 58 additions & 0 deletions docs/source/en/api/pipelines/kolors.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,64 @@ image = pipe(
image.save("kolors_sample.png")
```

### IP Adapter

Kolors needs a different IP Adapter to work, and it uses [Openai-CLIP-336](https://huggingface.co/openai/clip-vit-large-patch14-336) as an image encoder.

<Tip>

Using an IP Adapter with Kolors requires more than 24GB of VRAM. To use it, we recommend using [`~DiffusionPipeline.enable_model_cpu_offload`] on consumer GPUs.

</Tip>

<Tip>

While Kolors is integrated in Diffusers, you need to load the image encoder from a revision to use the safetensor files. You can still use the main branch of the original repository if you're comfortable loading pickle checkpoints.

</Tip>

```python
import torch
from transformers import CLIPVisionModelWithProjection

from diffusers import DPMSolverMultistepScheduler, KolorsPipeline
from diffusers.utils import load_image

image_encoder = CLIPVisionModelWithProjection.from_pretrained(
"Kwai-Kolors/Kolors-IP-Adapter-Plus",
subfolder="image_encoder",
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
revision="refs/pr/4",
)

pipe = KolorsPipeline.from_pretrained(
"Kwai-Kolors/Kolors-diffusers", image_encoder=image_encoder, torch_dtype=torch.float16, variant="fp16"
).to("cuda")
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True)

pipe.load_ip_adapter(
"Kwai-Kolors/Kolors-IP-Adapter-Plus",
subfolder="",
weight_name="ip_adapter_plus_general.safetensors",
revision="refs/pr/4",
image_encoder_folder=None,
)
pipe.enable_model_cpu_offload()

ipa_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/kolors/cat_square.png")

image = pipe(
prompt="best quality, high quality",
negative_prompt="",
guidance_scale=6.5,
num_inference_steps=25,
ip_adapter_image=ipa_image,
).images[0]

image.save("kolors_ipa_sample.png")
```

## KolorsPipeline

[[autodoc]] KolorsPipeline
Expand Down
11 changes: 9 additions & 2 deletions src/diffusers/loaders/ip_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,8 @@ def load_ip_adapter(

# create feature extractor if it has not been registered to the pipeline yet
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
feature_extractor = CLIPImageProcessor()
clip_image_size = self.image_encoder.config.image_size
feature_extractor = CLIPImageProcessor(size=clip_image_size, crop_size=clip_image_size)
self.register_modules(feature_extractor=feature_extractor)

# load ip-adapter into unet
Expand Down Expand Up @@ -319,7 +320,13 @@ def unload_ip_adapter(self):

# remove hidden encoder
self.unet.encoder_hid_proj = None
self.config.encoder_hid_dim_type = None
self.unet.config.encoder_hid_dim_type = None

# Kolors: restore `encoder_hid_proj` with `text_encoder_hid_proj`
if hasattr(self.unet, "text_encoder_hid_proj") and self.unet.text_encoder_hid_proj is not None:
self.unet.encoder_hid_proj = self.unet.text_encoder_hid_proj
self.unet.text_encoder_hid_proj = None
self.unet.config.encoder_hid_dim_type = "text_proj"

# restore original Unet attention processors layers
attn_procs = {}
Expand Down
9 changes: 9 additions & 0 deletions src/diffusers/loaders/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,6 +823,15 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False):
if not isinstance(state_dicts, list):
state_dicts = [state_dicts]

# Kolors Unet already has a `encoder_hid_proj`
if (
self.encoder_hid_proj is not None
and self.config.encoder_hid_dim_type == "text_proj"
and not hasattr(self, "text_encoder_hid_proj")
):
self.text_encoder_hid_proj = self.encoder_hid_proj

# Set encoder_hid_proj after loading ip_adapter weights,
# because `IPAdapterPlusImageProjection` also has `attn_processors`.
self.encoder_hid_proj = None
Expand Down
4 changes: 4 additions & 0 deletions src/diffusers/models/unets/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -1027,6 +1027,10 @@ def process_encoder_hidden_states(
raise ValueError(
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
)

if hasattr(self, "text_encoder_hid_proj") and self.text_encoder_hid_proj is not None:
encoder_hidden_states = self.text_encoder_hid_proj(encoder_hidden_states)

image_embeds = added_cond_kwargs.get("image_embeds")
image_embeds = self.encoder_hid_proj(image_embeds)
encoder_hidden_states = (encoder_hidden_states, image_embeds)
Expand Down
146 changes: 140 additions & 6 deletions src/diffusers/pipelines/kolors/pipeline_kolors.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection

from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import VaeImageProcessor
from ...loaders import StableDiffusionXLLoraLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import IPAdapterMixin, StableDiffusionXLLoraLoaderMixin
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import is_torch_xla_available, logging, replace_example_docstring
Expand Down Expand Up @@ -120,7 +121,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps


class KolorsPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffusionXLLoraLoaderMixin):
class KolorsPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffusionXLLoraLoaderMixin, IPAdapterMixin):
r"""
Pipeline for text-to-image generation using Kolors.
Expand All @@ -130,6 +131,7 @@ class KolorsPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffusionXLL
The pipeline also inherits the following loading methods:
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
Args:
vae ([`AutoencoderKL`]):
Expand All @@ -148,7 +150,11 @@ class KolorsPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffusionXLL
`Kwai-Kolors/Kolors-diffusers`.
"""

model_cpu_offload_seq = "text_encoder->unet->vae"
model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
_optional_components = [
"image_encoder",
"feature_extractor",
]
_callback_tensor_inputs = [
"latents",
"prompt_embeds",
Expand All @@ -166,11 +172,21 @@ def __init__(
tokenizer: ChatGLMTokenizer,
unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers,
image_encoder: CLIPVisionModelWithProjection = None,
feature_extractor: CLIPImageProcessor = None,
force_zeros_for_empty_prompt: bool = False,
):
super().__init__()

self.register_modules(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
image_encoder=image_encoder,
feature_extractor=feature_extractor,
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
Expand Down Expand Up @@ -343,6 +359,77 @@ def encode_prompt(

return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
dtype = next(self.image_encoder.parameters()).dtype

if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values

image = image.to(device=device, dtype=dtype)
if output_hidden_states:
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_enc_hidden_states = self.image_encoder(
torch.zeros_like(image), output_hidden_states=True
).hidden_states[-2]
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
num_images_per_prompt, dim=0
)
return image_enc_hidden_states, uncond_image_enc_hidden_states
else:
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)

return image_embeds, uncond_image_embeds

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
def prepare_ip_adapter_image_embeds(
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
):
image_embeds = []
if do_classifier_free_guidance:
negative_image_embeds = []
if ip_adapter_image_embeds is None:
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]

if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
raise ValueError(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)

for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
):
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
)

image_embeds.append(single_image_embeds[None, :])
if do_classifier_free_guidance:
negative_image_embeds.append(single_negative_image_embeds[None, :])
else:
for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
negative_image_embeds.append(single_negative_image_embeds)
image_embeds.append(single_image_embeds)

ip_adapter_image_embeds = []
for i, single_image_embeds in enumerate(image_embeds):
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
if do_classifier_free_guidance:
single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)

single_image_embeds = single_image_embeds.to(device=device)
ip_adapter_image_embeds.append(single_image_embeds)

return ip_adapter_image_embeds

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
Expand All @@ -364,16 +451,25 @@ def prepare_extra_step_kwargs(self, generator, eta):
def check_inputs(
self,
prompt,
num_inference_steps,
height,
width,
negative_prompt=None,
prompt_embeds=None,
pooled_prompt_embeds=None,
negative_prompt_embeds=None,
negative_pooled_prompt_embeds=None,
ip_adapter_image=None,
ip_adapter_image_embeds=None,
callback_on_step_end_tensor_inputs=None,
max_sequence_length=None,
):
if not isinstance(num_inference_steps, int) or num_inference_steps <= 0:
raise ValueError(
f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type"
f" {type(num_inference_steps)}."
)

if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")

Expand Down Expand Up @@ -420,6 +516,21 @@ def check_inputs(
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
)

if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
raise ValueError(
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
)

if ip_adapter_image_embeds is not None:
if not isinstance(ip_adapter_image_embeds, list):
raise ValueError(
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
)
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
raise ValueError(
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
)

if max_sequence_length is not None and max_sequence_length > 256:
raise ValueError(f"`max_sequence_length` cannot be greater than 256 but is {max_sequence_length}")

Expand Down Expand Up @@ -563,6 +674,8 @@ def __call__(
pooled_prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -649,6 +762,12 @@ def __call__(
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
provided, embeddings are computed from the `ip_adapter_image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
Expand Down Expand Up @@ -719,13 +838,16 @@ def __call__(
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
num_inference_steps,
height,
width,
negative_prompt,
prompt_embeds,
pooled_prompt_embeds,
negative_prompt_embeds,
negative_pooled_prompt_embeds,
ip_adapter_image,
ip_adapter_image_embeds,
callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
)
Expand Down Expand Up @@ -815,6 +937,15 @@ def __call__(
add_text_embeds = add_text_embeds.to(device)
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)

if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
image_embeds = self.prepare_ip_adapter_image_embeds(
ip_adapter_image,
ip_adapter_image_embeds,
device,
batch_size * num_images_per_prompt,
self.do_classifier_free_guidance,
)

# 8. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)

Expand Down Expand Up @@ -856,6 +987,9 @@ def __call__(
# predict the noise residual
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}

if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
added_cond_kwargs["image_embeds"] = image_embeds

noise_pred = self.unet(
latent_model_input,
t,
Expand Down
Loading

0 comments on commit 73acebb

Please sign in to comment.