From 0474e0595abed7da5d6e074d22e60d4d31ecb8ab Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Wed, 17 Jan 2024 19:11:59 +0000 Subject: [PATCH] .. --- llmfoundry/models/layers/attention.py | 2 + llmfoundry/models/mpt/modeling_mpt.py | 49 ++--- tests/models/layers/test_flash_attn.py | 180 ++++++++++-------- .../models/layers/test_flash_triton_torch.py | 16 +- 4 files changed, 140 insertions(+), 107 deletions(-) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 18d1e5152b..11cb5a7a72 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -235,6 +235,8 @@ def flash_attn_fn( ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor, torch.Tensor]]]: del key_padding_mask + if flash_attn_padding_info is None: + raise ValueError('flash_attn_padding_info is required for flash attn.') try: from flash_attn import bert_padding, flash_attn_interface # type: ignore # yapf: disable # isort: skip except: diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 8e9ec6435d..ebc9f35779 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -217,7 +217,9 @@ def gen_attention_mask_in_length(sequence_id: Union[None, torch.Tensor], S: int, return attention_mask_in_length -def get_flash_attn_padding_info( +def gen_flash_attn_padding_info( + bsz: int, + S: int, past_key_len: int, attention_mask_in_length: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None): @@ -225,9 +227,9 @@ def get_flash_attn_padding_info( if attention_mask_in_length is None: key_padding_mask = attention_mask if key_padding_mask is None: - key_padding_mask = torch.ones( - (x.shape[0], past_key_len + x.shape[1]), dtype=torch.bool) - query_padding_mask = key_padding_mask[:, -x.shape[1]:] + key_padding_mask = torch.ones((bsz, past_key_len + S), + dtype=torch.bool) + query_padding_mask = key_padding_mask[:, -S:] unpadding_function = bert_padding.unpad_input else: key_padding_mask = attention_mask_in_length @@ -550,10 +552,12 @@ def forward( raise ValueError( 'You cannot specify both input_ids and inputs_embeds.') elif input_ids is not None: + bsz = input_ids.size(0) S = input_ids.size(1) x = self.wte(input_ids) input_device = input_ids.device elif inputs_embeds is not None: + bsz = inputs_embeds.size(0) S = inputs_embeds.size(1) x = inputs_embeds input_device = inputs_embeds.device @@ -565,22 +569,23 @@ def forward( ), f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}' rotary_emb_w_meta_info = None - if self.learned_pos_emb or self.rope: - past_position = 0 - if past_key_values is not None: - if len(past_key_values) != self.config.n_layers: - raise ValueError( - f'past_key_values must provide a past_key_value for each attention ' - + - f'layer in the network ({len(past_key_values)=}; {self.config.n_layers=}).' - ) - # For attn_impl: triton and flash the past key tensor spec is (batch, seq, dim). - # For attn_impl: torch the past key tensor spec is (batch, heads, head_dim, seq). - # Here we shift position embedding using the `seq` dim of the past key - past_position = past_key_values[0][0].size(1) - if self.attn_impl == 'torch': - past_position = past_key_values[0][0].size(3) + past_position = 0 + if past_key_values is not None: + if len(past_key_values) != self.config.n_layers: + raise ValueError( + f'past_key_values must provide a past_key_value for each attention ' + + + f'layer in the network ({len(past_key_values)=}; {self.config.n_layers=}).' + ) + # For attn_impl: triton and flash the past key tensor spec is (batch, seq, dim). + # For attn_impl: torch the past key tensor spec is (batch, heads, head_dim, seq). + # Here we shift position embedding using the `seq` dim of the past key + past_position = past_key_values[0][0].size(1) + if self.attn_impl == 'torch': + past_position = past_key_values[0][0].size(3) + + if self.learned_pos_emb or self.rope: if self.learned_pos_emb and (S + past_position > self.config.max_seq_len): raise ValueError( @@ -660,10 +665,8 @@ def forward( all_self_attns = () if output_attentions else None flash_attn_padding_info = {} if self.attn_impl == 'flash': - past_key_len = past_key_values[0].shape[ - 1] if past_key_values is not None else 0 - flash_attn_padding_info = get_flash_attn_padding_info( - past_key_len, attention_mask_in_length, attention_mask) + flash_attn_padding_info = gen_flash_attn_padding_info( + bsz, S, past_position, attention_mask_in_length, attention_mask) for b_idx, block in enumerate(self.blocks): if output_hidden_states: diff --git a/tests/models/layers/test_flash_attn.py b/tests/models/layers/test_flash_attn.py index 3e1ec37b2e..51265bc8b8 100644 --- a/tests/models/layers/test_flash_attn.py +++ b/tests/models/layers/test_flash_attn.py @@ -12,6 +12,7 @@ flash_attn_fn, gen_slopes, is_flash_v2_installed, triton_flash_attn_fn) +from llmfoundry.models.mpt.modeling_mpt import gen_flash_attn_padding_info @pytest.mark.gpu @@ -35,22 +36,24 @@ def test_gqa_kv_repetition(kv_n_heads: int): kv_n_heads * d).to(torch.bfloat16).cuda() value_1.requires_grad = True - output_1, _, _ = flash_attn_fn(query=query_1, - key=key_1, - value=value_1, - n_heads=n_heads, - kv_n_heads=kv_n_heads, - past_key_value=None, - softmax_scale=1 / math.sqrt(d), - attn_bias=None, - key_padding_mask=None, - is_causal=True, - dropout_p=0.0, - training=False, - needs_weights=False, - multiquery=False, - attention_mask_in_length=None, - should_repeat_kv_for_gqa=True) + output_1, _, _ = flash_attn_fn( + query=query_1, + key=key_1, + value=value_1, + n_heads=n_heads, + kv_n_heads=kv_n_heads, + past_key_value=None, + softmax_scale=1 / math.sqrt(d), + attn_bias=None, + key_padding_mask=None, + is_causal=True, + dropout_p=0.0, + training=False, + needs_weights=False, + multiquery=False, + flash_attn_padding_info=gen_flash_attn_padding_info( + bsz, seqlen_1, 0, None, None), + should_repeat_kv_for_gqa=True) output_1.sum().backward() @@ -61,22 +64,24 @@ def test_gqa_kv_repetition(kv_n_heads: int): value_2 = value_1.detach().clone() value_2.requires_grad = True - output_2, _, _ = flash_attn_fn(query=query_2, - key=key_2, - value=value_2, - n_heads=n_heads, - kv_n_heads=kv_n_heads, - past_key_value=None, - softmax_scale=1 / math.sqrt(d), - attn_bias=None, - key_padding_mask=None, - is_causal=True, - dropout_p=0.0, - training=False, - needs_weights=False, - multiquery=False, - attention_mask_in_length=None, - should_repeat_kv_for_gqa=False) + output_2, _, _ = flash_attn_fn( + query=query_2, + key=key_2, + value=value_2, + n_heads=n_heads, + kv_n_heads=kv_n_heads, + past_key_value=None, + softmax_scale=1 / math.sqrt(d), + attn_bias=None, + key_padding_mask=None, + is_causal=True, + dropout_p=0.0, + training=False, + needs_weights=False, + multiquery=False, + flash_attn_padding_info=gen_flash_attn_padding_info( + bsz, seqlen_1, 0, None, None), + should_repeat_kv_for_gqa=False) output_2.sum().backward() assert torch.allclose(output_1, output_2) @@ -114,6 +119,9 @@ def test_seq_id_masking_FA_v2(): [3, 2, 1, 0, 0, 0]]).to(torch.int64).cuda() + flash_attn_padding_info_1 = gen_flash_attn_padding_info( + bsz, seqlen_1, 0, attention_mask_in_length_1, None) + output_1, _, _ = flash_attn_fn( query=query_1, key=key_1, @@ -129,7 +137,7 @@ def test_seq_id_masking_FA_v2(): training=False, needs_weights=False, multiquery=False, - attention_mask_in_length=attention_mask_in_length_1) + flash_attn_padding_info=flash_attn_padding_info_1) output_1.sum().backward() @@ -141,21 +149,25 @@ def test_seq_id_masking_FA_v2(): value_2 = value_1.detach().clone()[:, seq_range[0]:seq_range[1], :] value_2.requires_grad = True - output_2, _, _ = flash_attn_fn(query=query_2, - key=key_2, - value=value_2, - n_heads=n_heads, - kv_n_heads=kv_n_heads, - past_key_value=None, - softmax_scale=1 / math.sqrt(d), - attn_bias=None, - key_padding_mask=None, - is_causal=True, - dropout_p=0.0, - training=False, - needs_weights=False, - multiquery=False, - attention_mask_in_length=None) + flash_attn_padding_info_2 = gen_flash_attn_padding_info( + bsz, seq_range[1] - seq_range[0], 0, None, None) + + output_2, _, _ = flash_attn_fn( + query=query_2, + key=key_2, + value=value_2, + n_heads=n_heads, + kv_n_heads=kv_n_heads, + past_key_value=None, + softmax_scale=1 / math.sqrt(d), + attn_bias=None, + key_padding_mask=None, + is_causal=True, + dropout_p=0.0, + training=False, + needs_weights=False, + multiquery=False, + flash_attn_padding_info=flash_attn_padding_info_2) output_2.sum().backward() assert torch.allclose(output_1[:, seq_range[0]:seq_range[1], :], @@ -196,23 +208,25 @@ def test_sliding_window(sliding_window_size: int): device=device) value_1.requires_grad = True - output_1, _, _ = flash_attn_fn(query=query_1, - key=key_1, - value=value_1, - n_heads=n_heads, - kv_n_heads=n_heads, - past_key_value=None, - softmax_scale=1 / math.sqrt(d), - attn_bias=None, - key_padding_mask=None, - is_causal=True, - dropout_p=0.0, - training=False, - needs_weights=False, - multiquery=False, - attention_mask_in_length=None, - should_repeat_kv_for_gqa=True, - sliding_window_size=sliding_window_size) + output_1, _, _ = flash_attn_fn( + query=query_1, + key=key_1, + value=value_1, + n_heads=n_heads, + kv_n_heads=n_heads, + past_key_value=None, + softmax_scale=1 / math.sqrt(d), + attn_bias=None, + key_padding_mask=None, + is_causal=True, + dropout_p=0.0, + training=False, + needs_weights=False, + multiquery=False, + flash_attn_padding_info=gen_flash_attn_padding_info( + bsz, seqlen_1, 0, None, None), + should_repeat_kv_for_gqa=True, + sliding_window_size=sliding_window_size) output_1.sum().backward() @@ -284,23 +298,25 @@ def test_alibi_bias(n_heads: int): alibi_bias_max=8, device=torch.device(device), return_1d=True) - output_1, _, _ = flash_attn_fn(query=query_1, - key=key_1, - value=value_1, - n_heads=n_heads, - kv_n_heads=n_heads, - past_key_value=None, - softmax_scale=1 / math.sqrt(d), - attn_bias=None, - key_padding_mask=None, - is_causal=True, - dropout_p=0.0, - training=False, - needs_weights=False, - multiquery=False, - attention_mask_in_length=None, - should_repeat_kv_for_gqa=True, - alibi_slopes=alibi_slopes_1) + output_1, _, _ = flash_attn_fn( + query=query_1, + key=key_1, + value=value_1, + n_heads=n_heads, + kv_n_heads=n_heads, + past_key_value=None, + softmax_scale=1 / math.sqrt(d), + attn_bias=None, + key_padding_mask=None, + is_causal=True, + dropout_p=0.0, + training=False, + needs_weights=False, + multiquery=False, + flash_attn_padding_info=gen_flash_attn_padding_info( + bsz, seqlen_1, 0, None, None), + should_repeat_kv_for_gqa=True, + alibi_slopes=alibi_slopes_1) output_1.sum().backward() diff --git a/tests/models/layers/test_flash_triton_torch.py b/tests/models/layers/test_flash_triton_torch.py index 4ca5c7b668..e6b31c0d6f 100644 --- a/tests/models/layers/test_flash_triton_torch.py +++ b/tests/models/layers/test_flash_triton_torch.py @@ -10,6 +10,7 @@ is_flash_v2_installed) from llmfoundry.models.mpt.modeling_mpt import (apply_sequence_id, gen_attention_mask_in_length, + gen_flash_attn_padding_info, gen_rotary_embedding) @@ -164,6 +165,12 @@ def gen_bias(attn_impl: str): attn_uses_sequence_id=attn_uses_sequence_id, attn_impl=attn_impl_0, attention_mask=attention_mask) + + flash_attn_padding_info_0 = {} + if attn_impl_0 == 'flash': + flash_attn_padding_info_0 = gen_flash_attn_padding_info( + n, s, 0, attention_mask_in_length_0, attention_mask) + attention_mask_in_length_1 = gen_attention_mask_in_length( sequence_id=sequence_id, S=s, @@ -171,6 +178,11 @@ def gen_bias(attn_impl: str): attn_impl=attn_impl_1, attention_mask=attention_mask) + flash_attn_padding_info_1 = {} + if attn_impl_1 == 'flash': + flash_attn_padding_info_1 = gen_flash_attn_padding_info( + n, s, 0, attention_mask_in_length_1, attention_mask) + x0 = torch.randn(n, s, f).to(device) x1 = x0.clone().detach() x0.requires_grad = True @@ -216,7 +228,7 @@ def gen_bias(attn_impl: str): attention_mask=attention_mask, rotary_emb_w_meta_info=rotary_emb_w_meta_info, is_causal=True, - attention_mask_in_length=attention_mask_in_length_0, + flash_attn_padding_info=flash_attn_padding_info_0, alibi_slopes=alibi_slopes_0) attn_bias_1 = gen_bias(attn_impl_1) alibi_slopes_1 = None @@ -231,7 +243,7 @@ def gen_bias(attn_impl: str): attention_mask=attention_mask, rotary_emb_w_meta_info=rotary_emb_w_meta_info, is_causal=True, - attention_mask_in_length=attention_mask_in_length_1, + flash_attn_padding_info=flash_attn_padding_info_1, alibi_slopes=alibi_slopes_1) y0 *= attention_mask.unsqueeze(-1) y1 *= attention_mask.unsqueeze(-1)