-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Kandinsky 3.0] Follow-up TODOs #5944
Changes from 18 commits
293c480
e5a1f32
eca5c18
b8bb288
84ce3d6
4b5fe93
123dafc
145ddad
5a2dd24
3318034
89fdee4
d9406cb
e408cdf
7dced94
51fe17b
f5cfa5a
334cd2e
dbf5135
25c4e07
dd198cb
d60bc4e
1d170cc
c4eae7e
226f755
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -16,7 +16,7 @@ | |||||||||
|
||||||||||
import torch | ||||||||||
import torch.nn.functional as F | ||||||||||
from torch import einsum, nn | ||||||||||
from torch import nn | ||||||||||
|
||||||||||
from ..utils import USE_PEFT_BACKEND, deprecate, logging | ||||||||||
from ..utils.import_utils import is_xformers_available | ||||||||||
|
@@ -109,15 +109,19 @@ def __init__( | |||||||||
residual_connection: bool = False, | ||||||||||
_from_deprecated_attn_block: bool = False, | ||||||||||
processor: Optional["AttnProcessor"] = None, | ||||||||||
scale_mask_factor: float = 1.0, | ||||||||||
out_dim: int = None, | ||||||||||
): | ||||||||||
super().__init__() | ||||||||||
self.inner_dim = dim_head * heads | ||||||||||
self.inner_dim = out_dim if out_dim is not None else dim_head * heads | ||||||||||
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim | ||||||||||
self.upcast_attention = upcast_attention | ||||||||||
self.upcast_softmax = upcast_softmax | ||||||||||
self.rescale_output_factor = rescale_output_factor | ||||||||||
self.residual_connection = residual_connection | ||||||||||
self.dropout = dropout | ||||||||||
self.out_dim = out_dim if out_dim is not None else query_dim | ||||||||||
self.scale_mask_factor = scale_mask_factor | ||||||||||
|
||||||||||
# we make use of this private variable to know whether this class is loaded | ||||||||||
# with an deprecated state dict so that we can convert it on the fly | ||||||||||
|
@@ -126,7 +130,7 @@ def __init__( | |||||||||
self.scale_qk = scale_qk | ||||||||||
self.scale = dim_head**-0.5 if self.scale_qk else 1.0 | ||||||||||
|
||||||||||
self.heads = heads | ||||||||||
self.heads = out_dim // dim_head if out_dim is not None else heads | ||||||||||
# for slice_size > 0 the attention score computation | ||||||||||
# is split across the batch axis to save memory | ||||||||||
# You can set slice_size with `set_attention_slice` | ||||||||||
|
@@ -193,7 +197,7 @@ def __init__( | |||||||||
self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim) | ||||||||||
|
||||||||||
self.to_out = nn.ModuleList([]) | ||||||||||
self.to_out.append(linear_cls(self.inner_dim, query_dim, bias=out_bias)) | ||||||||||
self.to_out.append(linear_cls(self.inner_dim, self.out_dim, bias=out_bias)) | ||||||||||
self.to_out.append(nn.Dropout(dropout)) | ||||||||||
|
||||||||||
# set attention processor | ||||||||||
|
@@ -593,7 +597,7 @@ def get_attention_scores( | |||||||||
beta = 0 | ||||||||||
else: | ||||||||||
baddbmm_input = attention_mask | ||||||||||
beta = 1 | ||||||||||
beta = self.scale_mask_factor | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added a new config for set this to be a large negative number helps a lot with numerical stability. in kandinsky they "fill" the empty tokens
I set this config to be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. actually, I think I should do this instead!
This There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. refactored and now this script (one scenario when the attention_mask actually need to be applied) not exactly but similar outputs on main and branch from diffusers import AutoPipelineForText2Image
import torch
pipe = AutoPipelineForText2Image.from_pretrained("kandinsky-community/kandinsky-3", variant="fp16", torch_dtype=torch.float16)
pipe.enable_model_cpu_offload()
prompt = ["A photograph of the inside of a subway train. There are raccoons sitting on the seats. One of them is reading a newspaper. The window shows the city in the background.",
"A photograph of the inside of a subway train. There are raccoons sitting on the seats. One of them is reading a newspaper. The window shows the city in the background. A photograph of the inside of a subway train. There are raccoons sitting on the seats. One of them is reading a newspaper. The window shows the city in the background."]
generator = [torch.Generator(device="cpu").manual_seed(0),torch.Generator(device="cpu").manual_seed(1)]
image = pipe(prompt, num_inference_steps=25, generator=generator).images[0]
|
||||||||||
|
||||||||||
attention_scores = torch.baddbmm( | ||||||||||
baddbmm_input, | ||||||||||
|
@@ -2219,44 +2223,6 @@ def __call__( | |||||||||
return hidden_states | ||||||||||
|
||||||||||
|
||||||||||
# TODO(Yiyi): This class should not exist, we can replace it with a normal attention processor I believe | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice! |
||||||||||
# this way torch.compile and co. will work as well | ||||||||||
class Kandi3AttnProcessor: | ||||||||||
r""" | ||||||||||
Default kandinsky3 proccesor for performing attention-related computations. | ||||||||||
""" | ||||||||||
|
||||||||||
@staticmethod | ||||||||||
def _reshape(hid_states, h): | ||||||||||
b, n, f = hid_states.shape | ||||||||||
d = f // h | ||||||||||
return hid_states.unsqueeze(-1).reshape(b, n, h, d).permute(0, 2, 1, 3) | ||||||||||
|
||||||||||
def __call__( | ||||||||||
self, | ||||||||||
attn, | ||||||||||
x, | ||||||||||
context, | ||||||||||
context_mask=None, | ||||||||||
): | ||||||||||
query = self._reshape(attn.to_q(x), h=attn.num_heads) | ||||||||||
key = self._reshape(attn.to_k(context), h=attn.num_heads) | ||||||||||
value = self._reshape(attn.to_v(context), h=attn.num_heads) | ||||||||||
|
||||||||||
attention_matrix = einsum("b h i d, b h j d -> b h i j", query, key) | ||||||||||
|
||||||||||
if context_mask is not None: | ||||||||||
max_neg_value = -torch.finfo(attention_matrix.dtype).max | ||||||||||
context_mask = context_mask.unsqueeze(1).unsqueeze(1) | ||||||||||
attention_matrix = attention_matrix.masked_fill(~(context_mask != 0), max_neg_value) | ||||||||||
attention_matrix = (attention_matrix * attn.scale).softmax(dim=-1) | ||||||||||
|
||||||||||
out = einsum("b h i j, b h j d -> b h i d", attention_matrix, value) | ||||||||||
out = out.permute(0, 2, 1, 3).reshape(out.shape[0], out.shape[2], -1) | ||||||||||
out = attn.to_out[0](out) | ||||||||||
return out | ||||||||||
|
||||||||||
|
||||||||||
LORA_ATTENTION_PROCESSORS = ( | ||||||||||
LoRAAttnProcessor, | ||||||||||
LoRAAttnProcessor2_0, | ||||||||||
|
@@ -2282,7 +2248,6 @@ def __call__( | |||||||||
LoRAXFormersAttnProcessor, | ||||||||||
IPAdapterAttnProcessor, | ||||||||||
IPAdapterAttnProcessor2_0, | ||||||||||
Kandi3AttnProcessor, | ||||||||||
) | ||||||||||
|
||||||||||
AttentionProcessor = Union[ | ||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is
out_dim
different fromquery_dim
here?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@patrickvonplaten
The only difference is the
to_out
layer here - Kandinsky attention output does not change the dimension frominner_dim
while our attention class will project the output toquery_dim
. I added anout_dim
for this purpose, but we can add a different config if it makes more sense!diffusers/src/diffusers/models/unet_kandi3.py
Line 453 in d1b2a1a
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That works! Makes sense