diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index eef20e5b32..f6eb329e0e 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -514,29 +514,23 @@ def _apply_qk_gn( ): dtype = query.dtype + w = _expand_params(n_heads, q_ln.weight) + b = _expand_params(n_heads, q_ln.bias) if isinstance(q_ln, LPLayerNorm): - query = low_precision_groupnorm(query, - n_heads, - q_ln.weight, - q_ln.bias, + query = low_precision_groupnorm(query, n_heads, w, b, eps=q_ln.eps).to(dtype) elif isinstance(q_ln, nn.LayerNorm): - w = _expand_params(n_heads, q_ln.weight) - b = _expand_params(n_heads, q_ln.bias) query = nn.functional.group_norm(query, n_heads, w, b, eps=q_ln.eps) else: raise ValueError( f'qk_gn not applicable for given q_ln type ({type(q_ln)=}).') + w = _expand_params(kv_n_heads, k_ln.weight) + b = _expand_params(kv_n_heads, k_ln.bias) if isinstance(k_ln, LPLayerNorm): - key = low_precision_groupnorm(key, - kv_n_heads, - k_ln.weight, - k_ln.bias, + key = low_precision_groupnorm(key, kv_n_heads, w, b, eps=k_ln.eps).to(dtype) elif isinstance(k_ln, nn.LayerNorm): - w = _expand_params(kv_n_heads, k_ln.weight) - b = _expand_params(kv_n_heads, k_ln.bias) key = nn.functional.group_norm(key, kv_n_heads, w, b, eps=k_ln.eps) else: raise ValueError( diff --git a/llmfoundry/models/layers/norm.py b/llmfoundry/models/layers/norm.py index 09070f2418..313c547b67 100644 --- a/llmfoundry/models/layers/norm.py +++ b/llmfoundry/models/layers/norm.py @@ -53,15 +53,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) -def _expand_params(x: torch.Tensor, param: Optional[torch.Tensor] = None): - # repeat param if params are applied per group - if param is None: - return None - if x.shape[-1] == param.shape[-1]: - return param - return param.repeat(x.shape[-1] // param.shape[-1]) - - def low_precision_groupnorm( x: torch.Tensor, groups: int, @@ -71,14 +62,10 @@ def low_precision_groupnorm( ): device = x.device downcast_x = _cast_if_autocast_enabled(x) - downcast_weight, downcast_bias = None, None - if weight is not None: - downcast_weight = _cast_if_autocast_enabled(weight) - downcast_weight = _expand_params(x, downcast_weight) - if bias is not None: - downcast_bias = _cast_if_autocast_enabled(bias) - downcast_bias = _expand_params(x, downcast_bias) - + downcast_weight = _cast_if_autocast_enabled( + weight) if weight is not None else weight + downcast_bias = _cast_if_autocast_enabled( + bias) if bias is not None else bias with torch.autocast(enabled=False, device_type=device.type): return torch.nn.functional.group_norm( downcast_x,