diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 00bc61f7c4..1408dc1ca0 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -232,11 +232,13 @@ def gen_flash_attn_padding_info( attention_mask_in_length: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None): flash_attn_padding_info = {} + dummy_data = torch.ones((bsz, past_key_len + S), + dtype=torch.bool, + device=device) if attention_mask_in_length is None: key_padding_mask = attention_mask if key_padding_mask is None: - key_padding_mask = torch.ones((bsz, past_key_len + S), - dtype=torch.bool, device=device) + key_padding_mask = dummy_data query_padding_mask = key_padding_mask[:, -S:] unpadding_function = bert_padding.unpad_input else: @@ -245,13 +247,11 @@ def gen_flash_attn_padding_info( unpadding_function = bert_padding.unpad_input_for_concatenated_sequences _, indices_q, cu_seqlens_q, max_seqlen_q = unpadding_function( - torch.zeros(bsz, S, 1, device=device), query_padding_mask) + dummy_data[:, :S, None], query_padding_mask) _, indices_k, cu_seqlens_k, max_seqlen_k = unpadding_function( - torch.zeros(bsz, past_key_len + S, 1, device=device), - key_padding_mask) - _, indices_v, _, _ = unpadding_function( - torch.zeros(bsz, past_key_len + S, 1, device=device), - key_padding_mask) + dummy_data[:, :, None], key_padding_mask) + _, indices_v, _, _ = unpadding_function(dummy_data[:, :, None], + key_padding_mask) flash_attn_padding_info['indices_q'] = indices_q flash_attn_padding_info['indices_k'] = indices_k