Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Mochi Quality Issues #10033

Merged
merged 55 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
27f81bd
update
DN6 Nov 18, 2024
30dd9f6
update
DN6 Nov 18, 2024
10275fe
update
DN6 Nov 20, 2024
79380ca
update
DN6 Nov 20, 2024
21b0997
update
DN6 Nov 22, 2024
fcc59d0
update
DN6 Nov 23, 2024
1782d02
update
DN6 Nov 25, 2024
66a5f59
update
DN6 Nov 25, 2024
3ffa711
update
DN6 Nov 25, 2024
dded243
update
DN6 Nov 25, 2024
d99234f
update
DN6 Nov 25, 2024
8b9d5b6
update
DN6 Nov 25, 2024
2cfca5e
update
DN6 Nov 26, 2024
900fead
update
DN6 Nov 26, 2024
0b09231
update
DN6 Nov 26, 2024
883f5c8
update
DN6 Nov 26, 2024
59c9f5d
update
DN6 Nov 26, 2024
f3fefae
update
DN6 Nov 26, 2024
8a5d03b
update
DN6 Nov 26, 2024
b7464e5
update
DN6 Nov 27, 2024
fb4e175
update
DN6 Nov 27, 2024
61001c8
update
DN6 Nov 27, 2024
0fdef41
update
DN6 Nov 27, 2024
e6fe9f1
update
DN6 Nov 27, 2024
c17cef7
update
DN6 Nov 27, 2024
0e8f20d
update
DN6 Nov 27, 2024
6e2011a
update
DN6 Nov 27, 2024
9c5eb36
update
DN6 Nov 27, 2024
d759516
update
DN6 Nov 27, 2024
7854bde
update
DN6 Nov 27, 2024
2881f2f
update
DN6 Nov 27, 2024
7854061
update
DN6 Nov 27, 2024
b904325
Merge branch 'main' into mochi-quality
sayakpaul Nov 28, 2024
ba9c185
update
DN6 Nov 29, 2024
53dbc37
update
DN6 Nov 29, 2024
77f9d19
update
DN6 Nov 29, 2024
a298915
Merge branch 'mochi-quality' of https://github.com/huggingface/diffus…
DN6 Nov 29, 2024
dc96890
update
DN6 Nov 29, 2024
ae57913
update
DN6 Nov 29, 2024
7626a34
update
DN6 Nov 30, 2024
c39886a
update
DN6 Nov 30, 2024
bbc5892
update
DN6 Dec 7, 2024
3c70b54
update
DN6 Dec 7, 2024
11ce6b8
update
DN6 Dec 7, 2024
cc7b91d
update
DN6 Dec 7, 2024
09fe7ec
Merge branch 'main' into mochi-quality
DN6 Dec 8, 2024
ccabe5e
Merge branch 'main' into mochi-quality
a-r-r-o-w Dec 14, 2024
4c800e3
Merge branch 'main' into mochi-quality
DN6 Dec 16, 2024
2a6b82d
update
DN6 Dec 16, 2024
1421691
Merge branch 'mochi-quality' of https://github.com/huggingface/diffus…
DN6 Dec 16, 2024
cbbc54b
update
DN6 Dec 16, 2024
952f6e9
Merge branch 'main' into mochi-quality
DN6 Dec 16, 2024
b75db11
update
DN6 Dec 16, 2024
50c5607
Update src/diffusers/models/transformers/transformer_mochi.py
DN6 Dec 17, 2024
d80f477
Merge branch 'main' into mochi-quality
DN6 Dec 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
261 changes: 172 additions & 89 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,6 +906,177 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.processor(self, hidden_states)


class MochiAttention(nn.Module):
def __init__(
self,
query_dim: int,
added_kv_proj_dim: int,
processor: "MochiAttnProcessor2_0",
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
bias: bool = False,
added_proj_bias: bool = True,
out_dim: Optional[int] = None,
out_context_dim: Optional[int] = None,
out_bias: bool = True,
context_pre_only: bool = False,
eps: float = 1e-5,
):
super().__init__()
from .normalization import MochiRMSNorm

self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.out_dim = out_dim if out_dim is not None else query_dim
self.out_context_dim = out_context_dim if out_context_dim else query_dim
self.context_pre_only = context_pre_only

self.heads = out_dim // dim_head if out_dim is not None else heads

self.norm_q = MochiRMSNorm(dim_head, eps, True)
self.norm_k = MochiRMSNorm(dim_head, eps, True)
self.norm_added_q = MochiRMSNorm(dim_head, eps, True)
self.norm_added_k = MochiRMSNorm(dim_head, eps, True)

self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias)

self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
if self.context_pre_only is not None:
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)

self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
self.to_out.append(nn.Dropout(dropout))

if not self.context_pre_only:
self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias)

self.processor = processor

