Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IP adapter support for most pipelines #5900

Merged
merged 23 commits into from
Dec 10, 2023
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
aed5ff2
support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeli…
a-r-r-o-w Nov 21, 2023
5b546d7
support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeli…
a-r-r-o-w Nov 22, 2023
5ceb2fb
Merge branch 'main' into ip-adapter-txt2img
a-r-r-o-w Nov 22, 2023
da5d81d
support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeli…
a-r-r-o-w Nov 23, 2023
2538820
update tests
a-r-r-o-w Nov 23, 2023
960a1b8
support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeli…
a-r-r-o-w Nov 23, 2023
0223705
support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeli…
a-r-r-o-w Nov 25, 2023
18d3d2d
support ip-adapter in src/diffusers/pipelines/stable_diffusion_safe/p…
a-r-r-o-w Nov 25, 2023
6a5d5e7
Merge branch 'main' into ip-adapter-txt2img
a-r-r-o-w Nov 25, 2023
f3a76d1
support ip-adapter in src/diffusers/pipelines/latent_consistency_mode…
a-r-r-o-w Nov 25, 2023
bf62310
support ip-adapter in src/diffusers/pipelines/latent_consistency_mode…
a-r-r-o-w Nov 25, 2023
a606916
support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeli…
a-r-r-o-w Nov 27, 2023
a185099
revert changes to sd_attend_and_excite and sd_upscale
a-r-r-o-w Nov 27, 2023
368b21d
Merge branch 'main' into ip-adapter-txt2img
a-r-r-o-w Nov 28, 2023
a820d4f
make style
a-r-r-o-w Nov 28, 2023
f35726f
fix broken tests
a-r-r-o-w Nov 28, 2023
ff602e1
Merge branch 'main' into ip-adapter-txt2img
a-r-r-o-w Nov 28, 2023
3b405f7
Merge branch 'main' into ip-adapter-txt2img
yiyixuxu Dec 7, 2023
3ac975d
update ip-adapter implementation to latest
a-r-r-o-w Dec 7, 2023
b915ed2
Merge branch 'main' into ip-adapter-txt2img
a-r-r-o-w Dec 7, 2023
23e94f8
Merge branch 'main' into ip-adapter-txt2img
sayakpaul Dec 8, 2023
a3ac5ce
apply suggestions from review
a-r-r-o-w Dec 9, 2023
0d64ce3
Merge branch 'main' into ip-adapter-txt2img
a-r-r-o-w Dec 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@

import PIL.Image
import torch
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection

from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import LCMScheduler
Expand Down Expand Up @@ -129,7 +129,7 @@ def retrieve_timesteps(


class LatentConsistencyModelImg2ImgPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin, FromSingleFileMixin
):
r"""
Pipeline for image-to-image generation using a latent consistency model.
Expand All @@ -142,6 +142,7 @@ class LatentConsistencyModelImg2ImgPipeline(
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters

Args:
vae ([`AutoencoderKL`]):
a-r-r-o-w marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -166,7 +167,7 @@ class LatentConsistencyModelImg2ImgPipeline(
"""

model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"]
_optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
_exclude_from_cpu_offload = ["safety_checker"]
_callback_tensor_inputs = ["latents", "denoised", "prompt_embeds", "w_embedding"]

Expand All @@ -179,6 +180,7 @@ def __init__(
scheduler: LCMScheduler,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor,
image_encoder: Optional[CLIPVisionModelWithProjection] = None,
a-r-r-o-w marked this conversation as resolved.
Show resolved Hide resolved
requires_safety_checker: bool = True,
):
super().__init__()
Expand All @@ -191,6 +193,7 @@ def __init__(
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)

if safety_checker is None and requires_safety_checker:
Expand Down Expand Up @@ -449,6 +452,20 @@ def encode_prompt(

return prompt_embeds, negative_prompt_embeds

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
def encode_image(self, image, device, num_images_per_prompt):
dtype = next(self.image_encoder.parameters()).dtype

if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values

image = image.to(device=device, dtype=dtype)
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)

