Skip to content

Commit

Permalink
Update llmfoundry/models/mpt/modeling_mpt.py
Browse files Browse the repository at this point in the history
Co-authored-by: Vitaliy Chiley <[email protected]>
  • Loading branch information
ShashankMosaicML and vchiley authored Jan 17, 2024
1 parent 3351d23 commit 3d8cda8
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def gen_flash_attn_padding_info(
key_padding_mask = attention_mask
if key_padding_mask is None:
key_padding_mask = torch.ones((bsz, past_key_len + S),
dtype=torch.bool).to(device=device)
dtype=torch.bool, device=device)
query_padding_mask = key_padding_mask[:, -S:]
unpadding_function = bert_padding.unpad_input
else:
Expand All @@ -245,12 +245,12 @@ def gen_flash_attn_padding_info(
unpadding_function = bert_padding.unpad_input_for_concatenated_sequences

_, indices_q, cu_seqlens_q, max_seqlen_q = unpadding_function(
torch.zeros(bsz, S, 1).to(device=device), query_padding_mask)
torch.zeros(bsz, S, 1, device=device), query_padding_mask)
_, indices_k, cu_seqlens_k, max_seqlen_k = unpadding_function(
torch.zeros(bsz, past_key_len + S, 1).to(device=device),
torch.zeros(bsz, past_key_len + S, 1, device=device),
key_padding_mask)
_, indices_v, _, _ = unpadding_function(
torch.zeros(bsz, past_key_len + S, 1).to(device=device),
torch.zeros(bsz, past_key_len + S, 1, device=device),
key_padding_mask)

flash_attn_padding_info['indices_q'] = indices_q
Expand Down

0 comments on commit 3d8cda8

Please sign in to comment.