-
Notifications
You must be signed in to change notification settings - Fork 5.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* copy transformer * copy vae * copy pipeline * make fix-copies * refactor; make original code work with diffusers; test latents for comparison generated with this commit * move rope into pipeline; remove flash attention; refactor * begin conversion script * make style * refactor attention * refactor * refactor final layer * their mlp -> our feedforward * make style * add docs * refactor layer names * refactor modulation * cleanup * refactor norms * refactor activations * refactor single blocks attention * refactor attention processor * make style * cleanup a bit * refactor double transformer block attention * update mochi attn proc * use diffusers attention implementation in all modules; checkpoint for all values matching original * remove helper functions in vae * refactor upsample * refactor causal conv * refactor resnet * refactor * refactor * refactor * grad checkpointing * autoencoder test * fix scaling factor * refactor clip * refactor llama text encoding * add coauthor Co-Authored-By: "Gregory D. Hunkins" <[email protected]> * refactor rope; diff: 0.14990234375; reason and fix: create rope grid on cpu and move to device Note: The following line diverges from original behaviour. We create the grid on the device, whereas original implementation creates it on CPU and then moves it to device. This results in numerical differences in layerwise debugging outputs, but visually it is the same. * use diffusers timesteps embedding; diff: 0.10205078125 * rename * convert * update * add tests for transformer * add pipeline tests; text encoder 2 is not optional * fix attention implementation for torch * add example * update docs * update docs * apply suggestions from review * refactor vae * update * Apply suggestions from code review Co-authored-by: hlky <[email protected]> * Update src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py Co-authored-by: hlky <[email protected]> * Update src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py Co-authored-by: hlky <[email protected]> * make fix-copies * update --------- Co-authored-by: "Gregory D. Hunkins" <[email protected]> Co-authored-by: hlky <[email protected]>
- Loading branch information
1 parent
8957324
commit aace1f4
Showing
24 changed files
with
3,676 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
<!-- Copyright 2024 The HuggingFace Team. All rights reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | ||
the License. You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | ||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
specific language governing permissions and limitations under the License. --> | ||
|
||
# AutoencoderKLHunyuanVideo | ||
|
||
The 3D variational autoencoder (VAE) model with KL loss used in [HunyuanVideo](https://github.com/Tencent/HunyuanVideo/), which was introduced in [HunyuanVideo: A Systematic Framework For Large Video Generative Models](https://huggingface.co/papers/2412.03603) by Tencent. | ||
|
||
The model can be loaded with the following code snippet. | ||
|
||
```python | ||
from diffusers import AutoencoderKLHunyuanVideo | ||
|
||
vae = AutoencoderKLHunyuanVideo.from_pretrained("tencent/HunyuanVideo", torch_dtype=torch.float16) | ||
``` | ||
|
||
## AutoencoderKLHunyuanVideo | ||
|
||
[[autodoc]] AutoencoderKLHunyuanVideo | ||
- decode | ||
- all | ||
|
||
## DecoderOutput | ||
|
||
[[autodoc]] models.autoencoders.vae.DecoderOutput |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
<!-- Copyright 2024 The HuggingFace Team. All rights reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | ||
the License. You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | ||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
specific language governing permissions and limitations under the License. --> | ||
|
||
# HunyuanVideoTransformer3DModel | ||
|
||
A Diffusion Transformer model for 3D video-like data was introduced in [HunyuanVideo: A Systematic Framework For Large Video Generative Models](https://huggingface.co/papers/2412.03603) by Tencent. | ||
|
||
The model can be loaded with the following code snippet. | ||
|
||
```python | ||
from diffusers import HunyuanVideoTransformer3DModel | ||
|
||
transformer = HunyuanVideoTransformer3DModel.from_pretrained("tencent/HunyuanVideo", torch_dtype=torch.bfloat16) | ||
``` | ||
|
||
## HunyuanVideoTransformer3DModel | ||
|
||
[[autodoc]] HunyuanVideoTransformer3DModel | ||
|
||
## Transformer2DModelOutput | ||
|
||
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
<!-- Copyright 2024 The HuggingFace Team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. --> | ||
|
||
# HunyuanVideo | ||
|
||
[HunyuanVideo](https://www.arxiv.org/abs/2412.03603) by Tencent. | ||
|
||
*Recent advancements in video generation have significantly impacted daily life for both individuals and industries. However, the leading video generation models remain closed-source, resulting in a notable performance gap between industry capabilities and those available to the public. In this report, we introduce HunyuanVideo, an innovative open-source video foundation model that demonstrates performance in video generation comparable to, or even surpassing, that of leading closed-source models. HunyuanVideo encompasses a comprehensive framework that integrates several key elements, including data curation, advanced architectural design, progressive model scaling and training, and an efficient infrastructure tailored for large-scale model training and inference. As a result, we successfully trained a video generative model with over 13 billion parameters, making it the largest among all open-source models. We conducted extensive experiments and implemented a series of targeted designs to ensure high visual quality, motion dynamics, text-video alignment, and advanced filming techniques. According to evaluations by professionals, HunyuanVideo outperforms previous state-of-the-art models, including Runway Gen-3, Luma 1.6, and three top-performing Chinese video generative models. By releasing the code for the foundation model and its applications, we aim to bridge the gap between closed-source and open-source communities. This initiative will empower individuals within the community to experiment with their ideas, fostering a more dynamic and vibrant video generation ecosystem. The code is publicly available at [this https URL](https://github.com/Tencent/HunyuanVideo).* | ||
|
||
<Tip> | ||
|
||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines. | ||
|
||
</Tip> | ||
|
||
Recommendations for inference: | ||
- Both text encoders should be in `torch.float16`. | ||
- Transformer should be in `torch.bfloat16`. | ||
- VAE should be in `torch.float16`. | ||
- `num_frames` should be of the form `4 * k + 1`, for example `49` or `129`. | ||
- For smaller resolution images, try lower values of `shift` (between `2.0` to `5.0`) in the [Scheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/flow_match_euler_discrete#diffusers.FlowMatchEulerDiscreteScheduler.shift). For larger resolution images, try higher values (between `7.0` and `12.0`). The default value is `7.0` for HunyuanVideo. | ||
- For more information about supported resolutions and other details, please refer to the original repository [here](https://github.com/Tencent/HunyuanVideo/). | ||
|
||
## HunyuanVideoPipeline | ||
|
||
[[autodoc]] HunyuanVideoPipeline | ||
- all | ||
- __call__ | ||
|
||
## HunyuanVideoPipelineOutput | ||
|
||
[[autodoc]] pipelines.hunyuan_video.pipeline_output.HunyuanVideoPipelineOutput |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,257 @@ | ||
import argparse | ||
from typing import Any, Dict | ||
|
||
import torch | ||
from accelerate import init_empty_weights | ||
from transformers import AutoModel, AutoTokenizer, CLIPTextModel, CLIPTokenizer | ||
|
||
from diffusers import ( | ||
AutoencoderKLHunyuanVideo, | ||
FlowMatchEulerDiscreteScheduler, | ||
HunyuanVideoPipeline, | ||
HunyuanVideoTransformer3DModel, | ||
) | ||
|
||
|
||
def remap_norm_scale_shift_(key, state_dict): | ||
weight = state_dict.pop(key) | ||
shift, scale = weight.chunk(2, dim=0) | ||
new_weight = torch.cat([scale, shift], dim=0) | ||
state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight | ||
|
||
|
||
def remap_txt_in_(key, state_dict): | ||
def rename_key(key): | ||
new_key = key.replace("individual_token_refiner.blocks", "token_refiner.refiner_blocks") | ||
new_key = new_key.replace("adaLN_modulation.1", "norm_out.linear") | ||
new_key = new_key.replace("txt_in", "context_embedder") | ||
new_key = new_key.replace("t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1") | ||
new_key = new_key.replace("t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2") | ||
new_key = new_key.replace("c_embedder", "time_text_embed.text_embedder") | ||
new_key = new_key.replace("mlp", "ff") | ||
return new_key | ||
|
||
if "self_attn_qkv" in key: | ||
weight = state_dict.pop(key) | ||
to_q, to_k, to_v = weight.chunk(3, dim=0) | ||
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_q"))] = to_q | ||
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_k"))] = to_k | ||
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_v"))] = to_v | ||
else: | ||
state_dict[rename_key(key)] = state_dict.pop(key) | ||
|
||
|
||
def remap_img_attn_qkv_(key, state_dict): | ||
weight = state_dict.pop(key) | ||
to_q, to_k, to_v = weight.chunk(3, dim=0) | ||
state_dict[key.replace("img_attn_qkv", "attn.to_q")] = to_q | ||
state_dict[key.replace("img_attn_qkv", "attn.to_k")] = to_k | ||
state_dict[key.replace("img_attn_qkv", "attn.to_v")] = to_v | ||
|
||
|
||
def remap_txt_attn_qkv_(key, state_dict): | ||
weight = state_dict.pop(key) | ||
to_q, to_k, to_v = weight.chunk(3, dim=0) | ||
state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = to_q | ||
state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = to_k | ||
state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = to_v | ||
|
||
|
||
def remap_single_transformer_blocks_(key, state_dict): | ||
hidden_size = 3072 | ||
|
||
if "linear1.weight" in key: | ||
linear1_weight = state_dict.pop(key) | ||
split_size = (hidden_size, hidden_size, hidden_size, linear1_weight.size(0) - 3 * hidden_size) | ||
q, k, v, mlp = torch.split(linear1_weight, split_size, dim=0) | ||
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.weight") | ||
state_dict[f"{new_key}.attn.to_q.weight"] = q | ||
state_dict[f"{new_key}.attn.to_k.weight"] = k | ||
state_dict[f"{new_key}.attn.to_v.weight"] = v | ||
state_dict[f"{new_key}.proj_mlp.weight"] = mlp | ||
|
||
elif "linear1.bias" in key: | ||
linear1_bias = state_dict.pop(key) | ||
split_size = (hidden_size, hidden_size, hidden_size, linear1_bias.size(0) - 3 * hidden_size) | ||
q_bias, k_bias, v_bias, mlp_bias = torch.split(linear1_bias, split_size, dim=0) | ||
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.bias") | ||
state_dict[f"{new_key}.attn.to_q.bias"] = q_bias | ||
state_dict[f"{new_key}.attn.to_k.bias"] = k_bias | ||
state_dict[f"{new_key}.attn.to_v.bias"] = v_bias | ||
state_dict[f"{new_key}.proj_mlp.bias"] = mlp_bias | ||
|
||
else: | ||
new_key = key.replace("single_blocks", "single_transformer_blocks") | ||
new_key = new_key.replace("linear2", "proj_out") | ||
new_key = new_key.replace("q_norm", "attn.norm_q") | ||
new_key = new_key.replace("k_norm", "attn.norm_k") | ||
state_dict[new_key] = state_dict.pop(key) | ||
|
||
|
||
TRANSFORMER_KEYS_RENAME_DICT = { | ||
"img_in": "x_embedder", | ||
"time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1", | ||
"time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2", | ||
"guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1", | ||
"guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2", | ||
"vector_in.in_layer": "time_text_embed.text_embedder.linear_1", | ||
"vector_in.out_layer": "time_text_embed.text_embedder.linear_2", | ||
"double_blocks": "transformer_blocks", | ||
"img_attn_q_norm": "attn.norm_q", | ||
"img_attn_k_norm": "attn.norm_k", | ||
"img_attn_proj": "attn.to_out.0", | ||
"txt_attn_q_norm": "attn.norm_added_q", | ||
"txt_attn_k_norm": "attn.norm_added_k", | ||
"txt_attn_proj": "attn.to_add_out", | ||
"img_mod.linear": "norm1.linear", | ||
"img_norm1": "norm1.norm", | ||
"img_norm2": "norm2", | ||
"img_mlp": "ff", | ||
"txt_mod.linear": "norm1_context.linear", | ||
"txt_norm1": "norm1.norm", | ||
"txt_norm2": "norm2_context", | ||
"txt_mlp": "ff_context", | ||
"self_attn_proj": "attn.to_out.0", | ||
"modulation.linear": "norm.linear", | ||
"pre_norm": "norm.norm", | ||
"final_layer.norm_final": "norm_out.norm", | ||
"final_layer.linear": "proj_out", | ||
"fc1": "net.0.proj", | ||
"fc2": "net.2", | ||
"input_embedder": "proj_in", | ||
} | ||
|
||
TRANSFORMER_SPECIAL_KEYS_REMAP = { | ||
"txt_in": remap_txt_in_, | ||
"img_attn_qkv": remap_img_attn_qkv_, | ||
"txt_attn_qkv": remap_txt_attn_qkv_, | ||
"single_blocks": remap_single_transformer_blocks_, | ||
"final_layer.adaLN_modulation.1": remap_norm_scale_shift_, | ||
} | ||
|
||
VAE_KEYS_RENAME_DICT = {} | ||
|
||
VAE_SPECIAL_KEYS_REMAP = {} | ||
|
||
|
||
def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]: | ||
state_dict[new_key] = state_dict.pop(old_key) | ||
|
||
|
||
def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]: | ||
state_dict = saved_dict | ||
if "model" in saved_dict.keys(): | ||
state_dict = state_dict["model"] | ||
if "module" in saved_dict.keys(): | ||
state_dict = state_dict["module"] | ||
if "state_dict" in saved_dict.keys(): | ||
state_dict = state_dict["state_dict"] | ||
return state_dict | ||
|
||
|
||
def convert_transformer(ckpt_path: str): | ||
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=True)) | ||
|
||
with init_empty_weights(): | ||
transformer = HunyuanVideoTransformer3DModel() | ||
|
||
for key in list(original_state_dict.keys()): | ||
new_key = key[:] | ||
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): | ||
new_key = new_key.replace(replace_key, rename_key) | ||
update_state_dict_(original_state_dict, key, new_key) | ||
|
||
for key in list(original_state_dict.keys()): | ||
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items(): | ||
if special_key not in key: | ||
continue | ||
handler_fn_inplace(key, original_state_dict) | ||
|
||
transformer.load_state_dict(original_state_dict, strict=True, assign=True) | ||
return transformer | ||
|
||
|
||
def convert_vae(ckpt_path: str): | ||
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=True)) | ||
|
||
with init_empty_weights(): | ||
vae = AutoencoderKLHunyuanVideo() | ||
|
||
for key in list(original_state_dict.keys()): | ||
new_key = key[:] | ||
for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items(): | ||
new_key = new_key.replace(replace_key, rename_key) | ||
update_state_dict_(original_state_dict, key, new_key) | ||
|
||
for key in list(original_state_dict.keys()): | ||
for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items(): | ||
if special_key not in key: | ||
continue | ||
handler_fn_inplace(key, original_state_dict) | ||
|
||
vae.load_state_dict(original_state_dict, strict=True, assign=True) | ||
return vae | ||
|
||
|
||
def get_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint" | ||
) | ||
parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original VAE checkpoint") | ||
parser.add_argument("--text_encoder_path", type=str, default=None, help="Path to original llama checkpoint") | ||
parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to original llama tokenizer") | ||
parser.add_argument("--text_encoder_2_path", type=str, default=None, help="Path to original clip checkpoint") | ||
parser.add_argument("--save_pipeline", action="store_true") | ||
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") | ||
parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.") | ||
return parser.parse_args() | ||
|
||
|
||
DTYPE_MAPPING = { | ||
"fp32": torch.float32, | ||
"fp16": torch.float16, | ||
"bf16": torch.bfloat16, | ||
} | ||
|
||
|
||
if __name__ == "__main__": | ||
args = get_args() | ||
|
||
transformer = None | ||
dtype = DTYPE_MAPPING[args.dtype] | ||
|
||
if args.save_pipeline: | ||
assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None | ||
assert args.text_encoder_path is not None | ||
assert args.tokenizer_path is not None | ||
assert args.text_encoder_2_path is not None | ||
|
||
if args.transformer_ckpt_path is not None: | ||
transformer = convert_transformer(args.transformer_ckpt_path) | ||
transformer = transformer.to(dtype=dtype) | ||
if not args.save_pipeline: | ||
transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") | ||
|
||
if args.vae_ckpt_path is not None: | ||
vae = convert_vae(args.vae_ckpt_path) | ||
if not args.save_pipeline: | ||
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") | ||
|
||
if args.save_pipeline: | ||
text_encoder = AutoModel.from_pretrained(args.text_encoder_path, torch_dtype=torch.float16) | ||
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, padding_side="right") | ||
text_encoder_2 = CLIPTextModel.from_pretrained(args.text_encoder_2_path, torch_dtype=torch.float16) | ||
tokenizer_2 = CLIPTokenizer.from_pretrained(args.text_encoder_2_path) | ||
scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0) | ||
|
||
pipe = HunyuanVideoPipeline( | ||
transformer=transformer, | ||
vae=vae, | ||
text_encoder=text_encoder, | ||
tokenizer=tokenizer, | ||
text_encoder_2=text_encoder_2, | ||
tokenizer_2=tokenizer_2, | ||
scheduler=scheduler, | ||
) | ||
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") |
Oops, something went wrong.