From 0583a8d12ae903092f1e638cbc78cb03aab6ee9d Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 19 Nov 2024 17:40:38 +0530 Subject: [PATCH] Make CogVideoX RoPE implementation consistent (#9963) * update cogvideox rope implementation * apply suggestions from review --- .../pipelines/cogvideo/pipeline_cogvideox.py | 35 +++++++++++++------ .../pipeline_cogvideox_fun_control.py | 35 +++++++++++++------ .../pipeline_cogvideox_image2video.py | 13 ++++--- .../pipeline_cogvideox_video2video.py | 35 +++++++++++++------ 4 files changed, 78 insertions(+), 40 deletions(-) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index 313b753443bb..27c2de384cb8 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -444,21 +444,34 @@ def _prepare_rotary_positional_embeddings( grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) p = self.transformer.config.patch_size - p_t = self.transformer.config.patch_size_t or 1 + p_t = self.transformer.config.patch_size_t base_size_width = self.transformer.config.sample_width // p base_size_height = self.transformer.config.sample_height // p - base_num_frames = (num_frames + p_t - 1) // p_t - grid_crops_coords = get_resize_crop_region_for_grid( - (grid_height, grid_width), base_size_width, base_size_height - ) - freqs_cos, freqs_sin = get_3d_rotary_pos_embed( - embed_dim=self.transformer.config.attention_head_dim, - crops_coords=grid_crops_coords, - grid_size=(grid_height, grid_width), - temporal_size=base_num_frames, - ) + if p_t is None: + # CogVideoX 1.0 + grid_crops_coords = get_resize_crop_region_for_grid( + (grid_height, grid_width), base_size_width, base_size_height + ) + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=self.transformer.config.attention_head_dim, + crops_coords=grid_crops_coords, + grid_size=(grid_height, grid_width), + temporal_size=num_frames, + ) + else: + # CogVideoX 1.5 + base_num_frames = (num_frames + p_t - 1) // p_t + + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=self.transformer.config.attention_head_dim, + crops_coords=None, + grid_size=(grid_height, grid_width), + temporal_size=base_num_frames, + grid_type="slice", + max_size=(base_size_height, base_size_width), + ) freqs_cos = freqs_cos.to(device=device) freqs_sin = freqs_sin.to(device=device) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py index 4838335dc856..1c93f360362d 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py @@ -490,21 +490,34 @@ def _prepare_rotary_positional_embeddings( grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) p = self.transformer.config.patch_size - p_t = self.transformer.config.patch_size_t or 1 + p_t = self.transformer.config.patch_size_t base_size_width = self.transformer.config.sample_width // p base_size_height = self.transformer.config.sample_height // p - base_num_frames = (num_frames + p_t - 1) // p_t - grid_crops_coords = get_resize_crop_region_for_grid( - (grid_height, grid_width), base_size_width, base_size_height - ) - freqs_cos, freqs_sin = get_3d_rotary_pos_embed( - embed_dim=self.transformer.config.attention_head_dim, - crops_coords=grid_crops_coords, - grid_size=(grid_height, grid_width), - temporal_size=base_num_frames, - ) + if p_t is None: + # CogVideoX 1.0 + grid_crops_coords = get_resize_crop_region_for_grid( + (grid_height, grid_width), base_size_width, base_size_height + ) + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=self.transformer.config.attention_head_dim, + crops_coords=grid_crops_coords, + grid_size=(grid_height, grid_width), + temporal_size=num_frames, + ) + else: + # CogVideoX 1.5 + base_num_frames = (num_frames + p_t - 1) // p_t + + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=self.transformer.config.attention_head_dim, + crops_coords=None, + grid_size=(grid_height, grid_width), + temporal_size=base_num_frames, + grid_type="slice", + max_size=(base_size_height, base_size_width), + ) freqs_cos = freqs_cos.to(device=device) freqs_sin = freqs_sin.to(device=device) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py index 6fa8731dc99e..b227f3b0565a 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py @@ -528,6 +528,7 @@ def unfuse_qkv_projections(self) -> None: self.transformer.unfuse_qkv_projections() self.fusing_transformer = False + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._prepare_rotary_positional_embeddings def _prepare_rotary_positional_embeddings( self, height: int, @@ -541,11 +542,11 @@ def _prepare_rotary_positional_embeddings( p = self.transformer.config.patch_size p_t = self.transformer.config.patch_size_t - if p_t is None: - # CogVideoX 1.0 I2V - base_size_width = self.transformer.config.sample_width // p - base_size_height = self.transformer.config.sample_height // p + base_size_width = self.transformer.config.sample_width // p + base_size_height = self.transformer.config.sample_height // p + if p_t is None: + # CogVideoX 1.0 grid_crops_coords = get_resize_crop_region_for_grid( (grid_height, grid_width), base_size_width, base_size_height ) @@ -556,9 +557,7 @@ def _prepare_rotary_positional_embeddings( temporal_size=num_frames, ) else: - # CogVideoX 1.5 I2V - base_size_width = self.transformer.config.sample_width // p - base_size_height = self.transformer.config.sample_height // p + # CogVideoX 1.5 base_num_frames = (num_frames + p_t - 1) // p_t freqs_cos, freqs_sin = get_3d_rotary_pos_embed( diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py index 6af0ab4e115b..315e03553500 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py @@ -520,21 +520,34 @@ def _prepare_rotary_positional_embeddings( grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) p = self.transformer.config.patch_size - p_t = self.transformer.config.patch_size_t or 1 + p_t = self.transformer.config.patch_size_t base_size_width = self.transformer.config.sample_width // p base_size_height = self.transformer.config.sample_height // p - base_num_frames = (num_frames + p_t - 1) // p_t - grid_crops_coords = get_resize_crop_region_for_grid( - (grid_height, grid_width), base_size_width, base_size_height - ) - freqs_cos, freqs_sin = get_3d_rotary_pos_embed( - embed_dim=self.transformer.config.attention_head_dim, - crops_coords=grid_crops_coords, - grid_size=(grid_height, grid_width), - temporal_size=base_num_frames, - ) + if p_t is None: + # CogVideoX 1.0 + grid_crops_coords = get_resize_crop_region_for_grid( + (grid_height, grid_width), base_size_width, base_size_height + ) + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=self.transformer.config.attention_head_dim, + crops_coords=grid_crops_coords, + grid_size=(grid_height, grid_width), + temporal_size=num_frames, + ) + else: + # CogVideoX 1.5 + base_num_frames = (num_frames + p_t - 1) // p_t + + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=self.transformer.config.attention_head_dim, + crops_coords=None, + grid_size=(grid_height, grid_width), + temporal_size=base_num_frames, + grid_type="slice", + max_size=(base_size_height, base_size_width), + ) freqs_cos = freqs_cos.to(device=device) freqs_sin = freqs_sin.to(device=device)