def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
):
return self.processor(
self,
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
**kwargs,
)


class MochiAttnProcessor2_0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK! But cc @a-r-r-o-w here. He has been following the mochi-fix PR and added the attention processors to model files
I guess we keep them here for now until we refactor and move them all together?

Copy link
Member

@a-r-r-o-w a-r-r-o-w Dec 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should do the following going forward as design choice (just personal opinion so let's try to forumalate a plan for consistency):

  • Apart from transformer model code, the transformer files will also contain all the relevant attention processor implementation. This helps with readability because you don't have to jump between files and because attention_processor.py is now > 5k lines
  • If an alternate Attention class is required, let's keep it in the transformer file as well. These custom classes require some common methods that will probably not change between implementantions. For this, let's create a AttentionMixin class - for changing/getting attention processors, fusing, etc.
  • If an attention processor is required in both transformer and VAE (and possibly a different file) because of some common parts shared, let's keep the implementation in transformer file, and import it in the vae. If there's no common attention processor, let's keep the implementation respectively in transformer or vae.
  • If some specific layers are shared between transformer and vae (for example, GLUMB convolution in Sana), let's keep the implementation in transformer file too and import where required.
  • Let's create dedicated RoPE classes for each implementation. Any concerns about speed due to recreating the embeddings every inference step can be addressed by caching. Something as simple as functools.cache works here if we make the rope calculation forward dependant on just the num_frames, height, width. But a save hook would work as well. This - we can look into a bit later

WDYT?

"""Attention processor used in Mochi."""

def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("MochiAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")

def __call__(
self,
attn: "MochiAttention",
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)

query = query.unflatten(2, (attn.heads, -1))
key = key.unflatten(2, (attn.heads, -1))
value = value.unflatten(2, (attn.heads, -1))

if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)

encoder_query = attn.add_q_proj(encoder_hidden_states)
encoder_key = attn.add_k_proj(encoder_hidden_states)
encoder_value = attn.add_v_proj(encoder_hidden_states)

encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
encoder_value = encoder_value.unflatten(2, (attn.heads, -1))

if attn.norm_added_q is not None:
encoder_query = attn.norm_added_q(encoder_query)
if attn.norm_added_k is not None:
encoder_key = attn.norm_added_k(encoder_key)

if image_rotary_emb is not None:

def apply_rotary_emb(x, freqs_cos, freqs_sin):
x_even = x[..., 0::2].float()
x_odd = x[..., 1::2].float()

cos = (x_even * freqs_cos - x_odd * freqs_sin).to(x.dtype)
sin = (x_even * freqs_sin + x_odd * freqs_cos).to(x.dtype)

return torch.stack([cos, sin], dim=-1).flatten(-2)

query = apply_rotary_emb(query, *image_rotary_emb)
key = apply_rotary_emb(key, *image_rotary_emb)

query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
encoder_query, encoder_key, encoder_value = (
encoder_query.transpose(1, 2),
encoder_key.transpose(1, 2),
encoder_value.transpose(1, 2),
)

sequence_length = query.size(2)
encoder_sequence_length = encoder_query.size(2)
total_length = sequence_length + encoder_sequence_length

batch_size, heads, _, dim = query.shape
attn_outputs = []
for idx in range(batch_size):
mask = attention_mask[idx][None, :]
valid_prompt_token_indices = torch.nonzero(mask.flatten(), as_tuple=False).flatten()

valid_encoder_query = encoder_query[idx : idx + 1, :, valid_prompt_token_indices, :]
valid_encoder_key = encoder_key[idx : idx + 1, :, valid_prompt_token_indices, :]
valid_encoder_value = encoder_value[idx : idx + 1, :, valid_prompt_token_indices, :]

valid_query = torch.cat([query[idx : idx + 1], valid_encoder_query], dim=2)
valid_key = torch.cat([key[idx : idx + 1], valid_encoder_key], dim=2)
valid_value = torch.cat([value[idx : idx + 1], valid_encoder_value], dim=2)

attn_output = F.scaled_dot_product_attention(
valid_query, valid_key, valid_value, dropout_p=0.0, is_causal=False
)
valid_sequence_length = attn_output.size(2)
attn_output = F.pad(attn_output, (0, 0, 0, total_length - valid_sequence_length))
attn_outputs.append(attn_output)

hidden_states = torch.cat(attn_outputs, dim=0)
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)

hidden_states, encoder_hidden_states = hidden_states.split_with_sizes(
(sequence_length, encoder_sequence_length), dim=1
)

# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)

if hasattr(attn, "to_add_out"):
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)

return hidden_states, encoder_hidden_states


