diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index f4eb32cf63a8..d1404a1d6ea6 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -270,6 +270,8 @@ title: FluxTransformer2DModel - local: api/models/hunyuan_transformer2d title: HunyuanDiT2DModel + - local: api/models/hunyuan_video_transformer_3d + title: HunyuanVideoTransformer3DModel - local: api/models/latte_transformer3d title: LatteTransformer3DModel - local: api/models/lumina_nextdit2d @@ -316,6 +318,8 @@ title: AutoencoderKLAllegro - local: api/models/autoencoderkl_cogvideox title: AutoencoderKLCogVideoX + - local: api/models/autoencoder_kl_hunyuan_video + title: AutoencoderKLHunyuanVideo - local: api/models/autoencoderkl_ltx_video title: AutoencoderKLLTXVideo - local: api/models/autoencoderkl_mochi @@ -394,6 +398,8 @@ title: Flux - local: api/pipelines/hunyuandit title: Hunyuan-DiT + - local: api/pipelines/hunyuan_video + title: HunyuanVideo - local: api/pipelines/i2vgenxl title: I2VGen-XL - local: api/pipelines/pix2pix diff --git a/docs/source/en/api/models/autoencoder_kl_hunyuan_video.md b/docs/source/en/api/models/autoencoder_kl_hunyuan_video.md new file mode 100644 index 000000000000..f69c14814d3d --- /dev/null +++ b/docs/source/en/api/models/autoencoder_kl_hunyuan_video.md @@ -0,0 +1,32 @@ + + +# AutoencoderKLHunyuanVideo + +The 3D variational autoencoder (VAE) model with KL loss used in [HunyuanVideo](https://github.com/Tencent/HunyuanVideo/), which was introduced in [HunyuanVideo: A Systematic Framework For Large Video Generative Models](https://huggingface.co/papers/2412.03603) by Tencent. + +The model can be loaded with the following code snippet. + +```python +from diffusers import AutoencoderKLHunyuanVideo + +vae = AutoencoderKLHunyuanVideo.from_pretrained("tencent/HunyuanVideo", torch_dtype=torch.float16) +``` + +## AutoencoderKLHunyuanVideo + +[[autodoc]] AutoencoderKLHunyuanVideo + - decode + - all + +## DecoderOutput + +[[autodoc]] models.autoencoders.vae.DecoderOutput diff --git a/docs/source/en/api/models/hunyuan_video_transformer_3d.md b/docs/source/en/api/models/hunyuan_video_transformer_3d.md new file mode 100644 index 000000000000..73aea9832fc0 --- /dev/null +++ b/docs/source/en/api/models/hunyuan_video_transformer_3d.md @@ -0,0 +1,30 @@ + + +# HunyuanVideoTransformer3DModel + +A Diffusion Transformer model for 3D video-like data was introduced in [HunyuanVideo: A Systematic Framework For Large Video Generative Models](https://huggingface.co/papers/2412.03603) by Tencent. + +The model can be loaded with the following code snippet. + +```python +from diffusers import HunyuanVideoTransformer3DModel + +transformer = HunyuanVideoTransformer3DModel.from_pretrained("tencent/HunyuanVideo", torch_dtype=torch.bfloat16) +``` + +## HunyuanVideoTransformer3DModel + +[[autodoc]] HunyuanVideoTransformer3DModel + +## Transformer2DModelOutput + +[[autodoc]] models.modeling_outputs.Transformer2DModelOutput diff --git a/docs/source/en/api/pipelines/hunyuan_video.md b/docs/source/en/api/pipelines/hunyuan_video.md new file mode 100644 index 000000000000..86ef816fcd4d --- /dev/null +++ b/docs/source/en/api/pipelines/hunyuan_video.md @@ -0,0 +1,43 @@ + + +# HunyuanVideo + +[HunyuanVideo](https://www.arxiv.org/abs/2412.03603) by Tencent. + +*Recent advancements in video generation have significantly impacted daily life for both individuals and industries. However, the leading video generation models remain closed-source, resulting in a notable performance gap between industry capabilities and those available to the public. In this report, we introduce HunyuanVideo, an innovative open-source video foundation model that demonstrates performance in video generation comparable to, or even surpassing, that of leading closed-source models. HunyuanVideo encompasses a comprehensive framework that integrates several key elements, including data curation, advanced architectural design, progressive model scaling and training, and an efficient infrastructure tailored for large-scale model training and inference. As a result, we successfully trained a video generative model with over 13 billion parameters, making it the largest among all open-source models. We conducted extensive experiments and implemented a series of targeted designs to ensure high visual quality, motion dynamics, text-video alignment, and advanced filming techniques. According to evaluations by professionals, HunyuanVideo outperforms previous state-of-the-art models, including Runway Gen-3, Luma 1.6, and three top-performing Chinese video generative models. By releasing the code for the foundation model and its applications, we aim to bridge the gap between closed-source and open-source communities. This initiative will empower individuals within the community to experiment with their ideas, fostering a more dynamic and vibrant video generation ecosystem. The code is publicly available at [this https URL](https://github.com/Tencent/HunyuanVideo).* + + + +Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines. + + + +Recommendations for inference: +- Both text encoders should be in `torch.float16`. +- Transformer should be in `torch.bfloat16`. +- VAE should be in `torch.float16`. +- `num_frames` should be of the form `4 * k + 1`, for example `49` or `129`. +- For smaller resolution images, try lower values of `shift` (between `2.0` to `5.0`) in the [Scheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/flow_match_euler_discrete#diffusers.FlowMatchEulerDiscreteScheduler.shift). For larger resolution images, try higher values (between `7.0` and `12.0`). The default value is `7.0` for HunyuanVideo. +- For more information about supported resolutions and other details, please refer to the original repository [here](https://github.com/Tencent/HunyuanVideo/). + +## HunyuanVideoPipeline + +[[autodoc]] HunyuanVideoPipeline + - all + - __call__ + +## HunyuanVideoPipelineOutput + +[[autodoc]] pipelines.hunyuan_video.pipeline_output.HunyuanVideoPipelineOutput diff --git a/scripts/convert_hunyuan_video_to_diffusers.py b/scripts/convert_hunyuan_video_to_diffusers.py new file mode 100644 index 000000000000..464c9e0fb954 --- /dev/null +++ b/scripts/convert_hunyuan_video_to_diffusers.py @@ -0,0 +1,257 @@ +import argparse +from typing import Any, Dict + +import torch +from accelerate import init_empty_weights +from transformers import AutoModel, AutoTokenizer, CLIPTextModel, CLIPTokenizer + +from diffusers import ( + AutoencoderKLHunyuanVideo, + FlowMatchEulerDiscreteScheduler, + HunyuanVideoPipeline, + HunyuanVideoTransformer3DModel, +) + + +def remap_norm_scale_shift_(key, state_dict): + weight = state_dict.pop(key) + shift, scale = weight.chunk(2, dim=0) + new_weight = torch.cat([scale, shift], dim=0) + state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight + + +def remap_txt_in_(key, state_dict): + def rename_key(key): + new_key = key.replace("individual_token_refiner.blocks", "token_refiner.refiner_blocks") + new_key = new_key.replace("adaLN_modulation.1", "norm_out.linear") + new_key = new_key.replace("txt_in", "context_embedder") + new_key = new_key.replace("t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1") + new_key = new_key.replace("t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2") + new_key = new_key.replace("c_embedder", "time_text_embed.text_embedder") + new_key = new_key.replace("mlp", "ff") + return new_key + + if "self_attn_qkv" in key: + weight = state_dict.pop(key) + to_q, to_k, to_v = weight.chunk(3, dim=0) + state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_q"))] = to_q + state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_k"))] = to_k + state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_v"))] = to_v + else: + state_dict[rename_key(key)] = state_dict.pop(key) + + +def remap_img_attn_qkv_(key, state_dict): + weight = state_dict.pop(key) + to_q, to_k, to_v = weight.chunk(3, dim=0) + state_dict[key.replace("img_attn_qkv", "attn.to_q")] = to_q + state_dict[key.replace("img_attn_qkv", "attn.to_k")] = to_k + state_dict[key.replace("img_attn_qkv", "attn.to_v")] = to_v + + +def remap_txt_attn_qkv_(key, state_dict): + weight = state_dict.pop(key) + to_q, to_k, to_v = weight.chunk(3, dim=0) + state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = to_q + state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = to_k + state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = to_v + + +def remap_single_transformer_blocks_(key, state_dict): + hidden_size = 3072 + + if "linear1.weight" in key: + linear1_weight = state_dict.pop(key) + split_size = (hidden_size, hidden_size, hidden_size, linear1_weight.size(0) - 3 * hidden_size) + q, k, v, mlp = torch.split(linear1_weight, split_size, dim=0) + new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.weight") + state_dict[f"{new_key}.attn.to_q.weight"] = q + state_dict[f"{new_key}.attn.to_k.weight"] = k + state_dict[f"{new_key}.attn.to_v.weight"] = v + state_dict[f"{new_key}.proj_mlp.weight"] = mlp + + elif "linear1.bias" in key: + linear1_bias = state_dict.pop(key) + split_size = (hidden_size, hidden_size, hidden_size, linear1_bias.size(0) - 3 * hidden_size) + q_bias, k_bias, v_bias, mlp_bias = torch.split(linear1_bias, split_size, dim=0) + new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.bias") + state_dict[f"{new_key}.attn.to_q.bias"] = q_bias + state_dict[f"{new_key}.attn.to_k.bias"] = k_bias + state_dict[f"{new_key}.attn.to_v.bias"] = v_bias + state_dict[f"{new_key}.proj_mlp.bias"] = mlp_bias + + else: + new_key = key.replace("single_blocks", "single_transformer_blocks") + new_key = new_key.replace("linear2", "proj_out") + new_key = new_key.replace("q_norm", "attn.norm_q") + new_key = new_key.replace("k_norm", "attn.norm_k") + state_dict[new_key] = state_dict.pop(key) + + +TRANSFORMER_KEYS_RENAME_DICT = { + "img_in": "x_embedder", + "time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1", + "time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2", + "guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1", + "guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2", + "vector_in.in_layer": "time_text_embed.text_embedder.linear_1", + "vector_in.out_layer": "time_text_embed.text_embedder.linear_2", + "double_blocks": "transformer_blocks", + "img_attn_q_norm": "attn.norm_q", + "img_attn_k_norm": "attn.norm_k", + "img_attn_proj": "attn.to_out.0", + "txt_attn_q_norm": "attn.norm_added_q", + "txt_attn_k_norm": "attn.norm_added_k", + "txt_attn_proj": "attn.to_add_out", + "img_mod.linear": "norm1.linear", + "img_norm1": "norm1.norm", + "img_norm2": "norm2", + "img_mlp": "ff", + "txt_mod.linear": "norm1_context.linear", + "txt_norm1": "norm1.norm", + "txt_norm2": "norm2_context", + "txt_mlp": "ff_context", + "self_attn_proj": "attn.to_out.0", + "modulation.linear": "norm.linear", + "pre_norm": "norm.norm", + "final_layer.norm_final": "norm_out.norm", + "final_layer.linear": "proj_out", + "fc1": "net.0.proj", + "fc2": "net.2", + "input_embedder": "proj_in", +} + +TRANSFORMER_SPECIAL_KEYS_REMAP = { + "txt_in": remap_txt_in_, + "img_attn_qkv": remap_img_attn_qkv_, + "txt_attn_qkv": remap_txt_attn_qkv_, + "single_blocks": remap_single_transformer_blocks_, + "final_layer.adaLN_modulation.1": remap_norm_scale_shift_, +} + +VAE_KEYS_RENAME_DICT = {} + +VAE_SPECIAL_KEYS_REMAP = {} + + +def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]: + state_dict[new_key] = state_dict.pop(old_key) + + +def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]: + state_dict = saved_dict + if "model" in saved_dict.keys(): + state_dict = state_dict["model"] + if "module" in saved_dict.keys(): + state_dict = state_dict["module"] + if "state_dict" in saved_dict.keys(): + state_dict = state_dict["state_dict"] + return state_dict + + +def convert_transformer(ckpt_path: str): + original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=True)) + + with init_empty_weights(): + transformer = HunyuanVideoTransformer3DModel() + + for key in list(original_state_dict.keys()): + new_key = key[:] + for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + update_state_dict_(original_state_dict, key, new_key) + + for key in list(original_state_dict.keys()): + for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items(): + if special_key not in key: + continue + handler_fn_inplace(key, original_state_dict) + + transformer.load_state_dict(original_state_dict, strict=True, assign=True) + return transformer + + +def convert_vae(ckpt_path: str): + original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=True)) + + with init_empty_weights(): + vae = AutoencoderKLHunyuanVideo() + + for key in list(original_state_dict.keys()): + new_key = key[:] + for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + update_state_dict_(original_state_dict, key, new_key) + + for key in list(original_state_dict.keys()): + for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items(): + if special_key not in key: + continue + handler_fn_inplace(key, original_state_dict) + + vae.load_state_dict(original_state_dict, strict=True, assign=True) + return vae + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint" + ) + parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original VAE checkpoint") + parser.add_argument("--text_encoder_path", type=str, default=None, help="Path to original llama checkpoint") + parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to original llama tokenizer") + parser.add_argument("--text_encoder_2_path", type=str, default=None, help="Path to original clip checkpoint") + parser.add_argument("--save_pipeline", action="store_true") + parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") + parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.") + return parser.parse_args() + + +DTYPE_MAPPING = { + "fp32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, +} + + +if __name__ == "__main__": + args = get_args() + + transformer = None + dtype = DTYPE_MAPPING[args.dtype] + + if args.save_pipeline: + assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None + assert args.text_encoder_path is not None + assert args.tokenizer_path is not None + assert args.text_encoder_2_path is not None + + if args.transformer_ckpt_path is not None: + transformer = convert_transformer(args.transformer_ckpt_path) + transformer = transformer.to(dtype=dtype) + if not args.save_pipeline: + transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") + + if args.vae_ckpt_path is not None: + vae = convert_vae(args.vae_ckpt_path) + if not args.save_pipeline: + vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") + + if args.save_pipeline: + text_encoder = AutoModel.from_pretrained(args.text_encoder_path, torch_dtype=torch.float16) + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, padding_side="right") + text_encoder_2 = CLIPTextModel.from_pretrained(args.text_encoder_2_path, torch_dtype=torch.float16) + tokenizer_2 = CLIPTokenizer.from_pretrained(args.text_encoder_2_path) + scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0) + + pipe = HunyuanVideoPipeline( + transformer=transformer, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + scheduler=scheduler, + ) + pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 20914442b84a..dfa7a4df2d08 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -84,6 +84,7 @@ "AutoencoderKL", "AutoencoderKLAllegro", "AutoencoderKLCogVideoX", + "AutoencoderKLHunyuanVideo", "AutoencoderKLLTXVideo", "AutoencoderKLMochi", "AutoencoderKLTemporalDecoder", @@ -102,6 +103,7 @@ "HunyuanDiT2DControlNetModel", "HunyuanDiT2DModel", "HunyuanDiT2DMultiControlNetModel", + "HunyuanVideoTransformer3DModel", "I2VGenXLUNet", "Kandinsky3UNet", "LatteTransformer3DModel", @@ -287,6 +289,7 @@ "HunyuanDiTControlNetPipeline", "HunyuanDiTPAGPipeline", "HunyuanDiTPipeline", + "HunyuanVideoPipeline", "I2VGenXLPipeline", "IFImg2ImgPipeline", "IFImg2ImgSuperResolutionPipeline", @@ -590,6 +593,7 @@ AutoencoderKL, AutoencoderKLAllegro, AutoencoderKLCogVideoX, + AutoencoderKLHunyuanVideo, AutoencoderKLLTXVideo, AutoencoderKLMochi, AutoencoderKLTemporalDecoder, @@ -608,6 +612,7 @@ HunyuanDiT2DControlNetModel, HunyuanDiT2DModel, HunyuanDiT2DMultiControlNetModel, + HunyuanVideoTransformer3DModel, I2VGenXLUNet, Kandinsky3UNet, LatteTransformer3DModel, @@ -772,6 +777,7 @@ HunyuanDiTControlNetPipeline, HunyuanDiTPAGPipeline, HunyuanDiTPipeline, + HunyuanVideoPipeline, I2VGenXLPipeline, IFImg2ImgPipeline, IFImg2ImgSuperResolutionPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 687c555e0ce2..01e67b01d91a 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -31,6 +31,7 @@ _import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"] _import_structure["autoencoders.autoencoder_kl_allegro"] = ["AutoencoderKLAllegro"] _import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"] + _import_structure["autoencoders.autoencoder_kl_hunyuan_video"] = ["AutoencoderKLHunyuanVideo"] _import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"] _import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"] _import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"] @@ -67,6 +68,7 @@ _import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"] _import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"] _import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"] + _import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"] _import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"] _import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"] _import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"] @@ -97,6 +99,7 @@ AutoencoderKL, AutoencoderKLAllegro, AutoencoderKLCogVideoX, + AutoencoderKLHunyuanVideo, AutoencoderKLLTXVideo, AutoencoderKLMochi, AutoencoderKLTemporalDecoder, @@ -130,6 +133,7 @@ DualTransformer2DModel, FluxTransformer2DModel, HunyuanDiT2DModel, + HunyuanVideoTransformer3DModel, LatteTransformer3DModel, LTXVideoTransformer3DModel, LuminaNextDiT2DModel, diff --git a/src/diffusers/models/activations.py b/src/diffusers/models/activations.py index c1d4f0b46e15..c61baefa08f4 100644 --- a/src/diffusers/models/activations.py +++ b/src/diffusers/models/activations.py @@ -164,3 +164,15 @@ def __init__(self, dim_in: int, dim_out: int, bias: bool = True): def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.proj(x) return x * torch.sigmoid(1.702 * x) + + +class LinearActivation(nn.Module): + def __init__(self, dim_in: int, dim_out: int, bias: bool = True, activation: str = "silu"): + super().__init__() + + self.proj = nn.Linear(dim_in, dim_out, bias=bias) + self.activation = get_activation(activation) + + def forward(self, hidden_states): + hidden_states = self.proj(hidden_states) + return self.activation(hidden_states) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 02ed1f965abf..6749c7f17254 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -19,7 +19,7 @@ from ..utils import deprecate, logging from ..utils.torch_utils import maybe_allow_in_graph -from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU +from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, LinearActivation, SwiGLU from .attention_processor import Attention, JointAttnProcessor2_0 from .embeddings import SinusoidalPositionalEmbedding from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX @@ -1222,6 +1222,8 @@ def __init__( act_fn = ApproximateGELU(dim, inner_dim, bias=bias) elif activation_fn == "swiglu": act_fn = SwiGLU(dim, inner_dim, bias=bias) + elif activation_fn == "linear-silu": + act_fn = LinearActivation(dim, inner_dim, bias=bias, activation="silu") self.net = nn.ModuleList([]) # project in diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 77e35364ab09..ee6b010519e2 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -254,14 +254,22 @@ def __init__( self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) if self.context_pre_only is not None: self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + else: + self.add_q_proj = None + self.add_k_proj = None + self.add_v_proj = None if not self.pre_only: self.to_out = nn.ModuleList([]) self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) self.to_out.append(nn.Dropout(dropout)) + else: + self.to_out = None if self.context_pre_only is not None and not self.context_pre_only: self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias) + else: + self.to_add_out = None if qk_norm is not None and added_kv_proj_dim is not None: if qk_norm == "fp32_layer_norm": @@ -782,7 +790,11 @@ def fuse_projections(self, fuse=True): self.to_kv.bias.copy_(concatenated_bias) # handle added projections for SD3 and others. - if hasattr(self, "add_q_proj") and hasattr(self, "add_k_proj") and hasattr(self, "add_v_proj"): + if ( + getattr(self, "add_q_proj", None) is not None + and getattr(self, "add_k_proj", None) is not None + and getattr(self, "add_v_proj", None) is not None + ): concatenated_weights = torch.cat( [self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data] ) @@ -3938,7 +3950,7 @@ def apply_rotary_emb(x, freqs_cos, freqs_sin): # dropout hidden_states = attn.to_out[1](hidden_states) - if hasattr(attn, "to_add_out"): + if getattr(attn, "to_add_out", None) is not None: encoder_hidden_states = attn.to_add_out(encoder_hidden_states) return hidden_states, encoder_hidden_states diff --git a/src/diffusers/models/autoencoders/__init__.py b/src/diffusers/models/autoencoders/__init__.py index d08e67c40975..bb750a4410f2 100644 --- a/src/diffusers/models/autoencoders/__init__.py +++ b/src/diffusers/models/autoencoders/__init__.py @@ -3,6 +3,7 @@ from .autoencoder_kl import AutoencoderKL from .autoencoder_kl_allegro import AutoencoderKLAllegro from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX +from .autoencoder_kl_hunyuan_video import AutoencoderKLHunyuanVideo from .autoencoder_kl_ltx import AutoencoderKLLTXVideo from .autoencoder_kl_mochi import AutoencoderKLMochi from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py new file mode 100644 index 000000000000..bded90a8bcff --- /dev/null +++ b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py @@ -0,0 +1,1175 @@ +# Copyright 2024 The Hunyuan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import is_torch_version, logging +from ...utils.accelerate_utils import apply_forward_hook +from ..activations import get_activation +from ..attention_processor import Attention +from ..modeling_outputs import AutoencoderKLOutput +from ..modeling_utils import ModelMixin +from .vae import DecoderOutput, DiagonalGaussianDistribution + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def prepare_causal_attention_mask( + num_frames: int, height_width: int, dtype: torch.dtype, device: torch.device, batch_size: int = None +) -> torch.Tensor: + seq_len = num_frames * height_width + mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device) + for i in range(seq_len): + i_frame = i // height_width + mask[i, : (i_frame + 1) * height_width] = 0 + if batch_size is not None: + mask = mask.unsqueeze(0).expand(batch_size, -1, -1) + return mask + + +class HunyuanVideoCausalConv3d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int, int]] = 3, + stride: Union[int, Tuple[int, int, int]] = 1, + padding: Union[int, Tuple[int, int, int]] = 0, + dilation: Union[int, Tuple[int, int, int]] = 1, + bias: bool = True, + pad_mode: str = "replicate", + ) -> None: + super().__init__() + + kernel_size = (kernel_size, kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size + + self.pad_mode = pad_mode + self.time_causal_padding = ( + kernel_size[0] // 2, + kernel_size[0] // 2, + kernel_size[1] // 2, + kernel_size[1] // 2, + kernel_size[2] - 1, + 0, + ) + + self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = F.pad(hidden_states, self.time_causal_padding, mode=self.pad_mode) + return self.conv(hidden_states) + + +class HunyuanVideoUpsampleCausal3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + kernel_size: int = 3, + stride: int = 1, + bias: bool = True, + upsample_factor: Tuple[float, float, float] = (2, 2, 2), + ) -> None: + super().__init__() + + out_channels = out_channels or in_channels + self.upsample_factor = upsample_factor + + self.conv = HunyuanVideoCausalConv3d(in_channels, out_channels, kernel_size, stride, bias=bias) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_frames = hidden_states.size(2) + + first_frame, other_frames = hidden_states.split((1, num_frames - 1), dim=2) + first_frame = F.interpolate( + first_frame.squeeze(2), scale_factor=self.upsample_factor[1:], mode="nearest" + ).unsqueeze(2) + + if num_frames > 1: + # See: https://github.com/pytorch/pytorch/issues/81665 + # Unless you have a version of pytorch where non-contiguous implementation of F.interpolate + # is fixed, this will raise either a runtime error, or fail silently with bad outputs. + # If you are encountering an error here, make sure to try running encoding/decoding with + # `vae.enable_tiling()` first. If that doesn't work, open an issue at: + # https://github.com/huggingface/diffusers/issues + other_frames = other_frames.contiguous() + other_frames = F.interpolate(other_frames, scale_factor=self.upsample_factor, mode="nearest") + hidden_states = torch.cat((first_frame, other_frames), dim=2) + else: + hidden_states = first_frame + + hidden_states = self.conv(hidden_states) + return hidden_states + + +class HunyuanVideoDownsampleCausal3D(nn.Module): + def __init__( + self, + channels: int, + out_channels: Optional[int] = None, + padding: int = 1, + kernel_size: int = 3, + bias: bool = True, + stride=2, + ) -> None: + super().__init__() + out_channels = out_channels or channels + + self.conv = HunyuanVideoCausalConv3d(channels, out_channels, kernel_size, stride, padding, bias=bias) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.conv(hidden_states) + return hidden_states + + +class HunyuanVideoResnetBlockCausal3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + dropout: float = 0.0, + groups: int = 32, + eps: float = 1e-6, + non_linearity: str = "swish", + ) -> None: + super().__init__() + out_channels = out_channels or in_channels + + self.nonlinearity = get_activation(non_linearity) + + self.norm1 = nn.GroupNorm(groups, in_channels, eps=eps, affine=True) + self.conv1 = HunyuanVideoCausalConv3d(in_channels, out_channels, 3, 1, 0) + + self.norm2 = nn.GroupNorm(groups, out_channels, eps=eps, affine=True) + self.dropout = nn.Dropout(dropout) + self.conv2 = HunyuanVideoCausalConv3d(out_channels, out_channels, 3, 1, 0) + + self.conv_shortcut = None + if in_channels != out_channels: + self.conv_shortcut = HunyuanVideoCausalConv3d(in_channels, out_channels, 1, 1, 0) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + residual = hidden_states + + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.conv1(hidden_states) + + hidden_states = self.norm2(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + residual = self.conv_shortcut(residual) + + hidden_states = hidden_states + residual + return hidden_states + + +class HunyuanVideoMidBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + add_attention: bool = True, + attention_head_dim: int = 1, + ) -> None: + super().__init__() + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + self.add_attention = add_attention + + # There is always at least one resnet + resnets = [ + HunyuanVideoResnetBlockCausal3D( + in_channels=in_channels, + out_channels=in_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + non_linearity=resnet_act_fn, + ) + ] + attentions = [] + + for _ in range(num_layers): + if self.add_attention: + attentions.append( + Attention( + in_channels, + heads=in_channels // attention_head_dim, + dim_head=attention_head_dim, + eps=resnet_eps, + norm_num_groups=resnet_groups, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + else: + attentions.append(None) + + resnets.append( + HunyuanVideoResnetBlockCausal3D( + in_channels=in_channels, + out_channels=in_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + non_linearity=resnet_act_fn, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.resnets[0]), hidden_states, **ckpt_kwargs + ) + + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if attn is not None: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + hidden_states = hidden_states.permute(0, 2, 3, 4, 1).flatten(1, 3) + attention_mask = prepare_causal_attention_mask( + num_frames, height * width, hidden_states.dtype, hidden_states.device, batch_size=batch_size + ) + hidden_states = attn(hidden_states, attention_mask=attention_mask) + hidden_states = hidden_states.unflatten(1, (num_frames, height, width)).permute(0, 4, 1, 2, 3) + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, **ckpt_kwargs + ) + + else: + hidden_states = self.resnets[0](hidden_states) + + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if attn is not None: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + hidden_states = hidden_states.permute(0, 2, 3, 4, 1).flatten(1, 3) + attention_mask = prepare_causal_attention_mask( + num_frames, height * width, hidden_states.dtype, hidden_states.device, batch_size=batch_size + ) + hidden_states = attn(hidden_states, attention_mask=attention_mask) + hidden_states = hidden_states.unflatten(1, (num_frames, height, width)).permute(0, 4, 1, 2, 3) + + hidden_states = resnet(hidden_states) + + return hidden_states + + +class HunyuanVideoDownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + add_downsample: bool = True, + downsample_stride: int = 2, + downsample_padding: int = 1, + ) -> None: + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + HunyuanVideoResnetBlockCausal3D( + in_channels=in_channels, + out_channels=out_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + non_linearity=resnet_act_fn, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + HunyuanVideoDownsampleCausal3D( + out_channels, + out_channels=out_channels, + padding=downsample_padding, + stride=downsample_stride, + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + + for resnet in self.resnets: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, **ckpt_kwargs + ) + else: + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states + + +class HunyuanVideoUpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + add_upsample: bool = True, + upsample_scale_factor: Tuple[int, int, int] = (2, 2, 2), + ) -> None: + super().__init__() + resnets = [] + + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + resnets.append( + HunyuanVideoResnetBlockCausal3D( + in_channels=input_channels, + out_channels=out_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + non_linearity=resnet_act_fn, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList( + [ + HunyuanVideoUpsampleCausal3D( + out_channels, + out_channels=out_channels, + upsample_factor=upsample_scale_factor, + ) + ] + ) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + + for resnet in self.resnets: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, **ckpt_kwargs + ) + + else: + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class HunyuanVideoEncoder3D(nn.Module): + r""" + Causal encoder for 3D video-like data introduced in [Hunyuan Video](https://huggingface.co/papers/2412.03603). + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str, ...] = ( + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + ), + block_out_channels: Tuple[int, ...] = (128, 256, 512, 512), + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + double_z: bool = True, + mid_block_add_attention=True, + temporal_compression_ratio: int = 4, + spatial_compression_ratio: int = 8, + ) -> None: + super().__init__() + + self.conv_in = HunyuanVideoCausalConv3d(in_channels, block_out_channels[0], kernel_size=3, stride=1) + self.mid_block = None + self.down_blocks = nn.ModuleList([]) + + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + if down_block_type != "HunyuanVideoDownBlock3D": + raise ValueError(f"Unsupported down_block_type: {down_block_type}") + + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + num_spatial_downsample_layers = int(np.log2(spatial_compression_ratio)) + num_time_downsample_layers = int(np.log2(temporal_compression_ratio)) + + if temporal_compression_ratio == 4: + add_spatial_downsample = bool(i < num_spatial_downsample_layers) + add_time_downsample = bool( + i >= (len(block_out_channels) - 1 - num_time_downsample_layers) and not is_final_block + ) + elif temporal_compression_ratio == 8: + add_spatial_downsample = bool(i < num_spatial_downsample_layers) + add_time_downsample = bool(i < num_time_downsample_layers) + else: + raise ValueError(f"Unsupported time_compression_ratio: {temporal_compression_ratio}") + + downsample_stride_HW = (2, 2) if add_spatial_downsample else (1, 1) + downsample_stride_T = (2,) if add_time_downsample else (1,) + downsample_stride = tuple(downsample_stride_T + downsample_stride_HW) + + down_block = HunyuanVideoDownBlock3D( + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + add_downsample=bool(add_spatial_downsample or add_time_downsample), + resnet_eps=1e-6, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + downsample_stride=downsample_stride, + downsample_padding=0, + ) + + self.down_blocks.append(down_block) + + self.mid_block = HunyuanVideoMidBlock3D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + add_attention=mid_block_add_attention, + ) + + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6) + self.conv_act = nn.SiLU() + + conv_out_channels = 2 * out_channels if double_z else out_channels + self.conv_out = HunyuanVideoCausalConv3d(block_out_channels[-1], conv_out_channels, kernel_size=3) + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.conv_in(hidden_states) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + + for down_block in self.down_blocks: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(down_block), hidden_states, **ckpt_kwargs + ) + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.mid_block), hidden_states, **ckpt_kwargs + ) + else: + for down_block in self.down_blocks: + hidden_states = down_block(hidden_states) + + hidden_states = self.mid_block(hidden_states) + + hidden_states = self.conv_norm_out(hidden_states) + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states) + + return hidden_states + + +class HunyuanVideoDecoder3D(nn.Module): + r""" + Causal decoder for 3D video-like data introduced in [Hunyuan Video](https://huggingface.co/papers/2412.03603). + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + up_block_types: Tuple[str, ...] = ( + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + ), + block_out_channels: Tuple[int, ...] = (128, 256, 512, 512), + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + mid_block_add_attention=True, + time_compression_ratio: int = 4, + spatial_compression_ratio: int = 8, + ): + super().__init__() + self.layers_per_block = layers_per_block + + self.conv_in = HunyuanVideoCausalConv3d(in_channels, block_out_channels[-1], kernel_size=3, stride=1) + self.up_blocks = nn.ModuleList([]) + + # mid + self.mid_block = HunyuanVideoMidBlock3D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + add_attention=mid_block_add_attention, + ) + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + if up_block_type != "HunyuanVideoUpBlock3D": + raise ValueError(f"Unsupported up_block_type: {up_block_type}") + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + num_spatial_upsample_layers = int(np.log2(spatial_compression_ratio)) + num_time_upsample_layers = int(np.log2(time_compression_ratio)) + + if time_compression_ratio == 4: + add_spatial_upsample = bool(i < num_spatial_upsample_layers) + add_time_upsample = bool( + i >= len(block_out_channels) - 1 - num_time_upsample_layers and not is_final_block + ) + else: + raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}") + + upsample_scale_factor_HW = (2, 2) if add_spatial_upsample else (1, 1) + upsample_scale_factor_T = (2,) if add_time_upsample else (1,) + upsample_scale_factor = tuple(upsample_scale_factor_T + upsample_scale_factor_HW) + + up_block = HunyuanVideoUpBlock3D( + num_layers=self.layers_per_block + 1, + in_channels=prev_output_channel, + out_channels=output_channel, + add_upsample=bool(add_spatial_upsample or add_time_upsample), + upsample_scale_factor=upsample_scale_factor, + resnet_eps=1e-6, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + ) + + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) + self.conv_act = nn.SiLU() + self.conv_out = HunyuanVideoCausalConv3d(block_out_channels[0], out_channels, kernel_size=3) + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.conv_in(hidden_states) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.mid_block), hidden_states, **ckpt_kwargs + ) + + for up_block in self.up_blocks: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(up_block), hidden_states, **ckpt_kwargs + ) + else: + hidden_states = self.mid_block(hidden_states) + + for up_block in self.up_blocks: + hidden_states = up_block(hidden_states) + + # post-process + hidden_states = self.conv_norm_out(hidden_states) + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states) + + return hidden_states + + +class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin): + r""" + A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. + Introduced in [HunyuanVideo](https://huggingface.co/papers/2412.03603). + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + latent_channels: int = 16, + down_block_types: Tuple[str, ...] = ( + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + ), + up_block_types: Tuple[str, ...] = ( + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + ), + block_out_channels: Tuple[int] = (128, 256, 512, 512), + layers_per_block: int = 2, + act_fn: str = "silu", + norm_num_groups: int = 32, + scaling_factor: float = 0.476986, + spatial_compression_ratio: int = 8, + temporal_compression_ratio: int = 4, + mid_block_add_attention: bool = True, + ) -> None: + super().__init__() + + self.time_compression_ratio = temporal_compression_ratio + + self.encoder = HunyuanVideoEncoder3D( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + norm_num_groups=norm_num_groups, + act_fn=act_fn, + double_z=True, + mid_block_add_attention=mid_block_add_attention, + temporal_compression_ratio=temporal_compression_ratio, + spatial_compression_ratio=spatial_compression_ratio, + ) + + self.decoder = HunyuanVideoDecoder3D( + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + norm_num_groups=norm_num_groups, + act_fn=act_fn, + time_compression_ratio=temporal_compression_ratio, + spatial_compression_ratio=spatial_compression_ratio, + mid_block_add_attention=mid_block_add_attention, + ) + + self.quant_conv = nn.Conv3d(2 * latent_channels, 2 * latent_channels, kernel_size=1) + self.post_quant_conv = nn.Conv3d(latent_channels, latent_channels, kernel_size=1) + + self.spatial_compression_ratio = spatial_compression_ratio + self.temporal_compression_ratio = temporal_compression_ratio + + # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension + # to perform decoding of a single video latent at a time. + self.use_slicing = False + + # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent + # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the + # intermediate tiles together, the memory requirement can be lowered. + self.use_tiling = False + + # When decoding temporally long video latents, the memory requirement is very high. By decoding latent frames + # at a fixed frame batch size (based on `self.num_latent_frames_batch_sizes`), the memory requirement can be lowered. + self.use_framewise_encoding = True + self.use_framewise_decoding = True + + # The minimal tile height and width for spatial tiling to be used + self.tile_sample_min_height = 256 + self.tile_sample_min_width = 256 + self.tile_sample_min_num_frames = 64 + + # The minimal distance between two spatial tiles + self.tile_sample_stride_height = 192 + self.tile_sample_stride_width = 192 + self.tile_sample_stride_num_frames = 48 + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (HunyuanVideoEncoder3D, HunyuanVideoDecoder3D)): + module.gradient_checkpointing = value + + def enable_tiling( + self, + tile_sample_min_height: Optional[int] = None, + tile_sample_min_width: Optional[int] = None, + tile_sample_min_num_frames: Optional[int] = None, + tile_sample_stride_height: Optional[float] = None, + tile_sample_stride_width: Optional[float] = None, + tile_sample_stride_num_frames: Optional[float] = None, + ) -> None: + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + + Args: + tile_sample_min_height (`int`, *optional*): + The minimum height required for a sample to be separated into tiles across the height dimension. + tile_sample_min_width (`int`, *optional*): + The minimum width required for a sample to be separated into tiles across the width dimension. + tile_sample_min_num_frames (`int`, *optional*): + The minimum number of frames required for a sample to be separated into tiles across the frame + dimension. + tile_sample_stride_height (`int`, *optional*): + The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are + no tiling artifacts produced across the height dimension. + tile_sample_stride_width (`int`, *optional*): + The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling + artifacts produced across the width dimension. + tile_sample_stride_num_frames (`int`, *optional*): + The stride between two consecutive frame tiles. This is to ensure that there are no tiling artifacts + produced across the frame dimension. + """ + self.use_tiling = True + self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height + self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width + self.tile_sample_min_num_frames = tile_sample_min_num_frames or self.tile_sample_min_num_frames + self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height + self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width + self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames + + def disable_tiling(self) -> None: + r""" + Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_tiling = False + + def enable_slicing(self) -> None: + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_slicing = True + + def disable_slicing(self) -> None: + r""" + Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_slicing = False + + def _encode(self, x: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = x.shape + + if self.use_framewise_decoding and num_frames > self.tile_sample_min_num_frames: + return self._temporal_tiled_encode(x) + + if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): + return self.tiled_encode(x) + + x = self.encoder(x) + enc = self.quant_conv(x) + return enc + + @apply_forward_hook + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + r""" + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded videos. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + batch_size, num_channels, num_frames, height, width = z.shape + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_stride_width // self.spatial_compression_ratio + tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio + + if self.use_framewise_decoding and num_frames > tile_latent_min_num_frames: + return self._temporal_tiled_decode(z, return_dict=return_dict) + + if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height): + return self.tiled_decode(z, return_dict=return_dict) + + z = self.post_quant_conv(z) + dec = self.decoder(z) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + @apply_forward_hook + def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z).sample + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-2], b.shape[-2], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( + y / blend_extent + ) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-1], b.shape[-1], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( + x / blend_extent + ) + return b + + def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-3], b.shape[-3], blend_extent) + for x in range(blend_extent): + b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * ( + x / blend_extent + ) + return b + + def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput: + r"""Encode a batch of images using a tiled encoder. + + Args: + x (`torch.Tensor`): Input batch of videos. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + batch_size, num_channels, num_frames, height, width = x.shape + latent_height = height // self.spatial_compression_ratio + latent_width = width // self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = tile_latent_min_height - tile_latent_stride_height + blend_width = tile_latent_min_width - tile_latent_stride_width + + # Split x into overlapping tiles and encode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, self.tile_sample_stride_height): + row = [] + for j in range(0, width, self.tile_sample_stride_width): + tile = x[:, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size] + tile = self.encoder(tile) + tile = self.quant_conv(tile) + row.append(tile) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width]) + result_rows.append(torch.cat(result_row, dim=-1)) + + enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width] + return enc + + def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + + batch_size, num_channels, num_frames, height, width = z.shape + sample_height = height * self.spatial_compression_ratio + sample_width = width * self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = self.tile_sample_min_height - self.tile_sample_stride_height + blend_width = self.tile_sample_min_width - self.tile_sample_stride_width + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, tile_latent_stride_height): + row = [] + for j in range(0, width, tile_latent_stride_width): + tile = z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width] + tile = self.post_quant_conv(tile) + decoded = self.decoder(tile) + row.append(decoded) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width]) + result_rows.append(torch.cat(result_row, dim=-1)) + + dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width] + + if not return_dict: + return (dec,) + return DecoderOutput(sample=dec) + + def _temporal_tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput: + batch_size, num_channels, num_frames, height, width = x.shape + latent_num_frames = (num_frames - 1) // self.temporal_compression_ratio + 1 + + tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio + tile_latent_stride_num_frames = self.tile_sample_stride_num_frames // self.temporal_compression_ratio + blend_num_frames = tile_latent_min_num_frames - tile_latent_stride_num_frames + + row = [] + for i in range(0, num_frames, self.tile_sample_stride_num_frames): + tile = x[:, :, i : i + self.tile_sample_min_num_frames + 1, :, :] + if self.use_tiling and (height > self.tile_sample_min_height or width > self.tile_sample_min_width): + tile = self.tiled_encode(tile) + else: + tile = self.encoder(tile) + tile = self.quant_conv(tile) + if i > 0: + tile = tile[:, :, 1:, :, :] + row.append(tile) + + result_row = [] + for i, tile in enumerate(row): + if i > 0: + tile = self.blend_t(row[i - 1], tile, blend_num_frames) + result_row.append(tile[:, :, :tile_latent_stride_num_frames, :, :]) + else: + result_row.append(tile[:, :, : tile_latent_stride_num_frames + 1, :, :]) + + enc = torch.cat(result_row, dim=2)[:, :, :latent_num_frames] + return enc + + def _temporal_tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + batch_size, num_channels, num_frames, height, width = z.shape + num_sample_frames = (num_frames - 1) * self.temporal_compression_ratio + 1 + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio + tile_latent_stride_num_frames = self.tile_sample_stride_num_frames // self.temporal_compression_ratio + blend_num_frames = self.tile_sample_min_num_frames - self.tile_sample_stride_num_frames + + row = [] + for i in range(0, num_frames, tile_latent_stride_num_frames): + tile = z[:, :, i : i + tile_latent_min_num_frames + 1, :, :] + if self.use_tiling and (tile.shape[-1] > tile_latent_min_width or tile.shape[-2] > tile_latent_min_height): + decoded = self.tiled_decode(tile, return_dict=True).sample + else: + tile = self.post_quant_conv(tile) + decoded = self.decoder(tile) + if i > 0: + decoded = decoded[:, :, 1:, :, :] + row.append(decoded) + + result_row = [] + for i, tile in enumerate(row): + if i > 0: + tile = self.blend_t(row[i - 1], tile, blend_num_frames) + result_row.append(tile[:, :, : self.tile_sample_stride_num_frames, :, :]) + else: + result_row.append(tile[:, :, : self.tile_sample_stride_num_frames + 1, :, :]) + + dec = torch.cat(result_row, dim=2)[:, :, :num_sample_frames] + + if not return_dict: + return (dec,) + return DecoderOutput(sample=dec) + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[DecoderOutput, torch.Tensor]: + r""" + Args: + sample (`torch.Tensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z, return_dict=return_dict) + return dec diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 6a13e80772e3..3a33c8070c08 100644 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -18,6 +18,7 @@ from .transformer_allegro import AllegroTransformer3DModel from .transformer_cogview3plus import CogView3PlusTransformer2DModel from .transformer_flux import FluxTransformer2DModel + from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel from .transformer_ltx import LTXVideoTransformer3DModel from .transformer_mochi import MochiTransformer3DModel from .transformer_sd3 import SD3Transformer2DModel diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py new file mode 100644 index 000000000000..d8f9834ea61c --- /dev/null +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -0,0 +1,723 @@ +# Copyright 2024 The Hunyuan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import is_torch_version +from ..attention import FeedForward +from ..attention_processor import Attention, AttentionProcessor +from ..embeddings import ( + CombinedTimestepGuidanceTextProjEmbeddings, + CombinedTimestepTextProjEmbeddings, + get_1d_rotary_pos_embed, +) +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle + + +class HunyuanVideoAttnProcessor2_0: + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "HunyuanVideoAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if attn.add_q_proj is None and encoder_hidden_states is not None: + hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1) + + # 1. QKV projections + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) + key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) + value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) + + # 2. QK normalization + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # 3. Rotational positional embeddings applied to latent stream + if image_rotary_emb is not None: + from ..embeddings import apply_rotary_emb + + if attn.add_q_proj is None and encoder_hidden_states is not None: + query = torch.cat( + [ + apply_rotary_emb(query[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb), + query[:, :, -encoder_hidden_states.shape[1] :], + ], + dim=2, + ) + key = torch.cat( + [ + apply_rotary_emb(key[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb), + key[:, :, -encoder_hidden_states.shape[1] :], + ], + dim=2, + ) + else: + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + # 4. Encoder condition QKV projection and normalization + if attn.add_q_proj is not None and encoder_hidden_states is not None: + encoder_query = attn.add_q_proj(encoder_hidden_states) + encoder_key = attn.add_k_proj(encoder_hidden_states) + encoder_value = attn.add_v_proj(encoder_hidden_states) + + encoder_query = encoder_query.unflatten(2, (attn.heads, -1)).transpose(1, 2) + encoder_key = encoder_key.unflatten(2, (attn.heads, -1)).transpose(1, 2) + encoder_value = encoder_value.unflatten(2, (attn.heads, -1)).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_query = attn.norm_added_q(encoder_query) + if attn.norm_added_k is not None: + encoder_key = attn.norm_added_k(encoder_key) + + query = torch.cat([query, encoder_query], dim=2) + key = torch.cat([key, encoder_key], dim=2) + value = torch.cat([value, encoder_value], dim=2) + + # 5. Attention + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + # 6. Output projection + if encoder_hidden_states is not None: + hidden_states, encoder_hidden_states = ( + hidden_states[:, : -encoder_hidden_states.shape[1]], + hidden_states[:, -encoder_hidden_states.shape[1] :], + ) + + if getattr(attn, "to_out", None) is not None: + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + + if getattr(attn, "to_add_out", None) is not None: + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + + +class HunyuanVideoPatchEmbed(nn.Module): + def __init__( + self, + patch_size: Union[int, Tuple[int, int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + ) -> None: + super().__init__() + + patch_size = (patch_size, patch_size, patch_size) if isinstance(patch_size, int) else patch_size + self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.proj(hidden_states) + hidden_states = hidden_states.flatten(2).transpose(1, 2) # BCFHW -> BNC + return hidden_states + + +class HunyuanVideoAdaNorm(nn.Module): + def __init__(self, in_features: int, out_features: Optional[int] = None) -> None: + super().__init__() + + out_features = out_features or 2 * in_features + self.linear = nn.Linear(in_features, out_features) + self.nonlinearity = nn.SiLU() + + def forward( + self, temb: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + temb = self.linear(self.nonlinearity(temb)) + gate_msa, gate_mlp = temb.chunk(2, dim=1) + gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1) + return gate_msa, gate_mlp + + +class HunyuanVideoIndividualTokenRefinerBlock(nn.Module): + def __init__( + self, + num_attention_heads: int, + attention_head_dim: int, + mlp_width_ratio: str = 4.0, + mlp_drop_rate: float = 0.0, + attention_bias: bool = True, + ) -> None: + super().__init__() + + hidden_size = num_attention_heads * attention_head_dim + + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6) + self.attn = Attention( + query_dim=hidden_size, + cross_attention_dim=None, + heads=num_attention_heads, + dim_head=attention_head_dim, + bias=attention_bias, + ) + + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6) + self.ff = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="linear-silu", dropout=mlp_drop_rate) + + self.norm_out = HunyuanVideoAdaNorm(hidden_size, 2 * hidden_size) + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + norm_hidden_states = self.norm1(hidden_states) + + attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=None, + attention_mask=attention_mask, + ) + + gate_msa, gate_mlp = self.norm_out(temb) + hidden_states = hidden_states + attn_output * gate_msa + + ff_output = self.ff(self.norm2(hidden_states)) + hidden_states = hidden_states + ff_output * gate_mlp + + return hidden_states + + +class HunyuanVideoIndividualTokenRefiner(nn.Module): + def __init__( + self, + num_attention_heads: int, + attention_head_dim: int, + num_layers: int, + mlp_width_ratio: float = 4.0, + mlp_drop_rate: float = 0.0, + attention_bias: bool = True, + ) -> None: + super().__init__() + + self.refiner_blocks = nn.ModuleList( + [ + HunyuanVideoIndividualTokenRefinerBlock( + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + mlp_width_ratio=mlp_width_ratio, + mlp_drop_rate=mlp_drop_rate, + attention_bias=attention_bias, + ) + for _ in range(num_layers) + ] + ) + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> None: + self_attn_mask = None + if attention_mask is not None: + batch_size = attention_mask.shape[0] + seq_len = attention_mask.shape[1] + attention_mask = attention_mask.to(hidden_states.device).bool() + self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1) + self_attn_mask_2 = self_attn_mask_1.transpose(2, 3) + self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool() + self_attn_mask[:, :, :, 0] = True + + for block in self.refiner_blocks: + hidden_states = block(hidden_states, temb, self_attn_mask) + + return hidden_states + + +class HunyuanVideoTokenRefiner(nn.Module): + def __init__( + self, + in_channels: int, + num_attention_heads: int, + attention_head_dim: int, + num_layers: int, + mlp_ratio: float = 4.0, + mlp_drop_rate: float = 0.0, + attention_bias: bool = True, + ) -> None: + super().__init__() + + hidden_size = num_attention_heads * attention_head_dim + + self.time_text_embed = CombinedTimestepTextProjEmbeddings( + embedding_dim=hidden_size, pooled_projection_dim=in_channels + ) + self.proj_in = nn.Linear(in_channels, hidden_size, bias=True) + self.token_refiner = HunyuanVideoIndividualTokenRefiner( + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + num_layers=num_layers, + mlp_width_ratio=mlp_ratio, + mlp_drop_rate=mlp_drop_rate, + attention_bias=attention_bias, + ) + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.LongTensor, + attention_mask: Optional[torch.LongTensor] = None, + ) -> torch.Tensor: + if attention_mask is None: + pooled_projections = hidden_states.mean(dim=1) + else: + original_dtype = hidden_states.dtype + mask_float = attention_mask.float().unsqueeze(-1) + pooled_projections = (hidden_states * mask_float).sum(dim=1) / mask_float.sum(dim=1) + pooled_projections = pooled_projections.to(original_dtype) + + temb = self.time_text_embed(timestep, pooled_projections) + hidden_states = self.proj_in(hidden_states) + hidden_states = self.token_refiner(hidden_states, temb, attention_mask) + + return hidden_states + + +class HunyuanVideoRotaryPosEmbed(nn.Module): + def __init__(self, patch_size: int, patch_size_t: int, rope_dim: List[int], theta: float = 256.0) -> None: + super().__init__() + + self.patch_size = patch_size + self.patch_size_t = patch_size_t + self.rope_dim = rope_dim + self.theta = theta + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + rope_sizes = [num_frames // self.patch_size_t, height // self.patch_size, width // self.patch_size] + + axes_grids = [] + for i in range(3): + # Note: The following line diverges from original behaviour. We create the grid on the device, whereas + # original implementation creates it on CPU and then moves it to device. This results in numerical + # differences in layerwise debugging outputs, but visually it is the same. + grid = torch.arange(0, rope_sizes[i], device=hidden_states.device, dtype=torch.float32) + axes_grids.append(grid) + grid = torch.meshgrid(*axes_grids, indexing="ij") # [W, H, T] + grid = torch.stack(grid, dim=0) # [3, W, H, T] + + freqs = [] + for i in range(3): + freq = get_1d_rotary_pos_embed(self.rope_dim[i], grid[i].reshape(-1), self.theta, use_real=True) + freqs.append(freq) + + freqs_cos = torch.cat([f[0] for f in freqs], dim=1) # (W * H * T, D / 2) + freqs_sin = torch.cat([f[1] for f in freqs], dim=1) # (W * H * T, D / 2) + return freqs_cos, freqs_sin + + +class HunyuanVideoSingleTransformerBlock(nn.Module): + def __init__( + self, + num_attention_heads: int, + attention_head_dim: int, + mlp_ratio: float = 4.0, + qk_norm: str = "rms_norm", + ) -> None: + super().__init__() + + hidden_size = num_attention_heads * attention_head_dim + mlp_dim = int(hidden_size * mlp_ratio) + + self.attn = Attention( + query_dim=hidden_size, + cross_attention_dim=None, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=hidden_size, + bias=True, + processor=HunyuanVideoAttnProcessor2_0(), + qk_norm=qk_norm, + eps=1e-6, + pre_only=True, + ) + + self.norm = AdaLayerNormZeroSingle(hidden_size, norm_type="layer_norm") + self.proj_mlp = nn.Linear(hidden_size, mlp_dim) + self.act_mlp = nn.GELU(approximate="tanh") + self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + text_seq_length = encoder_hidden_states.shape[1] + hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1) + + residual = hidden_states + + # 1. Input normalization + norm_hidden_states, gate = self.norm(hidden_states, emb=temb) + mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) + + norm_hidden_states, norm_encoder_hidden_states = ( + norm_hidden_states[:, :-text_seq_length, :], + norm_hidden_states[:, -text_seq_length:, :], + ) + + # 2. Attention + attn_output, context_attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + attention_mask=attention_mask, + image_rotary_emb=image_rotary_emb, + ) + attn_output = torch.cat([attn_output, context_attn_output], dim=1) + + # 3. Modulation and residual connection + hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) + hidden_states = gate.unsqueeze(1) * self.proj_out(hidden_states) + hidden_states = hidden_states + residual + + hidden_states, encoder_hidden_states = ( + hidden_states[:, :-text_seq_length, :], + hidden_states[:, -text_seq_length:, :], + ) + return hidden_states, encoder_hidden_states + + +class HunyuanVideoTransformerBlock(nn.Module): + def __init__( + self, + num_attention_heads: int, + attention_head_dim: int, + mlp_ratio: float, + qk_norm: str = "rms_norm", + ) -> None: + super().__init__() + + hidden_size = num_attention_heads * attention_head_dim + + self.norm1 = AdaLayerNormZero(hidden_size, norm_type="layer_norm") + self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm") + + self.attn = Attention( + query_dim=hidden_size, + cross_attention_dim=None, + added_kv_proj_dim=hidden_size, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=hidden_size, + context_pre_only=False, + bias=True, + processor=HunyuanVideoAttnProcessor2_0(), + qk_norm=qk_norm, + eps=1e-6, + ) + + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate") + + self.norm2_context = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate") + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # 1. Input normalization + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) + norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( + encoder_hidden_states, emb=temb + ) + + # 2. Joint attention + attn_output, context_attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + attention_mask=attention_mask, + image_rotary_emb=freqs_cis, + ) + + # 3. Modulation and residual connection + hidden_states = hidden_states + attn_output * gate_msa.unsqueeze(1) + encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa.unsqueeze(1) + + norm_hidden_states = self.norm2(hidden_states) + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + + # 4. Feed-forward + ff_output = self.ff(norm_hidden_states) + context_ff_output = self.ff_context(norm_encoder_hidden_states) + + hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output + encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + + return hidden_states, encoder_hidden_states + + +class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + in_channels: int = 16, + out_channels: int = 16, + num_attention_heads: int = 24, + attention_head_dim: int = 128, + num_layers: int = 20, + num_single_layers: int = 40, + num_refiner_layers: int = 2, + mlp_ratio: float = 4.0, + patch_size: int = 2, + patch_size_t: int = 1, + qk_norm: str = "rms_norm", + guidance_embeds: bool = True, + text_embed_dim: int = 4096, + pooled_projection_dim: int = 768, + rope_theta: float = 256.0, + rope_axes_dim: Tuple[int] = (16, 56, 56), + ) -> None: + super().__init__() + + inner_dim = num_attention_heads * attention_head_dim + out_channels = out_channels or in_channels + + # 1. Latent and condition embedders + self.x_embedder = HunyuanVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim) + self.context_embedder = HunyuanVideoTokenRefiner( + text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers + ) + self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings(inner_dim, pooled_projection_dim) + + # 2. RoPE + self.rope = HunyuanVideoRotaryPosEmbed(patch_size, patch_size_t, rope_axes_dim, rope_theta) + + # 3. Dual stream transformer blocks + self.transformer_blocks = nn.ModuleList( + [ + HunyuanVideoTransformerBlock( + num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm + ) + for _ in range(num_layers) + ] + ) + + # 4. Single stream transformer blocks + self.single_transformer_blocks = nn.ModuleList( + [ + HunyuanVideoSingleTransformerBlock( + num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm + ) + for _ in range(num_single_layers) + ] + ) + + # 5. Output projection + self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels) + + self.gradient_checkpointing = False + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.LongTensor, + encoder_hidden_states: torch.Tensor, + encoder_attention_mask: torch.Tensor, + pooled_projections: torch.Tensor, + guidance: torch.Tensor = None, + return_dict: bool = True, + ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p, p_t = self.config.patch_size, self.config.patch_size_t + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p + post_patch_width = width // p + + # 1. RoPE + image_rotary_emb = self.rope(hidden_states) + + # 2. Conditional embeddings + temb = self.time_text_embed(timestep, guidance, pooled_projections) + hidden_states = self.x_embedder(hidden_states) + encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask) + + # 3. Attention mask preparation + latent_sequence_length = hidden_states.shape[1] + condition_sequence_length = encoder_hidden_states.shape[1] + sequence_length = latent_sequence_length + condition_sequence_length + attention_mask = torch.zeros( + batch_size, sequence_length, sequence_length, device=hidden_states.device, dtype=torch.bool + ) # [B, N, N] + + effective_condition_sequence_length = encoder_attention_mask.sum(dim=1, dtype=torch.int) # [B,] + effective_sequence_length = latent_sequence_length + effective_condition_sequence_length + + for i in range(batch_size): + attention_mask[i, : effective_sequence_length[i], : effective_sequence_length[i]] = True + + # 4. Transformer blocks + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + + for block in self.transformer_blocks: + hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + temb, + attention_mask, + image_rotary_emb, + **ckpt_kwargs, + ) + + for block in self.single_transformer_blocks: + hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + temb, + attention_mask, + image_rotary_emb, + **ckpt_kwargs, + ) + + else: + for block in self.transformer_blocks: + hidden_states, encoder_hidden_states = block( + hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb + ) + + for block in self.single_transformer_blocks: + hidden_states, encoder_hidden_states = block( + hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb + ) + + # 5. Output projection + hidden_states = self.norm_out(hidden_states, temb) + hidden_states = self.proj_out(hidden_states) + + hidden_states = hidden_states.reshape( + batch_size, post_patch_num_frames, post_patch_height, post_patch_width, -1, p_t, p, p + ) + hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7) + hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + + if not return_dict: + return (hidden_states,) + + return Transformer2DModelOutput(sample=hidden_states) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 6f1b842f92f2..e7fd7ec78bed 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -214,6 +214,7 @@ "IFSuperResolutionPipeline", ] _import_structure["hunyuandit"] = ["HunyuanDiTPipeline"] + _import_structure["hunyuan_video"] = ["HunyuanVideoPipeline"] _import_structure["kandinsky"] = [ "KandinskyCombinedPipeline", "KandinskyImg2ImgCombinedPipeline", @@ -549,6 +550,7 @@ FluxPriorReduxPipeline, ReduxImageEncoder, ) + from .hunyuan_video import HunyuanVideoPipeline from .hunyuandit import HunyuanDiTPipeline from .i2vgen_xl import I2VGenXLPipeline from .kandinsky import ( diff --git a/src/diffusers/pipelines/hunyuan_video/__init__.py b/src/diffusers/pipelines/hunyuan_video/__init__.py new file mode 100644 index 000000000000..978ed7f96110 --- /dev/null +++ b/src/diffusers/pipelines/hunyuan_video/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_hunyuan_video"] = ["HunyuanVideoPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_hunyuan_video import HunyuanVideoPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py new file mode 100644 index 000000000000..bd3d3c1e8485 --- /dev/null +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py @@ -0,0 +1,675 @@ +# Copyright 2024 The HunyuanVideo Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import HunyuanVideoPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel + >>> from diffusers.utils import export_to_video + + >>> model_id = "tencent/HunyuanVideo" + >>> transformer = HunyuanVideoTransformer3DModel.from_pretrained( + ... model_id, subfolder="transformer", torch_dtype=torch.bfloat16 + ... ) + >>> pipe = HunyuanVideoPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.float16) + >>> pipe.vae.enable_tiling() + >>> pipe.to("cuda") + + >>> output = pipe( + ... prompt="A cat walks on the grass, realistic", + ... height=320, + ... width=512, + ... num_frames=61, + ... num_inference_steps=30, + ... ).frames[0] + >>> export_to_video(output, "output.mp4", fps=15) + ``` +""" + + +DEFAULT_PROMPT_TEMPLATE = { + "template": ( + "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: " + "1. The main content and theme of the video." + "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects." + "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects." + "4. background environment, light, style and atmosphere." + "5. camera angles, movements, and transitions used in the video:<|eot_id|>" + "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" + ), + "crop_start": 95, +} + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class HunyuanVideoPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using HunyuanVideo. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + text_encoder ([`LlamaModel`]): + [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers). + tokenizer_2 (`LlamaTokenizer`): + Tokenizer from [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers). + transformer ([`HunyuanVideoTransformer3DModel`]): + Conditional Transformer to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLHunyuanVideo`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder_2 ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer_2 (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + text_encoder: LlamaModel, + tokenizer: LlamaTokenizerFast, + transformer: HunyuanVideoTransformer3DModel, + vae: AutoencoderKLHunyuanVideo, + scheduler: FlowMatchEulerDiscreteScheduler, + text_encoder_2: CLIPTextModel, + tokenizer_2: CLIPTokenizer, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + ) + + self.vae_scale_factor_temporal = ( + self.vae.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4 + ) + self.vae_scale_factor_spatial = ( + self.vae.spatial_compression_ratio if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + def _get_llama_prompt_embeds( + self, + prompt: Union[str, List[str]], + prompt_template: Dict[str, Any], + num_videos_per_prompt: int = 1, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 256, + num_hidden_layers_to_skip: int = 2, + ) -> Tuple[torch.Tensor, torch.Tensor]: + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + prompt = [prompt_template["template"].format(p) for p in prompt] + + crop_start = prompt_template.get("crop_start", None) + if crop_start is None: + prompt_template_input = self.tokenizer( + prompt_template["template"], + padding="max_length", + return_tensors="pt", + return_length=False, + return_overflowing_tokens=False, + return_attention_mask=False, + ) + crop_start = prompt_template_input["input_ids"].shape[-1] + # Remove <|eot_id|> token and placeholder {} + crop_start -= 2 + + max_sequence_length += crop_start + text_inputs = self.tokenizer( + prompt, + max_length=max_sequence_length, + padding="max_length", + truncation=True, + return_tensors="pt", + return_length=False, + return_overflowing_tokens=False, + return_attention_mask=True, + ) + text_input_ids = text_inputs.input_ids.to(device=device) + prompt_attention_mask = text_inputs.attention_mask.to(device=device) + + prompt_embeds = self.text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_attention_mask, + output_hidden_states=True, + ).hidden_states[-(num_hidden_layers_to_skip + 1)] + prompt_embeds = prompt_embeds.to(dtype=dtype) + + if crop_start is not None and crop_start > 0: + prompt_embeds = prompt_embeds[:, crop_start:] + prompt_attention_mask = prompt_attention_mask[:, crop_start:] + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.repeat(1, num_videos_per_prompt) + prompt_attention_mask = prompt_attention_mask.view(batch_size * num_videos_per_prompt, seq_len) + + return prompt_embeds, prompt_attention_mask + + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_videos_per_prompt: int = 1, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 77, + ) -> torch.Tensor: + device = device or self._execution_device + dtype = dtype or self.text_encoder_2.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False).pooler_output + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]] = None, + prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 256, + ): + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_llama_prompt_embeds( + prompt, + prompt_template, + num_videos_per_prompt, + device=device, + dtype=dtype, + max_sequence_length=max_sequence_length, + ) + + if pooled_prompt_embeds is None: + if prompt_2 is None and pooled_prompt_embeds is None: + prompt_2 = prompt + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt, + num_videos_per_prompt, + device=device, + dtype=dtype, + max_sequence_length=77, + ) + + return prompt_embeds, pooled_prompt_embeds, prompt_attention_mask + + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + prompt_template=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if prompt_template is not None: + if not isinstance(prompt_template, dict): + raise ValueError(f"`prompt_template` has to be of type `dict` but is {type(prompt_template)}") + if "template" not in prompt_template: + raise ValueError( + f"`prompt_template` has to contain a key `template` but only found {prompt_template.keys()}" + ) + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: 32, + height: int = 720, + width: int = 1280, + num_frames: int = 129, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + shape = ( + batch_size, + num_channels_latents, + num_frames, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Union[str, List[str]] = None, + height: int = 720, + width: int = 1280, + num_frames: int = 129, + num_inference_steps: int = 50, + sigmas: List[float] = None, + guidance_scale: float = 6.0, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE, + max_sequence_length: int = 256, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead. + height (`int`, defaults to `720`): + The height in pixels of the generated image. + width (`int`, defaults to `1280`): + The width in pixels of the generated image. + num_frames (`int`, defaults to `129`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, defaults to `6.0`): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. Note that the only available HunyuanVideo model is + CFG-distilled, which means that traditional guidance between unconditional and conditional latent is + not applied. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a plain tuple. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~HunyuanVideoPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`HunyuanVideoPipelineOutput`] is returned, otherwise a `tuple` is returned + where the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + prompt_embeds, + callback_on_step_end_tensor_inputs, + prompt_template, + ) + + self._guidance_scale = guidance_scale + self._interrupt = False + + device = self._execution_device + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_template=prompt_template, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + device=device, + max_sequence_length=max_sequence_length, + ) + + transformer_dtype = self.transformer.dtype + prompt_embeds = prompt_embeds.to(transformer_dtype) + prompt_attention_mask = prompt_attention_mask.to(transformer_dtype) + if pooled_prompt_embeds is not None: + pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype) + + # 4. Prepare timesteps + sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + ) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_latent_frames, + torch.float32, + device, + generator, + latents, + ) + + # 6. Prepare guidance condition + guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0 + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = latents.to(transformer_dtype) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + pooled_projections=pooled_prompt_embeds, + guidance=guidance, + return_dict=False, + )[0] + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return HunyuanVideoPipelineOutput(frames=video) diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_output.py b/src/diffusers/pipelines/hunyuan_video/pipeline_output.py new file mode 100644 index 000000000000..c5cb853e3932 --- /dev/null +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_output.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +import torch + +from diffusers.utils import BaseOutput + + +@dataclass +class HunyuanVideoPipelineOutput(BaseOutput): + r""" + Output class for HunyuanVideo pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + frames: torch.Tensor diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 0f2aad5c5000..4b6ac10385cf 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -107,6 +107,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class AutoencoderKLHunyuanVideo(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class AutoencoderKLLTXVideo(metaclass=DummyObject): _backends = ["torch"] @@ -377,6 +392,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class HunyuanVideoTransformer3DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class I2VGenXLUNet(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 8aefce9d624e..e148c025d191 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -572,6 +572,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class HunyuanVideoPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class I2VGenXLPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py b/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py new file mode 100644 index 000000000000..826ac30d5f2f --- /dev/null +++ b/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py @@ -0,0 +1,159 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import AutoencoderKLHunyuanVideo +from diffusers.utils.testing_utils import ( + enable_full_determinism, + floats_tensor, + torch_device, +) + +from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin + + +enable_full_determinism() + + +class AutoencoderKLHunyuanVideoTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): + model_class = AutoencoderKLHunyuanVideo + main_input_name = "sample" + base_precision = 1e-2 + + def get_autoencoder_kl_hunyuan_video_config(self): + return { + "in_channels": 3, + "out_channels": 3, + "latent_channels": 4, + "down_block_types": ( + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + ), + "up_block_types": ( + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + ), + "block_out_channels": (8, 8, 8, 8), + "layers_per_block": 1, + "act_fn": "silu", + "norm_num_groups": 4, + "scaling_factor": 0.476986, + "spatial_compression_ratio": 8, + "temporal_compression_ratio": 4, + "mid_block_add_attention": True, + } + + @property + def dummy_input(self): + batch_size = 2 + num_frames = 9 + num_channels = 3 + sizes = (16, 16) + + image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device) + + return {"sample": image} + + @property + def input_shape(self): + return (3, 9, 16, 16) + + @property + def output_shape(self): + return (3, 9, 16, 16) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = self.get_autoencoder_kl_hunyuan_video_config() + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_enable_disable_tiling(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + torch.manual_seed(0) + model = self.model_class(**init_dict).to(torch_device) + + inputs_dict.update({"return_dict": False}) + + torch.manual_seed(0) + output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0] + + torch.manual_seed(0) + model.enable_tiling() + output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0] + + self.assertLess( + (output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(), + 0.5, + "VAE tiling should not affect the inference results", + ) + + torch.manual_seed(0) + model.disable_tiling() + output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0] + + self.assertEqual( + output_without_tiling.detach().cpu().numpy().all(), + output_without_tiling_2.detach().cpu().numpy().all(), + "Without tiling outputs should match with the outputs when tiling is manually disabled.", + ) + + def test_enable_disable_slicing(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + torch.manual_seed(0) + model = self.model_class(**init_dict).to(torch_device) + + inputs_dict.update({"return_dict": False}) + + torch.manual_seed(0) + output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0] + + torch.manual_seed(0) + model.enable_slicing() + output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0] + + self.assertLess( + (output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(), + 0.5, + "VAE slicing should not affect the inference results", + ) + + torch.manual_seed(0) + model.disable_slicing() + output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0] + + self.assertEqual( + output_without_slicing.detach().cpu().numpy().all(), + output_without_slicing_2.detach().cpu().numpy().all(), + "Without slicing outputs should match with the outputs when slicing is manually disabled.", + ) + + def test_gradient_checkpointing_is_applied(self): + expected_set = { + "HunyuanVideoDecoder3D", + "HunyuanVideoDownBlock3D", + "HunyuanVideoEncoder3D", + "HunyuanVideoMidBlock3D", + "HunyuanVideoUpBlock3D", + } + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + @unittest.skip("Unsupported test.") + def test_outputs_equivalence(self): + pass diff --git a/tests/models/transformers/test_models_transformer_hunyuan_video.py b/tests/models/transformers/test_models_transformer_hunyuan_video.py new file mode 100644 index 000000000000..e8ea8cecbb9e --- /dev/null +++ b/tests/models/transformers/test_models_transformer_hunyuan_video.py @@ -0,0 +1,89 @@ +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import HunyuanVideoTransformer3DModel +from diffusers.utils.testing_utils import enable_full_determinism, torch_device + +from ..test_modeling_common import ModelTesterMixin + + +enable_full_determinism() + + +class HunyuanVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase): + model_class = HunyuanVideoTransformer3DModel + main_input_name = "hidden_states" + uses_custom_attn_processor = True + + @property + def dummy_input(self): + batch_size = 1 + num_channels = 4 + num_frames = 1 + height = 16 + width = 16 + text_encoder_embedding_dim = 16 + pooled_projection_dim = 8 + sequence_length = 12 + + hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device) + timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device) + encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device) + pooled_projections = torch.randn((batch_size, pooled_projection_dim)).to(torch_device) + encoder_attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device) + guidance = torch.randint(0, 1000, size=(batch_size,)).to(torch_device, dtype=torch.float32) + + return { + "hidden_states": hidden_states, + "timestep": timestep, + "encoder_hidden_states": encoder_hidden_states, + "pooled_projections": pooled_projections, + "encoder_attention_mask": encoder_attention_mask, + "guidance": guidance, + } + + @property + def input_shape(self): + return (4, 1, 16, 16) + + @property + def output_shape(self): + return (4, 1, 16, 16) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "in_channels": 4, + "out_channels": 4, + "num_attention_heads": 2, + "attention_head_dim": 10, + "num_layers": 1, + "num_single_layers": 1, + "num_refiner_layers": 1, + "patch_size": 1, + "patch_size_t": 1, + "guidance_embeds": True, + "text_embed_dim": 16, + "pooled_projection_dim": 8, + "rope_axes_dim": (2, 4, 4), + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_gradient_checkpointing_is_applied(self): + expected_set = {"HunyuanVideoTransformer3DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) diff --git a/tests/pipelines/hunyuan_video/__init__.py b/tests/pipelines/hunyuan_video/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/hunyuan_video/test_hunyuan_video.py b/tests/pipelines/hunyuan_video/test_hunyuan_video.py new file mode 100644 index 000000000000..567002268106 --- /dev/null +++ b/tests/pipelines/hunyuan_video/test_hunyuan_video.py @@ -0,0 +1,331 @@ +# Copyright 2024 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import unittest + +import numpy as np +import torch +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer, LlamaConfig, LlamaModel, LlamaTokenizer + +from diffusers import ( + AutoencoderKLHunyuanVideo, + FlowMatchEulerDiscreteScheduler, + HunyuanVideoPipeline, + HunyuanVideoTransformer3DModel, +) +from diffusers.utils.testing_utils import ( + enable_full_determinism, + torch_device, +) + +from ..test_pipelines_common import PipelineTesterMixin, to_np + + +enable_full_determinism() + + +class HunyuanVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = HunyuanVideoPipeline + params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) + batch_params = frozenset(["prompt"]) + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + + # there is no xformers processor for Flux + test_xformers_attention = False + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = HunyuanVideoTransformer3DModel( + in_channels=4, + out_channels=4, + num_attention_heads=2, + attention_head_dim=10, + num_layers=1, + num_single_layers=1, + num_refiner_layers=1, + patch_size=1, + patch_size_t=1, + guidance_embeds=True, + text_embed_dim=16, + pooled_projection_dim=8, + rope_axes_dim=(2, 4, 4), + ) + + torch.manual_seed(0) + vae = AutoencoderKLHunyuanVideo( + in_channels=3, + out_channels=3, + latent_channels=4, + down_block_types=( + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + ), + up_block_types=( + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + ), + block_out_channels=(8, 8, 8, 8), + layers_per_block=1, + act_fn="silu", + norm_num_groups=4, + scaling_factor=0.476986, + spatial_compression_ratio=8, + temporal_compression_ratio=4, + mid_block_add_attention=True, + ) + + torch.manual_seed(0) + scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0) + + llama_text_encoder_config = LlamaConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=16, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=2, + pad_token_id=1, + vocab_size=1000, + hidden_act="gelu", + projection_dim=32, + ) + clip_text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=8, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=2, + pad_token_id=1, + vocab_size=1000, + hidden_act="gelu", + projection_dim=32, + ) + + torch.manual_seed(0) + text_encoder = LlamaModel(llama_text_encoder_config) + tokenizer = LlamaTokenizer.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM") + + torch.manual_seed(0) + text_encoder_2 = CLIPTextModel(clip_text_encoder_config) + tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "text_encoder_2": text_encoder_2, + "tokenizer": tokenizer, + "tokenizer_2": tokenizer_2, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + # Cannot test with dummy prompt because tokenizers are not configured correctly. + # TODO(aryan): create dummy tokenizers and using from hub + inputs = { + "prompt": "", + "prompt_template": { + "template": "{}", + "crop_start": 0, + }, + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 4.5, + "height": 16, + "width": 16, + # 4 * k + 1 is the recommendation + "num_frames": 9, + "max_sequence_length": 16, + "output_type": "pt", + } + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + video = pipe(**inputs).frames + generated_video = video[0] + + self.assertEqual(generated_video.shape, (9, 3, 16, 16)) + expected_video = torch.randn(9, 3, 16, 16) + max_diff = np.abs(generated_video - expected_video).max() + self.assertLessEqual(max_diff, 1e10) + + def test_callback_inputs(self): + sig = inspect.signature(self.pipeline_class.__call__) + has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters + has_callback_step_end = "callback_on_step_end" in sig.parameters + + if not (has_callback_tensor_inputs and has_callback_step_end): + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + self.assertTrue( + hasattr(pipe, "_callback_tensor_inputs"), + f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs", + ) + + def callback_inputs_subset(pipe, i, t, callback_kwargs): + # iterate over callback args + for tensor_name, tensor_value in callback_kwargs.items(): + # check that we're only passing in allowed tensor inputs + assert tensor_name in pipe._callback_tensor_inputs + + return callback_kwargs + + def callback_inputs_all(pipe, i, t, callback_kwargs): + for tensor_name in pipe._callback_tensor_inputs: + assert tensor_name in callback_kwargs + + # iterate over callback args + for tensor_name, tensor_value in callback_kwargs.items(): + # check that we're only passing in allowed tensor inputs + assert tensor_name in pipe._callback_tensor_inputs + + return callback_kwargs + + inputs = self.get_dummy_inputs(torch_device) + + # Test passing in a subset + inputs["callback_on_step_end"] = callback_inputs_subset + inputs["callback_on_step_end_tensor_inputs"] = ["latents"] + output = pipe(**inputs)[0] + + # Test passing in a everything + inputs["callback_on_step_end"] = callback_inputs_all + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + output = pipe(**inputs)[0] + + def callback_inputs_change_tensor(pipe, i, t, callback_kwargs): + is_last = i == (pipe.num_timesteps - 1) + if is_last: + callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"]) + return callback_kwargs + + inputs["callback_on_step_end"] = callback_inputs_change_tensor + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + output = pipe(**inputs)[0] + assert output.abs().sum() < 1e10 + + def test_attention_slicing_forward_pass( + self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3 + ): + if not self.test_attention_slicing: + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + output_without_slicing = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=1) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing1 = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=2) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing2 = pipe(**inputs)[0] + + if test_max_difference: + max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max() + max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max() + self.assertLess( + max(max_diff1, max_diff2), + expected_max_diff, + "Attention slicing should not affect the inference results", + ) + + def test_vae_tiling(self, expected_diff_max: float = 0.2): + # Seems to require higher tolerance than the other tests + expected_diff_max = 0.6 + generator_device = "cpu" + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + pipe.to("cpu") + pipe.set_progress_bar_config(disable=None) + + # Without tiling + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_without_tiling = pipe(**inputs)[0] + + # With tiling + pipe.vae.enable_tiling( + tile_sample_min_height=96, + tile_sample_min_width=96, + tile_sample_stride_height=64, + tile_sample_stride_width=64, + ) + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_with_tiling = pipe(**inputs)[0] + + self.assertLess( + (to_np(output_without_tiling) - to_np(output_with_tiling)).max(), + expected_diff_max, + "VAE tiling should not affect the inference results", + ) + + # TODO(aryan): Create a dummy gemma model with smol vocab size + @unittest.skip( + "A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error." + ) + def test_inference_batch_consistent(self): + pass + + @unittest.skip( + "A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error." + ) + def test_inference_batch_single_identical(self): + pass