Skip to content

Commit

Permalink
[core] Hunyuan Video (#10136)
Browse files Browse the repository at this point in the history
* copy transformer

* copy vae

* copy pipeline

* make fix-copies

* refactor; make original code work with diffusers; test latents for comparison generated with this commit

* move rope into pipeline; remove flash attention; refactor

* begin conversion script

* make style

* refactor attention

* refactor

* refactor final layer

* their mlp -> our feedforward

* make style

* add docs

* refactor layer names

* refactor modulation

* cleanup

* refactor norms

* refactor activations

* refactor single blocks attention

* refactor attention processor

* make style

* cleanup a bit

* refactor double transformer block attention

* update mochi attn proc

* use diffusers attention implementation in all modules; checkpoint for all values matching original

* remove helper functions in vae

* refactor upsample

* refactor causal conv

* refactor resnet

* refactor

* refactor

* refactor

* grad checkpointing

* autoencoder test

* fix scaling factor

* refactor clip

* refactor llama text encoding

* add coauthor

Co-Authored-By: "Gregory D. Hunkins" <[email protected]>

* refactor rope; diff: 0.14990234375; reason and fix: create rope grid on cpu and move to device

Note: The following line diverges from original behaviour. We create the grid on the device, whereas
original implementation creates it on CPU and then moves it to device. This results in numerical
differences in layerwise debugging outputs, but visually it is the same.

* use diffusers timesteps embedding; diff: 0.10205078125

* rename

* convert

* update

* add tests for transformer

* add pipeline tests; text encoder 2 is not optional

* fix attention implementation for torch

* add example

* update docs

* update docs

* apply suggestions from review

* refactor vae

* update

* Apply suggestions from code review

Co-authored-by: hlky <[email protected]>

* Update src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py

Co-authored-by: hlky <[email protected]>

* Update src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py

Co-authored-by: hlky <[email protected]>

* make fix-copies

* update

---------

Co-authored-by: "Gregory D. Hunkins" <[email protected]>
Co-authored-by: hlky <[email protected]>
  • Loading branch information
3 people authored Dec 16, 2024
1 parent 8957324 commit aace1f4
Show file tree
Hide file tree
Showing 24 changed files with 3,676 additions and 3 deletions.
6 changes: 6 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,8 @@
title: FluxTransformer2DModel
- local: api/models/hunyuan_transformer2d
title: HunyuanDiT2DModel
- local: api/models/hunyuan_video_transformer_3d
title: HunyuanVideoTransformer3DModel
- local: api/models/latte_transformer3d
title: LatteTransformer3DModel
- local: api/models/lumina_nextdit2d
Expand Down Expand Up @@ -316,6 +318,8 @@
title: AutoencoderKLAllegro
- local: api/models/autoencoderkl_cogvideox
title: AutoencoderKLCogVideoX
- local: api/models/autoencoder_kl_hunyuan_video
title: AutoencoderKLHunyuanVideo
- local: api/models/autoencoderkl_ltx_video
title: AutoencoderKLLTXVideo
- local: api/models/autoencoderkl_mochi
Expand Down Expand Up @@ -394,6 +398,8 @@
title: Flux
- local: api/pipelines/hunyuandit
title: Hunyuan-DiT
- local: api/pipelines/hunyuan_video
title: HunyuanVideo
- local: api/pipelines/i2vgenxl
title: I2VGen-XL
- local: api/pipelines/pix2pix
Expand Down
32 changes: 32 additions & 0 deletions docs/source/en/api/models/autoencoder_kl_hunyuan_video.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->

# AutoencoderKLHunyuanVideo

The 3D variational autoencoder (VAE) model with KL loss used in [HunyuanVideo](https://github.com/Tencent/HunyuanVideo/), which was introduced in [HunyuanVideo: A Systematic Framework For Large Video Generative Models](https://huggingface.co/papers/2412.03603) by Tencent.

The model can be loaded with the following code snippet.

```python
from diffusers import AutoencoderKLHunyuanVideo

vae = AutoencoderKLHunyuanVideo.from_pretrained("tencent/HunyuanVideo", torch_dtype=torch.float16)
```

## AutoencoderKLHunyuanVideo

[[autodoc]] AutoencoderKLHunyuanVideo
- decode
- all

## DecoderOutput

[[autodoc]] models.autoencoders.vae.DecoderOutput
30 changes: 30 additions & 0 deletions docs/source/en/api/models/hunyuan_video_transformer_3d.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->

# HunyuanVideoTransformer3DModel

A Diffusion Transformer model for 3D video-like data was introduced in [HunyuanVideo: A Systematic Framework For Large Video Generative Models](https://huggingface.co/papers/2412.03603) by Tencent.

The model can be loaded with the following code snippet.

```python
from diffusers import HunyuanVideoTransformer3DModel

transformer = HunyuanVideoTransformer3DModel.from_pretrained("tencent/HunyuanVideo", torch_dtype=torch.bfloat16)
```

## HunyuanVideoTransformer3DModel

[[autodoc]] HunyuanVideoTransformer3DModel

## Transformer2DModelOutput

[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
43 changes: 43 additions & 0 deletions docs/source/en/api/pipelines/hunyuan_video.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License. -->

# HunyuanVideo

[HunyuanVideo](https://www.arxiv.org/abs/2412.03603) by Tencent.

*Recent advancements in video generation have significantly impacted daily life for both individuals and industries. However, the leading video generation models remain closed-source, resulting in a notable performance gap between industry capabilities and those available to the public. In this report, we introduce HunyuanVideo, an innovative open-source video foundation model that demonstrates performance in video generation comparable to, or even surpassing, that of leading closed-source models. HunyuanVideo encompasses a comprehensive framework that integrates several key elements, including data curation, advanced architectural design, progressive model scaling and training, and an efficient infrastructure tailored for large-scale model training and inference. As a result, we successfully trained a video generative model with over 13 billion parameters, making it the largest among all open-source models. We conducted extensive experiments and implemented a series of targeted designs to ensure high visual quality, motion dynamics, text-video alignment, and advanced filming techniques. According to evaluations by professionals, HunyuanVideo outperforms previous state-of-the-art models, including Runway Gen-3, Luma 1.6, and three top-performing Chinese video generative models. By releasing the code for the foundation model and its applications, we aim to bridge the gap between closed-source and open-source communities. This initiative will empower individuals within the community to experiment with their ideas, fostering a more dynamic and vibrant video generation ecosystem. The code is publicly available at [this https URL](https://github.com/Tencent/HunyuanVideo).*

<Tip>

Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.

</Tip>

Recommendations for inference:
- Both text encoders should be in `torch.float16`.
- Transformer should be in `torch.bfloat16`.
- VAE should be in `torch.float16`.
- `num_frames` should be of the form `4 * k + 1`, for example `49` or `129`.
- For smaller resolution images, try lower values of `shift` (between `2.0` to `5.0`) in the [Scheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/flow_match_euler_discrete#diffusers.FlowMatchEulerDiscreteScheduler.shift). For larger resolution images, try higher values (between `7.0` and `12.0`). The default value is `7.0` for HunyuanVideo.
- For more information about supported resolutions and other details, please refer to the original repository [here](https://github.com/Tencent/HunyuanVideo/).

## HunyuanVideoPipeline

[[autodoc]] HunyuanVideoPipeline
- all
- __call__

## HunyuanVideoPipelineOutput

[[autodoc]] pipelines.hunyuan_video.pipeline_output.HunyuanVideoPipelineOutput
257 changes: 257 additions & 0 deletions scripts/convert_hunyuan_video_to_diffusers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
import argparse
from typing import Any, Dict

import torch
from accelerate import init_empty_weights
from transformers import AutoModel, AutoTokenizer, CLIPTextModel, CLIPTokenizer

from diffusers import (
AutoencoderKLHunyuanVideo,
FlowMatchEulerDiscreteScheduler,
HunyuanVideoPipeline,
HunyuanVideoTransformer3DModel,
)


def remap_norm_scale_shift_(key, state_dict):
weight = state_dict.pop(key)
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight


def remap_txt_in_(key, state_dict):
def rename_key(key):
new_key = key.replace("individual_token_refiner.blocks", "token_refiner.refiner_blocks")
new_key = new_key.replace("adaLN_modulation.1", "norm_out.linear")
new_key = new_key.replace("txt_in", "context_embedder")
new_key = new_key.replace("t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1")
new_key = new_key.replace("t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2")
new_key = new_key.replace("c_embedder", "time_text_embed.text_embedder")
new_key = new_key.replace("mlp", "ff")
return new_key

if "self_attn_qkv" in key:
weight = state_dict.pop(key)
to_q, to_k, to_v = weight.chunk(3, dim=0)
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_q"))] = to_q
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_k"))] = to_k
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_v"))] = to_v
else:
state_dict[rename_key(key)] = state_dict.pop(key)


def remap_img_attn_qkv_(key, state_dict):
weight = state_dict.pop(key)
to_q, to_k, to_v = weight.chunk(3, dim=0)
state_dict[key.replace("img_attn_qkv", "attn.to_q")] = to_q
state_dict[key.replace("img_attn_qkv", "attn.to_k")] = to_k
state_dict[key.replace("img_attn_qkv", "attn.to_v")] = to_v


def remap_txt_attn_qkv_(key, state_dict):
weight = state_dict.pop(key)
to_q, to_k, to_v = weight.chunk(3, dim=0)
state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = to_q
state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = to_k
state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = to_v


def remap_single_transformer_blocks_(key, state_dict):
hidden_size = 3072

if "linear1.weight" in key:
linear1_weight = state_dict.pop(key)
split_size = (hidden_size, hidden_size, hidden_size, linear1_weight.size(0) - 3 * hidden_size)
q, k, v, mlp = torch.split(linear1_weight, split_size, dim=0)
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.weight")
state_dict[f"{new_key}.attn.to_q.weight"] = q
state_dict[f"{new_key}.attn.to_k.weight"] = k
state_dict[f"{new_key}.attn.to_v.weight"] = v
state_dict[f"{new_key}.proj_mlp.weight"] = mlp

elif "linear1.bias" in key:
linear1_bias = state_dict.pop(key)
split_size = (hidden_size, hidden_size, hidden_size, linear1_bias.size(0) - 3 * hidden_size)
q_bias, k_bias, v_bias, mlp_bias = torch.split(linear1_bias, split_size, dim=0)
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.bias")
state_dict[f"{new_key}.attn.to_q.bias"] = q_bias
state_dict[f"{new_key}.attn.to_k.bias"] = k_bias
state_dict[f"{new_key}.attn.to_v.bias"] = v_bias
state_dict[f"{new_key}.proj_mlp.bias"] = mlp_bias

else:
new_key = key.replace("single_blocks", "single_transformer_blocks")
new_key = new_key.replace("linear2", "proj_out")
new_key = new_key.replace("q_norm", "attn.norm_q")
new_key = new_key.replace("k_norm", "attn.norm_k")
state_dict[new_key] = state_dict.pop(key)


TRANSFORMER_KEYS_RENAME_DICT = {
"img_in": "x_embedder",
"time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1",
"time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2",
"guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1",
"guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2",
"vector_in.in_layer": "time_text_embed.text_embedder.linear_1",
"vector_in.out_layer": "time_text_embed.text_embedder.linear_2",
"double_blocks": "transformer_blocks",
"img_attn_q_norm": "attn.norm_q",
"img_attn_k_norm": "attn.norm_k",
"img_attn_proj": "attn.to_out.0",
"txt_attn_q_norm": "attn.norm_added_q",
"txt_attn_k_norm": "attn.norm_added_k",
"txt_attn_proj": "attn.to_add_out",
"img_mod.linear": "norm1.linear",
"img_norm1": "norm1.norm",
"img_norm2": "norm2",
"img_mlp": "ff",
"txt_mod.linear": "norm1_context.linear",
"txt_norm1": "norm1.norm",
"txt_norm2": "norm2_context",
"txt_mlp": "ff_context",
"self_attn_proj": "attn.to_out.0",
"modulation.linear": "norm.linear",
"pre_norm": "norm.norm",
"final_layer.norm_final": "norm_out.norm",
"final_layer.linear": "proj_out",
"fc1": "net.0.proj",
"fc2": "net.2",
"input_embedder": "proj_in",
}

TRANSFORMER_SPECIAL_KEYS_REMAP = {
"txt_in": remap_txt_in_,
"img_attn_qkv": remap_img_attn_qkv_,
"txt_attn_qkv": remap_txt_attn_qkv_,
"single_blocks": remap_single_transformer_blocks_,
"final_layer.adaLN_modulation.1": remap_norm_scale_shift_,
}

VAE_KEYS_RENAME_DICT = {}

VAE_SPECIAL_KEYS_REMAP = {}


def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
state_dict[new_key] = state_dict.pop(old_key)


def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
state_dict = saved_dict
if "model" in saved_dict.keys():
state_dict = state_dict["model"]
if "module" in saved_dict.keys():
state_dict = state_dict["module"]
if "state_dict" in saved_dict.keys():
state_dict = state_dict["state_dict"]
return state_dict


def convert_transformer(ckpt_path: str):
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=True))

with init_empty_weights():
transformer = HunyuanVideoTransformer3DModel()

for key in list(original_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_(original_state_dict, key, new_key)

for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)

transformer.load_state_dict(original_state_dict, strict=True, assign=True)
return transformer


def convert_vae(ckpt_path: str):
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=True))

with init_empty_weights():
vae = AutoencoderKLHunyuanVideo()

for key in list(original_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_(original_state_dict, key, new_key)

for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)

vae.load_state_dict(original_state_dict, strict=True, assign=True)
return vae


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
)
parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original VAE checkpoint")
parser.add_argument("--text_encoder_path", type=str, default=None, help="Path to original llama checkpoint")
parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to original llama tokenizer")
parser.add_argument("--text_encoder_2_path", type=str, default=None, help="Path to original clip checkpoint")
parser.add_argument("--save_pipeline", action="store_true")
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.")
return parser.parse_args()


DTYPE_MAPPING = {
"fp32": torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16,
}


if __name__ == "__main__":
args = get_args()

transformer = None
dtype = DTYPE_MAPPING[args.dtype]

if args.save_pipeline:
assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None
assert args.text_encoder_path is not None
assert args.tokenizer_path is not None
assert args.text_encoder_2_path is not None

if args.transformer_ckpt_path is not None:
transformer = convert_transformer(args.transformer_ckpt_path)
transformer = transformer.to(dtype=dtype)
if not args.save_pipeline:
transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")

if args.vae_ckpt_path is not None:
vae = convert_vae(args.vae_ckpt_path)
if not args.save_pipeline:
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")

if args.save_pipeline:
text_encoder = AutoModel.from_pretrained(args.text_encoder_path, torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, padding_side="right")
text_encoder_2 = CLIPTextModel.from_pretrained(args.text_encoder_2_path, torch_dtype=torch.float16)
tokenizer_2 = CLIPTokenizer.from_pretrained(args.text_encoder_2_path)
scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)

pipe = HunyuanVideoPipeline(
transformer=transformer,
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
text_encoder_2=text_encoder_2,
tokenizer_2=tokenizer_2,
scheduler=scheduler,
)
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
Loading

0 comments on commit aace1f4

Please sign in to comment.