Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Mochi Quality Issues #10033

Merged
merged 55 commits into from
Dec 17, 2024
Merged

Fix Mochi Quality Issues #10033

merged 55 commits into from
Dec 17, 2024

Conversation

DN6
Copy link
Collaborator

@DN6 DN6 commented Nov 27, 2024

What does this PR do?

We're seeing some quality issues with Mochi due to missing upcasts and differences between how attention is handled in the original repo.

This PR:

  1. Matches the transformer implementation 1:1 so that norms are upcast and run in the same precision as the original repo. 2. Changes the MochiAttnProcessor to match the original approach of dropping padding tokens.
  2. Runs the CFG and Sampling step in FP32

I'll update the docs PR: #9934 with a guide on how to reproduce the original repo results exactly once this PR is merged.

Fixes # (issue)

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@DN6 DN6 added roadmap Add to current release roadmap and removed roadmap Add to current release roadmap labels Dec 3, 2024
@a-r-r-o-w a-r-r-o-w mentioned this pull request Dec 14, 2024
@@ -198,7 +198,6 @@ def __init__(
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 128
)

# Copied from diffusers.pipelines.mochi.pipeline_mochi.MochiPipeline._get_t5_prompt_embeds with 256->128
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mochi zeros out negative prompt embeds if they are an empty string.

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you for the fix!!!!
so much work went into this PR:)

)


class MochiAttnProcessor2_0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK! But cc @a-r-r-o-w here. He has been following the mochi-fix PR and added the attention processors to model files
I guess we keep them here for now until we refactor and move them all together?

Copy link
Member

@a-r-r-o-w a-r-r-o-w Dec 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should do the following going forward as design choice (just personal opinion so let's try to forumalate a plan for consistency):

  • Apart from transformer model code, the transformer files will also contain all the relevant attention processor implementation. This helps with readability because you don't have to jump between files and because attention_processor.py is now > 5k lines
  • If an alternate Attention class is required, let's keep it in the transformer file as well. These custom classes require some common methods that will probably not change between implementantions. For this, let's create a AttentionMixin class - for changing/getting attention processors, fusing, etc.
  • If an attention processor is required in both transformer and VAE (and possibly a different file) because of some common parts shared, let's keep the implementation in transformer file, and import it in the vae. If there's no common attention processor, let's keep the implementation respectively in transformer or vae.
  • If some specific layers are shared between transformer and vae (for example, GLUMB convolution in Sana), let's keep the implementation in transformer file too and import where required.
  • Let's create dedicated RoPE classes for each implementation. Any concerns about speed due to recreating the embeddings every inference step can be addressed by caching. Something as simple as functools.cache works here if we make the rope calculation forward dependant on just the num_frames, height, width. But a save hook would work as well. This - we can look into a bit later

WDYT?

src/diffusers/models/transformers/transformer_mochi.py Outdated Show resolved Hide resolved
hidden_states = hidden_states + self.norm2(attn_hidden_states) * torch.tanh(gate_msa).unsqueeze(1)
norm_hidden_states = self.norm3(hidden_states) * (1 + scale_mlp.unsqueeze(1))
hidden_states = hidden_states + self.norm2(attn_hidden_states, torch.tanh(gate_msa).unsqueeze(1))
norm_hidden_states = self.norm3(hidden_states, (1 + scale_mlp.unsqueeze(1).to(torch.float32)))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it will automtically upcast, no? here in hidden_states = hidden_states * scale if the hidden_states is in float32 an scale is in bfloat16, the operation will be in float32 and we do not need to explicitly upcast scale, no?

this is non-merge blocking! just curious if I misunderstood something

    def forward(self, hidden_states, scale=None):
        ....
        if scale is not None:
            hidden_states = hidden_states * scale

        hidden_states = hidden_states.to(hidden_states_dtype)

        return hidden_states

@DN6
Copy link
Collaborator Author

DN6 commented Dec 17, 2024

Failing LoRA tests are unrelated. Seem to be timing out on SDXL tests. Merging.

@DN6 DN6 merged commit 128b96f into main Dec 17, 2024
14 of 15 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
close-to-merge roadmap Add to current release roadmap
Projects
Development

Successfully merging this pull request may close these issues.

6 participants