Skip to content

Commit

Permalink
#15969: ttnn implementation of sd3_5 combined_time_steps_text_project…
Browse files Browse the repository at this point in the history
…ions sub_module
  • Loading branch information
vguduruTT committed Dec 23, 2024
1 parent 91e61c0 commit fdf6421
Show file tree
Hide file tree
Showing 12 changed files with 526 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import torch.nn as nn
from models.experimental.functional_stable_diffusion3_5.reference.pix_art_alpha_text_projection import (
PixArtAlphaTextProjection,
)
from models.experimental.functional_stable_diffusion3_5.reference.time_step_embeddings import TimestepEmbedding
from models.experimental.functional_stable_diffusion3_5.reference.time_steps import Timesteps


class CombinedTimestepTextProjEmbeddings(nn.Module):
def __init__(self, embedding_dim, pooled_projection_dim):
super().__init__()

self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")

def forward(self, timestep, pooled_projection):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype))
pooled_projections = self.text_embedder(pooled_projection)
conditioning = timesteps_emb + pooled_projections

return conditioning
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import torch.nn as nn
import torch


class PixArtAlphaTextProjection(nn.Module):
def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh"):
super().__init__()
if out_features is None:
out_features = hidden_size
self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
self.act_1 = nn.SiLU()
self.linear_2 = nn.Linear(in_features=hidden_size, out_features=out_features, bias=True)

def forward(self, caption):
hidden_states = self.linear_1(caption)
hidden_states = self.act_1(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import torch.nn as nn
import torch
from typing import Optional


class TimestepEmbedding(nn.Module):
def __init__(
self,
in_channels: int,
time_embed_dim: int,
act_fn: str = "silu",
out_dim: int = None,
post_act_fn: Optional[str] = None,
cond_proj_dim=None,
sample_proj_bias=True,
):
super().__init__()
self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim, sample_proj_bias)
self.act = nn.SiLU()

def forward(self, sample, condition=None):
sample = self.linear_1(sample)
sample = self.act(sample)
sample = self.linear_2(sample)
return sample
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import torch.nn as nn
import torch
import math


def get_timestep_embedding(
timesteps: torch.Tensor,
embedding_dim: int,
flip_sin_to_cos: bool = False,
downscale_freq_shift: float = 1,
scale: float = 1,
max_period: int = 10000,
):
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"

half_dim = embedding_dim // 2
exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
exponent = exponent / (half_dim - downscale_freq_shift)

emb = torch.exp(exponent)
emb = timesteps[:, None].float() * emb[None, :]

emb = scale * emb
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)

if flip_sin_to_cos:
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)

if embedding_dim % 2 == 1:
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb


class Timesteps(nn.Module):
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
super().__init__()

self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift
self.scale = scale

def forward(self, timesteps):
t_emb = get_timestep_embedding(
timesteps,
self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
scale=self.scale,
)
return t_emb
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import ttnn
from models.experimental.functional_stable_diffusion3_5.ttnn.ttnn_pix_art_alpha_text_projection import (
ttnn_PixArtAlphaTextProjection,
)
from models.experimental.functional_stable_diffusion3_5.ttnn.ttnn_time_step_embeddings import ttnn_TimestepEmbedding
from models.experimental.functional_stable_diffusion3_5.ttnn.ttnn_time_steps import ttnn_Timesteps


class ttnn_CombinedTimestepTextProjEmbeddings:
def __init__(self, embedding_dim, pooled_projection_dim, parameters):
self.time_proj = ttnn_Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = ttnn_TimestepEmbedding(parameters.timestep_embedder)
self.text_embedder = ttnn_PixArtAlphaTextProjection(parameters.text_embedder)

def __call__(self, timestep, pooled_projection, device):
timesteps_proj = self.time_proj(timestep, device)
timesteps_emb = self.timestep_embedder(timesteps_proj, device)
pooled_projections = self.text_embedder(pooled_projection, device)
conditioning = ttnn.add(timesteps_emb, pooled_projections)
return conditioning
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import ttnn


class ttnn_PixArtAlphaTextProjection:
def __init__(self, parameters):
self.linear_1_w = parameters.linear_1.weight
self.linear_1_b = parameters.linear_1.bias
self.linear_2_w = parameters.linear_2.weight
self.linear_2_b = parameters.linear_2.bias

def __call__(self, caption, device):
hidden_states = ttnn.linear(caption, self.linear_1_w, bias=self.linear_1_b, memory_config=ttnn.L1_MEMORY_CONFIG)
hidden_states = ttnn.silu(hidden_states)
hidden_states = ttnn.linear(
hidden_states, self.linear_2_w, bias=self.linear_2_b, memory_config=ttnn.L1_MEMORY_CONFIG
)
return hidden_states
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import ttnn


class ttnn_TimestepEmbedding:
def __init__(self, parameters):
self.linear_1_w = parameters.linear_1.weight
self.linear_1_b = parameters.linear_1.bias
self.linear_2_w = parameters.linear_2.weight
self.linear_2_b = parameters.linear_2.bias

def __call__(self, sample, device):
sample = ttnn.linear(sample, self.linear_1_w, bias=self.linear_1_b, memory_config=ttnn.L1_MEMORY_CONFIG)
sample = ttnn.silu(sample)
sample = ttnn.linear(sample, self.linear_2_w, bias=self.linear_2_b, memory_config=ttnn.L1_MEMORY_CONFIG)
return sample
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import ttnn
import math
import torch


