From fac1c8e39cde5497b8d07fbd769638c5efeaf906 Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Mon, 25 Dec 2023 07:44:31 +0000 Subject: [PATCH] .. --- llmfoundry/models/layers/attention.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index cd2897662e..c51037c6f8 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -331,7 +331,9 @@ def flash_attn_fn( causal=reset_is_causal, return_attn_probs=needs_weights) elif is_flash_v2_installed(): - alibi_kwargs = {'alibi_slopes':attn_bias} if is_flash_v2_installed(v2_version='2.4.0.post1') else {} + alibi_kwargs = { + 'alibi_slopes': attn_bias + } if is_flash_v2_installed(v2_version='2.4.0.post1') else {} output_unpad = flash_attn_interface.flash_attn_varlen_func( q=query_unpad, k=key_unpad, @@ -805,7 +807,7 @@ def build_attn_bias( def gen_slopes(n_heads: int, alibi_bias_max: int = 8, device: Optional[torch.device] = None, - return_1d: bool=False) -> torch.Tensor: + return_1d: bool = False) -> torch.Tensor: _n_heads = 2**math.ceil(math.log2(n_heads)) m = torch.arange(1, _n_heads + 1, dtype=torch.float32, device=device) m = m.mul(alibi_bias_max / _n_heads)