Skip to content

Commit

Permalink
Use torch in get_2d_sincos_pos_embed and get_3d_sincos_pos_embed (
Browse files Browse the repository at this point in the history
#10156)

* Use torch in get_2d_sincos_pos_embed

* Use torch in get_3d_sincos_pos_embed

* get_1d_sincos_pos_embed_from_grid in LatteTransformer3DModel

* deprecate

* move deprecate, make private
  • Loading branch information
hlky authored Dec 13, 2024
1 parent 6bd30ba commit 6324340
Show file tree
Hide file tree
Showing 3 changed files with 247 additions and 18 deletions.
257 changes: 243 additions & 14 deletions src/diffusers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,78 @@ def get_3d_sincos_pos_embed(
temporal_size: int,
spatial_interpolation_scale: float = 1.0,
temporal_interpolation_scale: float = 1.0,
device: Optional[torch.device] = None,
output_type: str = "np",
) -> torch.Tensor:
r"""
Creates 3D sinusoidal positional embeddings.
Args:
embed_dim (`int`):
The embedding dimension of inputs. It must be divisible by 16.
spatial_size (`int` or `Tuple[int, int]`):
The spatial dimension of positional embeddings. If an integer is provided, the same size is applied to both
spatial dimensions (height and width).
temporal_size (`int`):
The temporal dimension of postional embeddings (number of frames).
spatial_interpolation_scale (`float`, defaults to 1.0):
Scale factor for spatial grid interpolation.
temporal_interpolation_scale (`float`, defaults to 1.0):
Scale factor for temporal grid interpolation.
Returns:
`torch.Tensor`:
The 3D sinusoidal positional embeddings of shape `[temporal_size, spatial_size[0] * spatial_size[1],
embed_dim]`.
"""
if output_type == "np":
return _get_3d_sincos_pos_embed_np(
embed_dim=embed_dim,
spatial_size=spatial_size,
temporal_size=temporal_size,
spatial_interpolation_scale=spatial_interpolation_scale,
temporal_interpolation_scale=temporal_interpolation_scale,
)
if embed_dim % 4 != 0:
raise ValueError("`embed_dim` must be divisible by 4")
if isinstance(spatial_size, int):
spatial_size = (spatial_size, spatial_size)

embed_dim_spatial = 3 * embed_dim // 4
embed_dim_temporal = embed_dim // 4

# 1. Spatial
grid_h = torch.arange(spatial_size[1], device=device, dtype=torch.float32) / spatial_interpolation_scale
grid_w = torch.arange(spatial_size[0], device=device, dtype=torch.float32) / spatial_interpolation_scale
grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here w goes first
grid = torch.stack(grid, dim=0)

grid = grid.reshape([2, 1, spatial_size[1], spatial_size[0]])
pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid, output_type="pt")

# 2. Temporal
grid_t = torch.arange(temporal_size, device=device, dtype=torch.float32) / temporal_interpolation_scale
pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t, output_type="pt")

# 3. Concat
pos_embed_spatial = pos_embed_spatial[None, :, :]
pos_embed_spatial = pos_embed_spatial.repeat_interleave(temporal_size, dim=0) # [T, H*W, D // 4 * 3]

pos_embed_temporal = pos_embed_temporal[:, None, :]
pos_embed_temporal = pos_embed_temporal.repeat_interleave(
spatial_size[0] * spatial_size[1], dim=1
) # [T, H*W, D // 4]

pos_embed = torch.concat([pos_embed_temporal, pos_embed_spatial], dim=-1) # [T, H*W, D]
return pos_embed


