Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kandinsky 3.0] Follow-up TODOs #5944

Merged
merged 24 commits into from
Dec 1, 2023
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/diffusers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@
_import_structure["unet_1d"] = ["UNet1DModel"]
_import_structure["unet_2d"] = ["UNet2DModel"]
_import_structure["unet_2d_condition"] = ["UNet2DConditionModel"]
_import_structure["unet_2d_condition_kandi3"] = ["Kandinsky3UNet"]
_import_structure["unet_3d_condition"] = ["UNet3DConditionModel"]
_import_structure["unet_kandi3"] = ["Kandinsky3UNet"]
_import_structure["unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"]
_import_structure["vq_model"] = ["VQModel"]

Expand All @@ -63,8 +63,8 @@
from .unet_1d import UNet1DModel
from .unet_2d import UNet2DModel
from .unet_2d_condition import UNet2DConditionModel
from .unet_2d_condition_kandi3 import Kandinsky3UNet
from .unet_3d_condition import UNet3DConditionModel
from .unet_kandi3 import Kandinsky3UNet
from .unet_motion_model import MotionAdapter, UNetMotionModel
from .vq_model import VQModel

Expand Down
53 changes: 9 additions & 44 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import torch
import torch.nn.functional as F
from torch import einsum, nn
from torch import nn

from ..utils import USE_PEFT_BACKEND, deprecate, logging
from ..utils.import_utils import is_xformers_available
Expand Down Expand Up @@ -109,15 +109,19 @@ def __init__(
residual_connection: bool = False,
_from_deprecated_attn_block: bool = False,
processor: Optional["AttnProcessor"] = None,
scale_mask_factor: float = 1.0,
out_dim: int = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is out_dim different from query_dim here?

Copy link
Collaborator Author

@yiyixuxu yiyixuxu Nov 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patrickvonplaten
The only difference is the to_outlayer here - Kandinsky attention output does not change the dimension from inner_dim while our attention class will project the output to query_dim. I added an out_dim for this purpose, but we can add a different config if it makes more sense!

self.to_out.append(nn.Linear(out_channels, out_channels, bias=False))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That works! Makes sense

):
super().__init__()
self.inner_dim = dim_head * heads
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.upcast_attention = upcast_attention
self.upcast_softmax = upcast_softmax
self.rescale_output_factor = rescale_output_factor
self.residual_connection = residual_connection
self.dropout = dropout
self.out_dim = out_dim if out_dim is not None else query_dim
self.scale_mask_factor = scale_mask_factor

# we make use of this private variable to know whether this class is loaded
# with an deprecated state dict so that we can convert it on the fly
Expand All @@ -126,7 +130,7 @@ def __init__(
self.scale_qk = scale_qk
self.scale = dim_head**-0.5 if self.scale_qk else 1.0

self.heads = heads
self.heads = out_dim // dim_head if out_dim is not None else heads
# for slice_size > 0 the attention score computation
# is split across the batch axis to save memory
# You can set slice_size with `set_attention_slice`
Expand Down Expand Up @@ -193,7 +197,7 @@ def __init__(
self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim)

self.to_out = nn.ModuleList([])
self.to_out.append(linear_cls(self.inner_dim, query_dim, bias=out_bias))
self.to_out.append(linear_cls(self.inner_dim, self.out_dim, bias=out_bias))
self.to_out.append(nn.Dropout(dropout))

# set attention processor
Expand Down Expand Up @@ -593,7 +597,7 @@ def get_attention_scores(
beta = 0
else:
baddbmm_input = attention_mask
beta = 1
beta = self.scale_mask_factor
Copy link
Collaborator Author

@yiyixuxu yiyixuxu Nov 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added a new config for Attention here

set this to be a large negative number helps a lot with numerical stability. in kandinsky they "fill" the empty tokens
in attention_matrix with largest possible negative number(see code ->

attention_matrix = attention_matrix.masked_fill(~(context_mask != 0), max_neg_value)
)

I set this config to be -60000.0 for simplicity - not exactly same but seem to be sufficient.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm is beta supposed to be used to control mask precision?

Copy link
Collaborator Author

@yiyixuxu yiyixuxu Nov 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually, I think I should do this instead!

attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0

This beta here is essentially trying to do the same thing - it's trying to zero out the zero token's attention score in the softmax operation. I did not realize I was missing this step because Kandinsky cuts off (most except for one) the zero tokens from prompt_embeds, so not doing this step or doing this step wrong still generates accurate output for the most part, except when batch_size > 1 - in that case the prompt_embeds will contain some zero tokens for shorter prompts and attention_mask needs to be applied correctly

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

refactored and now this script (one scenario when the attention_mask actually need to be applied) not exactly but similar outputs on main and branch

from diffusers import AutoPipelineForText2Image
import torch

pipe = AutoPipelineForText2Image.from_pretrained("kandinsky-community/kandinsky-3", variant="fp16", torch_dtype=torch.float16)
pipe.enable_model_cpu_offload()
        
prompt = ["A photograph of the inside of a subway train. There are raccoons sitting on the seats. One of them is reading a newspaper. The window shows the city in the background.",
          "A photograph of the inside of a subway train. There are raccoons sitting on the seats. One of them is reading a newspaper. The window shows the city in the background. A photograph of the inside of a subway train. There are raccoons sitting on the seats. One of them is reading a newspaper. The window shows the city in the background."]

generator = [torch.Generator(device="cpu").manual_seed(0),torch.Generator(device="cpu").manual_seed(1)]
image = pipe(prompt, num_inference_steps=25, generator=generator).images[0]
main branch
yiyi_test_3_out_bm yiyi_test_3_out


attention_scores = torch.baddbmm(
baddbmm_input,
Expand Down Expand Up @@ -2219,44 +2223,6 @@ def __call__(
return hidden_states


# TODO(Yiyi): This class should not exist, we can replace it with a normal attention processor I believe
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

# this way torch.compile and co. will work as well
class Kandi3AttnProcessor:
r"""
Default kandinsky3 proccesor for performing attention-related computations.
"""

@staticmethod
def _reshape(hid_states, h):
b, n, f = hid_states.shape
d = f // h
return hid_states.unsqueeze(-1).reshape(b, n, h, d).permute(0, 2, 1, 3)

def __call__(
self,
attn,
x,
context,
context_mask=None,
):
query = self._reshape(attn.to_q(x), h=attn.num_heads)
key = self._reshape(attn.to_k(context), h=attn.num_heads)
value = self._reshape(attn.to_v(context), h=attn.num_heads)

attention_matrix = einsum("b h i d, b h j d -> b h i j", query, key)

if context_mask is not None:
max_neg_value = -torch.finfo(attention_matrix.dtype).max
context_mask = context_mask.unsqueeze(1).unsqueeze(1)
attention_matrix = attention_matrix.masked_fill(~(context_mask != 0), max_neg_value)
attention_matrix = (attention_matrix * attn.scale).softmax(dim=-1)

out = einsum("b h i j, b h j d -> b h i d", attention_matrix, value)
out = out.permute(0, 2, 1, 3).reshape(out.shape[0], out.shape[2], -1)
out = attn.to_out[0](out)
return out


LORA_ATTENTION_PROCESSORS = (
LoRAAttnProcessor,
LoRAAttnProcessor2_0,
Expand All @@ -2282,7 +2248,6 @@ def __call__(
LoRAXFormersAttnProcessor,
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
Kandi3AttnProcessor,
)

AttentionProcessor = Union[
Expand Down
Loading
Loading