From 76b2ea4814b22441dee11a86fa09e093c3f7b144 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Tue, 17 Dec 2024 17:57:42 +0100 Subject: [PATCH] update --- src/diffusers/loaders/single_file_model.py | 5 + src/diffusers/loaders/single_file_utils.py | 109 ++++++++++++++++++ .../models/transformers/transformer_mochi.py | 3 +- 3 files changed, 116 insertions(+), 1 deletion(-) diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index 9641435fa5a6..d102282025c7 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -32,6 +32,7 @@ convert_ldm_vae_checkpoint, convert_ltx_transformer_checkpoint_to_diffusers, convert_ltx_vae_checkpoint_to_diffusers, + convert_mochi_transformer_checkpoint_to_diffusers, convert_sd3_transformer_checkpoint_to_diffusers, convert_stable_cascade_unet_single_file_to_diffusers, create_controlnet_diffusers_config_from_ldm, @@ -96,6 +97,10 @@ "default_subfolder": "vae", }, "AutoencoderDC": {"checkpoint_mapping_fn": convert_autoencoder_dc_checkpoint_to_diffusers}, + "MochiTransformer3DModel": { + "checkpoint_mapping_fn": convert_mochi_transformer_checkpoint_to_diffusers, + "default_subfolder": "transformer", + }, } diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 4e288737fe88..d0a90726717b 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -106,6 +106,7 @@ ], "autoencoder-dc": "decoder.stages.1.op_list.0.main.conv.conv.bias", "autoencoder-dc-sana": "encoder.project_in.conv.bias", + "mochi-1-preview": ["model.diffusion_model.blocks.0.attn.qkv_x.weight", "blocks.0.attn.qkv_x.weight"], } DIFFUSERS_DEFAULT_PIPELINE_PATHS = { @@ -157,6 +158,7 @@ "autoencoder-dc-f64c128": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers"}, "autoencoder-dc-f32c32": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers"}, "autoencoder-dc-f32c32-sana": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers"}, + "mochi-1-preview": {"pretrained_model_name_or_path": "genmo/mochi-1-preview"}, } # Use to configure model sample size when original config is provided @@ -610,6 +612,9 @@ def infer_diffusers_model_type(checkpoint): else: model_type = "autoencoder-dc-f128c512" + elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["mochi-1-preview"]): + model_type = "mochi-1-preview" + else: model_type = "v1" @@ -1750,6 +1755,12 @@ def swap_scale_shift(weight, dim): return new_weight +def swap_proj_gate(weight): + proj, gate = weight.chunk(2, dim=0) + new_weight = torch.cat([gate, proj], dim=0) + return new_weight + + def get_attn2_layers(state_dict): attn2_layers = [] for key in state_dict.keys(): @@ -2406,3 +2417,101 @@ def remap_proj_conv_(key: str, state_dict): handler_fn_inplace(key, converted_state_dict) return converted_state_dict + + +def convert_mochi_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): + new_state_dict = {} + + # Comfy checkpoints add this prefix + keys = list(checkpoint.keys()) + for k in keys: + if "model.diffusion_model." in k: + checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k) + + # Convert patch_embed + new_state_dict["patch_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight") + new_state_dict["patch_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias") + + # Convert time_embed + new_state_dict["time_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop("t_embedder.mlp.0.weight") + new_state_dict["time_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias") + new_state_dict["time_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop("t_embedder.mlp.2.weight") + new_state_dict["time_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias") + new_state_dict["time_embed.pooler.to_kv.weight"] = checkpoint.pop("t5_y_embedder.to_kv.weight") + new_state_dict["time_embed.pooler.to_kv.bias"] = checkpoint.pop("t5_y_embedder.to_kv.bias") + new_state_dict["time_embed.pooler.to_q.weight"] = checkpoint.pop("t5_y_embedder.to_q.weight") + new_state_dict["time_embed.pooler.to_q.bias"] = checkpoint.pop("t5_y_embedder.to_q.bias") + new_state_dict["time_embed.pooler.to_out.weight"] = checkpoint.pop("t5_y_embedder.to_out.weight") + new_state_dict["time_embed.pooler.to_out.bias"] = checkpoint.pop("t5_y_embedder.to_out.bias") + new_state_dict["time_embed.caption_proj.weight"] = checkpoint.pop("t5_yproj.weight") + new_state_dict["time_embed.caption_proj.bias"] = checkpoint.pop("t5_yproj.bias") + + # Convert transformer blocks + num_layers = 48 + for i in range(num_layers): + block_prefix = f"transformer_blocks.{i}." + old_prefix = f"blocks.{i}." + + # norm1 + new_state_dict[block_prefix + "norm1.linear.weight"] = checkpoint.pop(old_prefix + "mod_x.weight") + new_state_dict[block_prefix + "norm1.linear.bias"] = checkpoint.pop(old_prefix + "mod_x.bias") + if i < num_layers - 1: + new_state_dict[block_prefix + "norm1_context.linear.weight"] = checkpoint.pop(old_prefix + "mod_y.weight") + new_state_dict[block_prefix + "norm1_context.linear.bias"] = checkpoint.pop(old_prefix + "mod_y.bias") + else: + new_state_dict[block_prefix + "norm1_context.linear_1.weight"] = checkpoint.pop( + old_prefix + "mod_y.weight" + ) + new_state_dict[block_prefix + "norm1_context.linear_1.bias"] = checkpoint.pop(old_prefix + "mod_y.bias") + + # Visual attention + qkv_weight = checkpoint.pop(old_prefix + "attn.qkv_x.weight") + q, k, v = qkv_weight.chunk(3, dim=0) + + new_state_dict[block_prefix + "attn1.to_q.weight"] = q + new_state_dict[block_prefix + "attn1.to_k.weight"] = k + new_state_dict[block_prefix + "attn1.to_v.weight"] = v + new_state_dict[block_prefix + "attn1.norm_q.weight"] = checkpoint.pop(old_prefix + "attn.q_norm_x.weight") + new_state_dict[block_prefix + "attn1.norm_k.weight"] = checkpoint.pop(old_prefix + "attn.k_norm_x.weight") + new_state_dict[block_prefix + "attn1.to_out.0.weight"] = checkpoint.pop(old_prefix + "attn.proj_x.weight") + new_state_dict[block_prefix + "attn1.to_out.0.bias"] = checkpoint.pop(old_prefix + "attn.proj_x.bias") + + # Context attention + qkv_weight = checkpoint.pop(old_prefix + "attn.qkv_y.weight") + q, k, v = qkv_weight.chunk(3, dim=0) + + new_state_dict[block_prefix + "attn1.add_q_proj.weight"] = q + new_state_dict[block_prefix + "attn1.add_k_proj.weight"] = k + new_state_dict[block_prefix + "attn1.add_v_proj.weight"] = v + new_state_dict[block_prefix + "attn1.norm_added_q.weight"] = checkpoint.pop( + old_prefix + "attn.q_norm_y.weight" + ) + new_state_dict[block_prefix + "attn1.norm_added_k.weight"] = checkpoint.pop( + old_prefix + "attn.k_norm_y.weight" + ) + if i < num_layers - 1: + new_state_dict[block_prefix + "attn1.to_add_out.weight"] = checkpoint.pop( + old_prefix + "attn.proj_y.weight" + ) + new_state_dict[block_prefix + "attn1.to_add_out.bias"] = checkpoint.pop(old_prefix + "attn.proj_y.bias") + + # MLP + new_state_dict[block_prefix + "ff.net.0.proj.weight"] = swap_proj_gate( + checkpoint.pop(old_prefix + "mlp_x.w1.weight") + ) + new_state_dict[block_prefix + "ff.net.2.weight"] = checkpoint.pop(old_prefix + "mlp_x.w2.weight") + if i < num_layers - 1: + new_state_dict[block_prefix + "ff_context.net.0.proj.weight"] = swap_proj_gate( + checkpoint.pop(old_prefix + "mlp_y.w1.weight") + ) + new_state_dict[block_prefix + "ff_context.net.2.weight"] = checkpoint.pop(old_prefix + "mlp_y.w2.weight") + + # Output layers + new_state_dict["norm_out.linear.weight"] = swap_scale_shift(checkpoint.pop("final_layer.mod.weight"), dim=0) + new_state_dict["norm_out.linear.bias"] = swap_scale_shift(checkpoint.pop("final_layer.mod.bias"), dim=0) + new_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight") + new_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias") + + new_state_dict["pos_frequencies"] = checkpoint.pop("pos_frequencies") + + return new_state_dict diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index fe72dc56883e..41e5289f2d57 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -20,6 +20,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin +from ...loaders.single_file_model import FromOriginalModelMixin from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph from ..attention import FeedForward @@ -304,7 +305,7 @@ def forward( @maybe_allow_in_graph -class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): +class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): r""" A Transformer model for video-like data introduced in [Mochi](https://huggingface.co/genmo/mochi-1-preview).