From aace1f412bc41f521b699a3228f4ec3339160c98 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 16 Dec 2024 13:56:18 +0530 Subject: [PATCH] [core] Hunyuan Video (#10136) * copy transformer * copy vae * copy pipeline * make fix-copies * refactor; make original code work with diffusers; test latents for comparison generated with this commit * move rope into pipeline; remove flash attention; refactor * begin conversion script * make style * refactor attention * refactor * refactor final layer * their mlp -> our feedforward * make style * add docs * refactor layer names * refactor modulation * cleanup * refactor norms * refactor activations * refactor single blocks attention * refactor attention processor * make style * cleanup a bit * refactor double transformer block attention * update mochi attn proc * use diffusers attention implementation in all modules; checkpoint for all values matching original * remove helper functions in vae * refactor upsample * refactor causal conv * refactor resnet * refactor * refactor * refactor * grad checkpointing * autoencoder test * fix scaling factor * refactor clip * refactor llama text encoding * add coauthor Co-Authored-By: "Gregory D. Hunkins" * refactor rope; diff: 0.14990234375; reason and fix: create rope grid on cpu and move to device Note: The following line diverges from original behaviour. We create the grid on the device, whereas original implementation creates it on CPU and then moves it to device. This results in numerical differences in layerwise debugging outputs, but visually it is the same. * use diffusers timesteps embedding; diff: 0.10205078125 * rename * convert * update * add tests for transformer * add pipeline tests; text encoder 2 is not optional * fix attention implementation for torch * add example * update docs * update docs * apply suggestions from review * refactor vae * update * Apply suggestions from code review Co-authored-by: hlky * Update src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py Co-authored-by: hlky * Update src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py Co-authored-by: hlky * make fix-copies * update --------- Co-authored-by: "Gregory D. Hunkins" Co-authored-by: hlky --- docs/source/en/_toctree.yml | 6 + .../models/autoencoder_kl_hunyuan_video.md | 32 + .../models/hunyuan_video_transformer_3d.md | 30 + docs/source/en/api/pipelines/hunyuan_video.md | 43 + scripts/convert_hunyuan_video_to_diffusers.py | 257 ++++ src/diffusers/__init__.py | 6 + src/diffusers/models/__init__.py | 4 + src/diffusers/models/activations.py | 12 + src/diffusers/models/attention.py | 4 +- src/diffusers/models/attention_processor.py | 16 +- src/diffusers/models/autoencoders/__init__.py | 1 + .../autoencoder_kl_hunyuan_video.py | 1175 +++++++++++++++++ src/diffusers/models/transformers/__init__.py | 1 + .../transformers/transformer_hunyuan_video.py | 723 ++++++++++ src/diffusers/pipelines/__init__.py | 2 + .../pipelines/hunyuan_video/__init__.py | 48 + .../hunyuan_video/pipeline_hunyuan_video.py | 675 ++++++++++ .../hunyuan_video/pipeline_output.py | 20 + src/diffusers/utils/dummy_pt_objects.py | 30 + .../dummy_torch_and_transformers_objects.py | 15 + .../test_models_autoencoder_hunyuan_video.py | 159 +++ .../test_models_transformer_hunyuan_video.py | 89 ++ tests/pipelines/hunyuan_video/__init__.py | 0 .../hunyuan_video/test_hunyuan_video.py | 331 +++++ 24 files changed, 3676 insertions(+), 3 deletions(-) create mode 100644 docs/source/en/api/models/autoencoder_kl_hunyuan_video.md create mode 100644 docs/source/en/api/models/hunyuan_video_transformer_3d.md create mode 100644 docs/source/en/api/pipelines/hunyuan_video.md create mode 100644 scripts/convert_hunyuan_video_to_diffusers.py create mode 100644 src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py create mode 100644 src/diffusers/models/transformers/transformer_hunyuan_video.py create mode 100644 src/diffusers/pipelines/hunyuan_video/__init__.py create mode 100644 src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py create mode 100644 src/diffusers/pipelines/hunyuan_video/pipeline_output.py create mode 100644 tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py create mode 100644 tests/models/transformers/test_models_transformer_hunyuan_video.py create mode 100644 tests/pipelines/hunyuan_video/__init__.py create mode 100644 tests/pipelines/hunyuan_video/test_hunyuan_video.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index f4eb32cf63a8..d1404a1d6ea6 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -270,6 +270,8 @@ title: FluxTransformer2DModel - local: api/models/hunyuan_transformer2d title: HunyuanDiT2DModel + - local: api/models/hunyuan_video_transformer_3d + title: HunyuanVideoTransformer3DModel - local: api/models/latte_transformer3d title: LatteTransformer3DModel - local: api/models/lumina_nextdit2d @@ -316,6 +318,8 @@ title: AutoencoderKLAllegro - local: api/models/autoencoderkl_cogvideox title: AutoencoderKLCogVideoX + - local: api/models/autoencoder_kl_hunyuan_video + title: AutoencoderKLHunyuanVideo - local: api/models/autoencoderkl_ltx_video title: AutoencoderKLLTXVideo - local: api/models/autoencoderkl_mochi @@ -394,6 +398,8 @@ title: Flux - local: api/pipelines/hunyuandit title: Hunyuan-DiT + - local: api/pipelines/hunyuan_video + title: HunyuanVideo - local: api/pipelines/i2vgenxl title: I2VGen-XL - local: api/pipelines/pix2pix diff --git a/docs/source/en/api/models/autoencoder_kl_hunyuan_video.md b/docs/source/en/api/models/autoencoder_kl_hunyuan_video.md new file mode 100644 index 000000000000..f69c14814d3d --- /dev/null +++ b/docs/source/en/api/models/autoencoder_kl_hunyuan_video.md @@ -0,0 +1,32 @@ + + +# AutoencoderKLHunyuanVideo + +The 3D variational autoencoder (VAE) model with KL loss used in [HunyuanVideo](https://github.com/Tencent/HunyuanVideo/), which was introduced in [HunyuanVideo: A Systematic Framework For Large Video Generative Models](https://huggingface.co/papers/2412.03603) by Tencent. + +The model can be loaded with the following code snippet. + +```python +from diffusers import AutoencoderKLHunyuanVideo + +vae = AutoencoderKLHunyuanVideo.from_pretrained("tencent/HunyuanVideo", torch_dtype=torch.float16) +``` + +## AutoencoderKLHunyuanVideo + +[[autodoc]] AutoencoderKLHunyuanVideo + - decode + - all + +## DecoderOutput + +[[autodoc]] models.autoencoders.vae.DecoderOutput diff --git a/docs/source/en/api/models/hunyuan_video_transformer_3d.md b/docs/source/en/api/models/hunyuan_video_transformer_3d.md new file mode 100644 index 000000000000..73aea9832fc0 --- /dev/null +++ b/docs/source/en/api/models/hunyuan_video_transformer_3d.md @@ -0,0 +1,30 @@ + + +# HunyuanVideoTransformer3DModel + +A Diffusion Transformer model for 3D video-like data was introduced in [HunyuanVideo: A Systematic Framework For Large Video Generative Models](https://huggingface.co/papers/2412.03603) by Tencent. + +The model can be loaded with the following code snippet. + +```python +from diffusers import HunyuanVideoTransformer3DModel + +transformer = HunyuanVideoTransformer3DModel.from_pretrained("tencent/HunyuanVideo", torch_dtype=torch.bfloat16) +``` + +## HunyuanVideoTransformer3DModel + +[[autodoc]] HunyuanVideoTransformer3DModel + +## Transformer2DModelOutput + +[[autodoc]] models.modeling_outputs.Transformer2DModelOutput diff --git a/docs/source/en/api/pipelines/hunyuan_video.md b/docs/source/en/api/pipelines/hunyuan_video.md new file mode 100644 index 000000000000..86ef816fcd4d --- /dev/null +++ b/docs/source/en/api/pipelines/hunyuan_video.md @@ -0,0 +1,43 @@ + + +# HunyuanVideo + +[HunyuanVideo](https://www.arxiv.org/abs/2412.03603) by Tencent. + +*Recent advancements in video generation have significantly impacted daily life for both individuals and industries. However, the leading video generation models remain closed-source, resulting in a notable performance gap between industry capabilities and those available to the public. In this report, we introduce HunyuanVideo, an innovative open-source video foundation model that demonstrates performance in video generation comparable to, or even surpassing, that of leading closed-source models. HunyuanVideo encompasses a comprehensive framework that integrates several key elements, including data curation, advanced architectural design, progressive model scaling and training, and an efficient infrastructure tailored for large-scale model training and inference. As a result, we successfully trained a video generative model with over 13 billion parameters, making it the largest among all open-source models. We conducted extensive experiments and implemented a series of targeted designs to ensure high visual quality, motion dynamics, text-video alignment, and advanced filming techniques. According to evaluations by professionals, HunyuanVideo outperforms previous state-of-the-art models, including Runway Gen-3, Luma 1.6, and three top-performing Chinese video generative models. By releasing the code for the foundation model and its applications, we aim to bridge the gap between closed-source and open-source communities. This initiative will empower individuals within the community to experiment with their ideas, fostering a more dynamic and vibrant video generation ecosystem. The code is publicly available at [this https URL](https://github.com/Tencent/HunyuanVideo).* + + + +Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines. + + + +Recommendations for inference: +- Both text encoders should be in `torch.float16`. +- Transformer should be in `torch.bfloat16`. +- VAE should be in `torch.float16`. +- `num_frames` should be of the form `4 * k + 1`, for example `49` or `129`. +- For smaller resolution images, try lower values of `shift` (between `2.0` to `5.0`) in the [Scheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/flow_match_euler_discrete#diffusers.FlowMatchEulerDiscreteScheduler.shift). For larger resolution images, try higher values (between `7.0` and `12.0`). The default value is `7.0` for HunyuanVideo. +- For more information about supported resolutions and other details, please refer to the original repository [here](https://github.com/Tencent/HunyuanVideo/). + +## HunyuanVideoPipeline + +[[autodoc]] HunyuanVideoPipeline + - all + - __call__ + +## HunyuanVideoPipelineOutput + +[[autodoc]] pipelines.hunyuan_video.pipeline_output.HunyuanVideoPipelineOutput diff --git a/scripts/convert_hunyuan_video_to_diffusers.py b/scripts/convert_hunyuan_video_to_diffusers.py new file mode 100644 index 000000000000..464c9e0fb954 --- /dev/null +++ b/scripts/convert_hunyuan_video_to_diffusers.py @@ -0,0 +1,257 @@ +import argparse +from typing import Any, Dict + +import torch +from accelerate import init_empty_weights +from transformers import AutoModel, AutoTokenizer, CLIPTextModel, CLIPTokenizer + +from diffusers import ( + AutoencoderKLHunyuanVideo, + FlowMatchEulerDiscreteScheduler, + HunyuanVideoPipeline, + HunyuanVideoTransformer3DModel, +) + + +def remap_norm_scale_shift_(key, state_dict): + weight = state_dict.pop(key) + shift, scale = weight.chunk(2, dim=0) + new_weight = torch.cat([scale, shift], dim=0) + state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight + + +def remap_txt_in_(key, state_dict): + def rename_key(key): + new_key = key.replace("individual_token_refiner.blocks", "token_refiner.refiner_blocks") + new_key = new_key.replace("adaLN_modulation.1", "norm_out.linear") + new_key = new_key.replace("txt_in", "context_embedder") + new_key = new_key.replace("t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1") + new_key = new_key.replace("t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2") + new_key = new_key.replace("c_embedder", "time_text_embed.text_embedder") + new_key = new_key.replace("mlp", "ff") + return new_key + + if "self_attn_qkv" in key: + weight = state_dict.pop(key) + to_q, to_k, to_v = weight.chunk(3, dim=0) + state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_q"))] = to_q + state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_k"))] = to_k + state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_v"))] = to_v + else: + state_dict[rename_key(key)] = state_dict.pop(key) + + +def remap_img_attn_qkv_(key, state_dict): + weight = state_dict.pop(key) + to_q, to_k, to_v = weight.chunk(3, dim=0) + state_dict[key.replace("img_attn_qkv", "attn.to_q")] = to_q + state_dict[key.replace("img_attn_qkv", "attn.to_k")] = to_k + state_dict[key.replace("img_attn_qkv", "attn.to_v")] = to_v + + +def remap_txt_attn_qkv_(key, state_dict): + weight = state_dict.pop(key) + to_q, to_k, to_v = weight.chunk(3, dim=0) + state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = to_q + state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = to_k + state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = to_v + + +def remap_single_transformer_blocks_(key, state_dict): + hidden_size = 3072 + + if "linear1.weight" in key: + linear1_weight = state_dict.pop(key) + split_size = (hidden_size, hidden_size, hidden_size, linear1_weight.size(0) - 3 * hidden_size) + q, k, v, mlp = torch.split(linear1_weight, split_size, dim=0) + new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.weight") + state_dict[f"{new_key}.attn.to_q.weight"] = q + state_dict[f"{new_key}.attn.to_k.weight"] = k + state_dict[f"{new_key}.attn.to_v.weight"] = v + state_dict[f"{new_key}.proj_mlp.weight"] = mlp + + elif "linear1.bias" in key: + linear1_bias = state_dict.pop(key) + split_size = (hidden_size, hidden_size, hidden_size, linear1_bias.size(0) - 3 * hidden_size) + q_bias, k_bias, v_bias, mlp_bias = torch.split(linear1_bias, split_size, dim=0) + new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.bias") + state_dict[f"{new_key}.attn.to_q.bias"] = q_bias + state_dict[f"{new_key}.attn.to_k.bias"] = k_bias + state_dict[f"{new_key}.attn.to_v.bias"] = v_bias + state_dict[f"{new_key}.proj_mlp.bias"] = mlp_bias + + else: + new_key = key.replace("single_blocks", "single_transformer_blocks") + new_key = new_key.replace("linear2", "proj_out") + new_key = new_key.replace("q_norm", "attn.norm_q") + new_key = new_key.replace("k_norm", "attn.norm_k") + state_dict[new_key] = state_dict.pop(key) + + +TRANSFORMER_KEYS_RENAME_DICT = { + "img_in": "x_embedder", + "time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1", + "time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2", + "guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1", + "guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2", + "vector_in.in_layer": "time_text_embed.text_embedder.linear_1", + "vector_in.out_layer": "time_text_embed.text_embedder.linear_2", + "double_blocks": "transformer_blocks", + "img_attn_q_norm": "attn.norm_q", + "img_attn_k_norm": "attn.norm_k", + "img_attn_proj": "attn.to_out.0", + "txt_attn_q_norm": "attn.norm_added_q", + "txt_attn_k_norm": "attn.norm_added_k", + "txt_attn_proj": "attn.to_add_out", + "img_mod.linear": "norm1.linear", + "img_norm1": "norm1.norm", + "img_norm2": "norm2", + "img_mlp": "ff", + "txt_mod.linear": "norm1_context.linear", + "txt_norm1": "norm1.norm", + "txt_norm2": "norm2_context", + "txt_mlp": "ff_context", + "self_attn_proj": "attn.to_out.0", + "modulation.linear": "norm.linear", + "pre_norm": "norm.norm", + "final_layer.norm_final": "norm_out.norm", + "final_layer.linear": "proj_out", + "fc1": "net.0.proj", + "fc2": "net.2", + "input_embedder": "proj_in", +} + +TRANSFORMER_SPECIAL_KEYS_REMAP = { + "txt_in": remap_txt_in_, + "img_attn_qkv": remap_img_attn_qkv_, + "txt_attn_qkv": remap_txt_attn_qkv_, + "single_blocks": remap_single_transformer_blocks_, + "final_layer.adaLN_modulation.1": remap_norm_scale_shift_, +} + +VAE_KEYS_RENAME_DICT = {} + +VAE_SPECIAL_KEYS_REMAP = {} + + +def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]: + state_dict[new_key] = state_dict.pop(old_key) + + +def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]: + state_dict = saved_dict + if "model" in saved_dict.keys(): + state_dict = state_dict["model"] + if "module" in saved_dict.keys(): + state_dict = state_dict["module"] + if "state_dict" in saved_dict.keys(): + state_dict = state_dict["state_dict"] + return state_dict + + +def convert_transformer(ckpt_path: str): + original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=True)) + + with init_empty_weights(): + transformer = HunyuanVideoTransformer3DModel() + + for key in list(original_state_dict.keys()): + new_key = key[:] + for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + update_state_dict_(original_state_dict, key, new_key) + + for key in list(original_state_dict.keys()): + for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items(): + if special_key not in key: + continue + handler_fn_inplace(key, original_state_dict) + + transformer.load_state_dict(original_state_dict, strict=True, assign=True) + return transformer + + +def convert_vae(ckpt_path: str): + original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=True)) + + with init_empty_weights(): + vae = AutoencoderKLHunyuanVideo() + + for key in list(original_state_dict.keys()): + new_key = key[:] + for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + update_state_dict_(original_state_dict, key, new_key) + + for key in list(original_state_dict.keys()): + for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items(): + if special_key not in key: + continue + handler_fn_inplace(key, original_state_dict) + + vae.load_state_dict(original_state_dict, strict=True, assign=True) + return vae + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint" + ) + parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original VAE checkpoint") + parser.add_argument("--text_encoder_path", type=str, default=None, help="Path to original llama checkpoint") + parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to original llama tokenizer") + parser.add_argument("--text_encoder_2_path", type=str, default=None, help="Path to original clip checkpoint") + parser.add_argument("--save_pipeline", action="store_true") + parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") + parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.") + return parser.parse_args() + + +DTYPE_MAPPING = { + "fp32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, +} + + +if __name__ == "__main__": + args = get_args() + + transformer = None + dtype = DTYPE_MAPPING[args.dtype] + + if args.save_pipeline: + assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None + assert args.text_encoder_path is not None + assert args.tokenizer_path is not None + assert args.text_encoder_2_path is not None + + if args.transformer_ckpt_path is not None: + transformer = convert_transformer(args.transformer_ckpt_path) + transformer = transformer.to(dtype=dtype) + if not args.save_pipeline: + transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") + + if args.vae_ckpt_path is not None: + vae = convert_vae(args.vae_ckpt_path) + if not args.save_pipeline: + vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") + + if args.save_pipeline: + text_encoder = AutoModel.from_pretrained(args.text_encoder_path, torch_dtype=torch.float16) + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, padding_side="right") + text_encoder_2 = CLIPTextModel.from_pretrained(args.text_encoder_2_path, torch_dtype=torch.float16) + tokenizer_2 = CLIPTokenizer.from_pretrained(args.text_encoder_2_path) + scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0) + + pipe = HunyuanVideoPipeline( + transformer=transformer, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + scheduler=scheduler, + ) + pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 20914442b84a..dfa7a4df2d08 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -84,6 +84,7 @@ "AutoencoderKL", "AutoencoderKLAllegro", "AutoencoderKLCogVideoX", + "AutoencoderKLHunyuanVideo", "AutoencoderKLLTXVideo", "AutoencoderKLMochi", "AutoencoderKLTemporalDecoder", @@ -102,6 +103,7 @@ "HunyuanDiT2DControlNetModel", "HunyuanDiT2DModel", "HunyuanDiT2DMultiControlNetModel", + "HunyuanVideoTransformer3DModel", "I2VGenXLUNet", "Kandinsky3UNet", "LatteTransformer3DModel", @@ -287,6 +289,7 @@ "HunyuanDiTControlNetPipeline", "HunyuanDiTPAGPipeline", "HunyuanDiTPipeline", + "HunyuanVideoPipeline", "I2VGenXLPipeline", "IFImg2ImgPipeline", "IFImg2ImgSuperResolutionPipeline", @@ -590,6 +593,7 @@ AutoencoderKL, AutoencoderKLAllegro, AutoencoderKLCogVideoX, + AutoencoderKLHunyuanVideo, AutoencoderKLLTXVideo, AutoencoderKLMochi, AutoencoderKLTemporalDecoder, @@ -608,6 +612,7 @@ HunyuanDiT2DControlNetModel, HunyuanDiT2DModel, HunyuanDiT2DMultiControlNetModel, + HunyuanVideoTransformer3DModel, I2VGenXLUNet, Kandinsky3UNet, LatteTransformer3DModel, @@ -772,6 +777,7 @@ HunyuanDiTControlNetPipeline, HunyuanDiTPAGPipeline, HunyuanDiTPipeline, + HunyuanVideoPipeline, I2VGenXLPipeline, IFImg2ImgPipeline, IFImg2ImgSuperResolutionPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 687c555e0ce2..01e67b01d91a 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -31,6 +31,7 @@ _import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"] _import_structure["autoencoders.autoencoder_kl_allegro"] = ["AutoencoderKLAllegro"] _import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"] + _import_structure["autoencoders.autoencoder_kl_hunyuan_video"] = ["AutoencoderKLHunyuanVideo"] _import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"] _import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"] _import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"] @@ -67,6 +68,7 @@ _import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"] _import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"] _import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"] + _import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"] _import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"] _import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"] _import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"] @@ -97,6 +99,7 @@ AutoencoderKL, AutoencoderKLAllegro, AutoencoderKLCogVideoX, + AutoencoderKLHunyuanVideo, AutoencoderKLLTXVideo, AutoencoderKLMochi, AutoencoderKLTemporalDecoder, @@ -130,6 +133,7 @@ DualTransformer2DModel, FluxTransformer2DModel, HunyuanDiT2DModel, + HunyuanVideoTransformer3DModel, LatteTransformer3DModel, LTXVideoTransformer3DModel, LuminaNextDiT2DModel, diff --git a/src/diffusers/models/activations.py b/src/diffusers/models/activations.py index c1d4f0b46e15..c61baefa08f4 100644 --- a/src/diffusers/models/activations.py +++ b/src/diffusers/models/activations.py @@ -164,3 +164,15 @@ def __init__(self, dim_in: int, dim_out: int, bias: bool = True): def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.proj(x) return x * torch.sigmoid(1.702 * x) + + +class LinearActivation(nn.Module): + def __init__(self, dim_in: int, dim_out: int, bias: bool = True, activation: str = "silu"): + super().__init__() + + self.proj = nn.Linear(dim_in, dim_out, bias=bias) + self.activation = get_activation(activation) + + def forward(self, hidden_states): + hidden_states = self.proj(hidden_states) + return self.activation(hidden_states) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 02ed1f965abf..6749c7f17254 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -19,7 +19,7 @@ from ..utils import deprecate, logging from ..utils.torch_utils import maybe_allow_in_graph -from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU +from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, LinearActivation, SwiGLU from .attention_processor import Attention, JointAttnProcessor2_0 from .embeddings import SinusoidalPositionalEmbedding from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX @@ -1222,6 +1222,8 @@ def __init__( act_fn = ApproximateGELU(dim, inner_dim, bias=bias) elif activation_fn == "swiglu": act_fn = SwiGLU(dim, inner_dim, bias=bias) + elif activation_fn == "linear-silu": + act_fn = LinearActivation(dim, inner_dim, bias=bias, activation="silu") self.net = nn.ModuleList([]) # project in diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 77e35364ab09..ee6b010519e2 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -254,14 +254,22 @@ def __init__( self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) if self.context_pre_only is not None: self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + else: + self.add_q_proj = None + self.add_k_proj = None + self.add_v_proj = None if not self.pre_only: self.to_out = nn.ModuleList([]) self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) self.to_out.append(nn.Dropout(dropout)) + else: + self.to_out = None if self.context_pre_only is not None and not self.context_pre_only: self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias) + else: + self.to_add_out = None if qk_norm is not None and added_kv_proj_dim is not None: if qk_norm == "fp32_layer_norm": @@ -782,7 +790,11 @@ def fuse_projections(self, fuse=True): self.to_kv.bias.copy_(concatenated_bias) # handle added projections for SD3 and others. - if hasattr(self, "add_q_proj") and hasattr(self, "add_k_proj") and hasattr(self, "add_v_proj"): + if ( + getattr(self, "add_q_proj", None) is not None + and getattr(self, "add_k_proj", None) is not None + and getattr(self, "add_v_proj", None) is not None + ): concatenated_weights = torch.cat( [self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data] ) @@ -3938,7 +3950,7 @@ def apply_rotary_emb(x, freqs_cos, freqs_sin): # dropout hidden_states = attn.to_out[1](hidden_states) - if hasattr(attn, "to_add_out"): + if getattr(attn, "to_add_out", None) is not None: encoder_hidden_states = attn.to_add_out(encoder_hidden_states) return hidden_states, encoder_hidden_states diff --git a/src/diffusers/models/autoencoders/__init__.py b/src/diffusers/models/autoencoders/__init__.py index d08e67c40975..bb750a4410f2 100644 --- a/src/diffusers/models/autoencoders/__init__.py +++ b/src/diffusers/models/autoencoders/__init__.py @@ -3,6 +3,7 @@ from .autoencoder_kl import AutoencoderKL from .autoencoder_kl_allegro import AutoencoderKLAllegro from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX +from .autoencoder_kl_hunyuan_video import AutoencoderKLHunyuanVideo from .autoencoder_kl_ltx import AutoencoderKLLTXVideo from .autoencoder_kl_mochi import AutoencoderKLMochi from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py new file mode 100644 index 000000000000..bded90a8bcff --- /dev/null +++ b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py @@ -0,0 +1,1175 @@ +# Copyright 2024 The Hunyuan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import is_torch_version, logging +from ...utils.accelerate_utils import apply_forward_hook +from ..activations import get_activation +from ..attention_processor import Attention +from ..modeling_outputs import AutoencoderKLOutput +from ..modeling_utils import ModelMixin +from .vae import DecoderOutput, DiagonalGaussianDistribution + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def prepare_causal_attention_mask( + num_frames: int, height_width: int, dtype: torch.dtype, device: torch.device, batch_size: int = None +) -> torch.Tensor: + seq_len = num_frames * height_width + mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device) + for i in range(seq_len): + i_frame = i // height_width + mask[i, : (i_frame + 1) * height_width] = 0 + if batch_size is not None: + mask = mask.unsqueeze(0).expand(batch_size, -1, -1) + return mask + + +class HunyuanVideoCausalConv3d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int, int]] = 3, + stride: Union[int, Tuple[int, int, int]] = 1, + padding: Union[int, Tuple[int, int, int]] = 0, + dilation: Union[int, Tuple[int, int, int]] = 1, + bias: bool = True, + pad_mode: str = "replicate", + ) -> None: + super().__init__() + + kernel_size = (kernel_size, kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size + + self.pad_mode = pad_mode + self.time_causal_padding = ( + kernel_size[0] // 2, + kernel_size[0] // 2, + kernel_size[1] // 2, + kernel_size[1] // 2, + kernel_size[2] - 1, + 0, + ) + + self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = F.pad(hidden_states, self.time_causal_padding, mode=self.pad_mode) + return self.conv(hidden_states) + + +class HunyuanVideoUpsampleCausal3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + kernel_size: int = 3, + stride: int = 1, + bias: bool = True, + upsample_factor: Tuple[float, float, float] = (2, 2, 2), + ) -> None: + super().__init__() + + out_channels = out_channels or in_channels + self.upsample_factor = upsample_factor + + self.conv = HunyuanVideoCausalConv3d(in_channels, out_channels, kernel_size, stride, bias=bias) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_frames = hidden_states.size(2) + + first_frame, other_frames = hidden_states.split((1, num_frames - 1), dim=2) + first_frame = F.interpolate( + first_frame.squeeze(2), scale_factor=self.upsample_factor[1:], mode="nearest" + ).unsqueeze(2) + + if num_frames > 1: + # See: https://github.com/pytorch/pytorch/issues/81665 + # Unless you have a version of pytorch where non-contiguous implementation of F.interpolate + # is fixed, this will raise either a runtime error, or fail silently with bad outputs. + # If you are encountering an error here, make sure to try running encoding/decoding with + # `vae.enable_tiling()` first. If that doesn't work, open an issue at: + # https://github.com/huggingface/diffusers/issues + other_frames = other_frames.contiguous() + other_frames = F.interpolate(other_frames, scale_factor=self.upsample_factor, mode="nearest") + hidden_states = torch.cat((first_frame, other_frames), dim=2) + else: + hidden_states = first_frame + + hidden_states = self.conv(hidden_states) + return hidden_states + + +class HunyuanVideoDownsampleCausal3D(nn.Module): + def __init__( + self, + channels: int, + out_channels: Optional[int] = None, + padding: int = 1, + kernel_size: int = 3, + bias: bool = True, + stride=2, + ) -> None: + super().__init__() + out_channels = out_channels or channels + + self.conv = HunyuanVideoCausalConv3d(channels, out_channels, kernel_size, stride, padding, bias=bias) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.conv(hidden_states) + return hidden_states + + +class HunyuanVideoResnetBlockCausal3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + dropout: float = 0.0, + groups: int = 32, + eps: float = 1e-6, + non_linearity: str = "swish", + ) -> None: + super().__init__() + out_channels = out_channels or in_channels + + self.nonlinearity = get_activation(non_linearity) + + self.norm1 = nn.GroupNorm(groups, in_channels, eps=eps, affine=True) + self.conv1 = HunyuanVideoCausalConv3d(in_channels, out_channels, 3, 1, 0) + + self.norm2 = nn.GroupNorm(groups, out_channels, eps=eps, affine=True) + self.dropout = nn.Dropout(dropout) + self.conv2 = HunyuanVideoCausalConv3d(out_channels, out_channels, 3, 1, 0) + + self.conv_shortcut = None + if in_channels != out_channels: + self.conv_shortcut = HunyuanVideoCausalConv3d(in_channels, out_channels, 1, 1, 0) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + residual = hidden_states + + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.conv1(hidden_states) + + hidden_states = self.norm2(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + residual = self.conv_shortcut(residual) + + hidden_states = hidden_states + residual + return hidden_states + + +class HunyuanVideoMidBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + add_attention: bool = True, + attention_head_dim: int = 1, + ) -> None: + super().__init__() + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + self.add_attention = add_attention + + # There is always at least one resnet + resnets = [ + HunyuanVideoResnetBlockCausal3D( + in_channels=in_channels, + out_channels=in_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + non_linearity=resnet_act_fn, + ) + ] + attentions = [] + + for _ in range(num_layers): + if self.add_attention: + attentions.append( + Attention( + in_channels, + heads=in_channels // attention_head_dim, + dim_head=attention_head_dim, + eps=resnet_eps, + norm_num_groups=resnet_groups, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + else: + attentions.append(None) + + resnets.append( + HunyuanVideoResnetBlockCausal3D( + in_channels=in_channels, + out_channels=in_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + non_linearity=resnet_act_fn, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.resnets[0]), hidden_states, **ckpt_kwargs + ) + + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if attn is not None: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + hidden_states = hidden_states.permute(0, 2, 3, 4, 1).flatten(1, 3) + attention_mask = prepare_causal_attention_mask( + num_frames, height * width, hidden_states.dtype, hidden_states.device, batch_size=batch_size + ) + hidden_states = attn(hidden_states, attention_mask=attention_mask) + hidden_states = hidden_states.unflatten(1, (num_frames, height, width)).permute(0, 4, 1, 2, 3) + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, **ckpt_kwargs + ) + + else: + hidden_states = self.resnets[0](hidden_states) + + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if attn is not None: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + hidden_states = hidden_states.permute(0, 2, 3, 4, 1).flatten(1, 3) + attention_mask = prepare_causal_attention_mask( + num_frames, height * width, hidden_states.dtype, hidden_states.device, batch_size=batch_size + ) + hidden_states = attn(hidden_states, attention_mask=attention_mask) + hidden_states = hidden_states.unflatten(1, (num_frames, height, width)).permute(0, 4, 1, 2, 3) + + hidden_states = resnet(hidden_states) + + return hidden_states + + +class HunyuanVideoDownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + add_downsample: bool = True, + downsample_stride: int = 2, + downsample_padding: int = 1, + ) -> None: + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + HunyuanVideoResnetBlockCausal3D( + in_channels=in_channels, + out_channels=out_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + non_linearity=resnet_act_fn, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + HunyuanVideoDownsampleCausal3D( + out_channels, + out_channels=out_channels, + padding=downsample_padding, + stride=downsample_stride, + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + + for resnet in self.resnets: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, **ckpt_kwargs + ) + else: + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states + + +class HunyuanVideoUpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + add_upsample: bool = True, + upsample_scale_factor: Tuple[int, int, int] = (2, 2, 2), + ) -> None: + super().__init__() + resnets = [] + + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + resnets.append( + HunyuanVideoResnetBlockCausal3D( + in_channels=input_channels, + out_channels=out_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + non_linearity=resnet_act_fn, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList( + [ + HunyuanVideoUpsampleCausal3D( + out_channels, + out_channels=out_channels, + upsample_factor=upsample_scale_factor, + ) + ] + ) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + + for resnet in self.resnets: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, **ckpt_kwargs + ) + + else: + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class HunyuanVideoEncoder3D(nn.Module): + r""" + Causal encoder for 3D video-like data introduced in [Hunyuan Video](https://huggingface.co/papers/2412.03603). + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str, ...] = ( + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + ), + block_out_channels: Tuple[int, ...] = (128, 256, 512, 512), + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + double_z: bool = True, + mid_block_add_attention=True, + temporal_compression_ratio: int = 4, + spatial_compression_ratio: int = 8, + ) -> None: + super().__init__() + + self.conv_in = HunyuanVideoCausalConv3d(in_channels, block_out_channels[0], kernel_size=3, stride=1) + self.mid_block = None + self.down_blocks = nn.ModuleList([]) + + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + if down_block_type != "HunyuanVideoDownBlock3D": + raise ValueError(f"Unsupported down_block_type: {down_block_type}") + + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + num_spatial_downsample_layers = int(np.log2(spatial_compression_ratio)) + num_time_downsample_layers = int(np.log2(temporal_compression_ratio)) + + if temporal_compression_ratio == 4: + add_spatial_downsample = bool(i < num_spatial_downsample_layers) + add_time_downsample = bool( + i >= (len(block_out_channels) - 1 - num_time_downsample_layers) and not is_final_block + ) + elif temporal_compression_ratio == 8: + add_spatial_downsample = bool(i < num_spatial_downsample_layers) + add_time_downsample = bool(i < num_time_downsample_layers) + else: + raise ValueError(f"Unsupported time_compression_ratio: {temporal_compression_ratio}") + + downsample_stride_HW = (2, 2) if add_spatial_downsample else (1, 1) + downsample_stride_T = (2,) if add_time_downsample else (1,) + downsample_stride = tuple(downsample_stride_T + downsample_stride_HW) + + down_block = HunyuanVideoDownBlock3D( + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + add_downsample=bool(add_spatial_downsample or add_time_downsample), + resnet_eps=1e-6, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + downsample_stride=downsample_stride, + downsample_padding=0, + ) + + self.down_blocks.append(down_block) + + self.mid_block = HunyuanVideoMidBlock3D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + add_attention=mid_block_add_attention, + ) + + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6) + self.conv_act = nn.SiLU() + + conv_out_channels = 2 * out_channels if double_z else out_channels + self.conv_out = HunyuanVideoCausalConv3d(block_out_channels[-1], conv_out_channels, kernel_size=3) + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.conv_in(hidden_states) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + + for down_block in self.down_blocks: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(down_block), hidden_states, **ckpt_kwargs + ) + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.mid_block), hidden_states, **ckpt_kwargs + ) + else: + for down_block in self.down_blocks: + hidden_states = down_block(hidden_states) + + hidden_states = self.mid_block(hidden_states) + + hidden_states = self.conv_norm_out(hidden_states) + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states) + + return hidden_states + + +class HunyuanVideoDecoder3D(nn.Module): + r""" + Causal decoder for 3D video-like data introduced in [Hunyuan Video](https://huggingface.co/papers/2412.03603). + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + up_block_types: Tuple[str, ...] = ( + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + ), + block_out_channels: Tuple[int, ...] = (128, 256, 512, 512), + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + mid_block_add_attention=True, + time_compression_ratio: int = 4, + spatial_compression_ratio: int = 8, + ): + super().__init__() + self.layers_per_block = layers_per_block + + self.conv_in = HunyuanVideoCausalConv3d(in_channels, block_out_channels[-1], kernel_size=3, stride=1) + self.up_blocks = nn.ModuleList([]) + + # mid + self.mid_block = HunyuanVideoMidBlock3D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + add_attention=mid_block_add_attention, + ) + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + if up_block_type != "HunyuanVideoUpBlock3D": + raise ValueError(f"Unsupported up_block_type: {up_block_type}") + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + num_spatial_upsample_layers = int(np.log2(spatial_compression_ratio)) + num_time_upsample_layers = int(np.log2(time_compression_ratio)) + + if time_compression_ratio == 4: + add_spatial_upsample = bool(i < num_spatial_upsample_layers) + add_time_upsample = bool( + i >= len(block_out_channels) - 1 - num_time_upsample_layers and not is_final_block + ) + else: + raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}") + + upsample_scale_factor_HW = (2, 2) if add_spatial_upsample else (1, 1) + upsample_scale_factor_T = (2,) if add_time_upsample else (1,) + upsample_scale_factor = tuple(upsample_scale_factor_T + upsample_scale_factor_HW) + + up_block = HunyuanVideoUpBlock3D( + num_layers=self.layers_per_block + 1, + in_channels=prev_output_channel, + out_channels=output_channel, + add_upsample=bool(add_spatial_upsample or add_time_upsample), + upsample_scale_factor=upsample_scale_factor, + resnet_eps=1e-6, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + ) + + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) + self.conv_act = nn.SiLU() + self.conv_out = HunyuanVideoCausalConv3d(block_out_channels[0], out_channels, kernel_size=3) + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.conv_in(hidden_states) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.mid_block), hidden_states, **ckpt_kwargs + ) + + for up_block in self.up_blocks: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(up_block), hidden_states, **ckpt_kwargs + ) + else: + hidden_states = self.mid_block(hidden_states) + + for up_block in self.up_blocks: + hidden_states = up_block(hidden_states) + + # post-process + hidden_states = self.conv_norm_out(hidden_states) + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states) + + return hidden_states + + +class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin): + r""" + A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. + Introduced in [HunyuanVideo](https://huggingface.co/papers/2412.03603). + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + latent_channels: int = 16, + down_block_types: Tuple[str, ...] = ( + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + ), + up_block_types: Tuple[str, ...] = ( + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + ), + block_out_channels: Tuple[int] = (128, 256, 512, 512), + layers_per_block: int = 2, + act_fn: str = "silu", + norm_num_groups: int = 32, + scaling_factor: float = 0.476986, + spatial_compression_ratio: int = 8, + temporal_compression_ratio: int = 4, + mid_block_add_attention: bool = True, + ) -> None: + super().__init__() + + self.time_compression_ratio = temporal_compression_ratio + + self.encoder = HunyuanVideoEncoder3D( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + norm_num_groups=norm_num_groups, + act_fn=act_fn, + double_z=True, + mid_block_add_attention=mid_block_add_attention, + temporal_compression_ratio=temporal_compression_ratio, + spatial_compression_ratio=spatial_compression_ratio, + ) + + self.decoder = HunyuanVideoDecoder3D( + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + norm_num_groups=norm_num_groups, + act_fn=act_fn, + time_compression_ratio=temporal_compression_ratio, + spatial_compression_ratio=spatial_compression_ratio, + mid_block_add_attention=mid_block_add_attention, + ) + + self.quant_conv = nn.Conv3d(2 * latent_channels, 2 * latent_channels, kernel_size=1) + self.post_quant_conv = nn.Conv3d(latent_channels, latent_channels, kernel_size=1) + + self.spatial_compression_ratio = spatial_compression_ratio + self.temporal_compression_ratio = temporal_compression_ratio + + # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension + # to perform decoding of a single video latent at a time. + self.use_slicing = False + + # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent + # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the + # intermediate tiles together, the memory requirement can be lowered. + self.use_tiling = False + + # When decoding temporally long video latents, the memory requirement is very high. By decoding latent frames + # at a fixed frame batch size (based on `self.num_latent_frames_batch_sizes`), the memory requirement can be lowered. + self.use_framewise_encoding = True + self.use_framewise_decoding = True + + # The minimal tile height and width for spatial tiling to be used + self.tile_sample_min_height = 256 + self.tile_sample_min_width = 256 + self.tile_sample_min_num_frames = 64 + + # The minimal distance between two spatial tiles + self.tile_sample_stride_height = 192 + self.tile_sample_stride_width = 192 + self.tile_sample_stride_num_frames = 48 + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (HunyuanVideoEncoder3D, HunyuanVideoDecoder3D)): + module.gradient_checkpointing = value + + def enable_tiling( + self, + tile_sample_min_height: Optional[int] = None, + tile_sample_min_width: Optional[int] = None, + tile_sample_min_num_frames: Optional[int] = None, + tile_sample_stride_height: Optional[float] = None, + tile_sample_stride_width: Optional[float] = None, + tile_sample_stride_num_frames: Optional[float] = None, + ) -> None: + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + + Args: + tile_sample_min_height (`int`, *optional*): + The minimum height required for a sample to be separated into tiles across the height dimension. + tile_sample_min_width (`int`, *optional*): + The minimum width required for a sample to be separated into tiles across the width dimension. + tile_sample_min_num_frames (`int`, *optional*): + The minimum number of frames required for a sample to be separated into tiles across the frame + dimension. + tile_sample_stride_height (`int`, *optional*): + The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are + no tiling artifacts produced across the height dimension. + tile_sample_stride_width (`int`, *optional*): + The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling + artifacts produced across the width dimension. + tile_sample_stride_num_frames (`int`, *optional*): + The stride between two consecutive frame tiles. This is to ensure that there are no tiling artifacts + produced across the frame dimension. + """ + self.use_tiling = True + self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height + self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width + self.tile_sample_min_num_frames = tile_sample_min_num_frames or self.tile_sample_min_num_frames + self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height + self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width + self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames + + def disable_tiling(self) -> None: + r""" + Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_tiling = False + + def enable_slicing(self) -> None: + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_slicing = True + + def disable_slicing(self) -> None: + r""" + Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_slicing = False + + def _encode(self, x: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = x.shape + + if self.use_framewise_decoding and num_frames > self.tile_sample_min_num_frames: + return self._temporal_tiled_encode(x) + + if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): + return self.tiled_encode(x) + + x = self.encoder(x) + enc = self.quant_conv(x) + return enc + + @apply_forward_hook + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + r""" + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded videos. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + batch_size, num_channels, num_frames, height, width = z.shape + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_stride_width // self.spatial_compression_ratio + tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio + + if self.use_framewise_decoding and num_frames > tile_latent_min_num_frames: + return self._temporal_tiled_decode(z, return_dict=return_dict) + + if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height): + return self.tiled_decode(z, return_dict=return_dict) + + z = self.post_quant_conv(z) + dec = self.decoder(z) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + @apply_forward_hook + def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z).sample + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-2], b.shape[-2], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( + y / blend_extent + ) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-1], b.shape[-1], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( + x / blend_extent + ) + return b + + def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-3], b.shape[-3], blend_extent) + for x in range(blend_extent): + b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * ( + x / blend_extent + ) + return b + + def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput: + r"""Encode a batch of images using a tiled encoder. + + Args: + x (`torch.Tensor`): Input batch of videos. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + batch_size, num_channels, num_frames, height, width = x.shape + latent_height = height // self.spatial_compression_ratio + latent_width = width // self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = tile_latent_min_height - tile_latent_stride_height + blend_width = tile_latent_min_width - tile_latent_stride_width + + # Split x into overlapping tiles and encode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, self.tile_sample_stride_height): + row = [] + for j in range(0, width, self.tile_sample_stride_width): + tile = x[:, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size] + tile = self.encoder(tile) + tile = self.quant_conv(tile) + row.append(tile) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width]) + result_rows.append(torch.cat(result_row, dim=-1)) + + enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width] + return enc + + def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + + batch_size, num_channels, num_frames, height, width = z.shape + sample_height = height * self.spatial_compression_ratio + sample_width = width * self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = self.tile_sample_min_height - self.tile_sample_stride_height + blend_width = self.tile_sample_min_width - self.tile_sample_stride_width + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, tile_latent_stride_height): + row = [] + for j in range(0, width, tile_latent_stride_width): + tile = z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width] + tile = self.post_quant_conv(tile) + decoded = self.decoder(tile) + row.append(decoded) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width]) + result_rows.append(torch.cat(result_row, dim=-1)) + + dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width] + + if not return_dict: + return (dec,) + return DecoderOutput(sample=dec) + + def _temporal_tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput: + batch_size, num_channels, num_frames, height, width = x.shape + latent_num_frames = (num_frames - 1) // self.temporal_compression_ratio + 1 + + tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio + tile_latent_stride_num_frames = self.tile_sample_stride_num_frames // self.temporal_compression_ratio + blend_num_frames = tile_latent_min_num_frames - tile_latent_stride_num_frames + + row = [] + for i in range(0, num_frames, self.tile_sample_stride_num_frames): + tile = x[:, :, i : i + self.tile_sample_min_num_frames + 1, :, :] + if self.use_tiling and (height > self.tile_sample_min_height or width > self.tile_sample_min_width): + tile = self.tiled_encode(tile) + else: + tile = self.encoder(tile) + tile = self.quant_conv(tile) + if i > 0: + tile = tile[:, :, 1:, :, :] + row.append(tile) + + result_row = [] + for i, tile in enumerate(row): + if i > 0: + tile = self.blend_t(row[i - 1], tile, blend_num_frames) + result_row.append(tile[:, :, :tile_latent_stride_num_frames, :, :]) + else: + result_row.append(tile[:, :, : tile_latent_stride_num_frames + 1, :, :]) + + enc = torch.cat(result_row, dim=2)[:, :, :latent_num_frames] + return enc + + def _temporal_tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + batch_size, num_channels, num_frames, height, width = z.shape + num_sample_frames = (num_frames - 1) * self.temporal_compression_ratio + 1 + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio + tile_latent_stride_num_frames = self.tile_sample_stride_num_frames // self.temporal_compression_ratio + blend_num_frames = self.tile_sample_min_num_frames - self.tile_sample_stride_num_frames + + row = [] + for i in range(0, num_frames, tile_latent_stride_num_frames): + tile = z[:, :, i : i + tile_latent_min_num_frames + 1, :, :] + if self.use_tiling and (tile.shape[-1] > tile_latent_min_width or tile.shape[-2] > tile_latent_min_height): + decoded = self.tiled_decode(tile, return_dict=True).sample + else: + tile = self.post_quant_conv(tile) + decoded = self.decoder(tile) + if i > 0: + decoded = decoded[:, :, 1:, :, :] + row.append(decoded) + + result_row = [] + for i, tile in enumerate(row): + if i > 0: + tile = self.blend_t(row[i - 1], tile, blend_num_frames) + result_row.append(tile[:, :, : self.tile_sample_stride_num_frames, :, :]) + else: + result_row.append(tile[:, :, : self.tile_sample_stride_num_frames + 1, :, :]) + + dec = torch.cat(result_row, dim=2)[:, :, :num_sample_frames] + + if not return_dict: + return (dec,) + return DecoderOutput(sample=dec) + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[DecoderOutput, torch.Tensor]: + r""" + Args: + sample (`torch.Tensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z, return_dict=return_dict) + return dec diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 6a13e80772e3..3a33c8070c08 100644 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -18,6 +18,7 @@ from .transformer_allegro import AllegroTransformer3DModel from .transformer_cogview3plus import CogView3PlusTransformer2DModel from .transformer_flux import FluxTransformer2DModel + from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel from .transformer_ltx import LTXVideoTransformer3DModel from .transformer_mochi import MochiTransformer3DModel from .transformer_sd3 import SD3Transformer2DModel diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py new file mode 100644 index 000000000000..d8f9834ea61c --- /dev/null +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -0,0 +1,723 @@ +# Copyright 2024 The Hunyuan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import is_torch_version +from ..attention import FeedForward +from ..attention_processor import Attention, AttentionProcessor +from ..embeddings import ( + CombinedTimestepGuidanceTextProjEmbeddings, + CombinedTimestepTextProjEmbeddings, + get_1d_rotary_pos_embed, +) +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle + + +class HunyuanVideoAttnProcessor2_0: + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "HunyuanVideoAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if attn.add_q_proj is None and encoder_hidden_states is not None: + hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1) + + # 1. QKV projections + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) + key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) + value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) + + # 2. QK normalization + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # 3. Rotational positional embeddings applied to latent stream + if image_rotary_emb is not None: + from ..embeddings import apply_rotary_emb + + if attn.add_q_proj is None and encoder_hidden_states is not None: + query = torch.cat( + [ + apply_rotary_emb(query[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb), + query[:, :, -encoder_hidden_states.shape[1] :], + ], + dim=2, + ) + key = torch.cat( + [ + apply_rotary_emb(key[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb), + key[:, :, -encoder_hidden_states.shape[1] :], + ], + dim=2, + ) + else: + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + # 4. Encoder condition QKV projection and normalization + if attn.add_q_proj is not None and encoder_hidden_states is not None: + encoder_query = attn.add_q_proj(encoder_hidden_states) + encoder_key = attn.add_k_proj(encoder_hidden_states) + encoder_value = attn.add_v_proj(encoder_hidden_states) + + encoder_query = encoder_query.unflatten(2, (attn.heads, -1)).transpose(1, 2) + encoder_key = encoder_key.unflatten(2, (attn.heads, -1)).transpose(1, 2) + encoder_value = encoder_value.unflatten(2, (attn.heads, -1)).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_query = attn.norm_added_q(encoder_query) + if attn.norm_added_k is not None: + encoder_key = attn.norm_added_k(encoder_key) + + query = torch.cat([query, encoder_query], dim=2) + key = torch.cat([key, encoder_key], dim=2) + value = torch.cat([value, encoder_value], dim=2) + + # 5. Attention + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + # 6. Output projection + if encoder_hidden_states is not None: + hidden_states, encoder_hidden_states = ( + hidden_states[:, : -encoder_hidden_states.shape[1]], + hidden_states[:, -encoder_hidden_states.shape[1] :], + ) + + if getattr(attn, "to_out", None) is not None: + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + + if getattr(attn, "to_add_out", None) is not None: + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + + +class HunyuanVideoPatchEmbed(nn.Module): + def __init__( + self, + patch_size: Union[int, Tuple[int, int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + ) -> None: + super().__init__() + + patch_size = (patch_size, patch_size, patch_size) if isinstance(patch_size, int) else patch_size + self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.proj(hidden_states) + hidden_states = hidden_states.flatten(2).transpose(1, 2) # BCFHW -> BNC + return hidden_states + + +class HunyuanVideoAdaNorm(nn.Module): + def __init__(self, in_features: int, out_features: Optional[int] = None) -> None: + super().__init__() + + out_features = out_features or 2 * in_features + self.linear = nn.Linear(in_features, out_features) + self.nonlinearity = nn.SiLU() + + def forward( + self, temb: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + temb = self.linear(self.nonlinearity(temb)) + gate_msa, gate_mlp = temb.chunk(2, dim=1) + gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1) + return gate_msa, gate_mlp + + +class HunyuanVideoIndividualTokenRefinerBlock(nn.Module): + def __init__( + self, + num_attention_heads: int, + attention_head_dim: int, + mlp_width_ratio: str = 4.0, + mlp_drop_rate: float = 0.0, + attention_bias: bool = True, + ) -> None: + super().__init__() + + hidden_size = num_attention_heads * attention_head_dim + + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6) + self.attn = Attention( + query_dim=hidden_size, + cross_attention_dim=None, + heads=num_attention_heads, + dim_head=attention_head_dim, + bias=attention_bias, + ) + + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6) + self.ff = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="linear-silu", dropout=mlp_drop_rate) + + self.norm_out = HunyuanVideoAdaNorm(hidden_size, 2 * hidden_size) + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + norm_hidden_states = self.norm1(hidden_states) + + attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=None, + attention_mask=attention_mask, + ) + + gate_msa, gate_mlp = self.norm_out(temb) + hidden_states = hidden_states + attn_output * gate_msa + + ff_output = self.ff(self.norm2(hidden_states)) + hidden_states = hidden_states + ff_output * gate_mlp + + return hidden_states + + +class HunyuanVideoIndividualTokenRefiner(nn.Module): + def __init__( + self, + num_attention_heads: int, + attention_head_dim: int, + num_layers: int, + mlp_width_ratio: float = 4.0, + mlp_drop_rate: float = 0.0, + attention_bias: bool = True, + ) -> None: + super().__init__() + + self.refiner_blocks = nn.ModuleList( + [ + HunyuanVideoIndividualTokenRefinerBlock( + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + mlp_width_ratio=mlp_width_ratio, + mlp_drop_rate=mlp_drop_rate, + attention_bias=attention_bias, + ) + for _ in range(num_layers) + ] + ) + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> None: + self_attn_mask = None + if attention_mask is not None: + batch_size = attention_mask.shape[0] + seq_len = attention_mask.shape[1] + attention_mask = attention_mask.to(hidden_states.device).bool() + self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1) + self_attn_mask_2 = self_attn_mask_1.transpose(2, 3) + self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool() + self_attn_mask[:, :, :, 0] = True + + for block in self.refiner_blocks: + hidden_states = block(hidden_states, temb, self_attn_mask) + + return hidden_states + + +class HunyuanVideoTokenRefiner(nn.Module): + def __init__( + self, + in_channels: int, + num_attention_heads: int, + attention_head_dim: int, + num_layers: int, + mlp_ratio: float = 4.0, + mlp_drop_rate: float = 0.0, + attention_bias: bool = True, + ) -> None: + super().__init__() + + hidden_size = num_attention_heads * attention_head_dim + + self.time_text_embed = CombinedTimestepTextProjEmbeddings( + embedding_dim=hidden_size, pooled_projection_dim=in_channels + ) + self.proj_in = nn.Linear(in_channels, hidden_size, bias=True) + self.token_refiner = HunyuanVideoIndividualTokenRefiner( + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + num_layers=num_layers, + mlp_width_ratio=mlp_ratio, + mlp_drop_rate=mlp_drop_rate, + attention_bias=attention_bias, + ) + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.LongTensor, + attention_mask: Optional[torch.LongTensor] = None, + ) -> torch.Tensor: + if attention_mask is None: + pooled_projections = hidden_states.mean(dim=1) + else: + original_dtype = hidden_states.dtype + mask_float = attention_mask.float().unsqueeze(-1) + pooled_projections = (hidden_states * mask_float).sum(dim=1) / mask_float.sum(dim=1) + pooled_projections = pooled_projections.to(original_dtype) + + temb = self.time_text_embed(timestep, pooled_projections) + hidden_states = self.proj_in(hidden_states) + hidden_states = self.token_refiner(hidden_states, temb, attention_mask) + + return hidden_states + + +class HunyuanVideoRotaryPosEmbed(nn.Module): + def __init__(self, patch_size: int, patch_size_t: int, rope_dim: List[int], theta: float = 256.0) -> None: + super().__init__() + + self.patch_size = patch_size + self.patch_size_t = patch_size_t + self.rope_dim = rope_dim + self.theta = theta + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + rope_sizes = [num_frames // self.patch_size_t, height // self.patch_size, width // self.patch_size] + + axes_grids = [] + for i in range(3): + # Note: The following line diverges from original behaviour. We create the grid on the device, whereas + # original implementation creates it on CPU and then moves it to device. This results in numerical + # differences in layerwise debugging outputs, but visually it is the same. + grid = torch.arange(0, rope_sizes[i], device=hidden_states.device, dtype=torch.float32) + axes_grids.append(grid) + grid = torch.meshgrid(*axes_grids, indexing="ij") # [W, H, T] + grid = torch.stack(grid, dim=0) # [3, W, H, T] + + freqs = [] + for i in range(3): + freq = get_1d_rotary_pos_embed(self.rope_dim[i], grid[i].reshape(-1), self.theta, use_real=True) + freqs.append(freq) + + freqs_cos = torch.cat([f[0] for f in freqs], dim=1) # (W * H * T, D / 2) + freqs_sin = torch.cat([f[1] for f in freqs], dim=1) # (W * H * T, D / 2) + return freqs_cos, freqs_sin + + +class HunyuanVideoSingleTransformerBlock(nn.Module): + def __init__( + self, + num_attention_heads: int, + attention_head_dim: int, + mlp_ratio: float = 4.0, + qk_norm: str = "rms_norm", + ) -> None: + super().__init__() + + hidden_size = num_attention_heads * attention_head_dim + mlp_dim = int(hidden_size * mlp_ratio) + + self.attn = Attention( + query_dim=hidden_size, + cross_attention_dim=None, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=hidden_size, + bias=True, + processor=HunyuanVideoAttnProcessor2_0(), + qk_norm=qk_norm, + eps=1e-6, + pre_only=True, + ) + + self.norm = AdaLayerNormZeroSingle(hidden_size, norm_type="layer_norm") + self.proj_mlp = nn.Linear(hidden_size, mlp_dim) + self.act_mlp = nn.GELU(approximate="tanh") + self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + text_seq_length = encoder_hidden_states.shape[1] + hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1) + + residual = hidden_states + + # 1. Input normalization + norm_hidden_states, gate = self.norm(hidden_states, emb=temb) + mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) + + norm_hidden_states, norm_encoder_hidden_states = ( + norm_hidden_states[:, :-text_seq_length, :], + norm_hidden_states[:, -text_seq_length:, :], + ) + + # 2. Attention + attn_output, context_attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + attention_mask=attention_mask, + image_rotary_emb=image_rotary_emb, + ) + attn_output = torch.cat([attn_output, context_attn_output], dim=1) + + # 3. Modulation and residual connection + hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) + hidden_states = gate.unsqueeze(1) * self.proj_out(hidden_states) + hidden_states = hidden_states + residual + + hidden_states, encoder_hidden_states = ( + hidden_states[:, :-text_seq_length, :], + hidden_states[:, -text_seq_length:, :], + ) + return hidden_states, encoder_hidden_states + + +class HunyuanVideoTransformerBlock(nn.Module): + def __init__( + self, + num_attention_heads: int, + attention_head_dim: int, + mlp_ratio: float, + qk_norm: str = "rms_norm", + ) -> None: + super().__init__() + + hidden_size = num_attention_heads * attention_head_dim + + self.norm1 = AdaLayerNormZero(hidden_size, norm_type="layer_norm") + self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm") + + self.attn = Attention( + query_dim=hidden_size, + cross_attention_dim=None, + added_kv_proj_dim=hidden_size, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=hidden_size, + context_pre_only=False, + bias=True, + processor=HunyuanVideoAttnProcessor2_0(), + qk_norm=qk_norm, + eps=1e-6, + ) + + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate") + + self.norm2_context = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate") + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # 1. Input normalization + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) + norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( + encoder_hidden_states, emb=temb + ) + + # 2. Joint attention + attn_output, context_attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + attention_mask=attention_mask, + image_rotary_emb=freqs_cis, + ) + + # 3. Modulation and residual connection + hidden_states = hidden_states + attn_output * gate_msa.unsqueeze(1) + encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa.unsqueeze(1) + + norm_hidden_states = self.norm2(hidden_states) + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + + # 4. Feed-forward + ff_output = self.ff(norm_hidden_states) + context_ff_output = self.ff_context(norm_encoder_hidden_states) + + hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output + encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + + return hidden_states, encoder_hidden_states + + +class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + in_channels: int = 16, + out_channels: int = 16, + num_attention_heads: int = 24, + attention_head_dim: int = 128, + num_layers: int = 20, + num_single_layers: int = 40, + num_refiner_layers: int = 2, + mlp_ratio: float = 4.0, + patch_size: int = 2, + patch_size_t: int = 1, + qk_norm: str = "rms_norm", + guidance_embeds: bool = True, + text_embed_dim: int = 4096, + pooled_projection_dim: int = 768, + rope_theta: float = 256.0, + rope_axes_dim: Tuple[int] = (16, 56, 56), + ) -> None: + super().__init__() + + inner_dim = num_attention_heads * attention_head_dim + out_channels = out_channels or in_channels + + # 1. Latent and condition embedders + self.x_embedder = HunyuanVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim) + self.context_embedder = HunyuanVideoTokenRefiner( + text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers + ) + self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings(inner_dim, pooled_projection_dim) + + # 2. RoPE + self.rope = HunyuanVideoRotaryPosEmbed(patch_size, patch_size_t, rope_axes_dim, rope_theta) + + # 3. Dual stream transformer blocks + self.transformer_blocks = nn.ModuleList( + [ + HunyuanVideoTransformerBlock( + num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm + ) + for _ in range(num_layers) + ] + ) + + # 4. Single stream transformer blocks + self.single_transformer_blocks = nn.ModuleList( + [ + HunyuanVideoSingleTransformerBlock( + num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm + ) + for _ in range(num_single_layers) + ] + ) + + # 5. Output projection + self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels) + + self.gradient_checkpointing = False + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.LongTensor, + encoder_hidden_states: torch.Tensor, + encoder_attention_mask: torch.Tensor, + pooled_projections: torch.Tensor, + guidance: torch.Tensor = None, + return_dict: bool = True, + ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p, p_t = self.config.patch_size, self.config.patch_size_t + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p + post_patch_width = width // p + + # 1. RoPE + image_rotary_emb = self.rope(hidden_states) + + # 2. Conditional embeddings + temb = self.time_text_embed(timestep, guidance, pooled_projections) + hidden_states = self.x_embedder(hidden_states) + encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask) + + # 3. Attention mask preparation + latent_sequence_length = hidden_states.shape[1] + condition_sequence_length = encoder_hidden_states.shape[1] + sequence_length = latent_sequence_length + condition_sequence_length + attention_mask = torch.zeros( + batch_size, sequence_length, sequence_length, device=hidden_states.device, dtype=torch.bool + ) # [B, N, N] + + effective_condition_sequence_length = encoder_attention_mask.sum(dim=1, dtype=torch.int) # [B,] + effective_sequence_length = latent_sequence_length + effective_condition_sequence_length + + for i in range(batch_size): + attention_mask[i, : effective_sequence_length[i], : effective_sequence_length[i]] = True + + # 4. Transformer blocks + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + + for block in self.transformer_blocks: + hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + temb, + attention_mask, + image_rotary_emb, + **ckpt_kwargs, + ) + + for block in self.single_transformer_blocks: + hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + temb, + attention_mask, + image_rotary_emb, + **ckpt_kwargs, + ) + + else: + for block in self.transformer_blocks: + hidden_states, encoder_hidden_states = block( + hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb + ) + + for block in self.single_transformer_blocks: + hidden_states, encoder_hidden_states = block( + hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb + ) + + # 5. Output projection + hidden_states = self.norm_out(hidden_states, temb) + hidden_states = self.proj_out(hidden_states) + + hidden_states = hidden_states.reshape( + batch_size, post_patch_num_frames, post_patch_height, post_patch_width, -1, p_t, p, p + ) + hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7) + hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + + if not return_dict: + return (hidden_states,) + + return Transformer2DModelOutput(sample=hidden_states) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 6f1b842f92f2..e7fd7ec78bed 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -214,6 +214,7 @@ "IFSuperResolutionPipeline", ] _import_structure["hunyuandit"] = ["HunyuanDiTPipeline"] + _import_structure["hunyuan_video"] = ["HunyuanVideoPipeline"] _import_structure["kandinsky"] = [ "KandinskyCombinedPipeline", "KandinskyImg2ImgCombinedPipeline", @@ -549,6 +550,7 @@ FluxPriorReduxPipeline, ReduxImageEncoder, ) + from .hunyuan_video import HunyuanVideoPipeline from .hunyuandit import HunyuanDiTPipeline from .i2vgen_xl import I2VGenXLPipeline from .kandinsky import ( diff --git a/src/diffusers/pipelines/hunyuan_video/__init__.py b/src/diffusers/pipelines/hunyuan_video/__init__.py new file mode 100644 index 000000000000..978ed7f96110 --- /dev/null +++ b/src/diffusers/pipelines/hunyuan_video/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_hunyuan_video"] = ["HunyuanVideoPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_hunyuan_video import HunyuanVideoPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py new file mode 100644 index 000000000000..bd3d3c1e8485 --- /dev/null +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py @@ -0,0 +1,675 @@ +# Copyright 2024 The HunyuanVideo Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import HunyuanVideoPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel + >>> from diffusers.utils import export_to_video + + >>> model_id = "tencent/HunyuanVideo" + >>> transformer = HunyuanVideoTransformer3DModel.from_pretrained( + ... model_id, subfolder="transformer", torch_dtype=torch.bfloat16 + ... ) + >>> pipe = HunyuanVideoPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.float16) + >>> pipe.vae.enable_tiling() + >>> pipe.to("cuda") + + >>> output = pipe( + ... prompt="A cat walks on the grass, realistic", + ... height=320, + ... width=512, + ... num_frames=61, + ... num_inference_steps=30, + ... ).frames[0] + >>> export_to_video(output, "output.mp4", fps=15) + ``` +""" + + +DEFAULT_PROMPT_TEMPLATE = { + "template": ( + "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: " + "1. The main content and theme of the video." + "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects." + "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects." + "4. background environment, light, style and atmosphere." + "5. camera angles, movements, and transitions used in the video:<|eot_id|>" + "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" + ), + "crop_start": 95, +} + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class HunyuanVideoPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using HunyuanVideo. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + text_encoder ([`LlamaModel`]): + [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers). + tokenizer_2 (`LlamaTokenizer`): + Tokenizer from [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers). + transformer ([`HunyuanVideoTransformer3DModel`]): + Conditional Transformer to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLHunyuanVideo`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder_2 ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer_2 (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + text_encoder: LlamaModel, + tokenizer: LlamaTokenizerFast, + transformer: HunyuanVideoTransformer3DModel, + vae: AutoencoderKLHunyuanVideo, + scheduler: FlowMatchEulerDiscreteScheduler, + text_encoder_2: CLIPTextModel, + tokenizer_2: CLIPTokenizer, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + ) + + self.vae_scale_factor_temporal = ( + self.vae.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4 + ) + self.vae_scale_factor_spatial = ( + self.vae.spatial_compression_ratio if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + def _get_llama_prompt_embeds( + self, + prompt: Union[str, List[str]], + prompt_template: Dict[str, Any], + num_videos_per_prompt: int = 1, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 256, + num_hidden_layers_to_skip: int = 2, + ) -> Tuple[torch.Tensor, torch.Tensor]: + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + prompt = [prompt_template["template"].format(p) for p in prompt] + + crop_start = prompt_template.get("crop_start", None) + if crop_start is None: + prompt_template_input = self.tokenizer( + prompt_template["template"], + padding="max_length", + return_tensors="pt", + return_length=False, + return_overflowing_tokens=False, + return_attention_mask=False, + ) + crop_start = prompt_template_input["input_ids"].shape[-1] + # Remove <|eot_id|> token and placeholder {} + crop_start -= 2 + + max_sequence_length += crop_start + text_inputs = self.tokenizer( + prompt, + max_length=max_sequence_length, + padding="max_length", + truncation=True, + return_tensors="pt", + return_length=False, + return_overflowing_tokens=False, + return_attention_mask=True, + ) + text_input_ids = text_inputs.input_ids.to(device=device) + prompt_attention_mask = text_inputs.attention_mask.to(device=device) + + prompt_embeds = self.text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_attention_mask, + output_hidden_states=True, + ).hidden_states[-(num_hidden_layers_to_skip + 1)] + prompt_embeds = prompt_embeds.to(dtype=dtype) + + if crop_start is not None and crop_start > 0: + prompt_embeds = prompt_embeds[:, crop_start:] + prompt_attention_mask = prompt_attention_mask[:, crop_start:] + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.repeat(1, num_videos_per_prompt) + prompt_attention_mask = prompt_attention_mask.view(batch_size * num_videos_per_prompt, seq_len) + + return prompt_embeds, prompt_attention_mask + + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_videos_per_prompt: int = 1, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 77, + ) -> torch.Tensor: + device = device or self._execution_device + dtype = dtype or self.text_encoder_2.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False).pooler_output + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]] = None, + prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 256, + ): + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_llama_prompt_embeds( + prompt, + prompt_template, + num_videos_per_prompt, + device=device, + dtype=dtype, + max_sequence_length=max_sequence_length, + ) + + if pooled_prompt_embeds is None: + if prompt_2 is None and pooled_prompt_embeds is None: + prompt_2 = prompt + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt, + num_videos_per_prompt, + device=device, + dtype=dtype, + max_sequence_length=77, + ) + + return prompt_embeds, pooled_prompt_embeds, prompt_attention_mask + + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + prompt_template=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if prompt_template is not None: + if not isinstance(prompt_template, dict): + raise ValueError(f"`prompt_template` has to be of type `dict` but is {type(prompt_template)}") + if "template" not in prompt_template: + raise ValueError( + f"`prompt_template` has to contain a key `template` but only found {prompt_template.keys()}" + ) + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: 32, + height: int = 720, + width: int = 1280, + num_frames: int = 129, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + shape = ( + batch_size, + num_channels_latents, + num_frames, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Union[str, List[str]] = None, + height: int = 720, + width: int = 1280, + num_frames: int = 129, + num_inference_steps: int = 50, + sigmas: List[float] = None, + guidance_scale: float = 6.0, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE, + max_sequence_length: int = 256, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead. + height (`int`, defaults to `720`): + The height in pixels of the generated image. + width (`int`, defaults to `1280`): + The width in pixels of the generated image. + num_frames (`int`, defaults to `129`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, defaults to `6.0`): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. Note that the only available HunyuanVideo model is + CFG-distilled, which means that traditional guidance between unconditional and conditional latent is + not applied. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a plain tuple. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~HunyuanVideoPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`HunyuanVideoPipelineOutput`] is returned, otherwise a `tuple` is returned + where the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + prompt_embeds, + callback_on_step_end_tensor_inputs, + prompt_template, + ) + + self._guidance_scale = guidance_scale + self._interrupt = False + + device = self._execution_device + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_template=prompt_template, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + device=device, + max_sequence_length=max_sequence_length, + ) + + transformer_dtype = self.transformer.dtype + prompt_embeds = prompt_embeds.to(transformer_dtype) + prompt_attention_mask = prompt_attention_mask.to(transformer_dtype) + if pooled_prompt_embeds is not None: + pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype) + + # 4. Prepare timesteps + sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + ) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_latent_frames, + torch.float32, + device, + generator, + latents, + ) + + # 6. Prepare guidance condition + guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0 + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = latents.to(transformer_dtype) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + pooled_projections=pooled_prompt_embeds, + guidance=guidance, + return_dict=False, + )[0] + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return HunyuanVideoPipelineOutput(frames=video) diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_output.py b/src/diffusers/pipelines/hunyuan_video/pipeline_output.py new file mode 100644 index 000000000000..c5cb853e3932 --- /dev/null +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_output.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +import torch + +from diffusers.utils import BaseOutput + + +@dataclass +class HunyuanVideoPipelineOutput(BaseOutput): + r""" + Output class for HunyuanVideo pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + frames: torch.Tensor diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 0f2aad5c5000..4b6ac10385cf 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -107,6 +107,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class AutoencoderKLHunyuanVideo(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class AutoencoderKLLTXVideo(metaclass=DummyObject): _backends = ["torch"] @@ -377,6 +392,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class HunyuanVideoTransformer3DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class I2VGenXLUNet(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 8aefce9d624e..e148c025d191 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -572,6 +572,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class HunyuanVideoPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class I2VGenXLPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py b/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py new file mode 100644 index 000000000000..826ac30d5f2f --- /dev/null +++ b/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py @@ -0,0 +1,159 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import AutoencoderKLHunyuanVideo +from diffusers.utils.testing_utils import ( + enable_full_determinism, + floats_tensor, + torch_device, +) + +from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin + + +enable_full_determinism() + + +class AutoencoderKLHunyuanVideoTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): + model_class = AutoencoderKLHunyuanVideo + main_input_name = "sample" + base_precision = 1e-2 + + def get_autoencoder_kl_hunyuan_video_config(self): + return { + "in_channels": 3, + "out_channels": 3, + "latent_channels": 4, + "down_block_types": ( + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + ), + "up_block_types": ( + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + ), + "block_out_channels": (8, 8, 8, 8), + "layers_per_block": 1, + "act_fn": "silu", + "norm_num_groups": 4, + "scaling_factor": 0.476986, + "spatial_compression_ratio": 8, + "temporal_compression_ratio": 4, + "mid_block_add_attention": True, + } + + @property + def dummy_input(self): + batch_size = 2 + num_frames = 9 + num_channels = 3 + sizes = (16, 16) + + image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device) + + return {"sample": image} + + @property + def input_shape(self): + return (3, 9, 16, 16) + + @property + def output_shape(self): + return (3, 9, 16, 16) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = self.get_autoencoder_kl_hunyuan_video_config() + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_enable_disable_tiling(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + torch.manual_seed(0) + model = self.model_class(**init_dict).to(torch_device) + + inputs_dict.update({"return_dict": False}) + + torch.manual_seed(0) + output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0] + + torch.manual_seed(0) + model.enable_tiling() + output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0] + + self.assertLess( + (output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(), + 0.5, + "VAE tiling should not affect the inference results", + ) + + torch.manual_seed(0) + model.disable_tiling() + output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0] + + self.assertEqual( + output_without_tiling.detach().cpu().numpy().all(), + output_without_tiling_2.detach().cpu().numpy().all(), + "Without tiling outputs should match with the outputs when tiling is manually disabled.", + ) + + def test_enable_disable_slicing(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + torch.manual_seed(0) + model = self.model_class(**init_dict).to(torch_device) + + inputs_dict.update({"return_dict": False}) + + torch.manual_seed(0) + output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0] + + torch.manual_seed(0) + model.enable_slicing() + output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0] + + self.assertLess( + (output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(), + 0.5, + "VAE slicing should not affect the inference results", + ) + + torch.manual_seed(0) + model.disable_slicing() + output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0] + + self.assertEqual( + output_without_slicing.detach().cpu().numpy().all(), + output_without_slicing_2.detach().cpu().numpy().all(), + "Without slicing outputs should match with the outputs when slicing is manually disabled.", + ) + + def test_gradient_checkpointing_is_applied(self): + expected_set = { + "HunyuanVideoDecoder3D", + "HunyuanVideoDownBlock3D", + "HunyuanVideoEncoder3D", + "HunyuanVideoMidBlock3D", + "HunyuanVideoUpBlock3D", + } + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + @unittest.skip("Unsupported test.") + def test_outputs_equivalence(self): + pass diff --git a/tests/models/transformers/test_models_transformer_hunyuan_video.py b/tests/models/transformers/test_models_transformer_hunyuan_video.py new file mode 100644 index 000000000000..e8ea8cecbb9e --- /dev/null +++ b/tests/models/transformers/test_models_transformer_hunyuan_video.py @@ -0,0 +1,89 @@ +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import HunyuanVideoTransformer3DModel +from diffusers.utils.testing_utils import enable_full_determinism, torch_device + +from ..test_modeling_common import ModelTesterMixin + + +enable_full_determinism() + + +class HunyuanVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase): + model_class = HunyuanVideoTransformer3DModel + main_input_name = "hidden_states" + uses_custom_attn_processor = True + + @property + def dummy_input(self): + batch_size = 1 + num_channels = 4 + num_frames = 1 + height = 16 + width = 16 + text_encoder_embedding_dim = 16 + pooled_projection_dim = 8 + sequence_length = 12 + + hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device) + timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device) + encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device) + pooled_projections = torch.randn((batch_size, pooled_projection_dim)).to(torch_device) + encoder_attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device) + guidance = torch.randint(0, 1000, size=(batch_size,)).to(torch_device, dtype=torch.float32) + + return { + "hidden_states": hidden_states, + "timestep": timestep, + "encoder_hidden_states": encoder_hidden_states, + "pooled_projections": pooled_projections, + "encoder_attention_mask": encoder_attention_mask, + "guidance": guidance, + } + + @property + def input_shape(self): + return (4, 1, 16, 16) + + @property + def output_shape(self): + return (4, 1, 16, 16) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "in_channels": 4, + "out_channels": 4, + "num_attention_heads": 2, + "attention_head_dim": 10, + "num_layers": 1, + "num_single_layers": 1, + "num_refiner_layers": 1, + "patch_size": 1, + "patch_size_t": 1, + "guidance_embeds": True, + "text_embed_dim": 16, + "pooled_projection_dim": 8, + "rope_axes_dim": (2, 4, 4), + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_gradient_checkpointing_is_applied(self): + expected_set = {"HunyuanVideoTransformer3DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) diff --git a/tests/pipelines/hunyuan_video/__init__.py b/tests/pipelines/hunyuan_video/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/hunyuan_video/test_hunyuan_video.py b/tests/pipelines/hunyuan_video/test_hunyuan_video.py new file mode 100644 index 000000000000..567002268106 --- /dev/null +++ b/tests/pipelines/hunyuan_video/test_hunyuan_video.py @@ -0,0 +1,331 @@ +# Copyright 2024 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import unittest + +import numpy as np +import torch +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer, LlamaConfig, LlamaModel, LlamaTokenizer + +from diffusers import ( + AutoencoderKLHunyuanVideo, + FlowMatchEulerDiscreteScheduler, + HunyuanVideoPipeline, + HunyuanVideoTransformer3DModel, +) +from diffusers.utils.testing_utils import ( + enable_full_determinism, + torch_device, +) + +from ..test_pipelines_common import PipelineTesterMixin, to_np + + +enable_full_determinism() + + +class HunyuanVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = HunyuanVideoPipeline + params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) + batch_params = frozenset(["prompt"]) + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + + # there is no xformers processor for Flux + test_xformers_attention = False + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = HunyuanVideoTransformer3DModel( + in_channels=4, + out_channels=4, + num_attention_heads=2, + attention_head_dim=10, + num_layers=1, + num_single_layers=1, + num_refiner_layers=1, + patch_size=1, + patch_size_t=1, + guidance_embeds=True, + text_embed_dim=16, + pooled_projection_dim=8, + rope_axes_dim=(2, 4, 4), + ) + + torch.manual_seed(0) + vae = AutoencoderKLHunyuanVideo( + in_channels=3, + out_channels=3, + latent_channels=4, + down_block_types=( + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + ), + up_block_types=( + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + ), + block_out_channels=(8, 8, 8, 8), + layers_per_block=1, + act_fn="silu", + norm_num_groups=4, + scaling_factor=0.476986, + spatial_compression_ratio=8, + temporal_compression_ratio=4, + mid_block_add_attention=True, + ) + + torch.manual_seed(0) + scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0) + + llama_text_encoder_config = LlamaConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=16, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=2, + pad_token_id=1, + vocab_size=1000, + hidden_act="gelu", + projection_dim=32, + ) + clip_text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=8, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=2, + pad_token_id=1, + vocab_size=1000, + hidden_act="gelu", + projection_dim=32, + ) + + torch.manual_seed(0) + text_encoder = LlamaModel(llama_text_encoder_config) + tokenizer = LlamaTokenizer.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM") + + torch.manual_seed(0) + text_encoder_2 = CLIPTextModel(clip_text_encoder_config) + tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "text_encoder_2": text_encoder_2, + "tokenizer": tokenizer, + "tokenizer_2": tokenizer_2, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + # Cannot test with dummy prompt because tokenizers are not configured correctly. + # TODO(aryan): create dummy tokenizers and using from hub + inputs = { + "prompt": "", + "prompt_template": { + "template": "{}", + "crop_start": 0, + }, + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 4.5, + "height": 16, + "width": 16, + # 4 * k + 1 is the recommendation + "num_frames": 9, + "max_sequence_length": 16, + "output_type": "pt", + } + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + video = pipe(**inputs).frames + generated_video = video[0] + + self.assertEqual(generated_video.shape, (9, 3, 16, 16)) + expected_video = torch.randn(9, 3, 16, 16) + max_diff = np.abs(generated_video - expected_video).max() + self.assertLessEqual(max_diff, 1e10) + + def test_callback_inputs(self): + sig = inspect.signature(self.pipeline_class.__call__) + has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters + has_callback_step_end = "callback_on_step_end" in sig.parameters + + if not (has_callback_tensor_inputs and has_callback_step_end): + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + self.assertTrue( + hasattr(pipe, "_callback_tensor_inputs"), + f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs", + ) + + def callback_inputs_subset(pipe, i, t, callback_kwargs): + # iterate over callback args + for tensor_name, tensor_value in callback_kwargs.items(): + # check that we're only passing in allowed tensor inputs + assert tensor_name in pipe._callback_tensor_inputs + + return callback_kwargs + + def callback_inputs_all(pipe, i, t, callback_kwargs): + for tensor_name in pipe._callback_tensor_inputs: + assert tensor_name in callback_kwargs + + # iterate over callback args + for tensor_name, tensor_value in callback_kwargs.items(): + # check that we're only passing in allowed tensor inputs + assert tensor_name in pipe._callback_tensor_inputs + + return callback_kwargs + + inputs = self.get_dummy_inputs(torch_device) + + # Test passing in a subset + inputs["callback_on_step_end"] = callback_inputs_subset + inputs["callback_on_step_end_tensor_inputs"] = ["latents"] + output = pipe(**inputs)[0] + + # Test passing in a everything + inputs["callback_on_step_end"] = callback_inputs_all + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + output = pipe(**inputs)[0] + + def callback_inputs_change_tensor(pipe, i, t, callback_kwargs): + is_last = i == (pipe.num_timesteps - 1) + if is_last: + callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"]) + return callback_kwargs + + inputs["callback_on_step_end"] = callback_inputs_change_tensor + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + output = pipe(**inputs)[0] + assert output.abs().sum() < 1e10 + + def test_attention_slicing_forward_pass( + self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3 + ): + if not self.test_attention_slicing: + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + output_without_slicing = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=1) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing1 = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=2) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing2 = pipe(**inputs)[0] + + if test_max_difference: + max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max() + max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max() + self.assertLess( + max(max_diff1, max_diff2), + expected_max_diff, + "Attention slicing should not affect the inference results", + ) + + def test_vae_tiling(self, expected_diff_max: float = 0.2): + # Seems to require higher tolerance than the other tests + expected_diff_max = 0.6 + generator_device = "cpu" + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + pipe.to("cpu") + pipe.set_progress_bar_config(disable=None) + + # Without tiling + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_without_tiling = pipe(**inputs)[0] + + # With tiling + pipe.vae.enable_tiling( + tile_sample_min_height=96, + tile_sample_min_width=96, + tile_sample_stride_height=64, + tile_sample_stride_width=64, + ) + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_with_tiling = pipe(**inputs)[0] + + self.assertLess( + (to_np(output_without_tiling) - to_np(output_with_tiling)).max(), + expected_diff_max, + "VAE tiling should not affect the inference results", + ) + + # TODO(aryan): Create a dummy gemma model with smol vocab size + @unittest.skip( + "A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error." + ) + def test_inference_batch_consistent(self): + pass + + @unittest.skip( + "A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error." + ) + def test_inference_batch_single_identical(self): + pass