diff --git a/models/experimental/functional_stable_diffusion3_5/reference/attention.py b/models/experimental/functional_stable_diffusion3_5/reference/attention.py new file mode 100644 index 000000000000..3ba10f2a17e3 --- /dev/null +++ b/models/experimental/functional_stable_diffusion3_5/reference/attention.py @@ -0,0 +1,202 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import inspect +from typing import Callable, List, Optional +import torch +import torch.nn.functional as F +from torch import nn + + +class Attention(nn.Module): + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + kv_heads: Optional[int] = None, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + cross_attention_norm: Optional[str] = None, + cross_attention_norm_num_groups: int = 32, + qk_norm: Optional[str] = None, + added_kv_proj_dim: Optional[int] = None, + added_proj_bias: Optional[bool] = True, + norm_num_groups: Optional[int] = None, + spatial_norm_dim: Optional[int] = None, + out_bias: bool = True, + scale_qk: bool = True, + only_cross_attention: bool = False, + eps: float = 1e-5, + rescale_output_factor: float = 1.0, + residual_connection: bool = False, + _from_deprecated_attn_block: bool = False, + processor: Optional["AttnProcessor"] = None, + out_dim: int = None, + context_pre_only=None, + pre_only=False, + elementwise_affine: bool = True, + ): + super().__init__() + from diffusers.models.normalization import RMSNorm + + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads + self.query_dim = query_dim + self.use_bias = bias + self.is_cross_attention = cross_attention_dim is not None + self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + self.rescale_output_factor = rescale_output_factor + self.residual_connection = residual_connection + self.dropout = dropout + self.fused_projections = False + self.out_dim = out_dim if out_dim is not None else query_dim + self.context_pre_only = context_pre_only + self.pre_only = pre_only + self._from_deprecated_attn_block = _from_deprecated_attn_block + self.scale_qk = scale_qk + self.scale = dim_head**-0.5 if self.scale_qk else 1.0 + self.heads = out_dim // dim_head if out_dim is not None else heads + self.sliceable_head_dim = heads + self.added_kv_proj_dim = added_kv_proj_dim + self.only_cross_attention = only_cross_attention + if qk_norm is None: + self.norm_q = None + self.norm_k = None + elif qk_norm == "rms_norm": + self.norm_q = RMSNorm(dim_head, eps=eps) + self.norm_k = RMSNorm(dim_head, eps=eps) + if cross_attention_norm is None: + self.norm_cross = None + self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) + if not self.only_cross_attention: + self.to_k = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) + self.to_v = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) + else: + self.to_k = None + self.to_v = None + self.added_proj_bias = added_proj_bias + if self.added_kv_proj_dim is not None: + self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) + self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) + if self.context_pre_only is not None: + self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + if not self.pre_only: + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + if self.context_pre_only is not None and not self.context_pre_only: + self.to_add_out = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias) + if qk_norm is not None and added_kv_proj_dim is not None: + if qk_norm == "rms_norm": + self.norm_added_q = RMSNorm(dim_head, eps=eps) + self.norm_added_k = RMSNorm(dim_head, eps=eps) + else: + self.norm_added_q = None + self.norm_added_k = None + if processor is None: + processor = ( + AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() + ) + self.set_processor(processor) + + def set_processor(self, processor: "AttnProcessor") -> None: + if ( + hasattr(self, "processor") + and isinstance(self.processor, torch.nn.Module) + and not isinstance(processor, torch.nn.Module) + ): + logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") + self._modules.pop("processor") + self.processor = processor + + def __call__( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **cross_attention_kwargs, + ) -> torch.Tensor: + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + quiet_attn_parameters = {"ip_adapter_masks"} + unused_kwargs = [ + k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters + ] + cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} + return self.processor( + self, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + +class JointAttnProcessor2_0: + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + *args, + **kwargs, + ) -> torch.FloatTensor: + residual = hidden_states + batch_size = hidden_states.shape[0] + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + if encoder_hidden_states is not None: + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + query = torch.cat([query, encoder_hidden_states_query_proj], dim=2) + key = torch.cat([key, encoder_hidden_states_key_proj], dim=2) + value = torch.cat([value, encoder_hidden_states_value_proj], dim=2) + hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + if encoder_hidden_states is not None: + hidden_states, encoder_hidden_states = ( + hidden_states[:, : residual.shape[1]], + hidden_states[:, residual.shape[1] :], + ) + if not attn.context_pre_only: + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + hidden_states = attn.to_out[0](hidden_states) + if encoder_hidden_states is not None: + return hidden_states, encoder_hidden_states + else: + return hidden_states diff --git a/models/experimental/functional_stable_diffusion3_5/reference/rms_norm.py b/models/experimental/functional_stable_diffusion3_5/reference/rms_norm.py new file mode 100644 index 000000000000..b58e7de16f18 --- /dev/null +++ b/models/experimental/functional_stable_diffusion3_5/reference/rms_norm.py @@ -0,0 +1,33 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch.nn as nn +import numbers +import torch + + +class RMSNorm(nn.Module): + def __init__(self, dim, eps: float, elementwise_affine: bool = True): + super().__init__() + self.eps = eps + if isinstance(dim, numbers.Integral): + dim = (dim,) + self.dim = torch.Size(dim) + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(dim)) + else: + self.weight = None + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + val = torch.rsqrt(variance + self.eps) + hidden_states = hidden_states * val + if self.weight is not None: + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + hidden_states = hidden_states * self.weight + else: + hidden_states = hidden_states.to(input_dtype) + return hidden_states diff --git a/models/experimental/functional_stable_diffusion3_5/ttnn/ttnn_attention.py b/models/experimental/functional_stable_diffusion3_5/ttnn/ttnn_attention.py new file mode 100644 index 000000000000..7e7c205bdd7c --- /dev/null +++ b/models/experimental/functional_stable_diffusion3_5/ttnn/ttnn_attention.py @@ -0,0 +1,262 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import ttnn +from typing import Optional +from models.experimental.functional_stable_diffusion3_5.ttnn.ttnn_rms_norm import ttnn_RMSNorm +import torch.nn.functional as F + +SDPAProgramConfig = ttnn._ttnn.operations.transformer.SDPAProgramConfig + + +class ttnn_Attention: + def __init__( + self, + parameters, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + kv_heads: Optional[int] = None, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + cross_attention_norm: Optional[str] = None, + cross_attention_norm_num_groups: int = 32, + qk_norm: Optional[str] = None, + added_kv_proj_dim: Optional[int] = None, + added_proj_bias: Optional[bool] = True, + norm_num_groups: Optional[int] = None, + spatial_norm_dim: Optional[int] = None, + out_bias: bool = True, + scale_qk: bool = True, + only_cross_attention: bool = False, + eps: float = 1e-5, + rescale_output_factor: float = 1.0, + residual_connection: bool = False, + _from_deprecated_attn_block: bool = False, + processor: Optional["ttnn_JointAttnProcessor2_0"] = None, + out_dim: int = None, + context_pre_only=None, + pre_only=False, + elementwise_affine: bool = True, + ): + self.out_dim = out_dim if out_dim is not None else query_dim + self.context_pre_only = context_pre_only + self.pre_only = pre_only + self.heads = out_dim // dim_head if out_dim is not None else heads + self.added_kv_proj_dim = added_kv_proj_dim + self.added_proj_bias = added_proj_bias + if qk_norm is None: + self.norm_q = None + self.norm_k = None + else: + self.norm_q = ttnn_RMSNorm(dim=dim_head, eps=eps, elementwise_affine=True, parameters=parameters.norm_q) + self.norm_k = ttnn_RMSNorm(dim=dim_head, eps=eps, elementwise_affine=True, parameters=parameters.norm_k) + if qk_norm is not None and added_kv_proj_dim is not None: + if qk_norm == "rms_norm": + self.norm_added_q = ttnn_RMSNorm( + dim=dim_head, eps=eps, elementwise_affine=True, parameters=parameters.norm_added_q + ) + self.norm_added_k = ttnn_RMSNorm( + dim=dim_head, eps=eps, elementwise_affine=True, parameters=parameters.norm_added_k + ) + else: + self.norm_added_q = None + self.norm_added_k = None + self.to_q_weight = parameters.to_q.weight + self.to_q_bias = parameters.to_q.bias + self.to_k_weight = parameters.to_k.weight + self.to_k_bias = parameters.to_k.bias + self.to_v_weight = parameters.to_v.weight + self.to_v_bias = parameters.to_v.bias + if self.added_kv_proj_dim is not None: + self.add_k_proj_weight = parameters.add_k_proj.weight + if added_proj_bias: + self.add_k_proj_bias = parameters.add_k_proj.bias + self.add_v_proj_weight = parameters.add_v_proj.weight + if added_proj_bias: + self.add_v_proj_bias = parameters.add_v_proj.bias + if self.context_pre_only is not None: + self.add_q_proj_weight = parameters.add_q_proj.weight + if added_proj_bias: + self.add_q_proj_bias = parameters.add_q_proj.bias + if self.context_pre_only is not None and not self.context_pre_only: + self.to_add_out_weight = parameters.to_add_out.weight + self.to_add_out_bias = parameters.to_add_out.bias + if not self.pre_only: + self.to_out_weight = parameters.to_out[0].weight + self.to_out_bias = parameters.to_out[0].bias + self.processor = processor + + def __call__(self, hidden_states, encoder_hidden_states=None, attention_mask=None, device=None): + if encoder_hidden_states is not None: + hidden_states, encoder_hidden_states = self.processor( + self, hidden_states, encoder_hidden_states, attention_mask, device + ) + return hidden_states, encoder_hidden_states + else: + hidden_states = self.processor(self, hidden_states, encoder_hidden_states, attention_mask, device) + return hidden_states + + +class ttnn_JointAttnProcessor2_0: + def __init__(self): + pass + + def __call__(self, ttnn_Attention, hidden_states, encoder_hidden_states, attention_mask, device): + print(hidden_states.memory_config()) + batch_size = hidden_states.shape[0] + residual = hidden_states + query = ttnn.linear( + hidden_states, + ttnn_Attention.to_q_weight, + bias=ttnn_Attention.to_q_bias, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + print(query.memory_config()) + key = ttnn.linear( + hidden_states, + ttnn_Attention.to_k_weight, + bias=ttnn_Attention.to_k_bias, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + print(key.memory_config()) + value = ttnn.linear( + hidden_states, + ttnn_Attention.to_v_weight, + bias=ttnn_Attention.to_v_bias, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + print(value.memory_config()) + inner_dim = key.shape[-1] + head_dim = inner_dim // ttnn_Attention.heads + query = ttnn.reshape(query, (batch_size, query.shape[1], ttnn_Attention.heads, head_dim)) + print(query.memory_config()) + query = ttnn.permute(query, (0, 2, 1, 3)) + print(query.memory_config()) + key = ttnn.reshape(key, (batch_size, key.shape[1], ttnn_Attention.heads, head_dim)) + print(key.memory_config()) + key = ttnn.permute(key, (0, 2, 1, 3)) + value = ttnn.reshape(value, (batch_size, value.shape[1], ttnn_Attention.heads, head_dim)) + value = ttnn.permute(value, (0, 2, 1, 3)) + print(value.memory_config()) + if ttnn_Attention.norm_q is not None: + query = ttnn_Attention.norm_q(query, device) + print(query.memory_config()) + if ttnn_Attention.norm_k is not None: + key = ttnn_Attention.norm_k(key, device) + print(key.memory_config()) + if encoder_hidden_states is not None: + encoder_hidden_states_query_proj = ttnn.linear( + encoder_hidden_states, + ttnn_Attention.add_q_proj_weight, + bias=ttnn_Attention.add_q_proj_bias, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + print(encoder_hidden_states_query_proj.memory_config()) + encoder_hidden_states_key_proj = ttnn.linear( + encoder_hidden_states, + ttnn_Attention.add_k_proj_weight, + bias=ttnn_Attention.add_k_proj_bias, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + print(encoder_hidden_states_key_proj.memory_config()) + encoder_hidden_states_value_proj = ttnn.linear( + encoder_hidden_states, + ttnn_Attention.add_v_proj_weight, + bias=ttnn_Attention.add_v_proj_bias, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + print(encoder_hidden_states_value_proj.memory_config()) + encoder_hidden_states_query_proj = ttnn.reshape( + encoder_hidden_states_query_proj, + (batch_size, encoder_hidden_states_query_proj.shape[1], ttnn_Attention.heads, head_dim), + ) + print(encoder_hidden_states_query_proj.memory_config()) + encoder_hidden_states_query_proj = ttnn.permute(encoder_hidden_states_query_proj, (0, 2, 1, 3)) + encoder_hidden_states_key_proj = ttnn.reshape( + encoder_hidden_states_key_proj, + (batch_size, encoder_hidden_states_key_proj.shape[1], ttnn_Attention.heads, head_dim), + ) + print(encoder_hidden_states_key_proj.memory_config()) + encoder_hidden_states_key_proj = ttnn.permute(encoder_hidden_states_key_proj, (0, 2, 1, 3)) + encoder_hidden_states_value_proj = ttnn.reshape( + encoder_hidden_states_value_proj, + (batch_size, encoder_hidden_states_value_proj.shape[1], ttnn_Attention.heads, head_dim), + ) + print(encoder_hidden_states_key_proj.memory_config()) + encoder_hidden_states_value_proj = ttnn.permute(encoder_hidden_states_value_proj, (0, 2, 1, 3)) + if ttnn_Attention.norm_added_q is not None: + encoder_hidden_states_query_proj = ttnn_Attention.norm_added_q( + encoder_hidden_states_query_proj, device=device + ) + print(encoder_hidden_states_query_proj.memory_config()) + if ttnn_Attention.norm_added_k is not None: + encoder_hidden_states_key_proj = ttnn_Attention.norm_added_k( + encoder_hidden_states_key_proj, device=device + ) + print(encoder_hidden_states_key_proj.memory_config()) + query = ttnn.concat([query, encoder_hidden_states_query_proj], dim=2, memory_config=ttnn.DRAM_MEMORY_CONFIG) + key = ttnn.concat([key, encoder_hidden_states_key_proj], dim=2, memory_config=ttnn.DRAM_MEMORY_CONFIG) + value = ttnn.concat([value, encoder_hidden_states_value_proj], dim=2, memory_config=ttnn.DRAM_MEMORY_CONFIG) + else: + query = ttnn.to_memory_config(query, memory_config=ttnn.DRAM_MEMORY_CONFIG) + key = ttnn.to_memory_config(key, memory_config=ttnn.DRAM_MEMORY_CONFIG) + value = ttnn.to_memory_config(value, memory_config=ttnn.DRAM_MEMORY_CONFIG) + if encoder_hidden_states is None: + q_size = 32 + else: + q_size = 1184 + program_config = ttnn.SDPAProgramConfig( + compute_with_storage_grid_size=device.compute_with_storage_grid_size(), + q_chunk_size=q_size, + k_chunk_size=q_size, + exp_approx_mode=False, + ) + compute_kernel_config = ttnn.WormholeComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.HiFi4, + math_approx_mode=False, + fp32_dest_acc_en=False, + packer_l1_acc=False, + ) + hidden_states = ttnn.transformer.scaled_dot_product_attention( + query, + key, + value, + is_causal=False, + compute_kernel_config=compute_kernel_config, + program_config=program_config, + ) + hidden_states = ttnn.to_memory_config(hidden_states, memory_config=ttnn.L1_MEMORY_CONFIG) + print(hidden_states.memory_config()) + hidden_states = ttnn.permute(hidden_states, (0, 2, 1, 3)) + hidden_states = ttnn.reshape(hidden_states, (batch_size, -1, ttnn_Attention.heads * head_dim)) + if encoder_hidden_states is not None: + hidden_states, encoder_hidden_states = ( + hidden_states[:, : residual.shape[1], :], + hidden_states[:, residual.shape[1] :, :], + ) + print(hidden_states.memory_config(), encoder_hidden_states.memory_config()) + if not ttnn_Attention.context_pre_only: + encoder_hidden_states = ttnn.linear( + encoder_hidden_states, + ttnn_Attention.to_add_out_weight, + bias=ttnn_Attention.to_add_out_bias, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + print(encoder_hidden_states.memory_config()) + hidden_states = ttnn.linear( + hidden_states, + ttnn_Attention.to_out_weight, + bias=ttnn_Attention.to_out_bias, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + print(hidden_states.memory_config()) + if encoder_hidden_states is not None: + return hidden_states, encoder_hidden_states + else: + return hidden_states diff --git a/models/experimental/functional_stable_diffusion3_5/ttnn/ttnn_rms_norm.py b/models/experimental/functional_stable_diffusion3_5/ttnn/ttnn_rms_norm.py new file mode 100644 index 000000000000..b9085ff8fb86 --- /dev/null +++ b/models/experimental/functional_stable_diffusion3_5/ttnn/ttnn_rms_norm.py @@ -0,0 +1,33 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + + +import ttnn +import torch + + +class ttnn_RMSNorm: + def __init__(self, dim, eps, elementwise_affine, parameters): + self.eps = eps + if elementwise_affine: + self.weight = parameters.weight + else: + self.weight = None + + def __call__(self, hidden_states, device): + print("hidden_states config before pow", hidden_states.memory_config()) + variance = ttnn.pow(hidden_states, 2) + print("hidden_states config after pow", hidden_states.memory_config()) + variance = ttnn.mean(variance, dim=-1) + print("variance config after mean", variance.memory_config()) + variance = ttnn.add(variance, self.eps) + print("variance config after add", variance.memory_config()) + variance = ttnn.rsqrt(variance) + print("variance config after rsqrt", variance.memory_config()) + hidden_states = ttnn.multiply(hidden_states, variance) + print("hidden_states config after mul", hidden_states.memory_config()) + if self.weight is not None: + hidden_states = ttnn.multiply(hidden_states, self.weight) + print("hidden_states config after mean", hidden_states.memory_config()) + return hidden_states diff --git a/tests/ttnn/integration_tests/stable_diffusion3_5/test_ttnn_attention.py b/tests/ttnn/integration_tests/stable_diffusion3_5/test_ttnn_attention.py new file mode 100644 index 000000000000..e977c4232a45 --- /dev/null +++ b/tests/ttnn/integration_tests/stable_diffusion3_5/test_ttnn_attention.py @@ -0,0 +1,274 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +import torch.nn as nn +import ttnn +from models.experimental.functional_stable_diffusion3_5.ttnn.ttnn_attention import ( + ttnn_Attention as tt_module, + ttnn_JointAttnProcessor2_0, +) +from models.experimental.functional_stable_diffusion3_5.reference.attention import Attention, JointAttnProcessor2_0 +from models.experimental.functional_stable_diffusion3_5.reference.rms_norm import RMSNorm +from tests.ttnn.utils_for_testing import assert_with_pcc +from ttnn.model_preprocessing import ( + preprocess_model_parameters, + preprocess_linear_weight, + preprocess_linear_bias, +) +from models.utility_functions import skip_for_grayskull + + +def create_custom_preprocessor(device): + def custom_preprocessor(model, name): + parameters = {} + if isinstance(model, Attention): + parameters["norm_q"] = {} + parameters["norm_q"]["weight"] = ttnn.from_torch( + model.norm_q.weight.unsqueeze(0).unsqueeze(0).unsqueeze(0), + device=device, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + ) + parameters["norm_k"] = {} + parameters["norm_k"]["weight"] = ttnn.from_torch( + model.norm_k.weight.unsqueeze(0).unsqueeze(0).unsqueeze(0), + device=device, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + ) + parameters["to_q"] = {} + parameters["to_q"]["weight"] = preprocess_linear_weight(model.to_q.weight, dtype=ttnn.bfloat16) + parameters["to_q"]["bias"] = preprocess_linear_bias(model.to_q.bias, dtype=ttnn.bfloat16) + parameters["to_k"] = {} + parameters["to_k"]["weight"] = preprocess_linear_weight(model.to_k.weight, dtype=ttnn.bfloat16) + parameters["to_k"]["bias"] = preprocess_linear_bias(model.to_k.bias, dtype=ttnn.bfloat16) + parameters["to_v"] = {} + parameters["to_v"]["weight"] = preprocess_linear_weight(model.to_v.weight, dtype=ttnn.bfloat16) + parameters["to_v"]["bias"] = preprocess_linear_bias(model.to_v.bias, dtype=ttnn.bfloat16) + if hasattr(model, "add_k_proj"): + parameters["add_k_proj"] = {} + parameters["add_k_proj"]["weight"] = preprocess_linear_weight( + model.add_k_proj.weight, dtype=ttnn.bfloat16 + ) + parameters["add_k_proj"]["bias"] = preprocess_linear_bias(model.add_k_proj.bias, dtype=ttnn.bfloat16) + if hasattr(model, "add_v_proj"): + parameters["add_v_proj"] = {} + parameters["add_v_proj"]["weight"] = preprocess_linear_weight( + model.add_v_proj.weight, dtype=ttnn.bfloat16 + ) + parameters["add_v_proj"]["bias"] = preprocess_linear_bias(model.add_v_proj.bias, dtype=ttnn.bfloat16) + if hasattr(model, "add_q_proj"): + parameters["add_q_proj"] = {} + parameters["add_q_proj"]["weight"] = preprocess_linear_weight( + model.add_q_proj.weight, dtype=ttnn.bfloat16 + ) + parameters["add_q_proj"]["bias"] = preprocess_linear_bias(model.add_q_proj.bias, dtype=ttnn.bfloat16) + parameters["to_out"] = {} + parameters["to_out"][0] = {} + parameters["to_out"][0]["weight"] = preprocess_linear_weight(model.to_out[0].weight, dtype=ttnn.bfloat16) + parameters["to_out"][0]["bias"] = preprocess_linear_bias(model.to_out[0].bias, dtype=ttnn.bfloat16) + if hasattr(model, "to_add_out"): + parameters["to_add_out"] = {} + parameters["to_add_out"]["weight"] = preprocess_linear_weight( + model.to_add_out.weight, dtype=ttnn.bfloat16 + ) + parameters["to_add_out"]["bias"] = preprocess_linear_bias(model.to_add_out.bias, dtype=ttnn.bfloat16) + if model.norm_added_q != None: + parameters["norm_added_q"] = {} + parameters["norm_added_q"]["weight"] = ttnn.from_torch( + model.norm_added_q.weight.unsqueeze(0).unsqueeze(0).unsqueeze(0), + device=device, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + ) + if model.norm_added_k != None: + parameters["norm_added_k"] = {} + parameters["norm_added_k"]["weight"] = ttnn.from_torch( + model.norm_added_k.weight.unsqueeze(0).unsqueeze(0).unsqueeze(0), + device=device, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + ) + return parameters + + return custom_preprocessor + + +@pytest.mark.parametrize( + "attn_inputs, hidden_states, attention_mask, encoder_hidden_states", + [ + # ( + # { #512x512 + # "query_dim": 1536, + # "cross_attention_dim": None, + # "heads": 24, + # "kv_heads": None, + # "dim_head": 64, + # "dropout": 0.0, + # "bias": True, + # "upcast_attention": False, + # "upcast_softmax": False, + # "cross_attention_norm": None, + # "cross_attention_norm_num_groups": 32, + # "qk_norm": "rms_norm", + # "added_kv_proj_dim": 1536, + # "added_proj_bias": True, + # "norm_num_groups": None, + # "spatial_norm_dim": None, + # "out_bias": True, + # "scale_qk": True, + # "only_cross_attention": False, + # "eps": 1e-06, + # "rescale_output_factor": 1.0, + # "residual_connection": False, + # "_from_deprecated_attn_block": False, + # "out_dim": 1536, + # "context_pre_only": False, + # "pre_only": False, + # "elementwise_affine": True, + # }, + # torch.randn([2, 1024, 1536], dtype=torch.bfloat16), + # None, + # torch.randn([2, 154, 1536], dtype=torch.bfloat16), + # ), + # ( + # { #512x512 + # "query_dim": 1536, + # "cross_attention_dim": None, + # "heads": 24, + # "kv_heads": None, + # "dim_head": 64, + # "dropout": 0.0, + # "bias": True, + # "upcast_attention": False, + # "upcast_softmax": False, + # "cross_attention_norm": None, + # "cross_attention_norm_num_groups": 32, + # "qk_norm": "rms_norm", + # "added_kv_proj_dim": None, + # "added_proj_bias": True, + # "norm_num_groups": None, + # "spatial_norm_dim": None, + # "out_bias": True, + # "scale_qk": True, + # "only_cross_attention": False, + # "eps": 1e-06, + # "rescale_output_factor": 1.0, + # "residual_connection": False, + # "_from_deprecated_attn_block": False, + # "out_dim": 1536, + # "context_pre_only": None, + # "pre_only": False, + # "elementwise_affine": True, + # }, + # torch.randn([2, 1024, 1536], dtype=torch.bfloat16), + # None, + # None, + # ), + ( + { # 1024x1024 + "query_dim": 1536, + "cross_attention_dim": None, + "heads": 24, + "kv_heads": None, + "dim_head": 64, + "dropout": 0.0, + "bias": True, + "upcast_attention": False, + "upcast_softmax": False, + "cross_attention_norm": None, + "cross_attention_norm_num_groups": 32, + "qk_norm": "rms_norm", + "added_kv_proj_dim": 1536, + "added_proj_bias": True, + "norm_num_groups": None, + "spatial_norm_dim": None, + "out_bias": True, + "scale_qk": True, + "only_cross_attention": False, + "eps": 1e-06, + "rescale_output_factor": 1.0, + "residual_connection": False, + "_from_deprecated_attn_block": False, + "out_dim": 1536, + "context_pre_only": False, + "pre_only": False, + "elementwise_affine": True, + }, + torch.randn([2, 4096, 1536], dtype=torch.bfloat16), + None, + torch.randn([2, 333, 1536], dtype=torch.bfloat16), + ), + ( + { # 1024x1024 + "query_dim": 1536, + "cross_attention_dim": None, + "heads": 24, + "kv_heads": None, + "dim_head": 64, + "dropout": 0.0, + "bias": True, + "upcast_attention": False, + "upcast_softmax": False, + "cross_attention_norm": None, + "cross_attention_norm_num_groups": 32, + "qk_norm": "rms_norm", + "added_kv_proj_dim": None, + "added_proj_bias": True, + "norm_num_groups": None, + "spatial_norm_dim": None, + "out_bias": True, + "scale_qk": True, + "only_cross_attention": False, + "eps": 1e-06, + "rescale_output_factor": 1.0, + "residual_connection": False, + "_from_deprecated_attn_block": False, + "out_dim": 1536, + "context_pre_only": None, + "pre_only": False, + "elementwise_affine": True, + }, + torch.randn([2, 4096, 1536], dtype=torch.bfloat16), + None, + None, + ), + ], +) +@skip_for_grayskull() +@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True) +def test_ttnn_attention(attn_inputs, device, hidden_states, attention_mask, encoder_hidden_states, reset_seeds): + torch_sub_module = Attention(**attn_inputs, processor=JointAttnProcessor2_0()).to(dtype=torch.bfloat16) + parameters = preprocess_model_parameters( + initialize_model=lambda: torch_sub_module, device=device, custom_preprocessor=create_custom_preprocessor(device) + ) + tt_input_hidden_states = ttnn.from_torch( + hidden_states, dtype=ttnn.bfloat16, device=device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.L1_MEMORY_CONFIG + ) + if encoder_hidden_states is not None: + tt_input_encoder_hidden_states = ttnn.from_torch( + encoder_hidden_states, + dtype=ttnn.bfloat16, + device=device, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + else: + tt_input_encoder_hidden_states = None + tt_sub_module = tt_module(**attn_inputs, processor=ttnn_JointAttnProcessor2_0(), parameters=parameters) + if encoder_hidden_states is not None: + torch_out_1, torch_out_2 = torch_sub_module(hidden_states, encoder_hidden_states) + tt_out_1, tt_out_2 = tt_sub_module( + tt_input_hidden_states, tt_input_encoder_hidden_states, attention_mask, device + ) + else: + torch_out_1 = torch_sub_module(hidden_states, encoder_hidden_states) + tt_out_1 = tt_sub_module(tt_input_hidden_states, tt_input_encoder_hidden_states, attention_mask, device) + tt_out_in_torch_1 = ttnn.to_torch(tt_out_1) + if encoder_hidden_states is not None: + tt_out_in_torch_2 = ttnn.to_torch(tt_out_2) + assert_with_pcc(tt_out_in_torch_2, torch_out_2, 0.99) + assert_with_pcc(tt_out_in_torch_1, torch_out_1, 0.99) diff --git a/tests/ttnn/integration_tests/stable_diffusion3_5/test_ttnn_rms_norm.py b/tests/ttnn/integration_tests/stable_diffusion3_5/test_ttnn_rms_norm.py new file mode 100644 index 000000000000..7624765c0f5f --- /dev/null +++ b/tests/ttnn/integration_tests/stable_diffusion3_5/test_ttnn_rms_norm.py @@ -0,0 +1,60 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + + +import pytest +import torch +import torch.nn as nn +import ttnn +from models.experimental.functional_stable_diffusion3_5.ttnn.ttnn_rms_norm import ( + ttnn_RMSNorm as tt_module, +) +from models.experimental.functional_stable_diffusion3_5.reference.rms_norm import RMSNorm +from tests.ttnn.utils_for_testing import assert_with_pcc +from ttnn.model_preprocessing import ( + preprocess_model_parameters, +) +from models.utility_functions import skip_for_grayskull + + +def create_custom_preprocessor(device): + def custom_preprocessor(model, name): + parameters = {} + if isinstance(model, RMSNorm): + parameters["rms_norm"] = {} + parameters["rms_norm"]["weight"] = ttnn.from_torch( + model.weight.unsqueeze(0).unsqueeze(0).unsqueeze(0), + device=device, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + ) + return parameters + + return custom_preprocessor + + +@pytest.mark.parametrize( + "init_inputs,fwd_inputs", + [ + ((64, 1e-06, True), (2, 24, 333, 64)), + ], +) +@skip_for_grayskull() +@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True) +def test_ttnn_rms_norm(init_inputs, fwd_inputs, device, reset_seeds): + torch_sub_module = RMSNorm(dim=init_inputs[0], eps=init_inputs[1], elementwise_affine=init_inputs[2]) + parameters = preprocess_model_parameters( + initialize_model=lambda: torch_sub_module, device=device, custom_preprocessor=create_custom_preprocessor(device) + ) + hidden_states = torch.randn(fwd_inputs, dtype=torch.bfloat16) + tt_input_hidden_states = ttnn.from_torch( + hidden_states, dtype=ttnn.bfloat16, device=device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.L1_MEMORY_CONFIG + ) + tt_sub_module = tt_module( + dim=init_inputs[0], eps=init_inputs[1], elementwise_affine=init_inputs[2], parameters=parameters.rms_norm + ) + tt_out = tt_sub_module(hidden_states=tt_input_hidden_states, device=device) + torch_out = torch_sub_module(hidden_states) + tt_out_in_torch = ttnn.to_torch(tt_out) + assert_with_pcc(torch_out, tt_out_in_torch, 0.99)