Skip to content

Commit

Permalink
..
Browse files Browse the repository at this point in the history
  • Loading branch information
Shashank Rajput committed Dec 25, 2023
1 parent e31cb8a commit fac1c8e
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,9 @@ def flash_attn_fn(
causal=reset_is_causal,
return_attn_probs=needs_weights)
elif is_flash_v2_installed():
alibi_kwargs = {'alibi_slopes':attn_bias} if is_flash_v2_installed(v2_version='2.4.0.post1') else {}
alibi_kwargs = {
'alibi_slopes': attn_bias
} if is_flash_v2_installed(v2_version='2.4.0.post1') else {}
output_unpad = flash_attn_interface.flash_attn_varlen_func(
q=query_unpad,
k=key_unpad,
Expand Down Expand Up @@ -805,7 +807,7 @@ def build_attn_bias(
def gen_slopes(n_heads: int,
alibi_bias_max: int = 8,
device: Optional[torch.device] = None,
return_1d: bool=False) -> torch.Tensor:
return_1d: bool = False) -> torch.Tensor:
_n_heads = 2**math.ceil(math.log2(n_heads))
m = torch.arange(1, _n_heads + 1, dtype=torch.float32, device=device)
m = m.mul(alibi_bias_max / _n_heads)
Expand Down

0 comments on commit fac1c8e

Please sign in to comment.