From 3b2830618ddff967a1f3a1307a15e24a75c7ae6e Mon Sep 17 00:00:00 2001 From: "Yuxuan.Zhang" <2448370773@qq.com> Date: Tue, 19 Nov 2024 03:26:34 +0800 Subject: [PATCH] CogVideoX 1.5 (#9877) * CogVideoX1_1PatchEmbed test * 1360 * 768 * refactor * make style * update docs * add modeling tests for cogvideox 1.5 * update * make fix-copies * add ofs embed(for convert) * add ofs embed(for convert) * more resolution for cogvideox1.5-5b-i2v * use even number of latent frames only * update pipeline implementations * make style * set patch_size_t as None by default * #skip frames 0 * refactor * make style * update docs * fix ofs_embed * update docs * invert_scale_latents * update * fix * Update docs/source/en/api/pipelines/cogvideox.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/api/pipelines/cogvideox.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/api/pipelines/cogvideox.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/api/pipelines/cogvideox.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update src/diffusers/models/transformers/cogvideox_transformer_3d.py * update conversion script * remove copied from * fix test * Update docs/source/en/api/pipelines/cogvideox.md * Update docs/source/en/api/pipelines/cogvideox.md * Update docs/source/en/api/pipelines/cogvideox.md * Update docs/source/en/api/pipelines/cogvideox.md --------- Co-authored-by: Aryan Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/api/pipelines/cogvideox.md | 33 ++++--- scripts/convert_cogvideox_to_diffusers.py | 76 +++++++++++++--- .../autoencoders/autoencoder_kl_cogvideox.py | 1 + src/diffusers/models/embeddings.py | 76 ++++++++++++---- .../transformers/cogvideox_transformer_3d.py | 55 +++++++++--- .../pipelines/cogvideo/pipeline_cogvideox.py | 39 +++++--- .../pipeline_cogvideox_fun_control.py | 38 +++++--- .../pipeline_cogvideox_image2video.py | 89 ++++++++++++++----- .../pipeline_cogvideox_video2video.py | 29 ++++-- .../test_models_transformer_cogvideox.py | 61 +++++++++++++ 10 files changed, 405 insertions(+), 92 deletions(-) diff --git a/docs/source/en/api/pipelines/cogvideox.md b/docs/source/en/api/pipelines/cogvideox.md index f0f4fd37e6d5..40320896881c 100644 --- a/docs/source/en/api/pipelines/cogvideox.md +++ b/docs/source/en/api/pipelines/cogvideox.md @@ -29,16 +29,29 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.m This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The original codebase can be found [here](https://huggingface.co/THUDM). The original weights can be found under [hf.co/THUDM](https://huggingface.co/THUDM). -There are two models available that can be used with the text-to-video and video-to-video CogVideoX pipelines: -- [`THUDM/CogVideoX-2b`](https://huggingface.co/THUDM/CogVideoX-2b): The recommended dtype for running this model is `fp16`. -- [`THUDM/CogVideoX-5b`](https://huggingface.co/THUDM/CogVideoX-5b): The recommended dtype for running this model is `bf16`. - -There is one model available that can be used with the image-to-video CogVideoX pipeline: -- [`THUDM/CogVideoX-5b-I2V`](https://huggingface.co/THUDM/CogVideoX-5b-I2V): The recommended dtype for running this model is `bf16`. - -There are two models that support pose controllable generation (by the [Alibaba-PAI](https://huggingface.co/alibaba-pai) team): -- [`alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose): The recommended dtype for running this model is `bf16`. -- [`alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose): The recommended dtype for running this model is `bf16`. +There are three official CogVideoX checkpoints for text-to-video and video-to-video. +| checkpoints | recommended inference dtype | +|---|---| +| [`THUDM/CogVideoX-2b`](https://huggingface.co/THUDM/CogVideoX-2b) | torch.float16 | +| [`THUDM/CogVideoX-5b`](https://huggingface.co/THUDM/CogVideoX-5b) | torch.bfloat16 | +| [`THUDM/CogVideoX1.5-5b`](https://huggingface.co/THUDM/CogVideoX1.5-5b) | torch.bfloat16 | + +There are two official CogVideoX checkpoints available for image-to-video. +| checkpoints | recommended inference dtype | +|---|---| +| [`THUDM/CogVideoX-5b-I2V`](https://huggingface.co/THUDM/CogVideoX-5b-I2V) | torch.bfloat16 | +| [`THUDM/CogVideoX-1.5-5b-I2V`](https://huggingface.co/THUDM/CogVideoX-1.5-5b-I2V) | torch.bfloat16 | + +For the CogVideoX 1.5 series: +- Text-to-video (T2V) works best at a resolution of 1360x768 because it was trained with that specific resolution. +- Image-to-video (I2V) works for multiple resolutions. The width can vary from 768 to 1360, but the height must be 768. The height/width must be divisible by 16. +- Both T2V and I2V models support generation with 81 and 161 frames and work best at this value. Exporting videos at 16 FPS is recommended. + +There are two official CogVideoX checkpoints that support pose controllable generation (by the [Alibaba-PAI](https://huggingface.co/alibaba-pai) team). +| checkpoints | recommended inference dtype | +|---|---| +| [`alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose) | torch.bfloat16 | +| [`alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose) | torch.bfloat16 | ## Inference diff --git a/scripts/convert_cogvideox_to_diffusers.py b/scripts/convert_cogvideox_to_diffusers.py index 4343eaf34038..7eeed240c4de 100644 --- a/scripts/convert_cogvideox_to_diffusers.py +++ b/scripts/convert_cogvideox_to_diffusers.py @@ -80,6 +80,8 @@ def replace_up_keys_inplace(key: str, state_dict: Dict[str, Any]): "post_attn1_layernorm": "norm2.norm", "time_embed.0": "time_embedding.linear_1", "time_embed.2": "time_embedding.linear_2", + "ofs_embed.0": "ofs_embedding.linear_1", + "ofs_embed.2": "ofs_embedding.linear_2", "mixins.patch_embed": "patch_embed", "mixins.final_layer.norm_final": "norm_out.norm", "mixins.final_layer.linear": "proj_out", @@ -140,6 +142,7 @@ def convert_transformer( use_rotary_positional_embeddings: bool, i2v: bool, dtype: torch.dtype, + init_kwargs: Dict[str, Any], ): PREFIX_KEY = "model.diffusion_model." @@ -149,7 +152,9 @@ def convert_transformer( num_layers=num_layers, num_attention_heads=num_attention_heads, use_rotary_positional_embeddings=use_rotary_positional_embeddings, - use_learned_positional_embeddings=i2v, + ofs_embed_dim=512 if (i2v and init_kwargs["patch_size_t"] is not None) else None, # CogVideoX1.5-5B-I2V + use_learned_positional_embeddings=i2v and init_kwargs["patch_size_t"] is None, # CogVideoX-5B-I2V + **init_kwargs, ).to(dtype=dtype) for key in list(original_state_dict.keys()): @@ -163,13 +168,18 @@ def convert_transformer( if special_key not in key: continue handler_fn_inplace(key, original_state_dict) + transformer.load_state_dict(original_state_dict, strict=True) return transformer -def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype): +def convert_vae(ckpt_path: str, scaling_factor: float, version: str, dtype: torch.dtype): + init_kwargs = {"scaling_factor": scaling_factor} + if version == "1.5": + init_kwargs.update({"invert_scale_latents": True}) + original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True)) - vae = AutoencoderKLCogVideoX(scaling_factor=scaling_factor).to(dtype=dtype) + vae = AutoencoderKLCogVideoX(**init_kwargs).to(dtype=dtype) for key in list(original_state_dict.keys()): new_key = key[:] @@ -187,6 +197,34 @@ def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype): return vae +def get_transformer_init_kwargs(version: str): + if version == "1.0": + vae_scale_factor_spatial = 8 + init_kwargs = { + "patch_size": 2, + "patch_size_t": None, + "patch_bias": True, + "sample_height": 480 // vae_scale_factor_spatial, + "sample_width": 720 // vae_scale_factor_spatial, + "sample_frames": 49, + } + + elif version == "1.5": + vae_scale_factor_spatial = 8 + init_kwargs = { + "patch_size": 2, + "patch_size_t": 2, + "patch_bias": False, + "sample_height": 300, + "sample_width": 300, + "sample_frames": 81, + } + else: + raise ValueError("Unsupported version of CogVideoX.") + + return init_kwargs + + def get_args(): parser = argparse.ArgumentParser() parser.add_argument( @@ -202,6 +240,12 @@ def get_args(): parser.add_argument( "--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory" ) + parser.add_argument( + "--typecast_text_encoder", + action="store_true", + default=False, + help="Whether or not to apply fp16/bf16 precision to text_encoder", + ) # For CogVideoX-2B, num_layers is 30. For 5B, it is 42 parser.add_argument("--num_layers", type=int, default=30, help="Number of transformer blocks") # For CogVideoX-2B, num_attention_heads is 30. For 5B, it is 48 @@ -214,7 +258,18 @@ def get_args(): parser.add_argument("--scaling_factor", type=float, default=1.15258426, help="Scaling factor in the VAE") # For CogVideoX-2B, snr_shift_scale is 3.0. For 5B, it is 1.0 parser.add_argument("--snr_shift_scale", type=float, default=3.0, help="Scaling factor in the VAE") - parser.add_argument("--i2v", action="store_true", default=False, help="Whether to save the model weights in fp16") + parser.add_argument( + "--i2v", + action="store_true", + default=False, + help="Whether the model to be converted is the Image-to-Video version of CogVideoX.", + ) + parser.add_argument( + "--version", + choices=["1.0", "1.5"], + default="1.0", + help="Which version of CogVideoX to use for initializing default modeling parameters.", + ) return parser.parse_args() @@ -230,6 +285,7 @@ def get_args(): dtype = torch.float16 if args.fp16 else torch.bfloat16 if args.bf16 else torch.float32 if args.transformer_ckpt_path is not None: + init_kwargs = get_transformer_init_kwargs(args.version) transformer = convert_transformer( args.transformer_ckpt_path, args.num_layers, @@ -237,14 +293,19 @@ def get_args(): args.use_rotary_positional_embeddings, args.i2v, dtype, + init_kwargs, ) if args.vae_ckpt_path is not None: - vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, dtype) + # Keep VAE in float32 for better quality + vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, args.version, torch.float32) text_encoder_id = "google/t5-v1_1-xxl" tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH) text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir) + if args.typecast_text_encoder: + text_encoder = text_encoder.to(dtype=dtype) + # Apparently, the conversion does not work anymore without this :shrug: for param in text_encoder.parameters(): param.data = param.data.contiguous() @@ -276,11 +337,6 @@ def get_args(): scheduler=scheduler, ) - if args.fp16: - pipe = pipe.to(dtype=torch.float16) - if args.bf16: - pipe = pipe.to(dtype=torch.bfloat16) - # We don't use variant here because the model must be run in fp16 (2B) or bf16 (5B). It would be weird # for users to specify variant when the default is not fp32 and they want to run with the correct default (which # is either fp16/bf16 here). diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py index d9ee15062daf..fbcb964392f9 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py @@ -1057,6 +1057,7 @@ def __init__( force_upcast: float = True, use_quant_conv: bool = False, use_post_quant_conv: bool = False, + invert_scale_latents: bool = False, ): super().__init__() diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 7cbd958e1d6e..80775d477c0d 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -338,6 +338,7 @@ class CogVideoXPatchEmbed(nn.Module): def __init__( self, patch_size: int = 2, + patch_size_t: Optional[int] = None, in_channels: int = 16, embed_dim: int = 1920, text_embed_dim: int = 4096, @@ -355,6 +356,7 @@ def __init__( super().__init__() self.patch_size = patch_size + self.patch_size_t = patch_size_t self.embed_dim = embed_dim self.sample_height = sample_height self.sample_width = sample_width @@ -366,9 +368,15 @@ def __init__( self.use_positional_embeddings = use_positional_embeddings self.use_learned_positional_embeddings = use_learned_positional_embeddings - self.proj = nn.Conv2d( - in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias - ) + if patch_size_t is None: + # CogVideoX 1.0 checkpoints + self.proj = nn.Conv2d( + in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias + ) + else: + # CogVideoX 1.5 checkpoints + self.proj = nn.Linear(in_channels * patch_size * patch_size * patch_size_t, embed_dim) + self.text_proj = nn.Linear(text_embed_dim, embed_dim) if use_positional_embeddings or use_learned_positional_embeddings: @@ -407,12 +415,24 @@ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): """ text_embeds = self.text_proj(text_embeds) - batch, num_frames, channels, height, width = image_embeds.shape - image_embeds = image_embeds.reshape(-1, channels, height, width) - image_embeds = self.proj(image_embeds) - image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:]) - image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels] - image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels] + batch_size, num_frames, channels, height, width = image_embeds.shape + + if self.patch_size_t is None: + image_embeds = image_embeds.reshape(-1, channels, height, width) + image_embeds = self.proj(image_embeds) + image_embeds = image_embeds.view(batch_size, num_frames, *image_embeds.shape[1:]) + image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels] + image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels] + else: + p = self.patch_size + p_t = self.patch_size_t + + image_embeds = image_embeds.permute(0, 1, 3, 4, 2) + image_embeds = image_embeds.reshape( + batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels + ) + image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7).flatten(1, 3) + image_embeds = self.proj(image_embeds) embeds = torch.cat( [text_embeds, image_embeds], dim=1 @@ -497,7 +517,14 @@ def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tens def get_3d_rotary_pos_embed( - embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True + embed_dim, + crops_coords, + grid_size, + temporal_size, + theta: int = 10000, + use_real: bool = True, + grid_type: str = "linspace", + max_size: Optional[Tuple[int, int]] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ RoPE for video tokens with 3D structure. @@ -513,17 +540,30 @@ def get_3d_rotary_pos_embed( The size of the temporal dimension. theta (`float`): Scaling factor for frequency computation. + grid_type (`str`): + Whether to use "linspace" or "slice" to compute grids. Returns: `torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`. """ if use_real is not True: raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed") - start, stop = crops_coords - grid_size_h, grid_size_w = grid_size - grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32) - grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32) - grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32) + + if grid_type == "linspace": + start, stop = crops_coords + grid_size_h, grid_size_w = grid_size + grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32) + grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32) + grid_t = np.arange(temporal_size, dtype=np.float32) + grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32) + elif grid_type == "slice": + max_h, max_w = max_size + grid_size_h, grid_size_w = grid_size + grid_h = np.arange(max_h, dtype=np.float32) + grid_w = np.arange(max_w, dtype=np.float32) + grid_t = np.arange(temporal_size, dtype=np.float32) + else: + raise ValueError("Invalid value passed for `grid_type`.") # Compute dimensions for each axis dim_t = embed_dim // 4 @@ -559,6 +599,12 @@ def combine_time_height_width(freqs_t, freqs_h, freqs_w): t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w + + if grid_type == "slice": + t_cos, t_sin = t_cos[:temporal_size], t_sin[:temporal_size] + h_cos, h_sin = h_cos[:grid_size_h], h_sin[:grid_size_h] + w_cos, w_sin = w_cos[:grid_size_w], w_sin[:grid_size_w] + cos = combine_time_height_width(t_cos, h_cos, w_cos) sin = combine_time_height_width(t_sin, h_sin, w_sin) return cos, sin diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index 01c54ef090bd..b47d439774cc 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -170,6 +170,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): Whether to flip the sin to cos in the time embedding. time_embed_dim (`int`, defaults to `512`): Output dimension of timestep embeddings. + ofs_embed_dim (`int`, defaults to `512`): + Output dimension of "ofs" embeddings used in CogVideoX-5b-I2B in version 1.5 text_embed_dim (`int`, defaults to `4096`): Input dimension of text embeddings from the text encoder. num_layers (`int`, defaults to `30`): @@ -177,7 +179,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): dropout (`float`, defaults to `0.0`): The dropout probability to use. attention_bias (`bool`, defaults to `True`): - Whether or not to use bias in the attention projection layers. + Whether to use bias in the attention projection layers. sample_width (`int`, defaults to `90`): The width of the input latents. sample_height (`int`, defaults to `60`): @@ -198,7 +200,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): timestep_activation_fn (`str`, defaults to `"silu"`): Activation function to use when generating the timestep embeddings. norm_elementwise_affine (`bool`, defaults to `True`): - Whether or not to use elementwise affine in normalization layers. + Whether to use elementwise affine in normalization layers. norm_eps (`float`, defaults to `1e-5`): The epsilon value to use in normalization layers. spatial_interpolation_scale (`float`, defaults to `1.875`): @@ -219,6 +221,7 @@ def __init__( flip_sin_to_cos: bool = True, freq_shift: int = 0, time_embed_dim: int = 512, + ofs_embed_dim: Optional[int] = None, text_embed_dim: int = 4096, num_layers: int = 30, dropout: float = 0.0, @@ -227,6 +230,7 @@ def __init__( sample_height: int = 60, sample_frames: int = 49, patch_size: int = 2, + patch_size_t: Optional[int] = None, temporal_compression_ratio: int = 4, max_text_seq_length: int = 226, activation_fn: str = "gelu-approximate", @@ -237,6 +241,7 @@ def __init__( temporal_interpolation_scale: float = 1.0, use_rotary_positional_embeddings: bool = False, use_learned_positional_embeddings: bool = False, + patch_bias: bool = True, ): super().__init__() inner_dim = num_attention_heads * attention_head_dim @@ -251,10 +256,11 @@ def __init__( # 1. Patch embedding self.patch_embed = CogVideoXPatchEmbed( patch_size=patch_size, + patch_size_t=patch_size_t, in_channels=in_channels, embed_dim=inner_dim, text_embed_dim=text_embed_dim, - bias=True, + bias=patch_bias, sample_width=sample_width, sample_height=sample_height, sample_frames=sample_frames, @@ -267,10 +273,19 @@ def __init__( ) self.embedding_dropout = nn.Dropout(dropout) - # 2. Time embeddings + # 2. Time embeddings and ofs embedding(Only CogVideoX1.5-5B I2V have) + self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift) self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn) + self.ofs_proj = None + self.ofs_embedding = None + if ofs_embed_dim: + self.ofs_proj = Timesteps(ofs_embed_dim, flip_sin_to_cos, freq_shift) + self.ofs_embedding = TimestepEmbedding( + ofs_embed_dim, ofs_embed_dim, timestep_activation_fn + ) # same as time embeddings, for ofs + # 3. Define spatio-temporal transformers blocks self.transformer_blocks = nn.ModuleList( [ @@ -298,7 +313,15 @@ def __init__( norm_eps=norm_eps, chunk_dim=1, ) - self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels) + + if patch_size_t is None: + # For CogVideox 1.0 + output_dim = patch_size * patch_size * out_channels + else: + # For CogVideoX 1.5 + output_dim = patch_size * patch_size * patch_size_t * out_channels + + self.proj_out = nn.Linear(inner_dim, output_dim) self.gradient_checkpointing = False @@ -411,6 +434,7 @@ def forward( encoder_hidden_states: torch.Tensor, timestep: Union[int, float, torch.LongTensor], timestep_cond: Optional[torch.Tensor] = None, + ofs: Optional[Union[int, float, torch.LongTensor]] = None, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, @@ -442,6 +466,12 @@ def forward( t_emb = t_emb.to(dtype=hidden_states.dtype) emb = self.time_embedding(t_emb, timestep_cond) + if self.ofs_embedding is not None: + ofs_emb = self.ofs_proj(ofs) + ofs_emb = ofs_emb.to(dtype=hidden_states.dtype) + ofs_emb = self.ofs_embedding(ofs_emb) + emb = emb + ofs_emb + # 2. Patch embedding hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) hidden_states = self.embedding_dropout(hidden_states) @@ -491,12 +521,17 @@ def custom_forward(*inputs): hidden_states = self.proj_out(hidden_states) # 5. Unpatchify - # Note: we use `-1` instead of `channels`: - # - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels) - # - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels) p = self.config.patch_size - output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p) - output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) + p_t = self.config.patch_size_t + + if p_t is None: + output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p) + output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) + else: + output = hidden_states.reshape( + batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p + ) + output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2) if USE_PEFT_BACKEND: # remove `lora_scale` from each PEFT layer diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index 9cb042c9e80c..313b753443bb 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -442,8 +442,13 @@ def _prepare_rotary_positional_embeddings( ) -> Tuple[torch.Tensor, torch.Tensor]: grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) - base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) - base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + + p = self.transformer.config.patch_size + p_t = self.transformer.config.patch_size_t or 1 + + base_size_width = self.transformer.config.sample_width // p + base_size_height = self.transformer.config.sample_height // p + base_num_frames = (num_frames + p_t - 1) // p_t grid_crops_coords = get_resize_crop_region_for_grid( (grid_height, grid_width), base_size_width, base_size_height @@ -452,7 +457,7 @@ def _prepare_rotary_positional_embeddings( embed_dim=self.transformer.config.attention_head_dim, crops_coords=grid_crops_coords, grid_size=(grid_height, grid_width), - temporal_size=num_frames, + temporal_size=base_num_frames, ) freqs_cos = freqs_cos.to(device=device) @@ -481,9 +486,9 @@ def __call__( self, prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None, - height: int = 480, - width: int = 720, - num_frames: int = 49, + height: Optional[int] = None, + width: Optional[int] = None, + num_frames: Optional[int] = None, num_inference_steps: int = 50, timesteps: Optional[List[int]] = None, guidance_scale: float = 6, @@ -583,14 +588,13 @@ def __call__( `tuple`. When returning a tuple, the first element is a list with the generated images. """ - if num_frames > 49: - raise ValueError( - "The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation." - ) - if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial + width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial + num_frames = num_frames or self.transformer.config.sample_frames + num_videos_per_prompt = 1 # 1. Check inputs. Raise error if not correct @@ -640,7 +644,16 @@ def __call__( timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) self._num_timesteps = len(timesteps) - # 5. Prepare latents. + # 5. Prepare latents + latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + + # For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t + patch_size_t = self.transformer.config.patch_size_t + additional_frames = 0 + if patch_size_t is not None and latent_frames % patch_size_t != 0: + additional_frames = patch_size_t - latent_frames % patch_size_t + num_frames += additional_frames * self.vae_scale_factor_temporal + latent_channels = self.transformer.config.in_channels latents = self.prepare_latents( batch_size * num_videos_per_prompt, @@ -730,6 +743,8 @@ def __call__( progress_bar.update() if not output_type == "latent": + # Discard any padding frames that were added for CogVideoX 1.5 + latents = latents[:, additional_frames:] video = self.decode_latents(latents) video = self.video_processor.postprocess_video(video=video, output_type=output_type) else: diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py index 3655075bd519..4838335dc856 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py @@ -488,8 +488,13 @@ def _prepare_rotary_positional_embeddings( ) -> Tuple[torch.Tensor, torch.Tensor]: grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) - base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) - base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + + p = self.transformer.config.patch_size + p_t = self.transformer.config.patch_size_t or 1 + + base_size_width = self.transformer.config.sample_width // p + base_size_height = self.transformer.config.sample_height // p + base_num_frames = (num_frames + p_t - 1) // p_t grid_crops_coords = get_resize_crop_region_for_grid( (grid_height, grid_width), base_size_width, base_size_height @@ -498,7 +503,7 @@ def _prepare_rotary_positional_embeddings( embed_dim=self.transformer.config.attention_head_dim, crops_coords=grid_crops_coords, grid_size=(grid_height, grid_width), - temporal_size=num_frames, + temporal_size=base_num_frames, ) freqs_cos = freqs_cos.to(device=device) @@ -528,8 +533,8 @@ def __call__( prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None, control_video: Optional[List[Image.Image]] = None, - height: int = 480, - width: int = 720, + height: Optional[int] = None, + width: Optional[int] = None, num_inference_steps: int = 50, timesteps: Optional[List[int]] = None, guidance_scale: float = 6, @@ -634,6 +639,13 @@ def __call__( if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + if control_video is not None and isinstance(control_video[0], Image.Image): + control_video = [control_video] + + height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial + width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial + num_frames = len(control_video[0]) if control_video is not None else control_video_latents.size(2) + num_videos_per_prompt = 1 # 1. Check inputs. Raise error if not correct @@ -660,9 +672,6 @@ def __call__( else: batch_size = prompt_embeds.shape[0] - if control_video is not None and isinstance(control_video[0], Image.Image): - control_video = [control_video] - device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) @@ -688,9 +697,18 @@ def __call__( timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) self._num_timesteps = len(timesteps) - # 5. Prepare latents. + # 5. Prepare latents + latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + + # For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t + patch_size_t = self.transformer.config.patch_size_t + if patch_size_t is not None and latent_frames % patch_size_t != 0: + raise ValueError( + f"The number of latent frames must be divisible by `{patch_size_t=}` but the given video " + f"contains {latent_frames=}, which is not divisible." + ) + latent_channels = self.transformer.config.in_channels // 2 - num_frames = len(control_video[0]) if control_video is not None else control_video_latents.size(2) latents = self.prepare_latents( batch_size * num_videos_per_prompt, latent_channels, diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py index 783dae569bec..6fa8731dc99e 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py @@ -367,6 +367,10 @@ def prepare_latents( width // self.vae_scale_factor_spatial, ) + # For CogVideoX1.5, the latent should add 1 for padding (Not use) + if self.transformer.config.patch_size_t is not None: + shape = shape[:1] + (shape[1] + shape[1] % self.transformer.config.patch_size_t,) + shape[2:] + image = image.unsqueeze(2) # [B, C, F, H, W] if isinstance(generator, list): @@ -377,7 +381,13 @@ def prepare_latents( image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator) for img in image] image_latents = torch.cat(image_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4) # [B, F, C, H, W] - image_latents = self.vae_scaling_factor_image * image_latents + + if not self.vae.config.invert_scale_latents: + image_latents = self.vae_scaling_factor_image * image_latents + else: + # This is awkward but required because the CogVideoX team forgot to multiply the + # scaling factor during training :) + image_latents = 1 / self.vae_scaling_factor_image * image_latents padding_shape = ( batch_size, @@ -386,9 +396,15 @@ def prepare_latents( height // self.vae_scale_factor_spatial, width // self.vae_scale_factor_spatial, ) + latent_padding = torch.zeros(padding_shape, device=device, dtype=dtype) image_latents = torch.cat([image_latents, latent_padding], dim=1) + # Select the first frame along the second dimension + if self.transformer.config.patch_size_t is not None: + first_frame = image_latents[:, : image_latents.size(1) % self.transformer.config.patch_size_t, ...] + image_latents = torch.cat([first_frame, image_latents], dim=1) + if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: @@ -512,7 +528,6 @@ def unfuse_qkv_projections(self) -> None: self.transformer.unfuse_qkv_projections() self.fusing_transformer = False - # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._prepare_rotary_positional_embeddings def _prepare_rotary_positional_embeddings( self, height: int, @@ -522,18 +537,38 @@ def _prepare_rotary_positional_embeddings( ) -> Tuple[torch.Tensor, torch.Tensor]: grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) - base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) - base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) - grid_crops_coords = get_resize_crop_region_for_grid( - (grid_height, grid_width), base_size_width, base_size_height - ) - freqs_cos, freqs_sin = get_3d_rotary_pos_embed( - embed_dim=self.transformer.config.attention_head_dim, - crops_coords=grid_crops_coords, - grid_size=(grid_height, grid_width), - temporal_size=num_frames, - ) + p = self.transformer.config.patch_size + p_t = self.transformer.config.patch_size_t + + if p_t is None: + # CogVideoX 1.0 I2V + base_size_width = self.transformer.config.sample_width // p + base_size_height = self.transformer.config.sample_height // p + + grid_crops_coords = get_resize_crop_region_for_grid( + (grid_height, grid_width), base_size_width, base_size_height + ) + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=self.transformer.config.attention_head_dim, + crops_coords=grid_crops_coords, + grid_size=(grid_height, grid_width), + temporal_size=num_frames, + ) + else: + # CogVideoX 1.5 I2V + base_size_width = self.transformer.config.sample_width // p + base_size_height = self.transformer.config.sample_height // p + base_num_frames = (num_frames + p_t - 1) // p_t + + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=self.transformer.config.attention_head_dim, + crops_coords=None, + grid_size=(grid_height, grid_width), + temporal_size=base_num_frames, + grid_type="slice", + max_size=(base_size_height, base_size_width), + ) freqs_cos = freqs_cos.to(device=device) freqs_sin = freqs_sin.to(device=device) @@ -562,8 +597,8 @@ def __call__( image: PipelineImageInput, prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None, - height: int = 480, - width: int = 720, + height: Optional[int] = None, + width: Optional[int] = None, num_frames: int = 49, num_inference_steps: int = 50, timesteps: Optional[List[int]] = None, @@ -666,14 +701,13 @@ def __call__( `tuple`. When returning a tuple, the first element is a list with the generated images. """ - if num_frames > 49: - raise ValueError( - "The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation." - ) - if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial + width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial + num_frames = num_frames or self.transformer.config.sample_frames + num_videos_per_prompt = 1 # 1. Check inputs. Raise error if not correct @@ -726,6 +760,15 @@ def __call__( self._num_timesteps = len(timesteps) # 5. Prepare latents + latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + + # For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t + patch_size_t = self.transformer.config.patch_size_t + additional_frames = 0 + if patch_size_t is not None and latent_frames % patch_size_t != 0: + additional_frames = patch_size_t - latent_frames % patch_size_t + num_frames += additional_frames * self.vae_scale_factor_temporal + image = self.video_processor.preprocess(image, height=height, width=width).to( device, dtype=prompt_embeds.dtype ) @@ -754,6 +797,9 @@ def __call__( else None ) + # 8. Create ofs embeds if required + ofs_emb = None if self.transformer.config.ofs_embed_dim is None else latents.new_full((1,), fill_value=2.0) + # 8. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) @@ -778,6 +824,7 @@ def __call__( hidden_states=latent_model_input, encoder_hidden_states=prompt_embeds, timestep=timestep, + ofs=ofs_emb, image_rotary_emb=image_rotary_emb, attention_kwargs=attention_kwargs, return_dict=False, @@ -823,6 +870,8 @@ def __call__( progress_bar.update() if not output_type == "latent": + # Discard any padding frames that were added for CogVideoX 1.5 + latents = latents[:, additional_frames:] video = self.decode_latents(latents) video = self.video_processor.postprocess_video(video=video, output_type=output_type) else: diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py index e1e816eca16d..6af0ab4e115b 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py @@ -518,8 +518,13 @@ def _prepare_rotary_positional_embeddings( ) -> Tuple[torch.Tensor, torch.Tensor]: grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) - base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) - base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + + p = self.transformer.config.patch_size + p_t = self.transformer.config.patch_size_t or 1 + + base_size_width = self.transformer.config.sample_width // p + base_size_height = self.transformer.config.sample_height // p + base_num_frames = (num_frames + p_t - 1) // p_t grid_crops_coords = get_resize_crop_region_for_grid( (grid_height, grid_width), base_size_width, base_size_height @@ -528,7 +533,7 @@ def _prepare_rotary_positional_embeddings( embed_dim=self.transformer.config.attention_head_dim, crops_coords=grid_crops_coords, grid_size=(grid_height, grid_width), - temporal_size=num_frames, + temporal_size=base_num_frames, ) freqs_cos = freqs_cos.to(device=device) @@ -558,8 +563,8 @@ def __call__( video: List[Image.Image] = None, prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None, - height: int = 480, - width: int = 720, + height: Optional[int] = None, + width: Optional[int] = None, num_inference_steps: int = 50, timesteps: Optional[List[int]] = None, strength: float = 0.8, @@ -662,6 +667,10 @@ def __call__( if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial + width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial + num_frames = len(video) if latents is None else latents.size(1) + num_videos_per_prompt = 1 # 1. Check inputs. Raise error if not correct @@ -717,6 +726,16 @@ def __call__( self._num_timesteps = len(timesteps) # 5. Prepare latents + latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + + # For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t + patch_size_t = self.transformer.config.patch_size_t + if patch_size_t is not None and latent_frames % patch_size_t != 0: + raise ValueError( + f"The number of latent frames must be divisible by `{patch_size_t=}` but the given video " + f"contains {latent_frames=}, which is not divisible." + ) + if latents is None: video = self.video_processor.preprocess_video(video, height=height, width=width) video = video.to(device=device, dtype=prompt_embeds.dtype) diff --git a/tests/models/transformers/test_models_transformer_cogvideox.py b/tests/models/transformers/test_models_transformer_cogvideox.py index 1342577f0114..4c13b54e0620 100644 --- a/tests/models/transformers/test_models_transformer_cogvideox.py +++ b/tests/models/transformers/test_models_transformer_cogvideox.py @@ -76,6 +76,7 @@ def prepare_init_args_and_inputs_for_common(self): "sample_height": 8, "sample_frames": 8, "patch_size": 2, + "patch_size_t": None, "temporal_compression_ratio": 4, "max_text_seq_length": 8, } @@ -85,3 +86,63 @@ def prepare_init_args_and_inputs_for_common(self): def test_gradient_checkpointing_is_applied(self): expected_set = {"CogVideoXTransformer3DModel"} super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + +class CogVideoX1_5TransformerTests(ModelTesterMixin, unittest.TestCase): + model_class = CogVideoXTransformer3DModel + main_input_name = "hidden_states" + uses_custom_attn_processor = True + + @property + def dummy_input(self): + batch_size = 2 + num_channels = 4 + num_frames = 2 + height = 8 + width = 8 + embedding_dim = 8 + sequence_length = 8 + + hidden_states = torch.randn((batch_size, num_frames, num_channels, height, width)).to(torch_device) + encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) + timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device) + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "timestep": timestep, + } + + @property + def input_shape(self): + return (1, 4, 8, 8) + + @property + def output_shape(self): + return (1, 4, 8, 8) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + # Product of num_attention_heads * attention_head_dim must be divisible by 16 for 3D positional embeddings. + "num_attention_heads": 2, + "attention_head_dim": 8, + "in_channels": 4, + "out_channels": 4, + "time_embed_dim": 2, + "text_embed_dim": 8, + "num_layers": 1, + "sample_width": 8, + "sample_height": 8, + "sample_frames": 8, + "patch_size": 2, + "patch_size_t": 2, + "temporal_compression_ratio": 4, + "max_text_seq_length": 8, + "use_rotary_positional_embeddings": True, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_gradient_checkpointing_is_applied(self): + expected_set = {"CogVideoXTransformer3DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set)