diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 8d90bfa3a3..54df092bde 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -90,13 +90,8 @@ def scaled_multihead_dot_product_attention( training: bool = False, needs_weights: bool = False, multiquery: bool = False, - attention_mask_in_length: Optional[torch.Tensor] = None, - should_repeat_kv_for_gqa: Optional[bool] = True, - sliding_window_size: int = -1, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor, torch.Tensor]]]: - del attention_mask_in_length, should_repeat_kv_for_gqa, sliding_window_size - if multiquery: warnings.warn( DeprecationWarning( @@ -271,33 +266,28 @@ def flash_attn_fn( if key_padding_mask is None: key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool) query_padding_mask = key_padding_mask[:, -query.size(1):] + unpadding_function = bert_padding.unpad_input + else: + unpadding_function = bert_padding.unpad_input_for_concatenated_sequences + query_padding_mask = attention_mask_in_length + key_padding_mask = attention_mask_in_length - query_unpad, indices_q, cu_seqlens_q, max_seqlen_q = bert_padding.unpad_input( - query, query_padding_mask) - query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads) - - key_unpad, _, cu_seqlens_k, max_seqlen_k = bert_padding.unpad_input( - key, key_padding_mask) - key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads) + query_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpadding_function( + query, query_padding_mask) + query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads) - value_unpad, _, _, _ = bert_padding.unpad_input(value, key_padding_mask) - value_unpad = rearrange(value_unpad, - 'nnz (h d) -> nnz h d', - h=kv_n_heads) - else: - query_unpad, indices_q, cu_seqlens_q, max_seqlen_q = bert_padding.unpad_input_for_concatenated_sequences( - query, attention_mask_in_length) - query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads) + key_unpad, _, cu_seqlens_k, max_seqlen_k = unpadding_function( + key, key_padding_mask) + key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads) - key_unpad, _, cu_seqlens_k, max_seqlen_k = bert_padding.unpad_input_for_concatenated_sequences( - key, attention_mask_in_length) - key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads) + value_unpad, _, _, _ = unpadding_function(value, key_padding_mask) + value_unpad = rearrange(value_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads) - value_unpad, _, _, _ = bert_padding.unpad_input_for_concatenated_sequences( - value, attention_mask_in_length) - value_unpad = rearrange(value_unpad, - 'nnz (h d) -> nnz h d', - h=kv_n_heads) + if (kv_n_heads < n_heads) and (not is_flash_v2_installed()) and ( + not should_repeat_kv_for_gqa): + raise ValueError( + 'For Grouped Query Attention or Multi Query Attention, should_repeat_kv_for_gqa should be set to True if not using Flash Attention v2.' + ) if should_repeat_kv_for_gqa: # multi-query case @@ -383,12 +373,8 @@ def triton_flash_attn_fn( training: bool = False, needs_weights: bool = False, multiquery: bool = False, - attention_mask_in_length: Optional[torch.Tensor] = None, - should_repeat_kv_for_gqa: Optional[bool] = True, - sliding_window_size: int = -1, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor, torch.Tensor]]]: - del attention_mask_in_length, should_repeat_kv_for_gqa, sliding_window_size try: from llmfoundry.models.layers.flash_attn_triton import flash_attn_func except: @@ -659,6 +645,14 @@ def forward( query = query.view(bsz, seqlen, self.d_model) key = key.view(bsz, seqlen, self.kv_n_heads * self.head_dim) + extra_attn_kwargs = {} + if self.attn_impl == 'flash': + extra_attn_kwargs = { + 'attention_mask_in_length': attention_mask_in_length, + 'should_repeat_kv_for_gqa': not is_flash_v2_installed(), + 'sliding_window_size': self.sliding_window_size, + } + context, attn_weights, past_key_value = self.attn_fn( query, key, @@ -673,9 +667,7 @@ def forward( dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights, - attention_mask_in_length=attention_mask_in_length, - should_repeat_kv_for_gqa=not is_flash_v2_installed(), - sliding_window_size=self.sliding_window_size, + **extra_attn_kwargs, ) return self.out_proj(context), attn_weights, past_key_value diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index d2b0eeb3d7..6349656983 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -135,11 +135,20 @@ def gen_rotary_embedding(rope_head_dim: int, rope_impl: str, rope_theta: int, def gen_attention_mask_in_length(sequence_id: Union[None, torch.Tensor], S: int, attn_uses_sequence_id: bool, attn_impl: str, attention_mask: Union[torch.Tensor, None]): - # Generates the attention masks used for sequence masking in flash attention - # NOTE: Only supports sequence id based sparse attention for no attention masking or attention masking with right padding. - # In case of left padding: - # 1. Training with left padding is not supported in MPT (see https://github.com/mosaicml/llm-foundry/blob/1eecd4cb8e734499f77f6a35f657b8b20c0adfcb/llmfoundry/models/mpt/modeling_mpt.py#L407). - # 2. For generation with left padding, we only have a single sequence id per sample, so we don't need sequence id based sparse attention. + """Generates the attention mask used for sequence masking in FA v2. + + Only supports sequence id based sparse attention for no attention masking or attention masking with right padding. + In case of left padding: + 1. Training with left padding is not supported in MPT (see https://github.com/mosaicml/llm-foundry/blob/1eecd4cb8e734499f77f6a35f657b8b20c0adfcb/llmfoundry/models/mpt/modeling_mpt.py#L407). + 2. For generation with left padding, we only have a single sequence id per sample, so we don't need sequence id based sparse attention. + + Args: + sequence_id (Union[None, torch.Tensor]): Tensor containing the sequence id for each token. Shape (batch_size, seq_len). + S (int): Sequence length + attn_uses_sequence_id (bool): Whether the attention uses sequence id based masking. + attn_impl (str): Attention implementation. This function is only creates attention_mask_in_length for flash attention. + attention_mask (Union[torch.Tensor, None]): Attention mask tensor of shape (batch_size, seq_len) + """ attention_mask_in_length = None if (sequence_id is not None) and attn_uses_sequence_id and (attn_impl == 'flash'): @@ -149,7 +158,10 @@ def gen_attention_mask_in_length(sequence_id: Union[None, torch.Tensor], S: int, raise NotImplementedError( 'Left padding is not supported with flash attention when attn_uses_sequence_id is set to True.' ) - assert S == sequence_id.shape[-1] + if S != sequence_id.shape[-1]: + raise ValueError( + f'Sequence length ({S}) does not match length of sequences in sequence_id ({sequence_id.shape[-1]}).' + ) attention_mask_in_length = torch.nn.functional.one_hot(sequence_id) if attention_mask is not None: attention_mask_in_length = attention_mask_in_length.masked_fill( diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index 446fa98440..7e282dbc9d 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -225,23 +225,22 @@ def test_sliding_window(sliding_window_size: int): torch.ones(seqlen_1, seqlen_1), diagonal=-(sliding_window_size + 1)).to( dtype=dtype, device=device) * torch.finfo(attn_bias_2.dtype).min attn_bias_2 = attn_bias_2 + window_mask_2 - output_2, _, _ = triton_flash_attn_fn(query=query_2, - key=key_2, - value=value_2, - n_heads=n_heads, - kv_n_heads=n_heads, - past_key_value=None, - softmax_scale=1 / math.sqrt(d), - attn_bias=attn_bias_2, - key_padding_mask=None, - is_causal=True, - dropout_p=0.0, - training=False, - needs_weights=False, - multiquery=False, - attention_mask_in_length=None, - should_repeat_kv_for_gqa=False, - sliding_window_size=-1) + output_2, _, _ = triton_flash_attn_fn( + query=query_2, + key=key_2, + value=value_2, + n_heads=n_heads, + kv_n_heads=n_heads, + past_key_value=None, + softmax_scale=1 / math.sqrt(d), + attn_bias=attn_bias_2, + key_padding_mask=None, + is_causal=True, + dropout_p=0.0, + training=False, + needs_weights=False, + multiquery=False, + ) output_2.sum().backward()