diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py index 36fbf48d6493..5ec35ddf07ab 100644 --- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py @@ -18,11 +18,22 @@ import numpy as np import PIL.Image import torch -from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) -from ...image_processor import VaeImageProcessor -from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, MultiAdapter, T2IAdapter, UNet2DConditionModel +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, +) +from ...models import AutoencoderKL, ImageProjection, MultiAdapter, T2IAdapter, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, LoRAAttnProcessor2_0, @@ -169,7 +180,11 @@ def retrieve_timesteps( class StableDiffusionXLAdapterPipeline( - DiffusionPipeline, FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin + DiffusionPipeline, + TextualInversionLoaderMixin, + StableDiffusionXLLoraLoaderMixin, + IPAdapterMixin, + FromSingleFileMixin, ): r""" Pipeline for text-to-image generation using Stable Diffusion augmented with T2I-Adapter @@ -183,6 +198,7 @@ class StableDiffusionXLAdapterPipeline( - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: adapter ([`T2IAdapter`] or [`MultiAdapter`] or `List[T2IAdapter]`): @@ -211,8 +227,15 @@ class StableDiffusionXLAdapterPipeline( Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ - model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" - _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"] + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "feature_extractor", + "image_encoder", + ] def __init__( self, @@ -225,6 +248,8 @@ def __init__( adapter: Union[T2IAdapter, MultiAdapter, List[T2IAdapter]], scheduler: KarrasDiffusionSchedulers, force_zeros_for_empty_prompt: bool = True, + feature_extractor: CLIPImageProcessor = None, + image_encoder: CLIPVisionModelWithProjection = None, ): super().__init__() @@ -237,6 +262,8 @@ def __init__( unet=unet, adapter=adapter, scheduler=scheduler, + feature_extractor=feature_extractor, + image_encoder=image_encoder, ) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) @@ -511,6 +538,31 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature @@ -768,7 +820,7 @@ def __call__( self, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, - image: Union[torch.Tensor, PIL.Image.Image, List[PIL.Image.Image]] = None, + image: PipelineImageInput = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, @@ -785,6 +837,7 @@ def __call__( negative_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, @@ -876,6 +929,7 @@ def __call__( Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -991,7 +1045,7 @@ def __call__( device = self._execution_device - # 3. Encode input prompt + # 3.1 Encode input prompt ( prompt_embeds, negative_prompt_embeds, @@ -1012,6 +1066,15 @@ def __call__( clip_skip=clip_skip, ) + # 3.2 Encode ip_adapter_image + if ip_adapter_image is not None: + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) + if self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds]) + # 4. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) @@ -1028,10 +1091,10 @@ def __call__( latents, ) - # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + # 6.1 Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - # 6.5 Optionally get Guidance Scale Embedding + # 6.2 Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) @@ -1090,8 +1153,7 @@ def __call__( # 8. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) - - # 7.1 Apply denoising_end + # Apply denoising_end if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1: discrete_timestep_cutoff = int( round( @@ -1109,9 +1171,12 @@ def __call__( latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - # predict the noise residual added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + if ip_adapter_image is not None: + added_cond_kwargs["image_embeds"] = image_embeds + + # predict the noise residual if i < int(num_inference_steps * adapter_conditioning_factor): down_intrablock_additional_residuals = [state.clone() for state in adapter_state] else: @@ -1123,9 +1188,9 @@ def __call__( encoder_hidden_states=prompt_embeds, timestep_cond=timestep_cond, cross_attention_kwargs=cross_attention_kwargs, + down_intrablock_additional_residuals=down_intrablock_additional_residuals, added_cond_kwargs=added_cond_kwargs, return_dict=False, - down_intrablock_additional_residuals=down_intrablock_additional_residuals, )[0] # perform guidance diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py index d1920d59b447..2296e092e7bd 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py @@ -159,7 +159,8 @@ def get_dummy_components(self, adapter_type="full_adapter_xl", time_cond_proj_di "text_encoder_2": text_encoder_2, "tokenizer_2": tokenizer_2, # "safety_checker": None, - # "feature_extractor": None, + "feature_extractor": None, + "image_encoder": None, } return components @@ -265,7 +266,8 @@ def get_dummy_components_with_full_downscaling(self, adapter_type="full_adapter_ "text_encoder_2": text_encoder_2, "tokenizer_2": tokenizer_2, # "safety_checker": None, - # "feature_extractor": None, + "feature_extractor": None, + "image_encoder": None, } return components