Skip to content

Commit

Permalink
..
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML committed Jan 17, 2024
1 parent 5a9e1e8 commit 61d8ade
Showing 1 changed file with 33 additions and 30 deletions.
63 changes: 33 additions & 30 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,36 @@ def gen_attention_mask_in_length(sequence_id: Union[None, torch.Tensor], S: int,

return attention_mask_in_length


def get_flash_attn_padding_info(attention_mask_in_length, attention_mask, past_key_len):
flash_attn_padding_info = {}
if attention_mask_in_length is None:
key_padding_mask = attention_mask
if key_padding_mask is None:
key_padding_mask = torch.ones(
(x.shape[0], past_key_len + x.shape[1]),
dtype=torch.bool)
query_padding_mask = key_padding_mask[:, -x.shape[1]:]
unpadding_function = bert_padding.unpad_input
else:
key_padding_mask = attention_mask_in_length
query_padding_mask = attention_mask_in_length
unpadding_function = bert_padding.unpad_input_for_concatenated_sequences

_, indices_q, cu_seqlens_q, max_seqlen_q = unpadding_function(
torch.zeros(1, 1), query_padding_mask)
_, indices_k, cu_seqlens_k, max_seqlen_k = unpadding_function(
torch.zeros(1, 1), key_padding_mask)
_, indices_v, _, _ = unpadding_function(torch.zeros(1, 1),
key_padding_mask)

flash_attn_padding_info['indices_q'] = indices_q
flash_attn_padding_info['indices_k'] = indices_k
flash_attn_padding_info['indices_v'] = indices_v
flash_attn_padding_info['cu_seqlens_q'] = cu_seqlens_q
flash_attn_padding_info['cu_seqlens_k'] = cu_seqlens_k
flash_attn_padding_info['max_seqlen_q'] = max_seqlen_q
flash_attn_padding_info['max_seqlen_k'] = max_seqlen_k
return flash_attn_padding_info
def apply_sequence_id(attn_bias: torch.Tensor, sequence_id: torch.LongTensor,
max_seq_len: int) -> torch.Tensor:
seq_len = sequence_id.shape[-1]
Expand Down Expand Up @@ -626,35 +655,9 @@ def forward(
all_self_attns = () if output_attentions else None
flash_attn_padding_info = {}
if self.attn_impl == 'flash':
if attention_mask_in_length is None:
key_padding_mask = attention_mask
if key_padding_mask is None:
past_key_len = past_key_values[0].shape[
1] if past_key_values is not None else 0
key_padding_mask = torch.ones(
(x.shape[0], past_key_len + x.shape[1]),
dtype=torch.bool)
query_padding_mask = key_padding_mask[:, -x.shape[1]:]
unpadding_function = bert_padding.unpad_input
else:
key_padding_mask = attention_mask_in_length
query_padding_mask = attention_mask_in_length
unpadding_function = bert_padding.unpad_input_for_concatenated_sequences

_, indices_q, cu_seqlens_q, max_seqlen_q = unpadding_function(
torch.zeros(1, 1), query_padding_mask)
_, indices_k, cu_seqlens_k, max_seqlen_k = unpadding_function(
torch.zeros(1, 1), key_padding_mask)
_, indices_v, _, _ = unpadding_function(torch.zeros(1, 1),
key_padding_mask)

flash_attn_padding_info['indices_q'] = indices_q
flash_attn_padding_info['indices_k'] = indices_k
flash_attn_padding_info['indices_v'] = indices_v
flash_attn_padding_info['cu_seqlens_q'] = cu_seqlens_q
flash_attn_padding_info['cu_seqlens_k'] = cu_seqlens_k
flash_attn_padding_info['max_seqlen_q'] = max_seqlen_q
flash_attn_padding_info['max_seqlen_k'] = max_seqlen_k
past_key_len = past_key_values[0].shape[1] if past_key_values is not None else 0
flash_attn_padding_info = get_flash_attn_padding_info(attention_mask_in_length, attention_mask, past_key_len)

for b_idx, block in enumerate(self.blocks):
if output_hidden_states:
assert all_hidden_states is not None # pyright
Expand Down

0 comments on commit 61d8ade

Please sign in to comment.