def _get_3d_sincos_pos_embed_np(
embed_dim: int,
spatial_size: Union[int, Tuple[int, int]],
temporal_size: int,
spatial_interpolation_scale: float = 1.0,
temporal_interpolation_scale: float = 1.0,
) -> np.ndarray:
r"""
Creates 3D sinusoidal positional embeddings.
Expand All @@ -106,6 +178,12 @@ def get_3d_sincos_pos_embed(
The 3D sinusoidal positional embeddings of shape `[temporal_size, spatial_size[0] * spatial_size[1],
embed_dim]`.
"""
deprecation_message = (
"`get_3d_sincos_pos_embed` uses `torch` and supports `device`."
" `from_numpy` is no longer required."
" Pass `output_type='pt' to use the new version now."
)
deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
if embed_dim % 4 != 0:
raise ValueError("`embed_dim` must be divisible by 4")
if isinstance(spatial_size, int):
Expand Down Expand Up @@ -139,6 +217,143 @@ def get_3d_sincos_pos_embed(


def get_2d_sincos_pos_embed(
embed_dim,
grid_size,
cls_token=False,
extra_tokens=0,
interpolation_scale=1.0,
base_size=16,
device: Optional[torch.device] = None,
output_type: str = "np",
):
"""
Creates 2D sinusoidal positional embeddings.
Args:
embed_dim (`int`):
The embedding dimension.
grid_size (`int`):
The size of the grid height and width.
cls_token (`bool`, defaults to `False`):
Whether or not to add a classification token.
extra_tokens (`int`, defaults to `0`):
The number of extra tokens to add.
interpolation_scale (`float`, defaults to `1.0`):
The scale of the interpolation.
Returns:
pos_embed (`torch.Tensor`):
Shape is either `[grid_size * grid_size, embed_dim]` if not using cls_token, or `[1 + grid_size*grid_size,
embed_dim]` if using cls_token
"""
if output_type == "np":
deprecation_message = (
"`get_2d_sincos_pos_embed` uses `torch` and supports `device`."
" `from_numpy` is no longer required."
" Pass `output_type='pt' to use the new version now."
)
deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
return get_2d_sincos_pos_embed_np(
embed_dim=embed_dim,
grid_size=grid_size,
cls_token=cls_token,
extra_tokens=extra_tokens,
interpolation_scale=interpolation_scale,
base_size=base_size,
)
if isinstance(grid_size, int):
grid_size = (grid_size, grid_size)

grid_h = (
torch.arange(grid_size[0], device=device, dtype=torch.float32)
/ (grid_size[0] / base_size)
/ interpolation_scale
)
grid_w = (
torch.arange(grid_size[1], device=device, dtype=torch.float32)
/ (grid_size[1] / base_size)
/ interpolation_scale
)
grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here w goes first
grid = torch.stack(grid, dim=0)

grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, output_type=output_type)
if cls_token and extra_tokens > 0:
pos_embed = torch.concat([torch.zeros([extra_tokens, embed_dim]), pos_embed], dim=0)
return pos_embed


def get_2d_sincos_pos_embed_from_grid(embed_dim, grid, output_type="np"):
r"""
This function generates 2D sinusoidal positional embeddings from a grid.
Args:
embed_dim (`int`): The embedding dimension.
grid (`torch.Tensor`): Grid of positions with shape `(H * W,)`.
Returns:
`torch.Tensor`: The 2D sinusoidal positional embeddings with shape `(H * W, embed_dim)`
"""
if output_type == "np":
deprecation_message = (
"`get_2d_sincos_pos_embed_from_grid` uses `torch` and supports `device`."
" `from_numpy` is no longer required."
" Pass `output_type='pt' to use the new version now."
)
deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
return get_2d_sincos_pos_embed_from_grid_np(
embed_dim=embed_dim,
grid=grid,
)
if embed_dim % 2 != 0:
raise ValueError("embed_dim must be divisible by 2")

# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0], output_type=output_type) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1], output_type=output_type) # (H*W, D/2)

emb = torch.concat([emb_h, emb_w], dim=1) # (H*W, D)
return emb


def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np"):
"""
This function generates 1D positional embeddings from a grid.
Args:
embed_dim (`int`): The embedding dimension `D`
pos (`torch.Tensor`): 1D tensor of positions with shape `(M,)`
Returns:
`torch.Tensor`: Sinusoidal positional embeddings of shape `(M, D)`.
"""
if output_type == "np":
deprecation_message = (
"`get_1d_sincos_pos_embed_from_grid` uses `torch` and supports `device`."
" `from_numpy` is no longer required."
" Pass `output_type='pt' to use the new version now."
)
deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
return get_1d_sincos_pos_embed_from_grid_np(embed_dim=embed_dim, pos=pos)
if embed_dim % 2 != 0:
raise ValueError("embed_dim must be divisible by 2")

