Skip to content

Commit

Permalink
..
Browse files Browse the repository at this point in the history
  • Loading branch information
Shashank Rajput committed Dec 24, 2023
1 parent 325c996 commit fcb59d4
Showing 1 changed file with 3 additions and 17 deletions.
20 changes: 3 additions & 17 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,23 +330,8 @@ def flash_attn_fn(
softmax_scale=softmax_scale,
causal=reset_is_causal,
return_attn_probs=needs_weights)
elif is_flash_v2_installed(v2_version='2.4.0.post1'):
output_unpad = flash_attn_interface.flash_attn_varlen_func(
q=query_unpad,
k=key_unpad,
v=value_unpad,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
dropout_p=dropout_p,
softmax_scale=softmax_scale,
causal=reset_is_causal,
return_attn_probs=needs_weights,
window_size=(sliding_window_size, sliding_window_size),
alibi_slopes=attn_bias,
)
elif is_flash_v2_installed():
alibi_kwargs = {'alibi_slopes':attn_bias} if is_flash_v2_installed(v2_version='2.4.0.post1') else {}
output_unpad = flash_attn_interface.flash_attn_varlen_func(
q=query_unpad,
k=key_unpad,
Expand All @@ -359,7 +344,8 @@ def flash_attn_fn(
softmax_scale=softmax_scale,
causal=reset_is_causal,
return_attn_probs=needs_weights,
window_size=(sliding_window_size, sliding_window_size))
window_size=(sliding_window_size, sliding_window_size),
**alibi_kwargs)
else:
raise RuntimeError(
'flash-attn==1.0.9 or flash-attn==2.3.6 is required.')
Expand Down

0 comments on commit fcb59d4

Please sign in to comment.