Skip to content

Commit

Permalink
..
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML committed Jan 17, 2024
1 parent 34e4a99 commit 03113a9
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,9 +247,11 @@ def gen_flash_attn_padding_info(
_, indices_q, cu_seqlens_q, max_seqlen_q = unpadding_function(
torch.zeros(bsz, S, 1).to(device=device), query_padding_mask)
_, indices_k, cu_seqlens_k, max_seqlen_k = unpadding_function(
torch.zeros(bsz, S, 1).to(device=device), key_padding_mask)
torch.zeros(bsz, past_key_len + S, 1).to(device=device),
key_padding_mask)
_, indices_v, _, _ = unpadding_function(
torch.zeros(bsz, S, 1).to(device=device), key_padding_mask)
torch.zeros(bsz, past_key_len + S, 1).to(device=device),
key_padding_mask)

flash_attn_padding_info['indices_q'] = indices_q
flash_attn_padding_info['indices_k'] = indices_k
Expand Down

0 comments on commit 03113a9

Please sign in to comment.