Skip to content

Commit

Permalink
Revert "perf improvement"
Browse files Browse the repository at this point in the history
This reverts commit 2b62d5e.
  • Loading branch information
vchiley committed Jan 20, 2024
1 parent 2b62d5e commit a478efa
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 29 deletions.
18 changes: 6 additions & 12 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,29 +514,23 @@ def _apply_qk_gn(
):
dtype = query.dtype

w = _expand_params(n_heads, q_ln.weight)
b = _expand_params(n_heads, q_ln.bias)
if isinstance(q_ln, LPLayerNorm):
query = low_precision_groupnorm(query,
n_heads,
q_ln.weight,
q_ln.bias,
query = low_precision_groupnorm(query, n_heads, w, b,
eps=q_ln.eps).to(dtype)
elif isinstance(q_ln, nn.LayerNorm):
w = _expand_params(n_heads, q_ln.weight)
b = _expand_params(n_heads, q_ln.bias)
query = nn.functional.group_norm(query, n_heads, w, b, eps=q_ln.eps)
else:
raise ValueError(
f'qk_gn not applicable for given q_ln type ({type(q_ln)=}).')

w = _expand_params(kv_n_heads, k_ln.weight)
b = _expand_params(kv_n_heads, k_ln.bias)
if isinstance(k_ln, LPLayerNorm):
key = low_precision_groupnorm(key,
kv_n_heads,
k_ln.weight,
k_ln.bias,
key = low_precision_groupnorm(key, kv_n_heads, w, b,
eps=k_ln.eps).to(dtype)
elif isinstance(k_ln, nn.LayerNorm):
w = _expand_params(kv_n_heads, k_ln.weight)
b = _expand_params(kv_n_heads, k_ln.bias)
key = nn.functional.group_norm(key, kv_n_heads, w, b, eps=k_ln.eps)
else:
raise ValueError(
Expand Down
21 changes: 4 additions & 17 deletions llmfoundry/models/layers/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
)


def _expand_params(x: torch.Tensor, param: Optional[torch.Tensor] = None):
# repeat param if params are applied per group
if param is None:
return None
if x.shape[-1] == param.shape[-1]:
return param
return param.repeat(x.shape[-1] // param.shape[-1])


def low_precision_groupnorm(
x: torch.Tensor,
groups: int,
Expand All @@ -71,14 +62,10 @@ def low_precision_groupnorm(
):
device = x.device
downcast_x = _cast_if_autocast_enabled(x)
downcast_weight, downcast_bias = None, None
if weight is not None:
downcast_weight = _cast_if_autocast_enabled(weight)
downcast_weight = _expand_params(x, downcast_weight)
if bias is not None:
downcast_bias = _cast_if_autocast_enabled(bias)
downcast_bias = _expand_params(x, downcast_bias)

downcast_weight = _cast_if_autocast_enabled(
weight) if weight is not None else weight
downcast_bias = _cast_if_autocast_enabled(
bias) if bias is not None else bias
with torch.autocast(enabled=False, device_type=device.type):
return torch.nn.functional.group_norm(
downcast_x,
Expand Down

0 comments on commit a478efa

Please sign in to comment.