Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Mochi Quality Issues #10033

Merged
merged 55 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
27f81bd
update
DN6 Nov 18, 2024
30dd9f6
update
DN6 Nov 18, 2024
10275fe
update
DN6 Nov 20, 2024
79380ca
update
DN6 Nov 20, 2024
21b0997
update
DN6 Nov 22, 2024
fcc59d0
update
DN6 Nov 23, 2024
1782d02
update
DN6 Nov 25, 2024
66a5f59
update
DN6 Nov 25, 2024
3ffa711
update
DN6 Nov 25, 2024
dded243
update
DN6 Nov 25, 2024
d99234f
update
DN6 Nov 25, 2024
8b9d5b6
update
DN6 Nov 25, 2024
2cfca5e
update
DN6 Nov 26, 2024
900fead
update
DN6 Nov 26, 2024
0b09231
update
DN6 Nov 26, 2024
883f5c8
update
DN6 Nov 26, 2024
59c9f5d
update
DN6 Nov 26, 2024
f3fefae
update
DN6 Nov 26, 2024
8a5d03b
update
DN6 Nov 26, 2024
b7464e5
update
DN6 Nov 27, 2024
fb4e175
update
DN6 Nov 27, 2024
61001c8
update
DN6 Nov 27, 2024
0fdef41
update
DN6 Nov 27, 2024
e6fe9f1
update
DN6 Nov 27, 2024
c17cef7
update
DN6 Nov 27, 2024
0e8f20d
update
DN6 Nov 27, 2024
6e2011a
update
DN6 Nov 27, 2024
9c5eb36
update
DN6 Nov 27, 2024
d759516
update
DN6 Nov 27, 2024
7854bde
update
DN6 Nov 27, 2024
2881f2f
update
DN6 Nov 27, 2024
7854061
update
DN6 Nov 27, 2024
b904325
Merge branch 'main' into mochi-quality
sayakpaul Nov 28, 2024
ba9c185
update
DN6 Nov 29, 2024
53dbc37
update
DN6 Nov 29, 2024
77f9d19
update
DN6 Nov 29, 2024
a298915
Merge branch 'mochi-quality' of https://github.com/huggingface/diffus…
DN6 Nov 29, 2024
dc96890
update
DN6 Nov 29, 2024
ae57913
update
DN6 Nov 29, 2024
7626a34
update
DN6 Nov 30, 2024
c39886a
update
DN6 Nov 30, 2024
bbc5892
update
DN6 Dec 7, 2024
3c70b54
update
DN6 Dec 7, 2024
11ce6b8
update
DN6 Dec 7, 2024
cc7b91d
update
DN6 Dec 7, 2024
09fe7ec
Merge branch 'main' into mochi-quality
DN6 Dec 8, 2024
ccabe5e
Merge branch 'main' into mochi-quality
a-r-r-o-w Dec 14, 2024
4c800e3
Merge branch 'main' into mochi-quality
DN6 Dec 16, 2024
2a6b82d
update
DN6 Dec 16, 2024
1421691
Merge branch 'mochi-quality' of https://github.com/huggingface/diffus…
DN6 Dec 16, 2024
cbbc54b
update
DN6 Dec 16, 2024
952f6e9
Merge branch 'main' into mochi-quality
DN6 Dec 16, 2024
b75db11
update
DN6 Dec 16, 2024
50c5607
Update src/diffusers/models/transformers/transformer_mochi.py
DN6 Dec 17, 2024
d80f477
Merge branch 'main' into mochi-quality
DN6 Dec 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 0 additions & 88 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3510,94 +3510,6 @@ def __call__(
return hidden_states


class MochiAttnProcessor2_0:
"""Attention processor used in Mochi."""

def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("MochiAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")

def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)

query = query.unflatten(2, (attn.heads, -1))
key = key.unflatten(2, (attn.heads, -1))
value = value.unflatten(2, (attn.heads, -1))

if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)

encoder_query = attn.add_q_proj(encoder_hidden_states)
encoder_key = attn.add_k_proj(encoder_hidden_states)
encoder_value = attn.add_v_proj(encoder_hidden_states)

encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
encoder_value = encoder_value.unflatten(2, (attn.heads, -1))

if attn.norm_added_q is not None:
encoder_query = attn.norm_added_q(encoder_query)
if attn.norm_added_k is not None:
encoder_key = attn.norm_added_k(encoder_key)

if image_rotary_emb is not None:

def apply_rotary_emb(x, freqs_cos, freqs_sin):
x_even = x[..., 0::2].float()
x_odd = x[..., 1::2].float()

cos = (x_even * freqs_cos - x_odd * freqs_sin).to(x.dtype)
sin = (x_even * freqs_sin + x_odd * freqs_cos).to(x.dtype)

return torch.stack([cos, sin], dim=-1).flatten(-2)

query = apply_rotary_emb(query, *image_rotary_emb)
key = apply_rotary_emb(key, *image_rotary_emb)

query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
encoder_query, encoder_key, encoder_value = (
encoder_query.transpose(1, 2),
encoder_key.transpose(1, 2),
encoder_value.transpose(1, 2),
)

sequence_length = query.size(2)
encoder_sequence_length = encoder_query.size(2)

query = torch.cat([query, encoder_query], dim=2)
key = torch.cat([key, encoder_key], dim=2)
value = torch.cat([value, encoder_value], dim=2)

hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)

hidden_states, encoder_hidden_states = hidden_states.split_with_sizes(
(sequence_length, encoder_sequence_length), dim=1
)

# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)

if hasattr(attn, "to_add_out"):
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)

return hidden_states, encoder_hidden_states


class FusedAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses
Expand Down
1 change: 0 additions & 1 deletion src/diffusers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,6 @@ def forward(self, latent):
height, width = latent.shape[-2:]
else:
height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size

latent = self.proj(latent)
if self.flatten:
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
Expand Down
30 changes: 1 addition & 29 deletions src/diffusers/models/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,33 +234,6 @@ def forward(
return x, gate_msa, scale_mlp, gate_mlp


class MochiRMSNormZero(nn.Module):
r"""
Adaptive RMS Norm used in Mochi.

Parameters:
embedding_dim (`int`): The size of each embedding vector.
"""

def __init__(
self, embedding_dim: int, hidden_dim: int, eps: float = 1e-5, elementwise_affine: bool = False
) -> None:
super().__init__()

self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, hidden_dim)
self.norm = RMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)

def forward(
self, hidden_states: torch.Tensor, emb: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
emb = self.linear(self.silu(emb))
scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
hidden_states = self.norm(hidden_states) * (1 + scale_msa[:, None])

return hidden_states, gate_msa, scale_mlp, gate_mlp


class AdaLayerNormSingle(nn.Module):
r"""
Norm layer adaptive layer norm single (adaLN-single).
Expand Down Expand Up @@ -537,8 +510,7 @@ def forward(self, hidden_states):
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
hidden_states = hidden_states * self.weight
else:
hidden_states = hidden_states.to(input_dtype)
hidden_states = hidden_states.to(input_dtype)

return hidden_states

Expand Down
Loading
Loading