uncond_image_embeds = torch.zeros_like(image_embeds)
return image_embeds, uncond_image_embeds

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
Expand Down Expand Up @@ -647,6 +664,7 @@ def __call__(
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -695,6 +713,8 @@ def __call__(
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*):
Optional image input to work with IP Adapters.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Expand Down Expand Up @@ -758,6 +778,13 @@ def __call__(
device = self._execution_device
# do_classifier_free_guidance = guidance_scale > 1.0

if ip_adapter_image is not None:
image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
# Negative embeds are not supported in LCM yet.
# if do_classifier_free_guidance:
if False:
a-r-r-o-w marked this conversation as resolved.
Show resolved Hide resolved
image_embeds = torch.cat([negative_image_embeds, image_embeds])

# 3. Encode input prompt
lora_scale = (
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
Expand Down Expand Up @@ -815,6 +842,9 @@ def __call__(
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, None)

# 7.1 Add image embeds for IP-Adapter
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None

# 8. LCM Multistep Sampling Loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
Expand All @@ -829,6 +859,7 @@ def __call__(
timestep_cond=w_embedding,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=self.cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
from typing import Any, Callable, Dict, List, Optional, Union

import torch
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection

from ...image_processor import VaeImageProcessor
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import LCMScheduler
Expand Down Expand Up @@ -107,7 +107,7 @@ def retrieve_timesteps(


class LatentConsistencyModelPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin, FromSingleFileMixin
):
r"""
Pipeline for text-to-image generation using a latent consistency model.
Expand All @@ -120,6 +120,7 @@ class LatentConsistencyModelPipeline(
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters

Args:
vae ([`AutoencoderKL`]):
Expand All @@ -144,7 +145,7 @@ class LatentConsistencyModelPipeline(
"""

model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"]
_optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
_exclude_from_cpu_offload = ["safety_checker"]
_callback_tensor_inputs = ["latents", "denoised", "prompt_embeds", "w_embedding"]

Expand All @@ -157,6 +158,7 @@ def __init__(
scheduler: LCMScheduler,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor,
image_encoder: Optional[CLIPVisionModelWithProjection] = None,
requires_safety_checker: bool = True,
):
super().__init__()
Expand Down Expand Up @@ -185,6 +187,7 @@ def __init__(
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
Expand Down Expand Up @@ -433,6 +436,20 @@ def encode_prompt(

return prompt_embeds, negative_prompt_embeds

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
def encode_image(self, image, device, num_images_per_prompt):
dtype = next(self.image_encoder.parameters()).dtype

if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values

image = image.to(device=device, dtype=dtype)
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)

uncond_image_embeds = torch.zeros_like(image_embeds)
return image_embeds, uncond_image_embeds

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
Expand Down Expand Up @@ -581,6 +598,7 @@ def __call__(
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -629,6 +647,8 @@ def __call__(
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*):
Optional image input to work with IP Adapters.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Expand Down Expand Up @@ -697,6 +717,13 @@ def __call__(
device = self._execution_device
# do_classifier_free_guidance = guidance_scale > 1.0

if ip_adapter_image is not None:
image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
# Negative embeds are not supported in LCM yet.
# if do_classifier_free_guidance:
if False:
image_embeds = torch.cat([negative_image_embeds, image_embeds])

# 3. Encode input prompt
lora_scale = (
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
Expand Down Expand Up @@ -748,6 +775,9 @@ def __call__(
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, None)

# 7.1 Add image embeds for IP-Adapter
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None

# 8. LCM MultiStep Sampling Loop:
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
Expand All @@ -762,6 +792,7 @@ def __call__(
timestep_cond=w_embedding,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=self.cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
import numpy as np
import PIL.Image
import torch
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection

from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import PIL_INTERPOLATION, deprecate, logging
Expand Down Expand Up @@ -72,7 +72,9 @@ def retrieve_latents(
raise AttributeError("Could not access latents of provided encoder_output")


class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
class StableDiffusionInstructPix2PixPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin
):
r"""
Pipeline for pixel-level image editing by following text instructions (based on Stable Diffusion).

Expand All @@ -83,6 +85,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters

Args:
vae ([`AutoencoderKL`]):
Expand All @@ -105,7 +108,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion
"""

model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"]
_optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
_exclude_from_cpu_offload = ["safety_checker"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "image_latents"]

Expand All @@ -118,6 +121,7 @@ def __init__(
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor,
image_encoder: Optional[CLIPVisionModelWithProjection] = None,
requires_safety_checker: bool = True,
):
super().__init__()
Expand Down Expand Up @@ -146,6 +150,7 @@ def __init__(
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
Expand All @@ -166,6 +171,7 @@ def __call__(
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
Expand Down Expand Up @@ -213,6 +219,8 @@ def __call__(
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*):
Optional image input to work with IP Adapters.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Expand Down Expand Up @@ -293,6 +301,13 @@ def __call__(
self._guidance_scale = guidance_scale
self._image_guidance_scale = image_guidance_scale

device = self._execution_device

if ip_adapter_image is not None:
image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
if self.do_classifier_free_guidance:
image_embeds = torch.cat([image_embeds, negative_image_embeds, negative_image_embeds])

if image is None:
raise ValueError("`image` input cannot be undefined.")

Expand Down Expand Up @@ -367,6 +382,9 @@ def __call__(
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

# 8.1 Add image embeds for IP-Adapter
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None

# 9. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
Expand All @@ -383,7 +401,11 @@ def __call__(

# predict the noise residual
noise_pred = self.unet(
scaled_latent_model_input, t, encoder_hidden_states=prompt_embeds, return_dict=False
scaled_latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]

# Hack:
Expand Down Expand Up @@ -598,11 +620,25 @@ def _encode_prompt(
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
# pix2pix has two negative embeddings, and unlike in other pipelines latents are ordered [prompt_embeds, negative_prompt_embeds, negative_prompt_embeds]
# pix2pix has two negative embeddings, and unlike in other pipelines latents are ordered [prompt_embeds, negative_prompt_embeds, negative_prompt_embeds]
prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds, negative_prompt_embeds])

return prompt_embeds

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
def encode_image(self, image, device, num_images_per_prompt):
dtype = next(self.image_encoder.parameters()).dtype

if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values

image = image.to(device=device, dtype=dtype)
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)

uncond_image_embeds = torch.zeros_like(image_embeds)
return image_embeds, uncond_image_embeds

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
Expand Down
Loading
Loading