Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kandinsky 3.0] Follow-up TODOs #5944

Merged
merged 24 commits into from
Dec 1, 2023
Merged

[Kandinsky 3.0] Follow-up TODOs #5944

merged 24 commits into from
Dec 1, 2023

Conversation

yiyixuxu
Copy link
Collaborator

@yiyixuxu yiyixuxu commented Nov 27, 2023

work through the remaining TODOs from #5913

  • Treat all the TODO statemens in the code
  • Add better docs (added doc for Kandinsky3.0 #5937 (comment))
  • Add tests for img2img
  • Rename the pipeline and model files
  • Clean up all the unet blocks and get rid of hard to read code
  • Publish on discord etc...

text-2-image

from diffusers import AutoPipelineForText2Image
import torch

pipe = AutoPipelineForText2Image.from_pretrained("kandinsky-community/kandinsky-3", variant="fp16", torch_dtype=torch.float16)
pipe.enable_model_cpu_offload()
        
prompt = "A photograph of the inside of a subway train. There are raccoons sitting on the seats. One of them is reading a newspaper. The window shows the city in the background."

generator = torch.Generator(device="cpu").manual_seed(0)
image = pipe(prompt, num_inference_steps=25, generator=generator).images[0]

yiyi_test_1_out

image-2-image

from diffusers import AutoPipelineForImage2Image
from diffusers.utils import load_image
import torch

pipe = AutoPipelineForImage2Image.from_pretrained("kandinsky-community/kandinsky-3", variant="fp16", torch_dtype=torch.float16)
pipe.enable_model_cpu_offload()
        
prompt = "A painting of the inside of a subway train with tiny raccoons."
image = load_image("https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky3/t2i.png")

generator = torch.Generator(device="cpu").manual_seed(0)
image = pipe(prompt, image=image, strength=0.75, num_inference_steps=25, generator=generator).images[0]

yiyi_test_2_out

@yiyixuxu yiyixuxu marked this pull request as draft November 27, 2023 09:35
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@yiyixuxu yiyixuxu marked this pull request as ready for review November 29, 2023 04:41
@@ -593,7 +597,7 @@ def get_attention_scores(
beta = 0
else:
baddbmm_input = attention_mask
beta = 1
beta = self.scale_mask_factor
Copy link
Collaborator Author

@yiyixuxu yiyixuxu Nov 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added a new config for Attention here

set this to be a large negative number helps a lot with numerical stability. in kandinsky they "fill" the empty tokens
in attention_matrix with largest possible negative number(see code ->

attention_matrix = attention_matrix.masked_fill(~(context_mask != 0), max_neg_value)
)

I set this config to be -60000.0 for simplicity - not exactly same but seem to be sufficient.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm is beta supposed to be used to control mask precision?

Copy link
Collaborator Author

@yiyixuxu yiyixuxu Nov 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually, I think I should do this instead!

attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0

This beta here is essentially trying to do the same thing - it's trying to zero out the zero token's attention score in the softmax operation. I did not realize I was missing this step because Kandinsky cuts off (most except for one) the zero tokens from prompt_embeds, so not doing this step or doing this step wrong still generates accurate output for the most part, except when batch_size > 1 - in that case the prompt_embeds will contain some zero tokens for shorter prompts and attention_mask needs to be applied correctly

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

refactored and now this script (one scenario when the attention_mask actually need to be applied) not exactly but similar outputs on main and branch

from diffusers import AutoPipelineForText2Image
import torch

pipe = AutoPipelineForText2Image.from_pretrained("kandinsky-community/kandinsky-3", variant="fp16", torch_dtype=torch.float16)
pipe.enable_model_cpu_offload()
        
prompt = ["A photograph of the inside of a subway train. There are raccoons sitting on the seats. One of them is reading a newspaper. The window shows the city in the background.",
          "A photograph of the inside of a subway train. There are raccoons sitting on the seats. One of them is reading a newspaper. The window shows the city in the background. A photograph of the inside of a subway train. There are raccoons sitting on the seats. One of them is reading a newspaper. The window shows the city in the background."]

generator = [torch.Generator(device="cpu").manual_seed(0),torch.Generator(device="cpu").manual_seed(1)]
image = pipe(prompt, num_inference_steps=25, generator=generator).images[0]
main branch
yiyi_test_3_out_bm yiyi_test_3_out

@@ -109,15 +109,19 @@ def __init__(
residual_connection: bool = False,
_from_deprecated_attn_block: bool = False,
processor: Optional["AttnProcessor"] = None,
scale_mask_factor: float = 1.0,
out_dim: int = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is out_dim different from query_dim here?

Copy link
Collaborator Author

@yiyixuxu yiyixuxu Nov 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patrickvonplaten
The only difference is the to_outlayer here - Kandinsky attention output does not change the dimension from inner_dim while our attention class will project the output to query_dim. I added an out_dim for this purpose, but we can add a different config if it makes more sense!

self.to_out.append(nn.Linear(out_channels, out_channels, bias=False))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That works! Makes sense

@@ -2219,44 +2223,6 @@ def __call__(
return hidden_states


# TODO(Yiyi): This class should not exist, we can replace it with a normal attention processor I believe
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

@@ -1,16 +1,28 @@
import math
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually can we rename this file to unet_kandinsky3.py ? I don't like kandi.. much

image_mask = image_mask.reshape(image_mask.shape[0], -1)

out = self.attention(out, context, context_mask, image_mask)
out = self.attention(out, context, context_mask)
out = out.permute(0, 2, 1).unsqueeze(-1).reshape(out.shape[0], -1, height, width)
x = x + out
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very nice clean-ups!

@@ -228,14 +254,19 @@ def check_inputs(
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!


if output_type == "pil":
image = self.numpy_to_pil(image)
self.maybe_free_model_hooks()

if not return_dict:
return (image,)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

@@ -320,6 +349,8 @@ def __call__(
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed this bug #5963 (comment) here by adding attention_mask and negative_attention_mask argument to __call__

you should pass the attention_mask, negative_attention_mask along with prompt_embeds and negative_prompt_embeds, otherwise will get an error

from diffusers import AutoPipelineForText2Image
import torch

pipe = AutoPipelineForText2Image.from_pretrained("kandinsky-community/kandinsky-3", variant="fp16", torch_dtype=torch.float16)
pipe.enable_model_cpu_offload()

prompt = "A photograph of the inside of a subway train. There are raccoons sitting on the seats. One of them is reading a newspaper. The window shows the city in the background."
prompt_embeds, negative_prompt_embeds, attention_mask, negative_attention_mask = pipe.encode_prompt(
             prompt,
             True,
             device=torch.device("cuda")
         )

generator = torch.Generator(device="cpu").manual_seed(42)
image = pipe(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, attention_mask=attention_mask, negative_attention_mask=negative_attention_mask,num_inference_steps=25, generator=generator).images[0]

this works too

image = pipe(prompt_embeds=prompt_embeds, attention_mask=attention_mask,num_inference_steps=25, generator=generator).images[0]

@patrickvonplaten
Copy link
Contributor

@yiyixuxu lemme know once ready for a final review :-)

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great clean-up - thanks!

@yiyixuxu yiyixuxu merged commit b41f809 into main Dec 1, 2023
22 checks passed
@yiyixuxu yiyixuxu deleted the kand-3 branch December 1, 2023 17:14
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants