diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index f6eb329e0e..89f861c3f0 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -12,10 +12,10 @@ import transformers from einops import rearrange from packaging import version +from torch import nn from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY -from llmfoundry.models.layers.norm import (NORM_CLASS_REGISTRY, LPLayerNorm, - low_precision_groupnorm) +from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY def is_flash_v2_installed(v2_version: str = '2.0.0'): @@ -498,47 +498,6 @@ def triton_flash_attn_fn( return output, None, past_key_value -def _expand_params(heads: int, param: Optional[torch.Tensor] = None): - if param is None: - return None - return param.repeat(heads) - - -def _apply_qk_gn( - query: torch.Tensor, - key: torch.Tensor, - n_heads: int, - kv_n_heads: int, - q_ln: nn.Module, - k_ln: nn.Module, -): - dtype = query.dtype - - w = _expand_params(n_heads, q_ln.weight) - b = _expand_params(n_heads, q_ln.bias) - if isinstance(q_ln, LPLayerNorm): - query = low_precision_groupnorm(query, n_heads, w, b, - eps=q_ln.eps).to(dtype) - elif isinstance(q_ln, nn.LayerNorm): - query = nn.functional.group_norm(query, n_heads, w, b, eps=q_ln.eps) - else: - raise ValueError( - f'qk_gn not applicable for given q_ln type ({type(q_ln)=}).') - - w = _expand_params(kv_n_heads, k_ln.weight) - b = _expand_params(kv_n_heads, k_ln.bias) - if isinstance(k_ln, LPLayerNorm): - key = low_precision_groupnorm(key, kv_n_heads, w, b, - eps=k_ln.eps).to(dtype) - elif isinstance(k_ln, nn.LayerNorm): - key = nn.functional.group_norm(key, kv_n_heads, w, b, eps=k_ln.eps) - else: - raise ValueError( - f'qk_gn not applicable for given k_ln type ({type(k_ln)=}).') - - return query, key - - class GroupedQueryAttention(nn.Module): """Grouped Query Attention (GQA) is a generalization of Multi-head (MHA). @@ -670,22 +629,16 @@ def forward( key_padding_mask = attention_mask - if self.qk_gn: - # Applying groupnorm to qk - query, key = _apply_qk_gn( - query, - key, - self.n_heads, - self.kv_n_heads, - self.q_ln, - self.k_ln, - ) - - if self.qk_ln: + if self.qk_ln or self.qk_gn: # Applying layernorm to qk + q_shape, k_shape = query.shape, key.shape + if self.qk_gn: + b, s = query.shape[:2] + query = query.view(b, s, self.n_heads, -1) + key = key.view(b, s, self.kv_n_heads, -1) dtype = query.dtype - query = self.q_ln(query).to(dtype) - key = self.k_ln(key).to(dtype) + query = self.q_ln(query).to(dtype).view(q_shape) + key = self.k_ln(key).to(dtype).view(k_shape) if rotary_emb_w_meta_info is not None: rotary_emb = rotary_emb_w_meta_info['rotary_emb'] diff --git a/llmfoundry/models/layers/norm.py b/llmfoundry/models/layers/norm.py index 313c547b67..2ff4eaed0c 100644 --- a/llmfoundry/models/layers/norm.py +++ b/llmfoundry/models/layers/norm.py @@ -53,29 +53,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) -def low_precision_groupnorm( - x: torch.Tensor, - groups: int, - weight: Optional[torch.Tensor] = None, - bias: Optional[torch.Tensor] = None, - eps: float = 1e-05, -): - device = x.device - downcast_x = _cast_if_autocast_enabled(x) - downcast_weight = _cast_if_autocast_enabled( - weight) if weight is not None else weight - downcast_bias = _cast_if_autocast_enabled( - bias) if bias is not None else bias - with torch.autocast(enabled=False, device_type=device.type): - return torch.nn.functional.group_norm( - downcast_x, - groups, - downcast_weight, - downcast_bias, - eps, - ) - - def rms_norm(x: torch.Tensor, weight: Optional[torch.Tensor] = None, eps: float = 1e-5) -> torch.Tensor: