Skip to content

Commit

Permalink
Make CogVideoX RoPE implementation consistent (#9963)
Browse files Browse the repository at this point in the history
* update cogvideox rope implementation

* apply suggestions from review
  • Loading branch information
a-r-r-o-w authored Nov 19, 2024
1 parent 7d0b9c4 commit 0583a8d
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 40 deletions.
35 changes: 24 additions & 11 deletions src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,21 +444,34 @@ def _prepare_rotary_positional_embeddings(
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)

p = self.transformer.config.patch_size
p_t = self.transformer.config.patch_size_t or 1
p_t = self.transformer.config.patch_size_t

base_size_width = self.transformer.config.sample_width // p
base_size_height = self.transformer.config.sample_height // p
base_num_frames = (num_frames + p_t - 1) // p_t

grid_crops_coords = get_resize_crop_region_for_grid(
(grid_height, grid_width), base_size_width, base_size_height
)
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
embed_dim=self.transformer.config.attention_head_dim,
crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width),
temporal_size=base_num_frames,
)
if p_t is None:
# CogVideoX 1.0
grid_crops_coords = get_resize_crop_region_for_grid(
(grid_height, grid_width), base_size_width, base_size_height
)
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
embed_dim=self.transformer.config.attention_head_dim,
crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width),
temporal_size=num_frames,
)
else:
# CogVideoX 1.5
base_num_frames = (num_frames + p_t - 1) // p_t

freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
embed_dim=self.transformer.config.attention_head_dim,
crops_coords=None,
grid_size=(grid_height, grid_width),
temporal_size=base_num_frames,
grid_type="slice",
max_size=(base_size_height, base_size_width),
)

freqs_cos = freqs_cos.to(device=device)
freqs_sin = freqs_sin.to(device=device)
Expand Down
35 changes: 24 additions & 11 deletions src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,21 +490,34 @@ def _prepare_rotary_positional_embeddings(
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)

p = self.transformer.config.patch_size
p_t = self.transformer.config.patch_size_t or 1
p_t = self.transformer.config.patch_size_t

base_size_width = self.transformer.config.sample_width // p
base_size_height = self.transformer.config.sample_height // p
base_num_frames = (num_frames + p_t - 1) // p_t

grid_crops_coords = get_resize_crop_region_for_grid(
(grid_height, grid_width), base_size_width, base_size_height
)
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
embed_dim=self.transformer.config.attention_head_dim,
crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width),
temporal_size=base_num_frames,
)
if p_t is None:
# CogVideoX 1.0
grid_crops_coords = get_resize_crop_region_for_grid(
(grid_height, grid_width), base_size_width, base_size_height
)
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
embed_dim=self.transformer.config.attention_head_dim,
crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width),
temporal_size=num_frames,
)
else:
# CogVideoX 1.5
base_num_frames = (num_frames + p_t - 1) // p_t

freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
embed_dim=self.transformer.config.attention_head_dim,
crops_coords=None,
grid_size=(grid_height, grid_width),
temporal_size=base_num_frames,
grid_type="slice",
max_size=(base_size_height, base_size_width),
)

freqs_cos = freqs_cos.to(device=device)
freqs_sin = freqs_sin.to(device=device)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,7 @@ def unfuse_qkv_projections(self) -> None:
self.transformer.unfuse_qkv_projections()
self.fusing_transformer = False

# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._prepare_rotary_positional_embeddings
def _prepare_rotary_positional_embeddings(
self,
height: int,
Expand All @@ -541,11 +542,11 @@ def _prepare_rotary_positional_embeddings(
p = self.transformer.config.patch_size
p_t = self.transformer.config.patch_size_t

if p_t is None:
# CogVideoX 1.0 I2V
base_size_width = self.transformer.config.sample_width // p
base_size_height = self.transformer.config.sample_height // p
base_size_width = self.transformer.config.sample_width // p
base_size_height = self.transformer.config.sample_height // p

if p_t is None:
# CogVideoX 1.0
grid_crops_coords = get_resize_crop_region_for_grid(
(grid_height, grid_width), base_size_width, base_size_height
)
Expand All @@ -556,9 +557,7 @@ def _prepare_rotary_positional_embeddings(
temporal_size=num_frames,
)
else:
# CogVideoX 1.5 I2V
base_size_width = self.transformer.config.sample_width // p
base_size_height = self.transformer.config.sample_height // p
# CogVideoX 1.5
base_num_frames = (num_frames + p_t - 1) // p_t

freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
Expand Down
35 changes: 24 additions & 11 deletions src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,21 +520,34 @@ def _prepare_rotary_positional_embeddings(
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)

p = self.transformer.config.patch_size
p_t = self.transformer.config.patch_size_t or 1
p_t = self.transformer.config.patch_size_t

base_size_width = self.transformer.config.sample_width // p
base_size_height = self.transformer.config.sample_height // p
base_num_frames = (num_frames + p_t - 1) // p_t

grid_crops_coords = get_resize_crop_region_for_grid(
(grid_height, grid_width), base_size_width, base_size_height
)
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
embed_dim=self.transformer.config.attention_head_dim,
crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width),
temporal_size=base_num_frames,
)
if p_t is None:
# CogVideoX 1.0
grid_crops_coords = get_resize_crop_region_for_grid(
(grid_height, grid_width), base_size_width, base_size_height
)
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
embed_dim=self.transformer.config.attention_head_dim,
crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width),
temporal_size=num_frames,
)
else:
# CogVideoX 1.5
base_num_frames = (num_frames + p_t - 1) // p_t

freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
embed_dim=self.transformer.config.attention_head_dim,
crops_coords=None,
grid_size=(grid_height, grid_width),
temporal_size=base_num_frames,
grid_type="slice",
max_size=(base_size_height, base_size_width),
)

freqs_cos = freqs_cos.to(device=device)
freqs_sin = freqs_sin.to(device=device)
Expand Down

0 comments on commit 0583a8d

Please sign in to comment.