diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 54df092bde..47e5437dbb 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -251,12 +251,6 @@ def flash_attn_fn( past_key_value = (key, value) - if attn_bias is not None: - # clamp to 0 necessary for torch 2.0 compile() - _s_q = max(0, attn_bias.size(2) - query.size(1)) - _s_k = max(0, attn_bias.size(3) - key.size(1)) - attn_bias = attn_bias[:, :, _s_q:, _s_k:] - if attn_bias is not None: raise NotImplementedError(f'attn_bias not implemented for flash attn.') diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 6349656983..077306bddc 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -138,16 +138,16 @@ def gen_attention_mask_in_length(sequence_id: Union[None, torch.Tensor], S: int, """Generates the attention mask used for sequence masking in FA v2. Only supports sequence id based sparse attention for no attention masking or attention masking with right padding. - In case of left padding: - 1. Training with left padding is not supported in MPT (see https://github.com/mosaicml/llm-foundry/blob/1eecd4cb8e734499f77f6a35f657b8b20c0adfcb/llmfoundry/models/mpt/modeling_mpt.py#L407). - 2. For generation with left padding, we only have a single sequence id per sample, so we don't need sequence id based sparse attention. - - Args: - sequence_id (Union[None, torch.Tensor]): Tensor containing the sequence id for each token. Shape (batch_size, seq_len). - S (int): Sequence length - attn_uses_sequence_id (bool): Whether the attention uses sequence id based masking. - attn_impl (str): Attention implementation. This function is only creates attention_mask_in_length for flash attention. - attention_mask (Union[torch.Tensor, None]): Attention mask tensor of shape (batch_size, seq_len) + In case of left padding: + 1. Training with left padding is not supported in MPT (see https://github.com/mosaicml/llm-foundry/blob/1eecd4cb8e734499f77f6a35f657b8b20c0adfcb/llmfoundry/models/mpt/modeling_mpt.py#L407). + 2. For generation with left padding, we only have a single sequence id per sample, so we don't need sequence id based sparse attention. + + Args: + sequence_id (Union[None, torch.Tensor]): Tensor containing the sequence id for each token. Shape (batch_size, seq_len). + S (int): Sequence length + attn_uses_sequence_id (bool): Whether the attention uses sequence id based masking. + attn_impl (str): Attention implementation. This function is only creates attention_mask_in_length for flash attention. + attention_mask (Union[torch.Tensor, None]): Attention mask tensor of shape (batch_size, seq_len) """ attention_mask_in_length = None if (sequence_id is not None) and attn_uses_sequence_id and (attn_impl