diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index b8e4dde361..cd2897662e 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -330,23 +330,8 @@ def flash_attn_fn( softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights) - elif is_flash_v2_installed(v2_version='2.4.0.post1'): - output_unpad = flash_attn_interface.flash_attn_varlen_func( - q=query_unpad, - k=key_unpad, - v=value_unpad, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_q, - max_seqlen_k=max_seqlen_k, - dropout_p=dropout_p, - softmax_scale=softmax_scale, - causal=reset_is_causal, - return_attn_probs=needs_weights, - window_size=(sliding_window_size, sliding_window_size), - alibi_slopes=attn_bias, - ) elif is_flash_v2_installed(): + alibi_kwargs = {'alibi_slopes':attn_bias} if is_flash_v2_installed(v2_version='2.4.0.post1') else {} output_unpad = flash_attn_interface.flash_attn_varlen_func( q=query_unpad, k=key_unpad, @@ -359,7 +344,8 @@ def flash_attn_fn( softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights, - window_size=(sliding_window_size, sliding_window_size)) + window_size=(sliding_window_size, sliding_window_size), + **alibi_kwargs) else: raise RuntimeError( 'flash-attn==1.0.9 or flash-attn==2.3.6 is required.')