Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IP adapter support for most pipelines #5900

Merged
merged 23 commits into from
Dec 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
aed5ff2
support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeli…
a-r-r-o-w Nov 21, 2023
5b546d7
support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeli…
a-r-r-o-w Nov 22, 2023
5ceb2fb
Merge branch 'main' into ip-adapter-txt2img
a-r-r-o-w Nov 22, 2023
da5d81d
support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeli…
a-r-r-o-w Nov 23, 2023
2538820
update tests
a-r-r-o-w Nov 23, 2023
960a1b8
support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeli…
a-r-r-o-w Nov 23, 2023
0223705
support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeli…
a-r-r-o-w Nov 25, 2023
18d3d2d
support ip-adapter in src/diffusers/pipelines/stable_diffusion_safe/p…
a-r-r-o-w Nov 25, 2023
6a5d5e7
Merge branch 'main' into ip-adapter-txt2img
a-r-r-o-w Nov 25, 2023
f3a76d1
support ip-adapter in src/diffusers/pipelines/latent_consistency_mode…
a-r-r-o-w Nov 25, 2023
bf62310
support ip-adapter in src/diffusers/pipelines/latent_consistency_mode…
a-r-r-o-w Nov 25, 2023
a606916
support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeli…
a-r-r-o-w Nov 27, 2023
a185099
revert changes to sd_attend_and_excite and sd_upscale
a-r-r-o-w Nov 27, 2023
368b21d
Merge branch 'main' into ip-adapter-txt2img
a-r-r-o-w Nov 28, 2023
a820d4f
make style
a-r-r-o-w Nov 28, 2023
f35726f
fix broken tests
a-r-r-o-w Nov 28, 2023
ff602e1
Merge branch 'main' into ip-adapter-txt2img
a-r-r-o-w Nov 28, 2023
3b405f7
Merge branch 'main' into ip-adapter-txt2img
yiyixuxu Dec 7, 2023
3ac975d
update ip-adapter implementation to latest
a-r-r-o-w Dec 7, 2023
b915ed2
Merge branch 'main' into ip-adapter-txt2img
a-r-r-o-w Dec 7, 2023
23e94f8
Merge branch 'main' into ip-adapter-txt2img
sayakpaul Dec 8, 2023
a3ac5ce
apply suggestions from review
a-r-r-o-w Dec 9, 2023
0d64ce3
Merge branch 'main' into ip-adapter-txt2img
a-r-r-o-w Dec 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@

import PIL.Image
import torch
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection

from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import LCMScheduler
from ...utils import (
Expand Down Expand Up @@ -129,7 +129,7 @@ def retrieve_timesteps(


class LatentConsistencyModelImg2ImgPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin, FromSingleFileMixin
):
r"""
Pipeline for image-to-image generation using a latent consistency model.
Expand All @@ -142,6 +142,7 @@ class LatentConsistencyModelImg2ImgPipeline(
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters

Args:
vae ([`AutoencoderKL`]):
a-r-r-o-w marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -166,7 +167,7 @@ class LatentConsistencyModelImg2ImgPipeline(
"""

model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"]
_optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
_exclude_from_cpu_offload = ["safety_checker"]
_callback_tensor_inputs = ["latents", "denoised", "prompt_embeds", "w_embedding"]

Expand All @@ -179,6 +180,7 @@ def __init__(
scheduler: LCMScheduler,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor,
image_encoder: Optional[CLIPVisionModelWithProjection] = None,
a-r-r-o-w marked this conversation as resolved.
Show resolved Hide resolved
requires_safety_checker: bool = True,
):
super().__init__()
Expand All @@ -191,6 +193,7 @@ def __init__(
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)

if safety_checker is None and requires_safety_checker:
Expand Down Expand Up @@ -449,6 +452,31 @@ def encode_prompt(

return prompt_embeds, negative_prompt_embeds

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
dtype = next(self.image_encoder.parameters()).dtype

if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values

image = image.to(device=device, dtype=dtype)
if output_hidden_states:
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_enc_hidden_states = self.image_encoder(
torch.zeros_like(image), output_hidden_states=True
).hidden_states[-2]
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
num_images_per_prompt, dim=0
)
return image_enc_hidden_states, uncond_image_enc_hidden_states
else:
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)

