Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow directly passing text embeddings to Stable Diffusion Pipeline for prompt weighting #2071

Merged
merged 19 commits into from
Jan 25, 2023
Merged
206 changes: 148 additions & 58 deletions src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py

Large diffs are not rendered by default.

188 changes: 133 additions & 55 deletions src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,11 @@ def __call__(
uncond_input = self.tokenizer(
[""] * batch_size, padding="max_length", max_length=77, truncation=True, return_tensors="pt"
)
uncond_embeddings = self.bert(uncond_input.input_ids.to(self.device))[0]
negative_prompt_embeds = self.bert(uncond_input.input_ids.to(self.device))[0]

# get prompt text embeddings
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt")
text_embeddings = self.bert(text_input.input_ids.to(self.device))[0]
prompt_embeds = self.bert(text_input.input_ids.to(self.device))[0]

# get the initial random noise unless the user supplied it
latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
Expand All @@ -144,7 +144,7 @@ def __call__(
)

if latents is None:
latents = randn_tensor(latents_shape, generator=generator, device=self.device, dtype=text_embeddings.dtype)
latents = randn_tensor(latents_shape, generator=generator, device=self.device, dtype=prompt_embeds.dtype)
else:
if latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
Expand All @@ -163,13 +163,13 @@ def __call__(
if guidance_scale == 1.0:
# guidance_scale of 1 means no guidance
latents_input = latents
context = text_embeddings
context = prompt_embeds
else:
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
latents_input = torch.cat([latents] * 2)
context = torch.cat([uncond_embeddings, text_embeddings])
context = torch.cat([negative_prompt_embeds, prompt_embeds])

# predict the noise residual
noise_pred = self.unet(latents_input, t, encoder_hidden_states=context).sample
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -364,21 +364,21 @@ def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free
image = self.feature_extractor(images=image, return_tensors="pt").pixel_values

image = image.to(device=device, dtype=dtype)
image_embeddings, uncond_embeddings = self.image_encoder(image, return_uncond_vector=True)
image_embeddings, negative_prompt_embeds = self.image_encoder(image, return_uncond_vector=True)

# duplicate image embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = image_embeddings.shape
image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1)
image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)

if do_classifier_free_guidance:
uncond_embeddings = uncond_embeddings.repeat(1, image_embeddings.shape[0], 1)
uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, 1, -1)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, image_embeddings.shape[0], 1)
negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, 1, -1)

# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
image_embeddings = torch.cat([uncond_embeddings, image_embeddings])
image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings])

return image_embeddings

Expand Down
181 changes: 128 additions & 53 deletions src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,60 +261,89 @@ def _execution_device(self):
return self.device

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
def _encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
):
r"""
Encodes the prompt into text encoder hidden states.

Args:
prompt (`str` or `list(int)`):
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
negative_ prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
"""
batch_size = len(prompt) if isinstance(prompt, list) else 1

text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]

if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
if prompt_embeds is None:
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids

if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = text_inputs.attention_mask.to(device)
else:
attention_mask = None
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)

text_embeddings = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
)
text_embeddings = text_embeddings[0]
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = text_inputs.attention_mask.to(device)
else:
attention_mask = None

prompt_embeds = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
)
prompt_embeds = prompt_embeds[0]

prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)

bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)

# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
Expand All @@ -334,7 +363,7 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr
else:
uncond_tokens = negative_prompt

max_length = text_input_ids.shape[-1]
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
Expand All @@ -348,26 +377,32 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr
else:
attention_mask = None

uncond_embeddings = self.text_encoder(
negative_prompt_embeds = self.text_encoder(
uncond_input.input_ids.to(device),
attention_mask=attention_mask,
)
uncond_embeddings = uncond_embeddings[0]
negative_prompt_embeds = negative_prompt_embeds[0]

if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
seq_len = negative_prompt_embeds.shape[1]

negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)

negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)

# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])

return text_embeddings
return prompt_embeds

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.check_inputs
def check_inputs(self, prompt, strength, callback_steps):
def check_inputs(
self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None
):
if not isinstance(prompt, str) and not isinstance(prompt, list):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")

Expand All @@ -382,6 +417,32 @@ def check_inputs(self, prompt, strength, callback_steps):
f" {type(callback_steps)}."
)

if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")

if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)

if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
Expand Down Expand Up @@ -492,6 +553,7 @@ def __call__(
num_images_per_prompt: Optional[int] = 1,
eta: Optional[float] = 0.1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
Expand Down Expand Up @@ -533,6 +595,13 @@ def __call__(
generator (`torch.Generator`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
Expand Down Expand Up @@ -569,8 +638,14 @@ def __call__(
do_classifier_free_guidance = guidance_scale > 1.0

# 3. Encode input prompt
text_embeddings = self._encode_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance, None)
source_text_embeddings = self._encode_prompt(
prompt_embeds = self._encode_prompt(
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
prompt_embeds=prompt_embeds,
)
source_prompt_embeds = self._encode_prompt(
source_prompt, device, num_images_per_prompt, do_classifier_free_guidance, None
)

Expand All @@ -584,7 +659,7 @@ def __call__(

# 6. Prepare latent variables
latents, clean_latents = self.prepare_latents(
image, latent_timestep, batch_size, num_images_per_prompt, text_embeddings.dtype, device, generator
image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator
)
source_latents = latents

Expand Down Expand Up @@ -612,17 +687,17 @@ def __call__(
],
dim=0,
)
concat_text_embeddings = torch.stack(
concat_prompt_embeds = torch.stack(
[
source_text_embeddings[0],
text_embeddings[0],
source_text_embeddings[1],
text_embeddings[1],
source_prompt_embeds[0],
prompt_embeds[0],
source_prompt_embeds[1],
prompt_embeds[1],
],
dim=0,
)
concat_noise_pred = self.unet(
concat_latent_model_input, t, encoder_hidden_states=concat_text_embeddings
concat_latent_model_input, t, encoder_hidden_states=concat_prompt_embeds
).sample

# perform guidance
Expand Down Expand Up @@ -662,7 +737,7 @@ def __call__(
image = self.decode_latents(latents)

# 10. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)

# 11. Convert to PIL
if output_type == "pil":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def _generate(
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")

# get prompt text embeddings
text_embeddings = self.text_encoder(prompt_ids, params=params["text_encoder"])[0]
prompt_embeds = self.text_encoder(prompt_ids, params=params["text_encoder"])[0]

# TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0`
# implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0`
Expand All @@ -210,8 +210,8 @@ def _generate(
).input_ids
else:
uncond_input = neg_prompt_ids
uncond_embeddings = self.text_encoder(uncond_input, params=params["text_encoder"])[0]
context = jnp.concatenate([uncond_embeddings, text_embeddings])
negative_prompt_embeds = self.text_encoder(uncond_input, params=params["text_encoder"])[0]
context = jnp.concatenate([negative_prompt_embeds, prompt_embeds])

latents_shape = (
batch_size,
Expand Down
Loading