Skip to content

Commit

Permalink
dummy data
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML committed Jan 17, 2024
1 parent 3d8cda8 commit b227bcf
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,11 +232,13 @@ def gen_flash_attn_padding_info(
attention_mask_in_length: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None):
flash_attn_padding_info = {}
dummy_data = torch.ones((bsz, past_key_len + S),
dtype=torch.bool,
device=device)
if attention_mask_in_length is None:
key_padding_mask = attention_mask
if key_padding_mask is None:
key_padding_mask = torch.ones((bsz, past_key_len + S),
dtype=torch.bool, device=device)
key_padding_mask = dummy_data
query_padding_mask = key_padding_mask[:, -S:]
unpadding_function = bert_padding.unpad_input
else:
Expand All @@ -245,13 +247,11 @@ def gen_flash_attn_padding_info(
unpadding_function = bert_padding.unpad_input_for_concatenated_sequences

_, indices_q, cu_seqlens_q, max_seqlen_q = unpadding_function(
torch.zeros(bsz, S, 1, device=device), query_padding_mask)
dummy_data[:, :S, None], query_padding_mask)
_, indices_k, cu_seqlens_k, max_seqlen_k = unpadding_function(
torch.zeros(bsz, past_key_len + S, 1, device=device),
key_padding_mask)
_, indices_v, _, _ = unpadding_function(
torch.zeros(bsz, past_key_len + S, 1, device=device),
key_padding_mask)
dummy_data[:, :, None], key_padding_mask)
_, indices_v, _, _ = unpadding_function(dummy_data[:, :, None],
key_padding_mask)

flash_attn_padding_info['indices_q'] = indices_q
flash_attn_padding_info['indices_k'] = indices_k
Expand Down

0 comments on commit b227bcf

Please sign in to comment.