return image_embeds, uncond_image_embeds

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
Expand Down Expand Up @@ -647,6 +675,7 @@ def __call__(
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -695,6 +724,8 @@ def __call__(
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*):
Optional image input to work with IP Adapters.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Expand Down Expand Up @@ -758,6 +789,12 @@ def __call__(
device = self._execution_device
# do_classifier_free_guidance = guidance_scale > 1.0

if ip_adapter_image is not None:
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
image_embeds, negative_image_embeds = self.encode_image(
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
)

# 3. Encode input prompt
lora_scale = (
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
Expand Down Expand Up @@ -815,6 +852,9 @@ def __call__(
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, None)

# 7.1 Add image embeds for IP-Adapter
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None

# 8. LCM Multistep Sampling Loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
Expand All @@ -829,6 +869,7 @@ def __call__(
timestep_cond=w_embedding,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=self.cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
from typing import Any, Callable, Dict, List, Optional, Union

import torch
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection

from ...image_processor import VaeImageProcessor
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import LCMScheduler
from ...utils import (
Expand Down Expand Up @@ -107,7 +107,7 @@ def retrieve_timesteps(


class LatentConsistencyModelPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin, FromSingleFileMixin
):
r"""
Pipeline for text-to-image generation using a latent consistency model.
Expand All @@ -120,6 +120,7 @@ class LatentConsistencyModelPipeline(
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters

Args:
vae ([`AutoencoderKL`]):
Expand All @@ -144,7 +145,7 @@ class LatentConsistencyModelPipeline(
"""

model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"]
_optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
_exclude_from_cpu_offload = ["safety_checker"]
_callback_tensor_inputs = ["latents", "denoised", "prompt_embeds", "w_embedding"]

Expand All @@ -157,6 +158,7 @@ def __init__(
scheduler: LCMScheduler,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor,
image_encoder: Optional[CLIPVisionModelWithProjection] = None,
requires_safety_checker: bool = True,
):
super().__init__()
Expand Down Expand Up @@ -185,6 +187,7 @@ def __init__(
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
Expand Down Expand Up @@ -433,6 +436,31 @@ def encode_prompt(

return prompt_embeds, negative_prompt_embeds

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
dtype = next(self.image_encoder.parameters()).dtype

if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values

image = image.to(device=device, dtype=dtype)
if output_hidden_states:
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_enc_hidden_states = self.image_encoder(
torch.zeros_like(image), output_hidden_states=True
).hidden_states[-2]
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
num_images_per_prompt, dim=0
)
return image_enc_hidden_states, uncond_image_enc_hidden_states
else:
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)

return image_embeds, uncond_image_embeds

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
Expand Down Expand Up @@ -581,6 +609,7 @@ def __call__(
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -629,6 +658,8 @@ def __call__(
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*):
Optional image input to work with IP Adapters.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Expand Down Expand Up @@ -697,6 +728,12 @@ def __call__(
device = self._execution_device
# do_classifier_free_guidance = guidance_scale > 1.0

if ip_adapter_image is not None:
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
image_embeds, negative_image_embeds = self.encode_image(
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
)

# 3. Encode input prompt
lora_scale = (
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
Expand Down Expand Up @@ -748,6 +785,9 @@ def __call__(
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, None)

# 7.1 Add image embeds for IP-Adapter
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None

# 8. LCM MultiStep Sampling Loop:
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
Expand All @@ -762,6 +802,7 @@ def __call__(
timestep_cond=w_embedding,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=self.cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]

Expand Down
Loading
Loading