Skip to content

Commit

Permalink
..
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML committed Jan 17, 2024
1 parent e98a01d commit 5a9e1e8
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,25 +627,26 @@ def forward(
flash_attn_padding_info = {}
if self.attn_impl == 'flash':
if attention_mask_in_length is None:
if attention_mask is None:
key_padding_mask = attention_mask
if key_padding_mask is None:
past_key_len = past_key_values[0].shape[
1] if past_key_values is not None else 0
attention_mask = torch.ones(
key_padding_mask = torch.ones(
(x.shape[0], past_key_len + x.shape[1]),
dtype=torch.bool)
query_padding_mask = attention_mask[:, -x.shape[1]:]
query_padding_mask = key_padding_mask[:, -x.shape[1]:]
unpadding_function = bert_padding.unpad_input
else:
attention_mask = attention_mask_in_length
key_padding_mask = attention_mask_in_length
query_padding_mask = attention_mask_in_length
unpadding_function = bert_padding.unpad_input_for_concatenated_sequences

_, indices_q, cu_seqlens_q, max_seqlen_q = unpadding_function(
torch.zeros(1, 1), query_padding_mask)
_, indices_k, cu_seqlens_k, max_seqlen_k = unpadding_function(
torch.zeros(1, 1), attention_mask)
torch.zeros(1, 1), key_padding_mask)
_, indices_v, _, _ = unpadding_function(torch.zeros(1, 1),
attention_mask)
key_padding_mask)

flash_attn_padding_info['indices_q'] = indices_q
flash_attn_padding_info['indices_k'] = indices_k
Expand Down

0 comments on commit 5a9e1e8

Please sign in to comment.