From 9d8d133033c9b518980de312e1beb7b91ee18ca9 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 25 Jan 2023 13:29:49 +0200 Subject: [PATCH] Allow directly passing text embeddings to Stable Diffusion Pipeline for prompt weighting (#2071) * add text embeds to sd * add text embeds to sd * finish tests * finish * finish * make style * fix tests * make style * make style * up * better docs * fix * fix * new try * up * up * finish --- .../alt_diffusion/pipeline_alt_diffusion.py | 206 +++++++++++++----- .../pipeline_alt_diffusion_img2img.py | 188 +++++++++++----- .../pipeline_latent_diffusion.py | 10 +- .../pipeline_paint_by_example.py | 8 +- .../pipeline_cycle_diffusion.py | 181 ++++++++++----- .../pipeline_flax_stable_diffusion.py | 6 +- .../pipeline_flax_stable_diffusion_img2img.py | 6 +- .../pipeline_flax_stable_diffusion_inpaint.py | 6 +- .../pipeline_onnx_stable_diffusion.py | 20 +- .../pipeline_onnx_stable_diffusion_img2img.py | 24 +- .../pipeline_onnx_stable_diffusion_inpaint.py | 24 +- ...ne_onnx_stable_diffusion_inpaint_legacy.py | 20 +- .../pipeline_stable_diffusion.py | 206 +++++++++++++----- .../pipeline_stable_diffusion_depth2img.py | 192 +++++++++++----- ...peline_stable_diffusion_image_variation.py | 4 +- .../pipeline_stable_diffusion_img2img.py | 188 +++++++++++----- .../pipeline_stable_diffusion_inpaint.py | 198 ++++++++++++----- ...ipeline_stable_diffusion_inpaint_legacy.py | 182 +++++++++++----- ...eline_stable_diffusion_instruct_pix2pix.py | 162 +++++++++----- .../pipeline_stable_diffusion_k_diffusion.py | 158 +++++++++----- .../pipeline_stable_diffusion_upscale.py | 161 +++++++++----- .../pipeline_stable_diffusion_safe.py | 76 +++++-- .../pipelines/unclip/pipeline_unclip.py | 30 +-- .../unclip/pipeline_unclip_image_variation.py | 24 +- src/diffusers/pipelines/unclip/text_proj.py | 10 +- ...ipeline_versatile_diffusion_dual_guided.py | 44 ++-- ...ine_versatile_diffusion_image_variation.py | 14 +- ...eline_versatile_diffusion_text_to_image.py | 72 ++++-- .../vq_diffusion/pipeline_vq_diffusion.py | 30 ++- .../stable_diffusion/test_stable_diffusion.py | 78 ++++++- 30 files changed, 1738 insertions(+), 790 deletions(-) diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index 6978ab8e28b21..d074c54634a4e 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -220,12 +220,21 @@ def _execution_device(self): return torch.device(module._hf_hook.execution_device) return self.device - def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ): r""" Encodes the prompt into text encoder hidden states. Args: - prompt (`str` or `list(int)`): + prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device @@ -233,47 +242,67 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). + negative_ prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. """ - batch_size = len(prompt) if isinstance(prompt, list) else 1 - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" + if prompt_embeds is None: + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = text_inputs.attention_mask.to(device) - else: - attention_mask = None + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) - text_embeddings = self.text_encoder( - text_input_ids.to(device), - attention_mask=attention_mask, - ) - text_embeddings = text_embeddings[0] + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method - bs_embed, seq_len, _ = text_embeddings.shape - text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) - text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance: + if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size @@ -293,7 +322,7 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr else: uncond_tokens = negative_prompt - max_length = text_input_ids.shape[-1] + max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", @@ -307,23 +336,27 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr else: attention_mask = None - uncond_embeddings = self.text_encoder( + negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) - uncond_embeddings = uncond_embeddings[0] + negative_prompt_embeds = negative_prompt_embeds[0] + if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = uncond_embeddings.shape[1] - uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) - uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - return text_embeddings + return prompt_embeds def run_safety_checker(self, image, device, dtype): if self.safety_checker is not None: @@ -360,10 +393,16 @@ def prepare_extra_step_kwargs(self, generator, eta): extra_step_kwargs["generator"] = generator return extra_step_kwargs - def check_inputs(self, prompt, height, width, callback_steps): - if not isinstance(prompt, str) and not isinstance(prompt, list): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") @@ -375,6 +414,32 @@ def check_inputs(self, prompt, height, width, callback_steps): f" {type(callback_steps)}." ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: @@ -396,7 +461,7 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, - prompt: Union[str, List[str]], + prompt: Union[str, List[str]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, @@ -406,6 +471,8 @@ def __call__( eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, @@ -415,8 +482,9 @@ def __call__( Function invoked when calling the pipeline for generation. Args: - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): @@ -431,8 +499,9 @@ def __call__( 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): @@ -445,6 +514,13 @@ def __call__( Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -472,10 +548,18 @@ def __call__( width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct - self.check_inputs(prompt, height, width, callback_steps) + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) # 2. Define call parameters - batch_size = 1 if isinstance(prompt, str) else len(prompt) + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` @@ -483,8 +567,14 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt - text_embeddings = self._encode_prompt( - prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, ) # 4. Prepare timesteps @@ -498,7 +588,7 @@ def __call__( num_channels_latents, height, width, - text_embeddings.dtype, + prompt_embeds.dtype, device, generator, latents, @@ -516,7 +606,7 @@ def __call__( latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample # perform guidance if do_classifier_free_guidance: @@ -536,7 +626,7 @@ def __call__( image = self.decode_latents(latents) # 9. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype) + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) # 10. Convert to PIL if output_type == "pil": diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py index 522a499de942b..6b64edadfc7ce 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -242,12 +242,21 @@ def _execution_device(self): return torch.device(module._hf_hook.execution_device) return self.device - def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ): r""" Encodes the prompt into text encoder hidden states. Args: - prompt (`str` or `list(int)`): + prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device @@ -255,47 +264,67 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). + negative_ prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. """ - batch_size = len(prompt) if isinstance(prompt, list) else 1 - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" + if prompt_embeds is None: + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = text_inputs.attention_mask.to(device) - else: - attention_mask = None + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) - text_embeddings = self.text_encoder( - text_input_ids.to(device), - attention_mask=attention_mask, - ) - text_embeddings = text_embeddings[0] + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method - bs_embed, seq_len, _ = text_embeddings.shape - text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) - text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance: + if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size @@ -315,7 +344,7 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr else: uncond_tokens = negative_prompt - max_length = text_input_ids.shape[-1] + max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", @@ -329,23 +358,27 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr else: attention_mask = None - uncond_embeddings = self.text_encoder( + negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) - uncond_embeddings = uncond_embeddings[0] + negative_prompt_embeds = negative_prompt_embeds[0] + if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = uncond_embeddings.shape[1] - uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) - uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - return text_embeddings + return prompt_embeds def run_safety_checker(self, image, device, dtype): if self.safety_checker is not None: @@ -382,7 +415,9 @@ def prepare_extra_step_kwargs(self, generator, eta): extra_step_kwargs["generator"] = generator return extra_step_kwargs - def check_inputs(self, prompt, strength, callback_steps): + def check_inputs( + self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None + ): if not isinstance(prompt, str) and not isinstance(prompt, list): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") @@ -397,6 +432,32 @@ def check_inputs(self, prompt, strength, callback_steps): f" {type(callback_steps)}." ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) @@ -462,7 +523,7 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, - prompt: Union[str, List[str]], + prompt: Union[str, List[str]] = None, image: Union[torch.FloatTensor, PIL.Image.Image] = None, strength: float = 0.8, num_inference_steps: Optional[int] = 50, @@ -471,6 +532,8 @@ def __call__( num_images_per_prompt: Optional[int] = 1, eta: Optional[float] = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, @@ -481,8 +544,9 @@ def __call__( Function invoked when calling the pipeline for generation. Args: - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. image (`torch.FloatTensor` or `PIL.Image.Image`): `Image`, or tensor representing an image batch, that will be used as the starting point for the process. @@ -502,8 +566,9 @@ def __call__( 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` + is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): @@ -512,6 +577,13 @@ def __call__( generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -537,8 +609,8 @@ def __call__( init_image = deprecate("init_image", "0.13.0", message, take_from=kwargs) image = init_image or image - # 1. Check inputs - self.check_inputs(prompt, strength, callback_steps) + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) # 2. Define call parameters batch_size = 1 if isinstance(prompt, str) else len(prompt) @@ -549,8 +621,14 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt - text_embeddings = self._encode_prompt( - prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, ) # 4. Preprocess image @@ -563,7 +641,7 @@ def __call__( # 6. Prepare latent variables latents = self.prepare_latents( - image, latent_timestep, batch_size, num_images_per_prompt, text_embeddings.dtype, device, generator + image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator ) # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline @@ -578,7 +656,7 @@ def __call__( latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample # perform guidance if do_classifier_free_guidance: @@ -598,7 +676,7 @@ def __call__( image = self.decode_latents(latents) # 10. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype) + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) # 11. Convert to PIL if output_type == "pil": diff --git a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py index 1f0ae4be12852..0a2448fd9dd63 100644 --- a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +++ b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py @@ -129,11 +129,11 @@ def __call__( uncond_input = self.tokenizer( [""] * batch_size, padding="max_length", max_length=77, truncation=True, return_tensors="pt" ) - uncond_embeddings = self.bert(uncond_input.input_ids.to(self.device))[0] + negative_prompt_embeds = self.bert(uncond_input.input_ids.to(self.device))[0] # get prompt text embeddings text_input = self.tokenizer(prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt") - text_embeddings = self.bert(text_input.input_ids.to(self.device))[0] + prompt_embeds = self.bert(text_input.input_ids.to(self.device))[0] # get the initial random noise unless the user supplied it latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) @@ -144,7 +144,7 @@ def __call__( ) if latents is None: - latents = randn_tensor(latents_shape, generator=generator, device=self.device, dtype=text_embeddings.dtype) + latents = randn_tensor(latents_shape, generator=generator, device=self.device, dtype=prompt_embeds.dtype) else: if latents.shape != latents_shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") @@ -163,13 +163,13 @@ def __call__( if guidance_scale == 1.0: # guidance_scale of 1 means no guidance latents_input = latents - context = text_embeddings + context = prompt_embeds else: # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes latents_input = torch.cat([latents] * 2) - context = torch.cat([uncond_embeddings, text_embeddings]) + context = torch.cat([negative_prompt_embeds, prompt_embeds]) # predict the noise residual noise_pred = self.unet(latents_input, t, encoder_hidden_states=context).sample diff --git a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py index 1fb4521d2c036..f24f3547c0baa 100644 --- a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +++ b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py @@ -364,7 +364,7 @@ def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free image = self.feature_extractor(images=image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - image_embeddings, uncond_embeddings = self.image_encoder(image, return_uncond_vector=True) + image_embeddings, negative_prompt_embeds = self.image_encoder(image, return_uncond_vector=True) # duplicate image embeddings for each generation per prompt, using mps friendly method bs_embed, seq_len, _ = image_embeddings.shape @@ -372,13 +372,13 @@ def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) if do_classifier_free_guidance: - uncond_embeddings = uncond_embeddings.repeat(1, image_embeddings.shape[0], 1) - uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, 1, -1) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, image_embeddings.shape[0], 1) + negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, 1, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes - image_embeddings = torch.cat([uncond_embeddings, image_embeddings]) + image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings]) return image_embeddings diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py index 1b006b3ab4389..d242a9f917d09 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py @@ -261,12 +261,21 @@ def _execution_device(self): return self.device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt - def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ): r""" Encodes the prompt into text encoder hidden states. Args: - prompt (`str` or `list(int)`): + prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device @@ -274,47 +283,67 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). + negative_ prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. """ - batch_size = len(prompt) if isinstance(prompt, list) else 1 - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" + if prompt_embeds is None: + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = text_inputs.attention_mask.to(device) - else: - attention_mask = None + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) - text_embeddings = self.text_encoder( - text_input_ids.to(device), - attention_mask=attention_mask, - ) - text_embeddings = text_embeddings[0] + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method - bs_embed, seq_len, _ = text_embeddings.shape - text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) - text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance: + if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size @@ -334,7 +363,7 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr else: uncond_tokens = negative_prompt - max_length = text_input_ids.shape[-1] + max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", @@ -348,26 +377,32 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr else: attention_mask = None - uncond_embeddings = self.text_encoder( + negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) - uncond_embeddings = uncond_embeddings[0] + negative_prompt_embeds = negative_prompt_embeds[0] + if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = uncond_embeddings.shape[1] - uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) - uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - return text_embeddings + return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.check_inputs - def check_inputs(self, prompt, strength, callback_steps): + def check_inputs( + self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None + ): if not isinstance(prompt, str) and not isinstance(prompt, list): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") @@ -382,6 +417,32 @@ def check_inputs(self, prompt, strength, callback_steps): f" {type(callback_steps)}." ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature @@ -492,6 +553,7 @@ def __call__( num_images_per_prompt: Optional[int] = 1, eta: Optional[float] = 0.1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, @@ -533,6 +595,13 @@ def __call__( generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -569,8 +638,14 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt - text_embeddings = self._encode_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance, None) - source_text_embeddings = self._encode_prompt( + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + ) + source_prompt_embeds = self._encode_prompt( source_prompt, device, num_images_per_prompt, do_classifier_free_guidance, None ) @@ -584,7 +659,7 @@ def __call__( # 6. Prepare latent variables latents, clean_latents = self.prepare_latents( - image, latent_timestep, batch_size, num_images_per_prompt, text_embeddings.dtype, device, generator + image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator ) source_latents = latents @@ -612,17 +687,17 @@ def __call__( ], dim=0, ) - concat_text_embeddings = torch.stack( + concat_prompt_embeds = torch.stack( [ - source_text_embeddings[0], - text_embeddings[0], - source_text_embeddings[1], - text_embeddings[1], + source_prompt_embeds[0], + prompt_embeds[0], + source_prompt_embeds[1], + prompt_embeds[1], ], dim=0, ) concat_noise_pred = self.unet( - concat_latent_model_input, t, encoder_hidden_states=concat_text_embeddings + concat_latent_model_input, t, encoder_hidden_states=concat_prompt_embeds ).sample # perform guidance @@ -662,7 +737,7 @@ def __call__( image = self.decode_latents(latents) # 10. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype) + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) # 11. Convert to PIL if output_type == "pil": diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py index d5add17107b9b..307299a0f33f6 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -196,7 +196,7 @@ def _generate( raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") # get prompt text embeddings - text_embeddings = self.text_encoder(prompt_ids, params=params["text_encoder"])[0] + prompt_embeds = self.text_encoder(prompt_ids, params=params["text_encoder"])[0] # TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0` # implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0` @@ -210,8 +210,8 @@ def _generate( ).input_ids else: uncond_input = neg_prompt_ids - uncond_embeddings = self.text_encoder(uncond_input, params=params["text_encoder"])[0] - context = jnp.concatenate([uncond_embeddings, text_embeddings]) + negative_prompt_embeds = self.text_encoder(uncond_input, params=params["text_encoder"])[0] + context = jnp.concatenate([negative_prompt_embeds, prompt_embeds]) latents_shape = ( batch_size, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py index 16ac2ba155d20..5214b6476ff12 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py @@ -192,7 +192,7 @@ def _generate( raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") # get prompt text embeddings - text_embeddings = self.text_encoder(prompt_ids, params=params["text_encoder"])[0] + prompt_embeds = self.text_encoder(prompt_ids, params=params["text_encoder"])[0] # TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0` # implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0` @@ -206,8 +206,8 @@ def _generate( ).input_ids else: uncond_input = neg_prompt_ids - uncond_embeddings = self.text_encoder(uncond_input, params=params["text_encoder"])[0] - context = jnp.concatenate([uncond_embeddings, text_embeddings]) + negative_prompt_embeds = self.text_encoder(uncond_input, params=params["text_encoder"])[0] + context = jnp.concatenate([negative_prompt_embeds, prompt_embeds]) latents_shape = ( batch_size, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py index 672bef442f344..5939595413d2b 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py @@ -224,7 +224,7 @@ def _generate( raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") # get prompt text embeddings - text_embeddings = self.text_encoder(prompt_ids, params=params["text_encoder"])[0] + prompt_embeds = self.text_encoder(prompt_ids, params=params["text_encoder"])[0] # TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0` # implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0` @@ -238,8 +238,8 @@ def _generate( ).input_ids else: uncond_input = neg_prompt_ids - uncond_embeddings = self.text_encoder(uncond_input, params=params["text_encoder"])[0] - context = jnp.concatenate([uncond_embeddings, text_embeddings]) + negative_prompt_embeds = self.text_encoder(uncond_input, params=params["text_encoder"])[0] + context = jnp.concatenate([negative_prompt_embeds, prompt_embeds]) latents_shape = ( batch_size, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py index 6dda632bb65a6..aacca71f63692 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py @@ -117,7 +117,7 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida Encodes the prompt into text encoder hidden states. Args: - prompt (`str` or `list(int)`): + prompt (`str` or `List[str]`): prompt to be encoded num_images_per_prompt (`int`): number of images that should be generated per prompt @@ -147,8 +147,8 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) - text_embeddings = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] - text_embeddings = np.repeat(text_embeddings, num_images_per_prompt, axis=0) + prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] + prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance: @@ -179,15 +179,15 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida truncation=True, return_tensors="np", ) - uncond_embeddings = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] - uncond_embeddings = np.repeat(uncond_embeddings, num_images_per_prompt, axis=0) + negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] + negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes - text_embeddings = np.concatenate([uncond_embeddings, text_embeddings]) + prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds]) - return text_embeddings + return prompt_embeds def __call__( self, @@ -232,12 +232,12 @@ def __call__( # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 - text_embeddings = self._encode_prompt( + prompt_embeds = self._encode_prompt( prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt ) # get the initial random noise unless the user supplied it - latents_dtype = text_embeddings.dtype + latents_dtype = prompt_embeds.dtype latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8) if latents is None: latents = generator.randn(*latents_shape).astype(latents_dtype) @@ -271,7 +271,7 @@ def __call__( # predict the noise residual timestep = np.array([t], dtype=timestep_dtype) - noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=text_embeddings) + noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds) noise_pred = noise_pred[0] # perform guidance diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py index 19f41f5e7bd7a..5fccbe7163c36 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py @@ -167,7 +167,7 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida Encodes the prompt into text encoder hidden states. Args: - prompt (`str` or `list(int)`): + prompt (`str` or `List[str]`): prompt to be encoded num_images_per_prompt (`int`): number of images that should be generated per prompt @@ -197,8 +197,8 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) - text_embeddings = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] - text_embeddings = np.repeat(text_embeddings, num_images_per_prompt, axis=0) + prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] + prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance: @@ -229,15 +229,15 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida truncation=True, return_tensors="np", ) - uncond_embeddings = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] - uncond_embeddings = np.repeat(uncond_embeddings, num_images_per_prompt, axis=0) + negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] + negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes - text_embeddings = np.concatenate([uncond_embeddings, text_embeddings]) + prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds]) - return text_embeddings + return prompt_embeds def __call__( self, @@ -345,11 +345,11 @@ def __call__( # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 - text_embeddings = self._encode_prompt( + prompt_embeds = self._encode_prompt( prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt ) - latents_dtype = text_embeddings.dtype + latents_dtype = prompt_embeds.dtype image = image.astype(latents_dtype) # encode the init image into latents and scale the latents init_latents = self.vae_encoder(sample=image)[0] @@ -417,9 +417,9 @@ def __call__( # predict the noise residual timestep = np.array([t], dtype=timestep_dtype) - noise_pred = self.unet( - sample=latent_model_input, timestep=timestep, encoder_hidden_states=text_embeddings - )[0] + noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)[ + 0 + ] # perform guidance if do_classifier_free_guidance: diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py index 8f4320e2300e2..5726e21b16a97 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py @@ -168,7 +168,7 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida Encodes the prompt into text encoder hidden states. Args: - prompt (`str` or `list(int)`): + prompt (`str` or `List[str]`): prompt to be encoded num_images_per_prompt (`int`): number of images that should be generated per prompt @@ -198,8 +198,8 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) - text_embeddings = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] - text_embeddings = np.repeat(text_embeddings, num_images_per_prompt, axis=0) + prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] + prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance: @@ -230,15 +230,15 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida truncation=True, return_tensors="np", ) - uncond_embeddings = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] - uncond_embeddings = np.repeat(uncond_embeddings, num_images_per_prompt, axis=0) + negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] + negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes - text_embeddings = np.concatenate([uncond_embeddings, text_embeddings]) + prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds]) - return text_embeddings + return prompt_embeds @torch.no_grad() def __call__( @@ -351,13 +351,13 @@ def __call__( # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 - text_embeddings = self._encode_prompt( + prompt_embeds = self._encode_prompt( prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt ) num_channels_latents = NUM_LATENT_CHANNELS latents_shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8) - latents_dtype = text_embeddings.dtype + latents_dtype = prompt_embeds.dtype if latents is None: latents = generator.randn(*latents_shape).astype(latents_dtype) else: @@ -424,9 +424,9 @@ def __call__( # predict the noise residual timestep = np.array([t], dtype=timestep_dtype) - noise_pred = self.unet( - sample=latent_model_input, timestep=timestep, encoder_hidden_states=text_embeddings - )[0] + noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)[ + 0 + ] # perform guidance if do_classifier_free_guidance: diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py index 98914eaa25c72..25b86e8e17a67 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py @@ -153,7 +153,7 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida Encodes the prompt into text encoder hidden states. Args: - prompt (`str` or `list(int)`): + prompt (`str` or `List[str]`): prompt to be encoded num_images_per_prompt (`int`): number of images that should be generated per prompt @@ -183,8 +183,8 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) - text_embeddings = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] - text_embeddings = np.repeat(text_embeddings, num_images_per_prompt, axis=0) + prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] + prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance: @@ -215,15 +215,15 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida truncation=True, return_tensors="np", ) - uncond_embeddings = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] - uncond_embeddings = np.repeat(uncond_embeddings, num_images_per_prompt, axis=0) + negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] + negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes - text_embeddings = np.concatenate([uncond_embeddings, text_embeddings]) + prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds]) - return text_embeddings + return prompt_embeds def __call__( self, @@ -338,11 +338,11 @@ def __call__( # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 - text_embeddings = self._encode_prompt( + prompt_embeds = self._encode_prompt( prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt ) - latents_dtype = text_embeddings.dtype + latents_dtype = prompt_embeds.dtype image = image.astype(latents_dtype) # encode the init image into latents and scale the latents @@ -399,7 +399,7 @@ def __call__( # predict the noise residual noise_pred = self.unet( - sample=latent_model_input, timestep=np.array([t]), encoder_hidden_states=text_embeddings + sample=latent_model_input, timestep=np.array([t]), encoder_hidden_states=prompt_embeds )[0] # perform guidance diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index b38ca866d58df..238b34e8fd9ed 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -217,12 +217,21 @@ def _execution_device(self): return torch.device(module._hf_hook.execution_device) return self.device - def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ): r""" Encodes the prompt into text encoder hidden states. Args: - prompt (`str` or `list(int)`): + prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device @@ -230,47 +239,67 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). + negative_ prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. """ - batch_size = len(prompt) if isinstance(prompt, list) else 1 - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" + if prompt_embeds is None: + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = text_inputs.attention_mask.to(device) - else: - attention_mask = None + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) - text_embeddings = self.text_encoder( - text_input_ids.to(device), - attention_mask=attention_mask, - ) - text_embeddings = text_embeddings[0] + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method - bs_embed, seq_len, _ = text_embeddings.shape - text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) - text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance: + if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size @@ -290,7 +319,7 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr else: uncond_tokens = negative_prompt - max_length = text_input_ids.shape[-1] + max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", @@ -304,23 +333,27 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr else: attention_mask = None - uncond_embeddings = self.text_encoder( + negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) - uncond_embeddings = uncond_embeddings[0] + negative_prompt_embeds = negative_prompt_embeds[0] + if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = uncond_embeddings.shape[1] - uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) - uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - return text_embeddings + return prompt_embeds def run_safety_checker(self, image, device, dtype): if self.safety_checker is not None: @@ -357,10 +390,16 @@ def prepare_extra_step_kwargs(self, generator, eta): extra_step_kwargs["generator"] = generator return extra_step_kwargs - def check_inputs(self, prompt, height, width, callback_steps): - if not isinstance(prompt, str) and not isinstance(prompt, list): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") @@ -372,6 +411,32 @@ def check_inputs(self, prompt, height, width, callback_steps): f" {type(callback_steps)}." ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: @@ -393,7 +458,7 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, - prompt: Union[str, List[str]], + prompt: Union[str, List[str]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, @@ -403,6 +468,8 @@ def __call__( eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, @@ -412,8 +479,9 @@ def __call__( Function invoked when calling the pipeline for generation. Args: - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): @@ -428,8 +496,9 @@ def __call__( 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): @@ -442,6 +511,13 @@ def __call__( Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -469,10 +545,18 @@ def __call__( width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct - self.check_inputs(prompt, height, width, callback_steps) + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) # 2. Define call parameters - batch_size = 1 if isinstance(prompt, str) else len(prompt) + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` @@ -480,8 +564,14 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt - text_embeddings = self._encode_prompt( - prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, ) # 4. Prepare timesteps @@ -495,7 +585,7 @@ def __call__( num_channels_latents, height, width, - text_embeddings.dtype, + prompt_embeds.dtype, device, generator, latents, @@ -513,7 +603,7 @@ def __call__( latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample # perform guidance if do_classifier_free_guidance: @@ -533,7 +623,7 @@ def __call__( image = self.decode_latents(latents) # 9. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype) + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) # 10. Convert to PIL if output_type == "pil": diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py index fca9cb9e37324..e3211da986ff0 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py @@ -159,12 +159,21 @@ def _execution_device(self): return self.device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt - def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ): r""" Encodes the prompt into text encoder hidden states. Args: - prompt (`str` or `list(int)`): + prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device @@ -172,47 +181,67 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). + negative_ prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. """ - batch_size = len(prompt) if isinstance(prompt, list) else 1 + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" + if prompt_embeds is None: + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = text_inputs.attention_mask.to(device) - else: - attention_mask = None + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) - text_embeddings = self.text_encoder( - text_input_ids.to(device), - attention_mask=attention_mask, - ) - text_embeddings = text_embeddings[0] + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method - bs_embed, seq_len, _ = text_embeddings.shape - text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) - text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance: + if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size @@ -232,7 +261,7 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr else: uncond_tokens = negative_prompt - max_length = text_input_ids.shape[-1] + max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", @@ -246,23 +275,27 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr else: attention_mask = None - uncond_embeddings = self.text_encoder( + negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) - uncond_embeddings = uncond_embeddings[0] + negative_prompt_embeds = negative_prompt_embeds[0] + if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = uncond_embeddings.shape[1] - uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) - uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - return text_embeddings + return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): @@ -302,12 +335,15 @@ def prepare_extra_step_kwargs(self, generator, eta): extra_step_kwargs["generator"] = generator return extra_step_kwargs - def check_inputs(self, prompt, strength, callback_steps): + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.check_inputs + def check_inputs( + self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None + ): if not isinstance(prompt, str) and not isinstance(prompt, list): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if strength < 0 or strength > 1: - raise ValueError(f"The value of strength should in [1.0, 1.0] but is {strength}") + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) @@ -317,6 +353,32 @@ def check_inputs(self, prompt, strength, callback_steps): f" {type(callback_steps)}." ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep @@ -424,8 +486,8 @@ def prepare_depth_map(self, image, depth_map, batch_size, do_classifier_free_gui @torch.no_grad() def __call__( self, - prompt: Union[str, List[str]], - image: Union[torch.FloatTensor, PIL.Image.Image], + prompt: Union[str, List[str]] = None, + image: Union[torch.FloatTensor, PIL.Image.Image] = None, depth_map: Optional[torch.FloatTensor] = None, strength: float = 0.8, num_inference_steps: Optional[int] = 50, @@ -434,6 +496,8 @@ def __call__( num_images_per_prompt: Optional[int] = 1, eta: Optional[float] = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, @@ -443,8 +507,9 @@ def __call__( Function invoked when calling the pipeline for generation. Args: - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. image (`torch.FloatTensor` or `PIL.Image.Image`): `Image`, or tensor representing an image batch, that will be used as the starting point for the process. @@ -464,8 +529,9 @@ def __call__( 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` + is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): @@ -474,6 +540,13 @@ def __call__( generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -520,6 +593,9 @@ def __call__( # 1. Check inputs self.check_inputs(prompt, strength, callback_steps) + if image is None: + raise ValueError("`image` input cannot be undefined.") + # 2. Define call parameters batch_size = 1 if isinstance(prompt, str) else len(prompt) device = self._execution_device @@ -529,8 +605,14 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt - text_embeddings = self._encode_prompt( - prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, ) # 4. Prepare depth mask @@ -539,7 +621,7 @@ def __call__( depth_map, batch_size * num_images_per_prompt, do_classifier_free_guidance, - text_embeddings.dtype, + prompt_embeds.dtype, device, ) @@ -553,7 +635,7 @@ def __call__( # 7. Prepare latent variables latents = self.prepare_latents( - image, latent_timestep, batch_size, num_images_per_prompt, text_embeddings.dtype, device, generator + image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator ) # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline @@ -569,7 +651,7 @@ def __call__( latent_model_input = torch.cat([latent_model_input, depth_mask], dim=1) # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample # perform guidance if do_classifier_free_guidance: diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py index 37d4a50efc459..51ace48557e4e 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py @@ -173,12 +173,12 @@ def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) if do_classifier_free_guidance: - uncond_embeddings = torch.zeros_like(image_embeddings) + negative_prompt_embeds = torch.zeros_like(image_embeddings) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes - image_embeddings = torch.cat([uncond_embeddings, image_embeddings]) + image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings]) return image_embeddings diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 19e4af180e11e..5f103b36d25b6 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -249,12 +249,21 @@ def _execution_device(self): return self.device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt - def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ): r""" Encodes the prompt into text encoder hidden states. Args: - prompt (`str` or `list(int)`): + prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device @@ -262,47 +271,67 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). + negative_ prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. """ - batch_size = len(prompt) if isinstance(prompt, list) else 1 - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" + if prompt_embeds is None: + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = text_inputs.attention_mask.to(device) - else: - attention_mask = None + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) - text_embeddings = self.text_encoder( - text_input_ids.to(device), - attention_mask=attention_mask, - ) - text_embeddings = text_embeddings[0] + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method - bs_embed, seq_len, _ = text_embeddings.shape - text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) - text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance: + if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size @@ -322,7 +351,7 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr else: uncond_tokens = negative_prompt - max_length = text_input_ids.shape[-1] + max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", @@ -336,23 +365,27 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr else: attention_mask = None - uncond_embeddings = self.text_encoder( + negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) - uncond_embeddings = uncond_embeddings[0] + negative_prompt_embeds = negative_prompt_embeds[0] + if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = uncond_embeddings.shape[1] - uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) - uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - return text_embeddings + return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): @@ -392,7 +425,9 @@ def prepare_extra_step_kwargs(self, generator, eta): extra_step_kwargs["generator"] = generator return extra_step_kwargs - def check_inputs(self, prompt, strength, callback_steps): + def check_inputs( + self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None + ): if not isinstance(prompt, str) and not isinstance(prompt, list): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") @@ -407,6 +442,32 @@ def check_inputs(self, prompt, strength, callback_steps): f" {type(callback_steps)}." ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) @@ -472,7 +533,7 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, - prompt: Union[str, List[str]], + prompt: Union[str, List[str]] = None, image: Union[torch.FloatTensor, PIL.Image.Image] = None, strength: float = 0.8, num_inference_steps: Optional[int] = 50, @@ -481,6 +542,8 @@ def __call__( num_images_per_prompt: Optional[int] = 1, eta: Optional[float] = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, @@ -491,8 +554,9 @@ def __call__( Function invoked when calling the pipeline for generation. Args: - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. image (`torch.FloatTensor` or `PIL.Image.Image`): `Image`, or tensor representing an image batch, that will be used as the starting point for the process. @@ -512,8 +576,9 @@ def __call__( 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` + is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): @@ -522,6 +587,13 @@ def __call__( generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -547,8 +619,8 @@ def __call__( init_image = deprecate("init_image", "0.13.0", message, take_from=kwargs) image = init_image or image - # 1. Check inputs - self.check_inputs(prompt, strength, callback_steps) + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) # 2. Define call parameters batch_size = 1 if isinstance(prompt, str) else len(prompt) @@ -559,8 +631,14 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt - text_embeddings = self._encode_prompt( - prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, ) # 4. Preprocess image @@ -573,7 +651,7 @@ def __call__( # 6. Prepare latent variables latents = self.prepare_latents( - image, latent_timestep, batch_size, num_images_per_prompt, text_embeddings.dtype, device, generator + image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator ) # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline @@ -588,7 +666,7 @@ def __call__( latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample # perform guidance if do_classifier_free_guidance: @@ -608,7 +686,7 @@ def __call__( image = self.decode_latents(latents) # 10. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype) + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) # 11. Convert to PIL if output_type == "pil": diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 140aa8da2a77b..263cc77375102 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -297,12 +297,21 @@ def _execution_device(self): return self.device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt - def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ): r""" Encodes the prompt into text encoder hidden states. Args: - prompt (`str` or `list(int)`): + prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device @@ -310,47 +319,67 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). + negative_ prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. """ - batch_size = len(prompt) if isinstance(prompt, list) else 1 - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" + if prompt_embeds is None: + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = text_inputs.attention_mask.to(device) - else: - attention_mask = None + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) - text_embeddings = self.text_encoder( - text_input_ids.to(device), - attention_mask=attention_mask, - ) - text_embeddings = text_embeddings[0] + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method - bs_embed, seq_len, _ = text_embeddings.shape - text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) - text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance: + if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size @@ -370,7 +399,7 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr else: uncond_tokens = negative_prompt - max_length = text_input_ids.shape[-1] + max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", @@ -384,23 +413,27 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr else: attention_mask = None - uncond_embeddings = self.text_encoder( + negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) - uncond_embeddings = uncond_embeddings[0] + negative_prompt_embeds = negative_prompt_embeds[0] + if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = uncond_embeddings.shape[1] - uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) - uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - return text_embeddings + return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): @@ -441,10 +474,16 @@ def decode_latents(self, latents): return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs - def check_inputs(self, prompt, height, width, callback_steps): - if not isinstance(prompt, str) and not isinstance(prompt, list): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") @@ -456,6 +495,32 @@ def check_inputs(self, prompt, height, width, callback_steps): f" {type(callback_steps)}." ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) @@ -528,9 +593,9 @@ def prepare_mask_latents( @torch.no_grad() def __call__( self, - prompt: Union[str, List[str]], - image: Union[torch.FloatTensor, PIL.Image.Image], - mask_image: Union[torch.FloatTensor, PIL.Image.Image], + prompt: Union[str, List[str]] = None, + image: Union[torch.FloatTensor, PIL.Image.Image] = None, + mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, @@ -540,6 +605,8 @@ def __call__( eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, + nrompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, @@ -549,8 +616,9 @@ def __call__( Function invoked when calling the pipeline for generation. Args: - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. image (`PIL.Image.Image`): `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will be masked out with `mask_image` and repainted according to `prompt`. @@ -573,8 +641,9 @@ def __call__( 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` + is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): @@ -587,6 +656,13 @@ def __call__( Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -645,6 +721,12 @@ def __call__( # 1. Check inputs self.check_inputs(prompt, height, width, callback_steps) + if image is None: + raise ValueError("`image` input cannot be undefined.") + + if mask_image is None: + raise ValueError("`mask_image` input cannot be undefined.") + # 2. Define call parameters batch_size = 1 if isinstance(prompt, str) else len(prompt) device = self._execution_device @@ -654,7 +736,7 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt - text_embeddings = self._encode_prompt( + prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt ) @@ -672,7 +754,7 @@ def __call__( num_channels_latents, height, width, - text_embeddings.dtype, + prompt_embeds.dtype, device, generator, latents, @@ -685,7 +767,7 @@ def __call__( batch_size * num_images_per_prompt, height, width, - text_embeddings.dtype, + prompt_embeds.dtype, device, generator, do_classifier_free_guidance, @@ -718,7 +800,7 @@ def __call__( latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample # perform guidance if do_classifier_free_guidance: @@ -738,7 +820,7 @@ def __call__( image = self.decode_latents(latents) # 12. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype) + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) # 13. Convert to PIL if output_type == "pil": diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py index 79679b00c0f7c..9f032daab4c08 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py @@ -216,12 +216,21 @@ def _execution_device(self): return self.device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt - def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ): r""" Encodes the prompt into text encoder hidden states. Args: - prompt (`str` or `list(int)`): + prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device @@ -229,47 +238,67 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). + negative_ prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. """ - batch_size = len(prompt) if isinstance(prompt, list) else 1 - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" + if prompt_embeds is None: + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = text_inputs.attention_mask.to(device) - else: - attention_mask = None + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) - text_embeddings = self.text_encoder( - text_input_ids.to(device), - attention_mask=attention_mask, - ) - text_embeddings = text_embeddings[0] + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method - bs_embed, seq_len, _ = text_embeddings.shape - text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) - text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance: + if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size @@ -289,7 +318,7 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr else: uncond_tokens = negative_prompt - max_length = text_input_ids.shape[-1] + max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", @@ -303,23 +332,27 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr else: attention_mask = None - uncond_embeddings = self.text_encoder( + negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) - uncond_embeddings = uncond_embeddings[0] + negative_prompt_embeds = negative_prompt_embeds[0] + if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = uncond_embeddings.shape[1] - uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) - uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - return text_embeddings + return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): @@ -360,7 +393,9 @@ def prepare_extra_step_kwargs(self, generator, eta): return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.check_inputs - def check_inputs(self, prompt, strength, callback_steps): + def check_inputs( + self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None + ): if not isinstance(prompt, str) and not isinstance(prompt, list): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") @@ -375,6 +410,32 @@ def check_inputs(self, prompt, strength, callback_steps): f" {type(callback_steps)}." ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep @@ -415,6 +476,8 @@ def __call__( add_predicted_noise: Optional[bool] = False, eta: Optional[float] = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, @@ -425,8 +488,9 @@ def __call__( Function invoked when calling the pipeline for generation. Args: - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. image (`torch.FloatTensor` or `PIL.Image.Image`): `Image`, or tensor representing an image batch, that will be used as the starting point for the process. This is the image whose masked region will be inpainted. @@ -450,8 +514,9 @@ def __call__( 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` + is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. add_predicted_noise (`bool`, *optional*, defaults to True): @@ -463,6 +528,13 @@ def __call__( generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -499,8 +571,14 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt - text_embeddings = self._encode_prompt( - prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, ) # 4. Preprocess image and mask @@ -518,7 +596,7 @@ def __call__( # 6. Prepare latent variables # encode the init image into latents and scale the latents latents, init_latents_orig, noise = self.prepare_latents( - image, latent_timestep, batch_size, num_images_per_prompt, text_embeddings.dtype, device, generator + image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator ) # 7. Prepare mask latent @@ -537,7 +615,7 @@ def __call__( latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample # perform guidance if do_classifier_free_guidance: @@ -566,7 +644,7 @@ def __call__( image = self.decode_latents(latents) # 11. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype) + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) # 12. Convert to PIL if output_type == "pil": diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py index d03a4d40db4ad..7923d3a73c4d4 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py @@ -127,8 +127,8 @@ def __init__( @torch.no_grad() def __call__( self, - prompt: Union[str, List[str]], - image: Union[torch.FloatTensor, PIL.Image.Image], + prompt: Union[str, List[str]] = None, + image: Union[torch.FloatTensor, PIL.Image.Image] = None, num_inference_steps: int = 100, guidance_scale: float = 7.5, image_guidance_scale: float = 1.5, @@ -137,6 +137,8 @@ def __call__( eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, @@ -146,8 +148,9 @@ def __call__( Function invoked when calling the pipeline for generation. Args: - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. image (`PIL.Image.Image`): `Image`, or tensor representing an image batch which will be repainted according to `prompt`. num_inference_steps (`int`, *optional*, defaults to 100): @@ -165,8 +168,9 @@ def __call__( generate images that are closely linked to the source image `image`, usually at the expense of lower image quality. This pipeline requires a value of at least `1`. negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` + is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): @@ -179,6 +183,13 @@ def __call__( Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -231,6 +242,9 @@ def __call__( # 0. Check inputs self.check_inputs(prompt, callback_steps) + if image is None: + raise ValueError("`image` input cannot be undefined.") + # 1. Define call parameters batch_size = 1 if isinstance(prompt, str) else len(prompt) device = self._execution_device @@ -242,8 +256,14 @@ def __call__( scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas") # 2. Encode input prompt - text_embeddings = self._encode_prompt( - prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, ) # 3. Preprocess image @@ -259,7 +279,7 @@ def __call__( image, batch_size, num_images_per_prompt, - text_embeddings.dtype, + prompt_embeds.dtype, device, do_classifier_free_guidance, generator, @@ -272,7 +292,7 @@ def __call__( num_channels_latents, height, width, - text_embeddings.dtype, + prompt_embeds.dtype, device, generator, latents, @@ -306,7 +326,7 @@ def __call__( scaled_latent_model_input = torch.cat([scaled_latent_model_input, image_latents], dim=1) # predict the noise residual - noise_pred = self.unet(scaled_latent_model_input, t, encoder_hidden_states=text_embeddings).sample + noise_pred = self.unet(scaled_latent_model_input, t, encoder_hidden_states=prompt_embeds).sample # Hack: # For karras style schedulers the model does classifer free guidance using the @@ -348,7 +368,7 @@ def __call__( image = self.decode_latents(latents) # 11. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype) + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) # 12. Convert to PIL if output_type == "pil": @@ -398,12 +418,21 @@ def _execution_device(self): return torch.device(module._hf_hook.execution_device) return self.device - def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ): r""" Encodes the prompt into text encoder hidden states. Args: - prompt (`str` or `list(int)`): + prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device @@ -411,47 +440,67 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). + negative_ prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. """ - batch_size = len(prompt) if isinstance(prompt, list) else 1 - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" + if prompt_embeds is None: + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = text_inputs.attention_mask.to(device) - else: - attention_mask = None + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) - text_embeddings = self.text_encoder( - text_input_ids.to(device), - attention_mask=attention_mask, - ) - text_embeddings = text_embeddings[0] + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method - bs_embed, seq_len, _ = text_embeddings.shape - text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) - text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance: + if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size @@ -471,7 +520,7 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr else: uncond_tokens = negative_prompt - max_length = text_input_ids.shape[-1] + max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", @@ -485,23 +534,28 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr else: attention_mask = None - uncond_embeddings = self.text_encoder( + negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) - uncond_embeddings = uncond_embeddings[0] + negative_prompt_embeds = negative_prompt_embeds[0] + if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = uncond_embeddings.shape[1] - uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) - uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes - text_embeddings = torch.cat([text_embeddings, uncond_embeddings, uncond_embeddings]) + # pix2pix has two negative embeddings, and unlike in other pipelines latents are ordered [prompt_embeds, negative_prompt_embeds, negative_prompt_embeds] + prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds, negative_prompt_embeds]) - return text_embeddings + return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py index 03bcea6f4f7bf..05229c5a1b71b 100755 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py @@ -162,12 +162,21 @@ def _execution_device(self): return self.device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt - def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ): r""" Encodes the prompt into text encoder hidden states. Args: - prompt (`str` or `list(int)`): + prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device @@ -175,47 +184,67 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). + negative_ prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. """ - batch_size = len(prompt) if isinstance(prompt, list) else 1 + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" + if prompt_embeds is None: + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = text_inputs.attention_mask.to(device) - else: - attention_mask = None + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) - text_embeddings = self.text_encoder( - text_input_ids.to(device), - attention_mask=attention_mask, - ) - text_embeddings = text_embeddings[0] + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method - bs_embed, seq_len, _ = text_embeddings.shape - text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) - text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance: + if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size @@ -235,7 +264,7 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr else: uncond_tokens = negative_prompt - max_length = text_input_ids.shape[-1] + max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", @@ -249,23 +278,27 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr else: attention_mask = None - uncond_embeddings = self.text_encoder( + negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) - uncond_embeddings = uncond_embeddings[0] + negative_prompt_embeds = negative_prompt_embeds[0] + if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = uncond_embeddings.shape[1] - uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) - uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - return text_embeddings + return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): @@ -317,7 +350,7 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype @torch.no_grad() def __call__( self, - prompt: Union[str, List[str]], + prompt: Union[str, List[str]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, @@ -327,6 +360,8 @@ def __call__( eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, @@ -336,8 +371,9 @@ def __call__( Function invoked when calling the pipeline for generation. Args: - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): @@ -352,8 +388,9 @@ def __call__( 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` + is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): @@ -366,6 +403,13 @@ def __call__( Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -404,14 +448,20 @@ def __call__( raise ValueError("has to use guidance_scale") # 3. Encode input prompt - text_embeddings = self._encode_prompt( - prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, ) # 4. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps, device=text_embeddings.device) + self.scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device) sigmas = self.scheduler.sigmas - sigmas = sigmas.to(text_embeddings.dtype) + sigmas = sigmas.to(prompt_embeds.dtype) # 5. Prepare latent variables num_channels_latents = self.unet.in_channels @@ -420,7 +470,7 @@ def __call__( num_channels_latents, height, width, - text_embeddings.dtype, + prompt_embeds.dtype, device, generator, latents, @@ -434,7 +484,7 @@ def model_fn(x, t): latent_model_input = torch.cat([x] * 2) t = torch.cat([t] * 2) - noise_pred = self.k_diffusion_model(latent_model_input, t, cond=text_embeddings) + noise_pred = self.k_diffusion_model(latent_model_input, t, cond=prompt_embeds) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) @@ -447,7 +497,7 @@ def model_fn(x, t): image = self.decode_latents(latents) # 9. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype) + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) # 10. Convert to PIL if output_type == "pil": diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py index e58ae1d7632f3..f2bf6a8bdc7eb 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py @@ -136,12 +136,21 @@ def _execution_device(self): return self.device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt - def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ): r""" Encodes the prompt into text encoder hidden states. Args: - prompt (`str` or `list(int)`): + prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device @@ -149,47 +158,67 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). + negative_ prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. """ - batch_size = len(prompt) if isinstance(prompt, list) else 1 + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" + if prompt_embeds is None: + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = text_inputs.attention_mask.to(device) - else: - attention_mask = None + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) - text_embeddings = self.text_encoder( - text_input_ids.to(device), - attention_mask=attention_mask, - ) - text_embeddings = text_embeddings[0] + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method - bs_embed, seq_len, _ = text_embeddings.shape - text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) - text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance: + if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size @@ -209,7 +238,7 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr else: uncond_tokens = negative_prompt - max_length = text_input_ids.shape[-1] + max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", @@ -223,23 +252,27 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr else: attention_mask = None - uncond_embeddings = self.text_encoder( + negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) - uncond_embeddings = uncond_embeddings[0] + negative_prompt_embeds = negative_prompt_embeds[0] + if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = uncond_embeddings.shape[1] - uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) - uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - return text_embeddings + return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): @@ -325,8 +358,8 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype @torch.no_grad() def __call__( self, - prompt: Union[str, List[str]], - image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]], + prompt: Union[str, List[str]] = None, + image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]] = None, num_inference_steps: int = 75, guidance_scale: float = 9.0, noise_level: int = 20, @@ -335,6 +368,8 @@ def __call__( eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, @@ -344,8 +379,9 @@ def __call__( Function invoked when calling the pipeline for generation. Args: - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. image (`PIL.Image.Image` or List[`PIL.Image.Image`] or `torch.FloatTensor`): `Image`, or tensor representing an image batch which will be upscaled. * num_inference_steps (`int`, *optional*, defaults to 50): @@ -358,8 +394,9 @@ def __call__( 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` + is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): @@ -372,6 +409,13 @@ def __call__( Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -422,6 +466,9 @@ def __call__( # 1. Check inputs self.check_inputs(prompt, image, noise_level, callback_steps) + if image is None: + raise ValueError("`image` input cannot be undefined.") + # 2. Define call parameters batch_size = 1 if isinstance(prompt, str) else len(prompt) device = self._execution_device @@ -431,13 +478,19 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt - text_embeddings = self._encode_prompt( - prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, ) # 4. Preprocess image image = preprocess(image) - image = image.to(dtype=text_embeddings.dtype, device=device) + image = image.to(dtype=prompt_embeds.dtype, device=device) # 5. set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) @@ -445,7 +498,7 @@ def __call__( # 5. Add noise to image noise_level = torch.tensor([noise_level], dtype=torch.long, device=device) - noise = randn_tensor(image.shape, generator=generator, device=device, dtype=text_embeddings.dtype) + noise = randn_tensor(image.shape, generator=generator, device=device, dtype=prompt_embeds.dtype) image = self.low_res_scheduler.add_noise(image, noise, noise_level) batch_multiplier = 2 if do_classifier_free_guidance else 1 @@ -460,7 +513,7 @@ def __call__( num_channels_latents, height, width, - text_embeddings.dtype, + prompt_embeds.dtype, device, generator, latents, @@ -493,7 +546,7 @@ def __call__( # predict the noise residual noise_pred = self.unet( - latent_model_input, t, encoder_hidden_states=text_embeddings, class_labels=noise_level + latent_model_input, t, encoder_hidden_states=prompt_embeds, class_labels=noise_level ).sample # perform guidance diff --git a/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py b/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py index ff4b41a9dc634..fbfc44a8c374e 100644 --- a/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +++ b/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py @@ -217,7 +217,7 @@ def _encode_prompt( Encodes the prompt into text encoder hidden states. Args: - prompt (`str` or `list(int)`): + prompt (`str` or `List[str]`): prompt to be encoded device: (`torch.device`): torch device @@ -253,16 +253,16 @@ def _encode_prompt( else: attention_mask = None - text_embeddings = self.text_encoder( + prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) - text_embeddings = text_embeddings[0] + prompt_embeds = prompt_embeds[0] # duplicate text embeddings for each generation per prompt, using mps friendly method - bs_embed, seq_len, _ = text_embeddings.shape - text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) - text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance: @@ -299,16 +299,16 @@ def _encode_prompt( else: attention_mask = None - uncond_embeddings = self.text_encoder( + negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) - uncond_embeddings = uncond_embeddings[0] + negative_prompt_embeds = negative_prompt_embeds[0] # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = uncond_embeddings.shape[1] - uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) - uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # Encode the safety concept text if enable_safety_guidance: @@ -329,15 +329,15 @@ def _encode_prompt( # For classifier free guidance + sld, we need to do three forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing three forward passes - text_embeddings = torch.cat([uncond_embeddings, text_embeddings, safety_embeddings]) + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, safety_embeddings]) else: # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - return text_embeddings + return prompt_embeds def run_safety_checker(self, image, device, dtype, enable_safety_guidance): if self.safety_checker is not None: @@ -390,10 +390,16 @@ def prepare_extra_step_kwargs(self, generator, eta): return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs - def check_inputs(self, prompt, height, width, callback_steps): - if not isinstance(prompt, str) and not isinstance(prompt, list): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") @@ -405,6 +411,32 @@ def check_inputs(self, prompt, height, width, callback_steps): f" {type(callback_steps)}." ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) @@ -583,7 +615,7 @@ def __call__( warnings.warn("Safety checker disabled!") # 3. Encode input prompt - text_embeddings = self._encode_prompt( + prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, enable_safety_guidance ) @@ -598,7 +630,7 @@ def __call__( num_channels_latents, height, width, - text_embeddings.dtype, + prompt_embeds.dtype, device, generator, latents, @@ -621,7 +653,7 @@ def __call__( latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample # perform guidance if do_classifier_free_guidance: @@ -680,7 +712,7 @@ def __call__( # 9. Run safety checker image, has_nsfw_concept, flagged_images = self.run_safety_checker( - image, device, text_embeddings.dtype, enable_safety_guidance + image, device, prompt_embeds.dtype, enable_safety_guidance ) # 10. Convert to PIL diff --git a/src/diffusers/pipelines/unclip/pipeline_unclip.py b/src/diffusers/pipelines/unclip/pipeline_unclip.py index abf2ec2223e53..b68c6626709a6 100644 --- a/src/diffusers/pipelines/unclip/pipeline_unclip.py +++ b/src/diffusers/pipelines/unclip/pipeline_unclip.py @@ -153,15 +153,15 @@ def _encode_prompt( text_encoder_output = self.text_encoder(text_input_ids.to(device)) - text_embeddings = text_encoder_output.text_embeds + prompt_embeds = text_encoder_output.text_embeds text_encoder_hidden_states = text_encoder_output.last_hidden_state else: batch_size = text_model_output[0].shape[0] - text_embeddings, text_encoder_hidden_states = text_model_output[0], text_model_output[1] + prompt_embeds, text_encoder_hidden_states = text_model_output[0], text_model_output[1] text_mask = text_attention_mask - text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0) + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) @@ -176,16 +176,16 @@ def _encode_prompt( return_tensors="pt", ) uncond_text_mask = uncond_input.attention_mask.bool().to(device) - uncond_embeddings_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device)) + negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device)) - uncond_embeddings = uncond_embeddings_text_encoder_output.text_embeds - uncond_text_encoder_hidden_states = uncond_embeddings_text_encoder_output.last_hidden_state + negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds + uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = uncond_embeddings.shape[1] - uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt) - uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len) + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) seq_len = uncond_text_encoder_hidden_states.shape[1] uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) @@ -199,12 +199,12 @@ def _encode_prompt( # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) text_mask = torch.cat([uncond_text_mask, text_mask]) - return text_embeddings, text_encoder_hidden_states, text_mask + return prompt_embeds, text_encoder_hidden_states, text_mask def enable_sequential_cpu_offload(self, gpu_id=0): r""" @@ -336,7 +336,7 @@ def __call__( do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0 - text_embeddings, text_encoder_hidden_states, text_mask = self._encode_prompt( + prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, text_model_output, text_attention_mask ) @@ -349,7 +349,7 @@ def __call__( prior_latents = self.prepare_latents( (batch_size, embedding_dim), - text_embeddings.dtype, + prompt_embeds.dtype, device, generator, prior_latents, @@ -363,7 +363,7 @@ def __call__( predicted_image_embedding = self.prior( latent_model_input, timestep=t, - proj_embedding=text_embeddings, + proj_embedding=prompt_embeds, encoder_hidden_states=text_encoder_hidden_states, attention_mask=text_mask, ).predicted_image_embedding @@ -397,7 +397,7 @@ def __call__( text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj( image_embeddings=image_embeddings, - text_embeddings=text_embeddings, + prompt_embeds=prompt_embeds, text_encoder_hidden_states=text_encoder_hidden_states, do_classifier_free_guidance=do_classifier_free_guidance, ) diff --git a/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py b/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py index c99157a0afb26..c5af8f6471fb5 100644 --- a/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +++ b/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py @@ -136,10 +136,10 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr text_mask = text_inputs.attention_mask.bool().to(device) text_encoder_output = self.text_encoder(text_input_ids.to(device)) - text_embeddings = text_encoder_output.text_embeds + prompt_embeds = text_encoder_output.text_embeds text_encoder_hidden_states = text_encoder_output.last_hidden_state - text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0) + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) @@ -155,16 +155,16 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr return_tensors="pt", ) uncond_text_mask = uncond_input.attention_mask.bool().to(device) - uncond_embeddings_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device)) + negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device)) - uncond_embeddings = uncond_embeddings_text_encoder_output.text_embeds - uncond_text_encoder_hidden_states = uncond_embeddings_text_encoder_output.last_hidden_state + negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds + uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = uncond_embeddings.shape[1] - uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt) - uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len) + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) seq_len = uncond_text_encoder_hidden_states.shape[1] uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) @@ -178,12 +178,12 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) text_mask = torch.cat([uncond_text_mask, text_mask]) - return text_embeddings, text_encoder_hidden_states, text_mask + return prompt_embeds, text_encoder_hidden_states, text_mask def _encode_image(self, image, device, num_images_per_prompt, image_embeddings: Optional[torch.Tensor] = None): dtype = next(self.image_encoder.parameters()).dtype @@ -314,7 +314,7 @@ def __call__( do_classifier_free_guidance = decoder_guidance_scale > 1.0 - text_embeddings, text_encoder_hidden_states, text_mask = self._encode_prompt( + prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance ) @@ -323,7 +323,7 @@ def __call__( # decoder text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj( image_embeddings=image_embeddings, - text_embeddings=text_embeddings, + prompt_embeds=prompt_embeds, text_encoder_hidden_states=text_encoder_hidden_states, do_classifier_free_guidance=do_classifier_free_guidance, ) diff --git a/src/diffusers/pipelines/unclip/text_proj.py b/src/diffusers/pipelines/unclip/text_proj.py index 4482010519a2b..a98cfbebdb906 100644 --- a/src/diffusers/pipelines/unclip/text_proj.py +++ b/src/diffusers/pipelines/unclip/text_proj.py @@ -52,7 +52,7 @@ def __init__( self.encoder_hidden_states_proj = nn.Linear(clip_embeddings_dim, cross_attention_dim) self.text_encoder_hidden_states_norm = nn.LayerNorm(cross_attention_dim) - def forward(self, *, image_embeddings, text_embeddings, text_encoder_hidden_states, do_classifier_free_guidance): + def forward(self, *, image_embeddings, prompt_embeds, text_encoder_hidden_states, do_classifier_free_guidance): if do_classifier_free_guidance: # Add the classifier free guidance embeddings to the image embeddings image_embeddings_batch_size = image_embeddings.shape[0] @@ -63,15 +63,15 @@ def forward(self, *, image_embeddings, text_embeddings, text_encoder_hidden_stat image_embeddings = torch.cat([classifier_free_guidance_embeddings, image_embeddings], dim=0) # The image embeddings batch size and the text embeddings batch size are equal - assert image_embeddings.shape[0] == text_embeddings.shape[0] + assert image_embeddings.shape[0] == prompt_embeds.shape[0] - batch_size = text_embeddings.shape[0] + batch_size = prompt_embeds.shape[0] # "Specifically, we modify the architecture described in Nichol et al. (2021) by projecting and # adding CLIP embeddings to the existing timestep embedding, ... - time_projected_text_embeddings = self.embedding_proj(text_embeddings) + time_projected_prompt_embeds = self.embedding_proj(prompt_embeds) time_projected_image_embeddings = self.clip_image_embeddings_project_to_time_embeddings(image_embeddings) - additive_clip_time_embeddings = time_projected_image_embeddings + time_projected_text_embeddings + additive_clip_time_embeddings = time_projected_image_embeddings + time_projected_prompt_embeds # ... and by projecting CLIP embeddings into four # extra tokens of context that are concatenated to the sequence of outputs from the GLIDE text encoder" diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py index 460244854222a..1d34937877e08 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py @@ -187,7 +187,7 @@ def _encode_text_prompt(self, prompt, device, num_images_per_prompt, do_classifi Encodes the prompt into text encoder hidden states. Args: - prompt (`str` or `list(int)`): + prompt (`str` or `List[str]`): prompt to be encoded device: (`torch.device`): torch device @@ -227,16 +227,16 @@ def normalize_embeddings(encoder_output): else: attention_mask = None - text_embeddings = self.text_encoder( + prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) - text_embeddings = normalize_embeddings(text_embeddings) + prompt_embeds = normalize_embeddings(prompt_embeds) # duplicate text embeddings for each generation per prompt, using mps friendly method - bs_embed, seq_len, _ = text_embeddings.shape - text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) - text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance: @@ -255,30 +255,30 @@ def normalize_embeddings(encoder_output): else: attention_mask = None - uncond_embeddings = self.text_encoder( + negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) - uncond_embeddings = normalize_embeddings(uncond_embeddings) + negative_prompt_embeds = normalize_embeddings(negative_prompt_embeds) # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = uncond_embeddings.shape[1] - uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) - uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - return text_embeddings + return prompt_embeds def _encode_image_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance): r""" Encodes the prompt into text encoder hidden states. Args: - prompt (`str` or `list(int)`): + prompt (`str` or `List[str]`): prompt to be encoded device: (`torch.device`): torch device @@ -313,18 +313,18 @@ def normalize_embeddings(encoder_output): uncond_images = [np.zeros((512, 512, 3)) + 0.5] * batch_size uncond_images = self.image_feature_extractor(images=uncond_images, return_tensors="pt") pixel_values = uncond_images.pixel_values.to(device).to(self.image_encoder.dtype) - uncond_embeddings = self.image_encoder(pixel_values) - uncond_embeddings = normalize_embeddings(uncond_embeddings) + negative_prompt_embeds = self.image_encoder(pixel_values) + negative_prompt_embeds = normalize_embeddings(negative_prompt_embeds) # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = uncond_embeddings.shape[1] - uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) - uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and conditional embeddings into a single batch # to avoid doing two forward passes - image_embeddings = torch.cat([uncond_embeddings, image_embeddings]) + image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings]) return image_embeddings @@ -524,9 +524,9 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompts - text_embeddings = self._encode_text_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance) + prompt_embeds = self._encode_text_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance) image_embeddings = self._encode_image_prompt(image, device, num_images_per_prompt, do_classifier_free_guidance) - dual_prompt_embeddings = torch.cat([text_embeddings, image_embeddings], dim=1) + dual_prompt_embeddings = torch.cat([prompt_embeds, image_embeddings], dim=1) prompt_types = ("text", "image") # 4. Prepare timesteps diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py index b08d9bb143a0b..5acf6f3300a0b 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py @@ -114,7 +114,7 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr Encodes the prompt into text encoder hidden states. Args: - prompt (`str` or `list(int)`): + prompt (`str` or `List[str]`): prompt to be encoded device: (`torch.device`): torch device @@ -173,18 +173,18 @@ def normalize_embeddings(encoder_output): uncond_images = self.image_feature_extractor(images=uncond_images, return_tensors="pt") pixel_values = uncond_images.pixel_values.to(device).to(self.image_encoder.dtype) - uncond_embeddings = self.image_encoder(pixel_values) - uncond_embeddings = normalize_embeddings(uncond_embeddings) + negative_prompt_embeds = self.image_encoder(pixel_values) + negative_prompt_embeds = normalize_embeddings(negative_prompt_embeds) # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = uncond_embeddings.shape[1] - uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) - uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and conditional embeddings into a single batch # to avoid doing two forward passes - image_embeddings = torch.cat([uncond_embeddings, image_embeddings]) + image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings]) return image_embeddings diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py index 06d8773eaf4a2..494b9f14f785b 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py @@ -138,7 +138,7 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr Encodes the prompt into text encoder hidden states. Args: - prompt (`str` or `list(int)`): + prompt (`str` or `List[str]`): prompt to be encoded device: (`torch.device`): torch device @@ -181,16 +181,16 @@ def normalize_embeddings(encoder_output): else: attention_mask = None - text_embeddings = self.text_encoder( + prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) - text_embeddings = normalize_embeddings(text_embeddings) + prompt_embeds = normalize_embeddings(prompt_embeds) # duplicate text embeddings for each generation per prompt, using mps friendly method - bs_embed, seq_len, _ = text_embeddings.shape - text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) - text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance: @@ -227,23 +227,23 @@ def normalize_embeddings(encoder_output): else: attention_mask = None - uncond_embeddings = self.text_encoder( + negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) - uncond_embeddings = normalize_embeddings(uncond_embeddings) + negative_prompt_embeds = normalize_embeddings(negative_prompt_embeds) # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = uncond_embeddings.shape[1] - uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) - uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - return text_embeddings + return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): @@ -273,10 +273,16 @@ def prepare_extra_step_kwargs(self, generator, eta): return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs - def check_inputs(self, prompt, height, width, callback_steps): - if not isinstance(prompt, str) and not isinstance(prompt, list): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") @@ -288,6 +294,32 @@ def check_inputs(self, prompt, height, width, callback_steps): f" {type(callback_steps)}." ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) @@ -412,7 +444,7 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt - text_embeddings = self._encode_prompt( + prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt ) @@ -427,7 +459,7 @@ def __call__( num_channels_latents, height, width, - text_embeddings.dtype, + prompt_embeds.dtype, device, generator, latents, @@ -443,7 +475,7 @@ def __call__( latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual - noise_pred = self.image_unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + noise_pred = self.image_unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample # perform guidance if do_classifier_free_guidance: diff --git a/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py b/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py index bd63eda030cc6..5436f33000a1b 100644 --- a/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py +++ b/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py @@ -120,7 +120,7 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] - text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] + prompt_embeds = self.text_encoder(text_input_ids.to(self.device))[0] # NOTE: This additional step of normalizing the text embeddings is from VQ-Diffusion. # While CLIP does normalize the pooled output of the text transformer when combining @@ -128,15 +128,15 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida # # CLIP normalizing the pooled output. # https://github.com/huggingface/transformers/blob/d92e22d1f28324f513f3080e5c47c071a3916721/src/transformers/models/clip/modeling_clip.py#L1052-L1053 - text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True) + prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True) # duplicate text embeddings for each generation per prompt - text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0) + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) if do_classifier_free_guidance: if self.learned_classifier_free_sampling_embeddings.learnable: - uncond_embeddings = self.learned_classifier_free_sampling_embeddings.embeddings - uncond_embeddings = uncond_embeddings.unsqueeze(0).repeat(batch_size, 1, 1) + negative_prompt_embeds = self.learned_classifier_free_sampling_embeddings.embeddings + negative_prompt_embeds = negative_prompt_embeds.unsqueeze(0).repeat(batch_size, 1, 1) else: uncond_tokens = [""] * batch_size @@ -148,21 +148,21 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida truncation=True, return_tensors="pt", ) - uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] + negative_prompt_embeds = self.text_encoder(uncond_input.input_ids.to(self.device))[0] # See comment for normalizing text embeddings - uncond_embeddings = uncond_embeddings / uncond_embeddings.norm(dim=-1, keepdim=True) + negative_prompt_embeds = negative_prompt_embeds / negative_prompt_embeds.norm(dim=-1, keepdim=True) # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = uncond_embeddings.shape[1] - uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) - uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - return text_embeddings + return prompt_embeds @torch.no_grad() def __call__( @@ -234,7 +234,7 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 - text_embeddings = self._encode_prompt(prompt, num_images_per_prompt, do_classifier_free_guidance) + prompt_embeds = self._encode_prompt(prompt, num_images_per_prompt, do_classifier_free_guidance) if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) @@ -273,9 +273,7 @@ def __call__( # predict the un-noised image # model_output == `log_p_x_0` - model_output = self.transformer( - latent_model_input, encoder_hidden_states=text_embeddings, timestep=t - ).sample + model_output = self.transformer(latent_model_input, encoder_hidden_states=prompt_embeds, timestep=t).sample if do_classifier_free_guidance: model_output_uncond, model_output_text = model_output.chunk(2) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index 3c2c3fcb136a7..a637e1957a647 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -120,7 +120,7 @@ def test_stable_diffusion_ddim(self): components = self.get_dummy_components() sd_pipe = StableDiffusionPipeline(**components) - sd_pipe = sd_pipe.to(device) + sd_pipe = sd_pipe.to(torch_device) sd_pipe.set_progress_bar_config(disable=None) inputs = self.get_dummy_inputs(device) @@ -134,6 +134,82 @@ def test_stable_diffusion_ddim(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + def test_stable_diffusion_prompt_embeds(self): + components = self.get_dummy_components() + sd_pipe = StableDiffusionPipeline(**components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + inputs["prompt"] = 3 * [inputs["prompt"]] + + # forward + output = sd_pipe(**inputs) + image_slice_1 = output.images[0, -3:, -3:, -1] + + inputs = self.get_dummy_inputs(torch_device) + prompt = 3 * [inputs.pop("prompt")] + + text_inputs = sd_pipe.tokenizer( + prompt, + padding="max_length", + max_length=sd_pipe.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_inputs = text_inputs["input_ids"].to(torch_device) + + prompt_embeds = sd_pipe.text_encoder(text_inputs)[0] + + inputs["prompt_embeds"] = prompt_embeds + + # forward + output = sd_pipe(**inputs) + image_slice_2 = output.images[0, -3:, -3:, -1] + + assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4 + + def test_stable_diffusion_negative_prompt_embeds(self): + components = self.get_dummy_components() + sd_pipe = StableDiffusionPipeline(**components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + negative_prompt = 3 * ["this is a negative prompt"] + inputs["negative_prompt"] = negative_prompt + inputs["prompt"] = 3 * [inputs["prompt"]] + + # forward + output = sd_pipe(**inputs) + image_slice_1 = output.images[0, -3:, -3:, -1] + + inputs = self.get_dummy_inputs(torch_device) + prompt = 3 * [inputs.pop("prompt")] + + embeds = [] + for p in [prompt, negative_prompt]: + text_inputs = sd_pipe.tokenizer( + p, + padding="max_length", + max_length=sd_pipe.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_inputs = text_inputs["input_ids"].to(torch_device) + + embeds.append(sd_pipe.text_encoder(text_inputs)[0]) + + inputs["prompt_embeds"], inputs["negative_prompt_embeds"] = embeds + + # forward + output = sd_pipe(**inputs) + image_slice_2 = output.images[0, -3:, -3:, -1] + + assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4 + def test_stable_diffusion_ddim_factor_8(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator