Skip to content

Commit

Permalink
..
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML committed Jan 17, 2024
1 parent e236305 commit 416525a
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def flash_attn_fn(
torch.Tensor]]]:
del key_padding_mask
try:
from flash_attn import bert_padding, flash_attn_interface, index_first_axis # type: ignore # yapf: disable # isort: skip
from flash_attn import bert_padding, flash_attn_interface # type: ignore # yapf: disable # isort: skip
except:
raise RuntimeError(
'Please install flash-attn==1.0.9 or flash-attn==2.3.6')
Expand Down Expand Up @@ -276,12 +276,12 @@ def flash_attn_fn(
max_seqlen_q = flash_attn_padding_info['max_seqlen_q']
max_seqlen_k = flash_attn_padding_info['max_seqlen_k']

query_unpad = index_first_axis(rearrange(query, 'b s ... -> (b s) ...'),
indices_q)
key_unpad = index_first_axis(rearrange(key, 'b s ... -> (b s) ...'),
indices_k)
value_unpad = index_first_axis(rearrange(value, 'b s ... -> (b s) ...'),
indices_v)
query_unpad = bert_padding.index_first_axis(
rearrange(query, 'b s ... -> (b s) ...'), indices_q)
key_unpad = bert_padding.index_first_axis(
rearrange(key, 'b s ... -> (b s) ...'), indices_k)
value_unpad = bert_padding.index_first_axis(
rearrange(value, 'b s ... -> (b s) ...'), indices_v)

if (kv_n_heads < n_heads) and (not is_flash_v2_installed()) and (
not should_repeat_kv_for_gqa):
Expand Down

0 comments on commit 416525a

Please sign in to comment.