class AttnProcessor:
r"""
Default processor for performing attention-related computations.
Expand Down Expand Up @@ -3868,94 +4039,6 @@ def __call__(
return hidden_states


class MochiAttnProcessor2_0:
"""Attention processor used in Mochi."""

def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("MochiAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")

def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)

query = query.unflatten(2, (attn.heads, -1))
key = key.unflatten(2, (attn.heads, -1))
value = value.unflatten(2, (attn.heads, -1))

if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)

encoder_query = attn.add_q_proj(encoder_hidden_states)
encoder_key = attn.add_k_proj(encoder_hidden_states)
encoder_value = attn.add_v_proj(encoder_hidden_states)

encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
encoder_value = encoder_value.unflatten(2, (attn.heads, -1))

if attn.norm_added_q is not None:
encoder_query = attn.norm_added_q(encoder_query)
if attn.norm_added_k is not None:
encoder_key = attn.norm_added_k(encoder_key)

if image_rotary_emb is not None:

def apply_rotary_emb(x, freqs_cos, freqs_sin):
x_even = x[..., 0::2].float()
x_odd = x[..., 1::2].float()

cos = (x_even * freqs_cos - x_odd * freqs_sin).to(x.dtype)
sin = (x_even * freqs_sin + x_odd * freqs_cos).to(x.dtype)

return torch.stack([cos, sin], dim=-1).flatten(-2)

query = apply_rotary_emb(query, *image_rotary_emb)
key = apply_rotary_emb(key, *image_rotary_emb)

query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
encoder_query, encoder_key, encoder_value = (
encoder_query.transpose(1, 2),
encoder_key.transpose(1, 2),
encoder_value.transpose(1, 2),
)

sequence_length = query.size(2)
encoder_sequence_length = encoder_query.size(2)

query = torch.cat([query, encoder_query], dim=2)
key = torch.cat([key, encoder_key], dim=2)
value = torch.cat([value, encoder_value], dim=2)

hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)

hidden_states, encoder_hidden_states = hidden_states.split_with_sizes(
(sequence_length, encoder_sequence_length), dim=1
)

# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)

if getattr(attn, "to_add_out", None) is not None:
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)

return hidden_states, encoder_hidden_states


class FusedAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses
Expand Down Expand Up @@ -5668,13 +5751,13 @@ def __call__(
AttnProcessorNPU,
AttnProcessor2_0,
MochiVaeAttnProcessor2_0,
MochiAttnProcessor2_0,
StableAudioAttnProcessor2_0,
HunyuanAttnProcessor2_0,
FusedHunyuanAttnProcessor2_0,
PAGHunyuanAttnProcessor2_0,
PAGCFGHunyuanAttnProcessor2_0,
LuminaAttnProcessor2_0,
MochiAttnProcessor2_0,
FusedAttnProcessor2_0,
CustomDiffusionXFormersAttnProcessor,
CustomDiffusionAttnProcessor2_0,
Expand Down
1 change: 0 additions & 1 deletion src/diffusers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,6 @@ def forward(self, latent):
height, width = latent.shape[-2:]
else:
height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size

latent = self.proj(latent)
if self.flatten:
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
Expand Down
57 changes: 30 additions & 27 deletions src/diffusers/models/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,33 +234,6 @@ def forward(
return x, gate_msa, scale_mlp, gate_mlp


class MochiRMSNormZero(nn.Module):
r"""
Adaptive RMS Norm used in Mochi.

Parameters:
embedding_dim (`int`): The size of each embedding vector.
"""

def __init__(
self, embedding_dim: int, hidden_dim: int, eps: float = 1e-5, elementwise_affine: bool = False
) -> None:
super().__init__()

self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, hidden_dim)
self.norm = RMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)

def forward(
self, hidden_states: torch.Tensor, emb: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
emb = self.linear(self.silu(emb))
scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
hidden_states = self.norm(hidden_states) * (1 + scale_msa[:, None])

return hidden_states, gate_msa, scale_mlp, gate_mlp


class AdaLayerNormSingle(nn.Module):
r"""
Norm layer adaptive layer norm single (adaLN-single).
Expand Down Expand Up @@ -549,6 +522,36 @@ def forward(self, hidden_states):
return hidden_states


# TODO: (Dhruv) This can be replaced with regular RMSNorm in Mochi once `_keep_in_fp32_modules` is supported
# for sharded checkpoints, see: https://github.com/huggingface/diffusers/issues/10013
class MochiRMSNorm(nn.Module):
def __init__(self, dim, eps: float, elementwise_affine: bool = True):
super().__init__()

self.eps = eps

if isinstance(dim, numbers.Integral):
dim = (dim,)

self.dim = torch.Size(dim)

if elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim))
else:
self.weight = None

def forward(self, hidden_states):
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)

if self.weight is not None:
hidden_states = hidden_states * self.weight
hidden_states = hidden_states.to(input_dtype)

return hidden_states


class GlobalResponseNorm(nn.Module):
# Taken from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105
def __init__(self, dim):
Expand Down
Loading
Loading