From bca1c3383f5d2ea3009d4ee297ccc26db146cf20 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 19 Jan 2024 23:39:10 +0000 Subject: [PATCH] updt to include low precision groupnorm; --- llmfoundry/models/layers/attention.py | 67 +++++++++++++++++++++++---- llmfoundry/models/layers/norm.py | 23 +++++++++ 2 files changed, 80 insertions(+), 10 deletions(-) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 89f861c3f0..f6eb329e0e 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -12,10 +12,10 @@ import transformers from einops import rearrange from packaging import version -from torch import nn from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY -from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY +from llmfoundry.models.layers.norm import (NORM_CLASS_REGISTRY, LPLayerNorm, + low_precision_groupnorm) def is_flash_v2_installed(v2_version: str = '2.0.0'): @@ -498,6 +498,47 @@ def triton_flash_attn_fn( return output, None, past_key_value +def _expand_params(heads: int, param: Optional[torch.Tensor] = None): + if param is None: + return None + return param.repeat(heads) + + +def _apply_qk_gn( + query: torch.Tensor, + key: torch.Tensor, + n_heads: int, + kv_n_heads: int, + q_ln: nn.Module, + k_ln: nn.Module, +): + dtype = query.dtype + + w = _expand_params(n_heads, q_ln.weight) + b = _expand_params(n_heads, q_ln.bias) + if isinstance(q_ln, LPLayerNorm): + query = low_precision_groupnorm(query, n_heads, w, b, + eps=q_ln.eps).to(dtype) + elif isinstance(q_ln, nn.LayerNorm): + query = nn.functional.group_norm(query, n_heads, w, b, eps=q_ln.eps) + else: + raise ValueError( + f'qk_gn not applicable for given q_ln type ({type(q_ln)=}).') + + w = _expand_params(kv_n_heads, k_ln.weight) + b = _expand_params(kv_n_heads, k_ln.bias) + if isinstance(k_ln, LPLayerNorm): + key = low_precision_groupnorm(key, kv_n_heads, w, b, + eps=k_ln.eps).to(dtype) + elif isinstance(k_ln, nn.LayerNorm): + key = nn.functional.group_norm(key, kv_n_heads, w, b, eps=k_ln.eps) + else: + raise ValueError( + f'qk_gn not applicable for given k_ln type ({type(k_ln)=}).') + + return query, key + + class GroupedQueryAttention(nn.Module): """Grouped Query Attention (GQA) is a generalization of Multi-head (MHA). @@ -629,16 +670,22 @@ def forward( key_padding_mask = attention_mask - if self.qk_ln or self.qk_gn: + if self.qk_gn: + # Applying groupnorm to qk + query, key = _apply_qk_gn( + query, + key, + self.n_heads, + self.kv_n_heads, + self.q_ln, + self.k_ln, + ) + + if self.qk_ln: # Applying layernorm to qk - q_shape, k_shape = query.shape, key.shape - if self.qk_gn: - b, s = query.shape[:2] - query = query.view(b, s, self.n_heads, -1) - key = key.view(b, s, self.kv_n_heads, -1) dtype = query.dtype - query = self.q_ln(query).to(dtype).view(q_shape) - key = self.k_ln(key).to(dtype).view(k_shape) + query = self.q_ln(query).to(dtype) + key = self.k_ln(key).to(dtype) if rotary_emb_w_meta_info is not None: rotary_emb = rotary_emb_w_meta_info['rotary_emb'] diff --git a/llmfoundry/models/layers/norm.py b/llmfoundry/models/layers/norm.py index 2ff4eaed0c..313c547b67 100644 --- a/llmfoundry/models/layers/norm.py +++ b/llmfoundry/models/layers/norm.py @@ -53,6 +53,29 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) +def low_precision_groupnorm( + x: torch.Tensor, + groups: int, + weight: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + eps: float = 1e-05, +): + device = x.device + downcast_x = _cast_if_autocast_enabled(x) + downcast_weight = _cast_if_autocast_enabled( + weight) if weight is not None else weight + downcast_bias = _cast_if_autocast_enabled( + bias) if bias is not None else bias + with torch.autocast(enabled=False, device_type=device.type): + return torch.nn.functional.group_norm( + downcast_x, + groups, + downcast_weight, + downcast_bias, + eps, + ) + + def rms_norm(x: torch.Tensor, weight: Optional[torch.Tensor] = None, eps: float = 1e-5) -> torch.Tensor: