Skip to content

Commit

Permalink
..
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML committed Jan 17, 2024
1 parent 5063149 commit 0f25b73
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,10 +326,15 @@ def flash_attn_fn(

query_unpad = bert_padding.index_first_axis(
rearrange(query, 'b s ... -> (b s) ...'), indices_q)
query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads)

key_unpad = bert_padding.index_first_axis(
rearrange(key, 'b s ... -> (b s) ...'), indices_k)
key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads)

value_unpad = bert_padding.index_first_axis(
rearrange(value, 'b s ... -> (b s) ...'), indices_v)
value_unpad = rearrange(value_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads)

if (kv_n_heads < n_heads) and (not is_flash_v2_installed()) and (
not should_repeat_kv_for_gqa):
Expand Down

0 comments on commit 0f25b73

Please sign in to comment.