def get_timestep_embedding_tt(
timesteps, embedding_dim, device, flip_sin_to_cos=False, downscale_freq_shift=1, scale=1, max_period=10000
):
half_dim = embedding_dim // 2
val1 = -math.log(max_period)
val2 = ttnn.arange(start=0, end=half_dim, dtype=ttnn.float32, device=device, memory_config=ttnn.L1_MEMORY_CONFIG)
val2 = ttnn.to_layout(val2, layout=ttnn.TILE_LAYOUT)
expon = ttnn.multiply(val2, val1)
exponent = ttnn.div(expon, (half_dim - downscale_freq_shift), round_mode=None)
emb = ttnn.exp(exponent)
emb = ttnn.squeeze(emb, dim=0)
emb = ttnn.squeeze(emb, dim=0)
timesteps_p = ttnn.permute(timesteps, (1, 0))
emb = ttnn.matmul(timesteps_p, emb, memory_config=ttnn.L1_MEMORY_CONFIG)
emb = ttnn.multiply(emb, scale)
emb_sin = ttnn.sin(emb)
emb_cos = ttnn.cos(emb)
emb = ttnn.concat([emb_sin, emb_cos], dim=-1, memory_config=ttnn.L1_MEMORY_CONFIG)
if flip_sin_to_cos:
emb = ttnn.concat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1, memory_config=ttnn.L1_MEMORY_CONFIG)

return emb


class ttnn_Timesteps:
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
super().__init__()

self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift
self.scale = scale

def __call__(self, timestamps, device):
t_emb = get_timestep_embedding_tt(
timestamps,
self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
scale=self.scale,
device=device,
)
return t_emb
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import pytest
import torch
import torch.nn as nn
import ttnn
from models.experimental.functional_stable_diffusion3_5.ttnn.ttnn_combined_time_step_text_proj_embeddings import (
ttnn_CombinedTimestepTextProjEmbeddings as tt_module,
)
from models.experimental.functional_stable_diffusion3_5.reference.combined_time_step_text_proj_embeddings import (
CombinedTimestepTextProjEmbeddings,
)
from tests.ttnn.utils_for_testing import assert_with_pcc
from ttnn.model_preprocessing import (
preprocess_model_parameters,
preprocess_linear_weight,
preprocess_linear_bias,
)
from models.utility_functions import skip_for_grayskull


def create_custom_preprocessor(device):
def custom_preprocessor(model, name):
parameters = {}
if isinstance(model, CombinedTimestepTextProjEmbeddings):
parameters["text_embedder"] = {}
parameters["text_embedder"]["linear_1"] = {}
parameters["text_embedder"]["linear_1"]["weight"] = preprocess_linear_weight(
model.text_embedder.linear_1.weight, dtype=ttnn.bfloat16
)
parameters["text_embedder"]["linear_1"]["bias"] = preprocess_linear_bias(
model.text_embedder.linear_1.bias, dtype=ttnn.bfloat16
)
parameters["text_embedder"]["linear_2"] = {}
parameters["text_embedder"]["linear_2"]["weight"] = preprocess_linear_weight(
model.text_embedder.linear_2.weight, dtype=ttnn.bfloat16
)
parameters["text_embedder"]["linear_2"]["bias"] = preprocess_linear_bias(
model.text_embedder.linear_2.bias, dtype=ttnn.bfloat16
)
parameters["timestep_embedder"] = {}
parameters["timestep_embedder"]["linear_1"] = {}
parameters["timestep_embedder"]["linear_1"]["weight"] = preprocess_linear_weight(
model.timestep_embedder.linear_1.weight, dtype=ttnn.bfloat16
)
parameters["timestep_embedder"]["linear_1"]["bias"] = preprocess_linear_bias(
model.timestep_embedder.linear_1.bias, dtype=ttnn.bfloat16
)
parameters["timestep_embedder"]["linear_2"] = {}
parameters["timestep_embedder"]["linear_2"]["weight"] = preprocess_linear_weight(
model.timestep_embedder.linear_2.weight, dtype=ttnn.bfloat16
)
parameters["timestep_embedder"]["linear_2"]["bias"] = preprocess_linear_bias(
model.timestep_embedder.linear_2.bias, dtype=ttnn.bfloat16
)

return parameters

return custom_preprocessor


@pytest.mark.parametrize(
"init_inputs,fwd_inputs",
[
((1536, 2048), (2, 2048)),
],
)
@skip_for_grayskull()
@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True)
def test_ttnn_combined_time_step_text_proj_embeddings(init_inputs, fwd_inputs, device, reset_seeds):
torch_sub_module = CombinedTimestepTextProjEmbeddings(
embedding_dim=init_inputs[0], pooled_projection_dim=init_inputs[1]
).to(dtype=torch.bfloat16)
parameters = preprocess_model_parameters(
initialize_model=lambda: torch_sub_module, device=device, custom_preprocessor=create_custom_preprocessor(device)
)
timesteps = torch.tensor([100, 100], dtype=torch.int32)
pooled_projection = torch.randn(fwd_inputs, dtype=torch.bfloat16)
tt_input_timesteps = ttnn.from_torch(
timesteps, dtype=ttnn.bfloat16, device=device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.L1_MEMORY_CONFIG
)
tt_input_pool_proj = ttnn.from_torch(
pooled_projection,
dtype=ttnn.bfloat16,
device=device,
layout=ttnn.TILE_LAYOUT,
memory_config=ttnn.L1_MEMORY_CONFIG,
)
tt_sub_module = tt_module(embedding_dim=init_inputs[0], pooled_projection_dim=init_inputs[1], parameters=parameters)
tt_out = tt_sub_module(timestep=tt_input_timesteps, pooled_projection=tt_input_pool_proj, device=device)
torch_out = torch_sub_module(timesteps, pooled_projection)
tt_out_in_torch = ttnn.to_torch(tt_out)
assert_with_pcc(torch_out, tt_out_in_torch, 0.99)
Loading

0 comments on commit fdf6421

Please sign in to comment.