Skip to content

Commit

Permalink
..
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML committed Nov 30, 2023
1 parent 67deef8 commit b855100
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 16 deletions.
6 changes: 0 additions & 6 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,12 +251,6 @@ def flash_attn_fn(

past_key_value = (key, value)

if attn_bias is not None:
# clamp to 0 necessary for torch 2.0 compile()
_s_q = max(0, attn_bias.size(2) - query.size(1))
_s_k = max(0, attn_bias.size(3) - key.size(1))
attn_bias = attn_bias[:, :, _s_q:, _s_k:]

if attn_bias is not None:
raise NotImplementedError(f'attn_bias not implemented for flash attn.')

Expand Down
20 changes: 10 additions & 10 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,16 +138,16 @@ def gen_attention_mask_in_length(sequence_id: Union[None, torch.Tensor], S: int,
"""Generates the attention mask used for sequence masking in FA v2.
Only supports sequence id based sparse attention for no attention masking or attention masking with right padding.
In case of left padding:
1. Training with left padding is not supported in MPT (see https://github.com/mosaicml/llm-foundry/blob/1eecd4cb8e734499f77f6a35f657b8b20c0adfcb/llmfoundry/models/mpt/modeling_mpt.py#L407).
2. For generation with left padding, we only have a single sequence id per sample, so we don't need sequence id based sparse attention.
Args:
sequence_id (Union[None, torch.Tensor]): Tensor containing the sequence id for each token. Shape (batch_size, seq_len).
S (int): Sequence length
attn_uses_sequence_id (bool): Whether the attention uses sequence id based masking.
attn_impl (str): Attention implementation. This function is only creates attention_mask_in_length for flash attention.
attention_mask (Union[torch.Tensor, None]): Attention mask tensor of shape (batch_size, seq_len)
In case of left padding:
1. Training with left padding is not supported in MPT (see https://github.com/mosaicml/llm-foundry/blob/1eecd4cb8e734499f77f6a35f657b8b20c0adfcb/llmfoundry/models/mpt/modeling_mpt.py#L407).
2. For generation with left padding, we only have a single sequence id per sample, so we don't need sequence id based sparse attention.
Args:
sequence_id (Union[None, torch.Tensor]): Tensor containing the sequence id for each token. Shape (batch_size, seq_len).
S (int): Sequence length
attn_uses_sequence_id (bool): Whether the attention uses sequence id based masking.
attn_impl (str): Attention implementation. This function is only creates attention_mask_in_length for flash attention.
attention_mask (Union[torch.Tensor, None]): Attention mask tensor of shape (batch_size, seq_len)
"""
attention_mask_in_length = None
if (sequence_id is not None) and attn_uses_sequence_id and (attn_impl
Expand Down

0 comments on commit b855100

Please sign in to comment.