-
Notifications
You must be signed in to change notification settings - Fork 87
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
#15969: ttnn implementation of sd3_5 attention sub_module
- Loading branch information
vguduruTT
committed
Dec 24, 2024
1 parent
88b981d
commit e3aea7e
Showing
6 changed files
with
868 additions
and
0 deletions.
There are no files selected for viewing
202 changes: 202 additions & 0 deletions
202
models/experimental/functional_stable_diffusion3_5/reference/attention.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,202 @@ | ||
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import inspect | ||
from typing import Callable, List, Optional | ||
import torch | ||
import torch.nn.functional as F | ||
from torch import nn | ||
|
||
|
||
class Attention(nn.Module): | ||
def __init__( | ||
self, | ||
query_dim: int, | ||
cross_attention_dim: Optional[int] = None, | ||
heads: int = 8, | ||
kv_heads: Optional[int] = None, | ||
dim_head: int = 64, | ||
dropout: float = 0.0, | ||
bias: bool = False, | ||
upcast_attention: bool = False, | ||
upcast_softmax: bool = False, | ||
cross_attention_norm: Optional[str] = None, | ||
cross_attention_norm_num_groups: int = 32, | ||
qk_norm: Optional[str] = None, | ||
added_kv_proj_dim: Optional[int] = None, | ||
added_proj_bias: Optional[bool] = True, | ||
norm_num_groups: Optional[int] = None, | ||
spatial_norm_dim: Optional[int] = None, | ||
out_bias: bool = True, | ||
scale_qk: bool = True, | ||
only_cross_attention: bool = False, | ||
eps: float = 1e-5, | ||
rescale_output_factor: float = 1.0, | ||
residual_connection: bool = False, | ||
_from_deprecated_attn_block: bool = False, | ||
processor: Optional["AttnProcessor"] = None, | ||
out_dim: int = None, | ||
context_pre_only=None, | ||
pre_only=False, | ||
elementwise_affine: bool = True, | ||
): | ||
super().__init__() | ||
from diffusers.models.normalization import RMSNorm | ||
|
||
self.inner_dim = out_dim if out_dim is not None else dim_head * heads | ||
self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads | ||
self.query_dim = query_dim | ||
self.use_bias = bias | ||
self.is_cross_attention = cross_attention_dim is not None | ||
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim | ||
self.upcast_attention = upcast_attention | ||
self.upcast_softmax = upcast_softmax | ||
self.rescale_output_factor = rescale_output_factor | ||
self.residual_connection = residual_connection | ||
self.dropout = dropout | ||
self.fused_projections = False | ||
self.out_dim = out_dim if out_dim is not None else query_dim | ||
self.context_pre_only = context_pre_only | ||
self.pre_only = pre_only | ||
self._from_deprecated_attn_block = _from_deprecated_attn_block | ||
self.scale_qk = scale_qk | ||
self.scale = dim_head**-0.5 if self.scale_qk else 1.0 | ||
self.heads = out_dim // dim_head if out_dim is not None else heads | ||
self.sliceable_head_dim = heads | ||
self.added_kv_proj_dim = added_kv_proj_dim | ||
self.only_cross_attention = only_cross_attention | ||
if qk_norm is None: | ||
self.norm_q = None | ||
self.norm_k = None | ||
elif qk_norm == "rms_norm": | ||
self.norm_q = RMSNorm(dim_head, eps=eps) | ||
self.norm_k = RMSNorm(dim_head, eps=eps) | ||
if cross_attention_norm is None: | ||
self.norm_cross = None | ||
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) | ||
if not self.only_cross_attention: | ||
self.to_k = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) | ||
self.to_v = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) | ||
else: | ||
self.to_k = None | ||
self.to_v = None | ||
self.added_proj_bias = added_proj_bias | ||
if self.added_kv_proj_dim is not None: | ||
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) | ||
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) | ||
if self.context_pre_only is not None: | ||
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) | ||
if not self.pre_only: | ||
self.to_out = nn.ModuleList([]) | ||
self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) | ||
if self.context_pre_only is not None and not self.context_pre_only: | ||
self.to_add_out = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias) | ||
if qk_norm is not None and added_kv_proj_dim is not None: | ||
if qk_norm == "rms_norm": | ||
self.norm_added_q = RMSNorm(dim_head, eps=eps) | ||
self.norm_added_k = RMSNorm(dim_head, eps=eps) | ||
else: | ||
self.norm_added_q = None | ||
self.norm_added_k = None | ||
if processor is None: | ||
processor = ( | ||
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() | ||
) | ||
self.set_processor(processor) | ||
|
||
def set_processor(self, processor: "AttnProcessor") -> None: | ||
if ( | ||
hasattr(self, "processor") | ||
and isinstance(self.processor, torch.nn.Module) | ||
and not isinstance(processor, torch.nn.Module) | ||
): | ||
logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") | ||
self._modules.pop("processor") | ||
self.processor = processor | ||
|
||
def __call__( | ||
self, | ||
hidden_states: torch.Tensor, | ||
encoder_hidden_states: Optional[torch.Tensor] = None, | ||
attention_mask: Optional[torch.Tensor] = None, | ||
**cross_attention_kwargs, | ||
) -> torch.Tensor: | ||
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) | ||
quiet_attn_parameters = {"ip_adapter_masks"} | ||
unused_kwargs = [ | ||
k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters | ||
] | ||
cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} | ||
return self.processor( | ||
self, | ||
hidden_states, | ||
encoder_hidden_states=encoder_hidden_states, | ||
attention_mask=attention_mask, | ||
**cross_attention_kwargs, | ||
) | ||
|
||
|
||
class JointAttnProcessor2_0: | ||
def __init__(self): | ||
if not hasattr(F, "scaled_dot_product_attention"): | ||
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") | ||
|
||
def __call__( | ||
self, | ||
attn: Attention, | ||
hidden_states: torch.FloatTensor, | ||
encoder_hidden_states: torch.FloatTensor = None, | ||
attention_mask: Optional[torch.FloatTensor] = None, | ||
*args, | ||
**kwargs, | ||
) -> torch.FloatTensor: | ||
residual = hidden_states | ||
batch_size = hidden_states.shape[0] | ||
query = attn.to_q(hidden_states) | ||
key = attn.to_k(hidden_states) | ||
value = attn.to_v(hidden_states) | ||
inner_dim = key.shape[-1] | ||
head_dim = inner_dim // attn.heads | ||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | ||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | ||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | ||
if attn.norm_q is not None: | ||
query = attn.norm_q(query) | ||
if attn.norm_k is not None: | ||
key = attn.norm_k(key) | ||
if encoder_hidden_states is not None: | ||
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) | ||
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) | ||
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) | ||
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( | ||
batch_size, -1, attn.heads, head_dim | ||
).transpose(1, 2) | ||
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( | ||
batch_size, -1, attn.heads, head_dim | ||
).transpose(1, 2) | ||
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( | ||
batch_size, -1, attn.heads, head_dim | ||
).transpose(1, 2) | ||
if attn.norm_added_q is not None: | ||
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) | ||
if attn.norm_added_k is not None: | ||
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) | ||
query = torch.cat([query, encoder_hidden_states_query_proj], dim=2) | ||
key = torch.cat([key, encoder_hidden_states_key_proj], dim=2) | ||
value = torch.cat([value, encoder_hidden_states_value_proj], dim=2) | ||
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) | ||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) | ||
hidden_states = hidden_states.to(query.dtype) | ||
if encoder_hidden_states is not None: | ||
hidden_states, encoder_hidden_states = ( | ||
hidden_states[:, : residual.shape[1]], | ||
hidden_states[:, residual.shape[1] :], | ||
) | ||
if not attn.context_pre_only: | ||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states) | ||
hidden_states = attn.to_out[0](hidden_states) | ||
if encoder_hidden_states is not None: | ||
return hidden_states, encoder_hidden_states | ||
else: | ||
return hidden_states |
33 changes: 33 additions & 0 deletions
33
models/experimental/functional_stable_diffusion3_5/reference/rms_norm.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import torch.nn as nn | ||
import numbers | ||
import torch | ||
|
||
|
||
class RMSNorm(nn.Module): | ||
def __init__(self, dim, eps: float, elementwise_affine: bool = True): | ||
super().__init__() | ||
self.eps = eps | ||
if isinstance(dim, numbers.Integral): | ||
dim = (dim,) | ||
self.dim = torch.Size(dim) | ||
if elementwise_affine: | ||
self.weight = nn.Parameter(torch.ones(dim)) | ||
else: | ||
self.weight = None | ||
|
||
def forward(self, hidden_states): | ||
input_dtype = hidden_states.dtype | ||
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) | ||
val = torch.rsqrt(variance + self.eps) | ||
hidden_states = hidden_states * val | ||
if self.weight is not None: | ||
if self.weight.dtype in [torch.float16, torch.bfloat16]: | ||
hidden_states = hidden_states.to(self.weight.dtype) | ||
hidden_states = hidden_states * self.weight | ||
else: | ||
hidden_states = hidden_states.to(input_dtype) | ||
return hidden_states |
Oops, something went wrong.