omega = torch.arange(embed_dim // 2, device=pos.device, dtype=torch.float64)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,)

pos = pos.reshape(-1) # (M,)
out = torch.outer(pos, omega) # (M, D/2), outer product

emb_sin = torch.sin(out) # (M, D/2)
emb_cos = torch.cos(out) # (M, D/2)

emb = torch.concat([emb_sin, emb_cos], dim=1) # (M, D)
return emb


def get_2d_sincos_pos_embed_np(
embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16
):
"""
Expand Down Expand Up @@ -170,13 +385,13 @@ def get_2d_sincos_pos_embed(
grid = np.stack(grid, axis=0)

grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
pos_embed = get_2d_sincos_pos_embed_from_grid_np(embed_dim, grid)
if cls_token and extra_tokens > 0:
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
return pos_embed


def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
def get_2d_sincos_pos_embed_from_grid_np(embed_dim, grid):
r"""
This function generates 2D sinusoidal positional embeddings from a grid.
Expand All @@ -191,14 +406,14 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
raise ValueError("embed_dim must be divisible by 2")

# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb_h = get_1d_sincos_pos_embed_from_grid_np(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid_np(embed_dim // 2, grid[1]) # (H*W, D/2)

emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb


def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
def get_1d_sincos_pos_embed_from_grid_np(embed_dim, pos):
"""
This function generates 1D positional embeddings from a grid.
Expand Down Expand Up @@ -288,10 +503,14 @@ def __init__(
self.pos_embed = None
elif pos_embed_type == "sincos":
pos_embed = get_2d_sincos_pos_embed(
embed_dim, grid_size, base_size=self.base_size, interpolation_scale=self.interpolation_scale
embed_dim,
grid_size,
base_size=self.base_size,
interpolation_scale=self.interpolation_scale,
output_type="pt",
)
persistent = True if pos_embed_max_size else False
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=persistent)
self.register_buffer("pos_embed", pos_embed.float().unsqueeze(0), persistent=persistent)
else:
raise ValueError(f"Unsupported pos_embed_type: {pos_embed_type}")

Expand Down Expand Up @@ -341,8 +560,10 @@ def forward(self, latent):
grid_size=(height, width),
base_size=self.base_size,
interpolation_scale=self.interpolation_scale,
device=latent.device,
output_type="pt",
)
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).to(latent.device)
pos_embed = pos_embed.float().unsqueeze(0)
else:
pos_embed = self.pos_embed

Expand Down Expand Up @@ -453,7 +674,9 @@ def __init__(
pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames)
self.register_buffer("pos_embedding", pos_embedding, persistent=persistent)

def _get_positional_embeddings(self, sample_height: int, sample_width: int, sample_frames: int) -> torch.Tensor:
def _get_positional_embeddings(
self, sample_height: int, sample_width: int, sample_frames: int, device: Optional[torch.device] = None
) -> torch.Tensor:
post_patch_height = sample_height // self.patch_size
post_patch_width = sample_width // self.patch_size
post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1
Expand All @@ -465,8 +688,10 @@ def _get_positional_embeddings(self, sample_height: int, sample_width: int, samp
post_time_compression_frames,
self.spatial_interpolation_scale,
self.temporal_interpolation_scale,
device=device,
output_type="pt",
)
pos_embedding = torch.from_numpy(pos_embedding).flatten(0, 1)
pos_embedding = pos_embedding.flatten(0, 1)
joint_pos_embedding = torch.zeros(
1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False
)
Expand Down Expand Up @@ -521,8 +746,10 @@ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
or self.sample_width != width
or self.sample_frames != pre_time_compression_frames
):
pos_embedding = self._get_positional_embeddings(height, width, pre_time_compression_frames)
pos_embedding = pos_embedding.to(embeds.device, dtype=embeds.dtype)
pos_embedding = self._get_positional_embeddings(
height, width, pre_time_compression_frames, device=embeds.device
)
pos_embedding = pos_embedding.to(dtype=embeds.dtype)
else:
pos_embedding = self.pos_embedding

Expand Down Expand Up @@ -552,9 +779,11 @@ def __init__(
# Linear projection for text embeddings
self.text_proj = nn.Linear(text_hidden_size, hidden_size)

pos_embed = get_2d_sincos_pos_embed(hidden_size, pos_embed_max_size, base_size=pos_embed_max_size)
pos_embed = get_2d_sincos_pos_embed(
hidden_size, pos_embed_max_size, base_size=pos_embed_max_size, output_type="pt"
)
pos_embed = pos_embed.reshape(pos_embed_max_size, pos_embed_max_size, hidden_size)
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float(), persistent=False)
self.register_buffer("pos_embed", pos_embed.float(), persistent=False)

def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, channel, height, width = hidden_states.shape
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/models/transformers/latte_transformer_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,9 @@ def __init__(

# define temporal positional embedding
temp_pos_embed = get_1d_sincos_pos_embed_from_grid(
inner_dim, torch.arange(0, video_length).unsqueeze(1)
inner_dim, torch.arange(0, video_length).unsqueeze(1), output_type="pt"
) # 1152 hidden size
self.register_buffer("temp_pos_embed", torch.from_numpy(temp_pos_embed).float().unsqueeze(0), persistent=False)
self.register_buffer("temp_pos_embed", temp_pos_embed.float().unsqueeze(0), persistent=False)

self.gradient_checkpointing = False

Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/pipelines/unidiffuser/modeling_uvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ def __init__(

self.use_pos_embed = use_pos_embed
if self.use_pos_embed:
pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5))
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5), output_type="pt")
self.register_buffer("pos_embed", pos_embed.float().unsqueeze(0), persistent=False)

def forward(self, latent):
latent = self.proj(latent)
Expand Down

0 comments on commit 6324340

Please sign in to comment.