diff --git a/models/experimental/functional_stable_diffusion3_5/reference/combined_time_step_text_proj_embeddings.py b/models/experimental/functional_stable_diffusion3_5/reference/combined_time_step_text_proj_embeddings.py new file mode 100644 index 00000000000..8881e4710ca --- /dev/null +++ b/models/experimental/functional_stable_diffusion3_5/reference/combined_time_step_text_proj_embeddings.py @@ -0,0 +1,27 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch.nn as nn +from models.experimental.functional_stable_diffusion3_5.reference.pix_art_alpha_text_projection import ( + PixArtAlphaTextProjection, +) +from models.experimental.functional_stable_diffusion3_5.reference.time_step_embeddings import TimestepEmbedding +from models.experimental.functional_stable_diffusion3_5.reference.time_steps import Timesteps + + +class CombinedTimestepTextProjEmbeddings(nn.Module): + def __init__(self, embedding_dim, pooled_projection_dim): + super().__init__() + + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu") + + def forward(self, timestep, pooled_projection): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) + pooled_projections = self.text_embedder(pooled_projection) + conditioning = timesteps_emb + pooled_projections + + return conditioning diff --git a/models/experimental/functional_stable_diffusion3_5/reference/pix_art_alpha_text_projection.py b/models/experimental/functional_stable_diffusion3_5/reference/pix_art_alpha_text_projection.py new file mode 100644 index 00000000000..14dcf0abd46 --- /dev/null +++ b/models/experimental/functional_stable_diffusion3_5/reference/pix_art_alpha_text_projection.py @@ -0,0 +1,22 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch.nn as nn +import torch + + +class PixArtAlphaTextProjection(nn.Module): + def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh"): + super().__init__() + if out_features is None: + out_features = hidden_size + self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True) + self.act_1 = nn.SiLU() + self.linear_2 = nn.Linear(in_features=hidden_size, out_features=out_features, bias=True) + + def forward(self, caption): + hidden_states = self.linear_1(caption) + hidden_states = self.act_1(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states diff --git a/models/experimental/functional_stable_diffusion3_5/reference/time_step_embeddings.py b/models/experimental/functional_stable_diffusion3_5/reference/time_step_embeddings.py new file mode 100644 index 00000000000..895b373b01c --- /dev/null +++ b/models/experimental/functional_stable_diffusion3_5/reference/time_step_embeddings.py @@ -0,0 +1,30 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch.nn as nn +import torch +from typing import Optional + + +class TimestepEmbedding(nn.Module): + def __init__( + self, + in_channels: int, + time_embed_dim: int, + act_fn: str = "silu", + out_dim: int = None, + post_act_fn: Optional[str] = None, + cond_proj_dim=None, + sample_proj_bias=True, + ): + super().__init__() + self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias) + self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim, sample_proj_bias) + self.act = nn.SiLU() + + def forward(self, sample, condition=None): + sample = self.linear_1(sample) + sample = self.act(sample) + sample = self.linear_2(sample) + return sample diff --git a/models/experimental/functional_stable_diffusion3_5/reference/time_steps.py b/models/experimental/functional_stable_diffusion3_5/reference/time_steps.py new file mode 100644 index 00000000000..c4f933a4369 --- /dev/null +++ b/models/experimental/functional_stable_diffusion3_5/reference/time_steps.py @@ -0,0 +1,55 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch.nn as nn +import torch +import math + + +def get_timestep_embedding( + timesteps: torch.Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +): + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + emb = scale * emb + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +class Timesteps(nn.Module): + def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1): + super().__init__() + + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + self.downscale_freq_shift = downscale_freq_shift + self.scale = scale + + def forward(self, timesteps): + t_emb = get_timestep_embedding( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + scale=self.scale, + ) + return t_emb diff --git a/models/experimental/functional_stable_diffusion3_5/ttnn/ttnn_combined_time_step_text_proj_embeddings.py b/models/experimental/functional_stable_diffusion3_5/ttnn/ttnn_combined_time_step_text_proj_embeddings.py new file mode 100644 index 00000000000..adfb235e5b1 --- /dev/null +++ b/models/experimental/functional_stable_diffusion3_5/ttnn/ttnn_combined_time_step_text_proj_embeddings.py @@ -0,0 +1,24 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import ttnn +from models.experimental.functional_stable_diffusion3_5.ttnn.ttnn_pix_art_alpha_text_projection import ( + ttnn_PixArtAlphaTextProjection, +) +from models.experimental.functional_stable_diffusion3_5.ttnn.ttnn_time_step_embeddings import ttnn_TimestepEmbedding +from models.experimental.functional_stable_diffusion3_5.ttnn.ttnn_time_steps import ttnn_Timesteps + + +class ttnn_CombinedTimestepTextProjEmbeddings: + def __init__(self, embedding_dim, pooled_projection_dim, parameters): + self.time_proj = ttnn_Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = ttnn_TimestepEmbedding(parameters.timestep_embedder) + self.text_embedder = ttnn_PixArtAlphaTextProjection(parameters.text_embedder) + + def __call__(self, timestep, pooled_projection, device): + timesteps_proj = self.time_proj(timestep, device) + timesteps_emb = self.timestep_embedder(timesteps_proj, device) + pooled_projections = self.text_embedder(pooled_projection, device) + conditioning = ttnn.add(timesteps_emb, pooled_projections) + return conditioning diff --git a/models/experimental/functional_stable_diffusion3_5/ttnn/ttnn_pix_art_alpha_text_projection.py b/models/experimental/functional_stable_diffusion3_5/ttnn/ttnn_pix_art_alpha_text_projection.py new file mode 100644 index 00000000000..8f4c55d61ff --- /dev/null +++ b/models/experimental/functional_stable_diffusion3_5/ttnn/ttnn_pix_art_alpha_text_projection.py @@ -0,0 +1,21 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import ttnn + + +class ttnn_PixArtAlphaTextProjection: + def __init__(self, parameters): + self.linear_1_w = parameters.linear_1.weight + self.linear_1_b = parameters.linear_1.bias + self.linear_2_w = parameters.linear_2.weight + self.linear_2_b = parameters.linear_2.bias + + def __call__(self, caption, device): + hidden_states = ttnn.linear(caption, self.linear_1_w, bias=self.linear_1_b, memory_config=ttnn.L1_MEMORY_CONFIG) + hidden_states = ttnn.silu(hidden_states) + hidden_states = ttnn.linear( + hidden_states, self.linear_2_w, bias=self.linear_2_b, memory_config=ttnn.L1_MEMORY_CONFIG + ) + return hidden_states diff --git a/models/experimental/functional_stable_diffusion3_5/ttnn/ttnn_time_step_embeddings.py b/models/experimental/functional_stable_diffusion3_5/ttnn/ttnn_time_step_embeddings.py new file mode 100644 index 00000000000..0fd7747d80b --- /dev/null +++ b/models/experimental/functional_stable_diffusion3_5/ttnn/ttnn_time_step_embeddings.py @@ -0,0 +1,19 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import ttnn + + +class ttnn_TimestepEmbedding: + def __init__(self, parameters): + self.linear_1_w = parameters.linear_1.weight + self.linear_1_b = parameters.linear_1.bias + self.linear_2_w = parameters.linear_2.weight + self.linear_2_b = parameters.linear_2.bias + + def __call__(self, sample, device): + sample = ttnn.linear(sample, self.linear_1_w, bias=self.linear_1_b, memory_config=ttnn.L1_MEMORY_CONFIG) + sample = ttnn.silu(sample) + sample = ttnn.linear(sample, self.linear_2_w, bias=self.linear_2_b, memory_config=ttnn.L1_MEMORY_CONFIG) + return sample diff --git a/models/experimental/functional_stable_diffusion3_5/ttnn/ttnn_time_steps.py b/models/experimental/functional_stable_diffusion3_5/ttnn/ttnn_time_steps.py new file mode 100644 index 00000000000..4b9285f7a7b --- /dev/null +++ b/models/experimental/functional_stable_diffusion3_5/ttnn/ttnn_time_steps.py @@ -0,0 +1,52 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import ttnn +import math +import torch + + +def get_timestep_embedding_tt( + timesteps, embedding_dim, device, flip_sin_to_cos=False, downscale_freq_shift=1, scale=1, max_period=10000 +): + half_dim = embedding_dim // 2 + val1 = -math.log(max_period) + val2 = ttnn.arange(start=0, end=half_dim, dtype=ttnn.float32, device=device, memory_config=ttnn.L1_MEMORY_CONFIG) + val2 = ttnn.to_layout(val2, layout=ttnn.TILE_LAYOUT) + expon = ttnn.multiply(val2, val1) + exponent = ttnn.div(expon, (half_dim - downscale_freq_shift), round_mode=None) + emb = ttnn.exp(exponent) + emb = ttnn.squeeze(emb, dim=0) + emb = ttnn.squeeze(emb, dim=0) + timesteps_p = ttnn.permute(timesteps, (1, 0)) + emb = ttnn.matmul(timesteps_p, emb, memory_config=ttnn.L1_MEMORY_CONFIG) + emb = ttnn.multiply(emb, scale) + emb_sin = ttnn.sin(emb) + emb_cos = ttnn.cos(emb) + emb = ttnn.concat([emb_sin, emb_cos], dim=-1, memory_config=ttnn.L1_MEMORY_CONFIG) + if flip_sin_to_cos: + emb = ttnn.concat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1, memory_config=ttnn.L1_MEMORY_CONFIG) + + return emb + + +class ttnn_Timesteps: + def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1): + super().__init__() + + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + self.downscale_freq_shift = downscale_freq_shift + self.scale = scale + + def __call__(self, timestamps, device): + t_emb = get_timestep_embedding_tt( + timestamps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + scale=self.scale, + device=device, + ) + return t_emb diff --git a/tests/ttnn/integration_tests/stable_diffusion3_5/test_ttnn_combined_time_step_text_proj_embeddings.py b/tests/ttnn/integration_tests/stable_diffusion3_5/test_ttnn_combined_time_step_text_proj_embeddings.py new file mode 100644 index 00000000000..6edd7fdf1ef --- /dev/null +++ b/tests/ttnn/integration_tests/stable_diffusion3_5/test_ttnn_combined_time_step_text_proj_embeddings.py @@ -0,0 +1,95 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +import torch.nn as nn +import ttnn +from models.experimental.functional_stable_diffusion3_5.ttnn.ttnn_combined_time_step_text_proj_embeddings import ( + ttnn_CombinedTimestepTextProjEmbeddings as tt_module, +) +from models.experimental.functional_stable_diffusion3_5.reference.combined_time_step_text_proj_embeddings import ( + CombinedTimestepTextProjEmbeddings, +) +from tests.ttnn.utils_for_testing import assert_with_pcc +from ttnn.model_preprocessing import ( + preprocess_model_parameters, + preprocess_linear_weight, + preprocess_linear_bias, +) +from models.utility_functions import skip_for_grayskull + + +def create_custom_preprocessor(device): + def custom_preprocessor(model, name): + parameters = {} + if isinstance(model, CombinedTimestepTextProjEmbeddings): + parameters["text_embedder"] = {} + parameters["text_embedder"]["linear_1"] = {} + parameters["text_embedder"]["linear_1"]["weight"] = preprocess_linear_weight( + model.text_embedder.linear_1.weight, dtype=ttnn.bfloat16 + ) + parameters["text_embedder"]["linear_1"]["bias"] = preprocess_linear_bias( + model.text_embedder.linear_1.bias, dtype=ttnn.bfloat16 + ) + parameters["text_embedder"]["linear_2"] = {} + parameters["text_embedder"]["linear_2"]["weight"] = preprocess_linear_weight( + model.text_embedder.linear_2.weight, dtype=ttnn.bfloat16 + ) + parameters["text_embedder"]["linear_2"]["bias"] = preprocess_linear_bias( + model.text_embedder.linear_2.bias, dtype=ttnn.bfloat16 + ) + parameters["timestep_embedder"] = {} + parameters["timestep_embedder"]["linear_1"] = {} + parameters["timestep_embedder"]["linear_1"]["weight"] = preprocess_linear_weight( + model.timestep_embedder.linear_1.weight, dtype=ttnn.bfloat16 + ) + parameters["timestep_embedder"]["linear_1"]["bias"] = preprocess_linear_bias( + model.timestep_embedder.linear_1.bias, dtype=ttnn.bfloat16 + ) + parameters["timestep_embedder"]["linear_2"] = {} + parameters["timestep_embedder"]["linear_2"]["weight"] = preprocess_linear_weight( + model.timestep_embedder.linear_2.weight, dtype=ttnn.bfloat16 + ) + parameters["timestep_embedder"]["linear_2"]["bias"] = preprocess_linear_bias( + model.timestep_embedder.linear_2.bias, dtype=ttnn.bfloat16 + ) + + return parameters + + return custom_preprocessor + + +@pytest.mark.parametrize( + "init_inputs,fwd_inputs", + [ + ((1536, 2048), (2, 2048)), + ], +) +@skip_for_grayskull() +@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True) +def test_ttnn_combined_time_step_text_proj_embeddings(init_inputs, fwd_inputs, device, reset_seeds): + torch_sub_module = CombinedTimestepTextProjEmbeddings( + embedding_dim=init_inputs[0], pooled_projection_dim=init_inputs[1] + ).to(dtype=torch.bfloat16) + parameters = preprocess_model_parameters( + initialize_model=lambda: torch_sub_module, device=device, custom_preprocessor=create_custom_preprocessor(device) + ) + timesteps = torch.tensor([100, 100], dtype=torch.int32) + pooled_projection = torch.randn(fwd_inputs, dtype=torch.bfloat16) + tt_input_timesteps = ttnn.from_torch( + timesteps, dtype=ttnn.bfloat16, device=device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.L1_MEMORY_CONFIG + ) + tt_input_pool_proj = ttnn.from_torch( + pooled_projection, + dtype=ttnn.bfloat16, + device=device, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + tt_sub_module = tt_module(embedding_dim=init_inputs[0], pooled_projection_dim=init_inputs[1], parameters=parameters) + tt_out = tt_sub_module(timestep=tt_input_timesteps, pooled_projection=tt_input_pool_proj, device=device) + torch_out = torch_sub_module(timesteps, pooled_projection) + tt_out_in_torch = ttnn.to_torch(tt_out) + assert_with_pcc(torch_out, tt_out_in_torch, 0.99) diff --git a/tests/ttnn/integration_tests/stable_diffusion3_5/test_ttnn_pix_art_alpha_text_projection.py b/tests/ttnn/integration_tests/stable_diffusion3_5/test_ttnn_pix_art_alpha_text_projection.py new file mode 100644 index 00000000000..feaff15f5d4 --- /dev/null +++ b/tests/ttnn/integration_tests/stable_diffusion3_5/test_ttnn_pix_art_alpha_text_projection.py @@ -0,0 +1,70 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +import ttnn +from models.experimental.functional_stable_diffusion3_5.ttnn.ttnn_pix_art_alpha_text_projection import ( + ttnn_PixArtAlphaTextProjection as tt_module, +) +from models.experimental.functional_stable_diffusion3_5.reference.pix_art_alpha_text_projection import ( + PixArtAlphaTextProjection, +) +from tests.ttnn.utils_for_testing import assert_with_pcc +from ttnn.model_preprocessing import ( + preprocess_model_parameters, + preprocess_linear_weight, + preprocess_linear_bias, +) +from models.utility_functions import skip_for_grayskull + + +def create_custom_preprocessor(device): + def custom_preprocessor(model, name): + parameters = {} + if isinstance(model, PixArtAlphaTextProjection): + parameters["text_embedder"] = {} + parameters["text_embedder"]["linear_1"] = {} + parameters["text_embedder"]["linear_1"]["weight"] = preprocess_linear_weight( + model.linear_1.weight, dtype=ttnn.bfloat16 + ) + parameters["text_embedder"]["linear_1"]["bias"] = preprocess_linear_bias( + model.linear_1.bias, dtype=ttnn.bfloat16 + ) + parameters["text_embedder"]["linear_2"] = {} + parameters["text_embedder"]["linear_2"]["weight"] = preprocess_linear_weight( + model.linear_2.weight, dtype=ttnn.bfloat16 + ) + parameters["text_embedder"]["linear_2"]["bias"] = preprocess_linear_bias( + model.linear_2.bias, dtype=ttnn.bfloat16 + ) + return parameters + + return custom_preprocessor + + +@pytest.mark.parametrize( + "in_features,hidden_size,out_features,fwd_input", + [ + (2048, 1536, None, (2, 2048)), + ], +) +@skip_for_grayskull() +@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True) +def test_ttnn_px_art_alpha_text(in_features, hidden_size, out_features, fwd_input, device, reset_seeds): + torch_sub_module = PixArtAlphaTextProjection( + in_features=in_features, hidden_size=hidden_size, out_features=out_features + ).to(dtype=torch.bfloat16) + parameters = preprocess_model_parameters( + initialize_model=lambda: torch_sub_module, device=device, custom_preprocessor=create_custom_preprocessor(device) + ) + torch_input = torch.randn(fwd_input, dtype=torch.bfloat16) + tt_input = ttnn.from_torch( + torch_input, dtype=ttnn.bfloat16, device=device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.L1_MEMORY_CONFIG + ) + tt_sub_module = tt_module(parameters.text_embedder) + tt_out = tt_sub_module(tt_input, device=device) + torch_out = torch_sub_module(torch_input) + tt_out_in_torch = ttnn.to_torch(tt_out) + assert_with_pcc(torch_out, tt_out_in_torch, 0.99) diff --git a/tests/ttnn/integration_tests/stable_diffusion3_5/test_ttnn_time_step_embeddings.py b/tests/ttnn/integration_tests/stable_diffusion3_5/test_ttnn_time_step_embeddings.py new file mode 100644 index 00000000000..9fb5a42367b --- /dev/null +++ b/tests/ttnn/integration_tests/stable_diffusion3_5/test_ttnn_time_step_embeddings.py @@ -0,0 +1,69 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +import torch.nn as nn +import ttnn +from models.experimental.functional_stable_diffusion3_5.reference.time_step_embeddings import TimestepEmbedding +from models.experimental.functional_stable_diffusion3_5.ttnn.ttnn_time_step_embeddings import ( + ttnn_TimestepEmbedding as tt_module, +) +from tests.ttnn.utils_for_testing import assert_with_pcc +from ttnn.model_preprocessing import ( + preprocess_model_parameters, + preprocess_linear_weight, + preprocess_linear_bias, +) +from models.utility_functions import skip_for_grayskull + + +def create_custom_preprocessor(device): + def custom_preprocessor(model, name): + parameters = {} + if isinstance(model, TimestepEmbedding): + parameters["timestep_embedder"] = {} + parameters["timestep_embedder"]["linear_1"] = {} + parameters["timestep_embedder"]["linear_1"]["weight"] = preprocess_linear_weight( + model.linear_1.weight, dtype=ttnn.bfloat16 + ) + parameters["timestep_embedder"]["linear_1"]["bias"] = preprocess_linear_bias( + model.linear_1.bias, dtype=ttnn.bfloat16 + ) + parameters["timestep_embedder"]["linear_2"] = {} + parameters["timestep_embedder"]["linear_2"]["weight"] = preprocess_linear_weight( + model.linear_2.weight, dtype=ttnn.bfloat16 + ) + parameters["timestep_embedder"]["linear_2"]["bias"] = preprocess_linear_bias( + model.linear_2.bias, dtype=ttnn.bfloat16 + ) + return parameters + + return custom_preprocessor + + +@pytest.mark.parametrize( + "init_inputs,fwd_input", + [ + ((256, 1536, "silu", None, None, None, True), (2, 256)), + ], +) +@skip_for_grayskull() +@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True) +def test_ttnn_step_embeddings(init_inputs, fwd_input, device, reset_seeds): + torch_sub_module = TimestepEmbedding( + init_inputs[0], init_inputs[1], init_inputs[2], init_inputs[3], init_inputs[4], init_inputs[5], init_inputs[6] + ).to(dtype=torch.bfloat16) + parameters = preprocess_model_parameters( + initialize_model=lambda: torch_sub_module, device=device, custom_preprocessor=create_custom_preprocessor(device) + ) + torch_input = torch.randn(fwd_input, dtype=torch.bfloat16) + tt_input = ttnn.from_torch( + torch_input, dtype=ttnn.bfloat16, device=device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.L1_MEMORY_CONFIG + ) + tt_sub_module = tt_module(parameters.timestep_embedder) + tt_out = tt_sub_module(tt_input, device) + torch_out = torch_sub_module(torch_input) + tt_out_in_torch = ttnn.to_torch(tt_out) + assert_with_pcc(torch_out, tt_out_in_torch, 0.99) diff --git a/tests/ttnn/integration_tests/stable_diffusion3_5/test_ttnn_time_steps.py b/tests/ttnn/integration_tests/stable_diffusion3_5/test_ttnn_time_steps.py new file mode 100644 index 00000000000..59fdcf747de --- /dev/null +++ b/tests/ttnn/integration_tests/stable_diffusion3_5/test_ttnn_time_steps.py @@ -0,0 +1,42 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import ttnn +import pytest +import torch +from models.experimental.functional_stable_diffusion3_5.ttnn.ttnn_time_steps import ttnn_Timesteps as tt_module +from models.experimental.functional_stable_diffusion3_5.reference.time_steps import Timesteps +from tests.ttnn.utils_for_testing import assert_with_pcc +from ttnn.model_preprocessing import ( + preprocess_model_parameters, + preprocess_linear_weight, + preprocess_linear_bias, +) +from models.utility_functions import skip_for_grayskull + + +@pytest.mark.parametrize( + "init_inputs", + [ + (256, True, 0, 1), + ], +) +@skip_for_grayskull() +@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True) +def test_ttnn_time_steps(init_inputs, device, reset_seeds): + torch_sub_module = Timesteps(init_inputs[0], init_inputs[1], init_inputs[2], init_inputs[3]) + parameters = preprocess_model_parameters( + initialize_model=lambda: torch_sub_module, + device=device, + ) + time_stamps = torch.tensor([100, 100], dtype=torch.int32) + tt_input = ttnn.from_torch( + time_stamps, dtype=ttnn.bfloat16, device=device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.L1_MEMORY_CONFIG + ) + tt_input = ttnn.squeeze(tt_input, dim=0) + tt_sub_module = tt_module(init_inputs[0], init_inputs[1], init_inputs[2], init_inputs[3]) + tt_out = tt_sub_module(tt_input, device) + torch_out = torch_sub_module(time_stamps) + tt_out_in_torch = ttnn.to_torch(tt_out) + assert_with_pcc(torch_out, tt_out_in_torch, 0.96)