diff --git a/llmfoundry/models/layers/__init__.py b/llmfoundry/models/layers/__init__.py index 68aa0fe7fe..05350b059b 100644 --- a/llmfoundry/models/layers/__init__.py +++ b/llmfoundry/models/layers/__init__.py @@ -2,9 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 from llmfoundry.models.layers.attention import ( - ATTN_CLASS_REGISTRY, MultiheadAttention, MultiQueryAttention, - attn_bias_shape, build_alibi_bias, build_attn_bias, flash_attn_fn, - scaled_multihead_dot_product_attention, triton_flash_attn_fn) + ATTN_CLASS_REGISTRY, GroupedQueryAttention, MultiheadAttention, + MultiQueryAttention, attn_bias_shape, build_alibi_bias, build_attn_bias, + flash_attn_fn, scaled_multihead_dot_product_attention, triton_flash_attn_fn) from llmfoundry.models.layers.blocks import MPTBlock from llmfoundry.models.layers.custom_embedding import SharedEmbedding from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY @@ -17,6 +17,7 @@ 'triton_flash_attn_fn', 'MultiheadAttention', 'MultiQueryAttention', + 'GroupedQueryAttention', 'attn_bias_shape', 'build_attn_bias', 'build_alibi_bias', diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 0fb6c0a042..fecd79553f 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -228,12 +228,17 @@ def flash_attn_fn( training: bool = False, needs_weights: bool = False, multiquery: bool = False, - attention_mask_in_length: Optional[torch.Tensor] = None, should_repeat_kv_for_gqa: Optional[bool] = True, sliding_window_size: int = -1, alibi_slopes: Optional[torch.Tensor] = None, + flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor, torch.Tensor]]]: + if key_padding_mask is not None: + raise ValueError('key_padding_mask should be None for flash attn.') + del key_padding_mask + if flash_attn_padding_info is None: + raise ValueError('flash_attn_padding_info is required for flash attn.') try: from flash_attn import bert_padding, flash_attn_interface # type: ignore # yapf: disable # isort: skip except: @@ -267,25 +272,24 @@ def flash_attn_fn( batch_size, seqlen = query.shape[:2] - if attention_mask_in_length is None: - if key_padding_mask is None: - key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool) - query_padding_mask = key_padding_mask[:, -query.size(1):] - unpadding_function = bert_padding.unpad_input - else: - key_padding_mask = attention_mask_in_length - query_padding_mask = attention_mask_in_length - unpadding_function = bert_padding.unpad_input_for_concatenated_sequences + indices_q = flash_attn_padding_info['indices_q'] + indices_k = flash_attn_padding_info['indices_k'] + indices_v = flash_attn_padding_info['indices_v'] + cu_seqlens_q = flash_attn_padding_info['cu_seqlens_q'] + cu_seqlens_k = flash_attn_padding_info['cu_seqlens_k'] + max_seqlen_q = flash_attn_padding_info['max_seqlen_q'] + max_seqlen_k = flash_attn_padding_info['max_seqlen_k'] - query_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpadding_function( - query, query_padding_mask) + query_unpad = bert_padding.index_first_axis( + rearrange(query, 'b s ... -> (b s) ...'), indices_q) query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads) - key_unpad, _, cu_seqlens_k, max_seqlen_k = unpadding_function( - key, key_padding_mask) + key_unpad = bert_padding.index_first_axis( + rearrange(key, 'b s ... -> (b s) ...'), indices_k) key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads) - value_unpad, _, _, _ = unpadding_function(value, key_padding_mask) + value_unpad = bert_padding.index_first_axis( + rearrange(value, 'b s ... -> (b s) ...'), indices_v) value_unpad = rearrange(value_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads) if (kv_n_heads < n_heads) and (not is_flash_v2_installed()) and ( @@ -599,8 +603,8 @@ def forward( rotary_emb_w_meta_info: Optional[dict] = None, is_causal: bool = True, needs_weights: bool = False, - attention_mask_in_length: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, + flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[ torch.Tensor, torch.Tensor]]]: qkv = self.Wqkv(x) @@ -666,11 +670,12 @@ def forward( extra_attn_kwargs = {} if self.attn_impl == 'flash': + key_padding_mask = None extra_attn_kwargs = { - 'attention_mask_in_length': attention_mask_in_length, 'should_repeat_kv_for_gqa': not is_flash_v2_installed(), 'sliding_window_size': self.sliding_window_size, 'alibi_slopes': alibi_slopes, + 'flash_attn_padding_info': flash_attn_padding_info, } context, attn_weights, past_key_value = self.attn_fn( diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index e5032998dc..036a4e7cd2 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -122,8 +122,8 @@ def forward( attention_mask: Optional[torch.ByteTensor] = None, is_causal: bool = True, output_attentions: bool = False, - attention_mask_in_length: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, + flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[ torch.Tensor, torch.Tensor]]]: a = self.norm_1(x) @@ -135,8 +135,8 @@ def forward( attention_mask=attention_mask, is_causal=is_causal, needs_weights=output_attentions, - attention_mask_in_length=attention_mask_in_length, alibi_slopes=alibi_slopes, + flash_attn_padding_info=flash_attn_padding_info, ) x = x + self.resid_attn_dropout(b) m = x diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 067b9435ee..31ea5ba71b 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -25,15 +25,23 @@ from composer.utils import dist from llmfoundry.eval.metrics.nlp import InContextLearningGenerationF1Score -from llmfoundry.models.layers.attention import is_flash_v2_installed +from llmfoundry.models.layers.attention import (is_flash_v1_installed, + is_flash_v2_installed) if is_flash_v2_installed(): try: # This try...except is needed because transformers requires it despite the 'if' statement above + from flash_attn import bert_padding from flash_attn.layers.rotary import \ RotaryEmbedding as DAILRotaryEmbedding except Exception as e: raise e +if is_flash_v1_installed(): + try: # This try...except is needed because transformers requires it despite the 'if' statement above + from flash_attn import bert_padding + except Exception as e: + raise e + from omegaconf import DictConfig from omegaconf import OmegaConf as om from transformers import PreTrainedModel, PreTrainedTokenizerBase @@ -217,6 +225,44 @@ def gen_attention_mask_in_length(sequence_id: Union[None, torch.Tensor], S: int, return attention_mask_in_length +def gen_flash_attn_padding_info( + bsz: int, + S: int, + past_key_len: int, + device: torch.device, + attention_mask_in_length: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None): + flash_attn_padding_info = {} + if attention_mask_in_length is None: + key_padding_mask = attention_mask + if key_padding_mask is None: + key_padding_mask = torch.ones((bsz, past_key_len + S), + dtype=torch.bool, + device=device) + query_padding_mask = key_padding_mask[:, -S:] + unpadding_function = bert_padding.unpad_input + else: + key_padding_mask = attention_mask_in_length + query_padding_mask = attention_mask_in_length + unpadding_function = bert_padding.unpad_input_for_concatenated_sequences + + _, indices_q, cu_seqlens_q, max_seqlen_q = unpadding_function( + torch.empty(bsz, S, 1, device=device), query_padding_mask) + _, indices_k, cu_seqlens_k, max_seqlen_k = unpadding_function( + torch.empty(bsz, past_key_len + S, 1, device=device), key_padding_mask) + _, indices_v, _, _ = unpadding_function( + torch.empty(bsz, past_key_len + S, 1, device=device), key_padding_mask) + + flash_attn_padding_info['indices_q'] = indices_q + flash_attn_padding_info['indices_k'] = indices_k + flash_attn_padding_info['indices_v'] = indices_v + flash_attn_padding_info['cu_seqlens_q'] = cu_seqlens_q + flash_attn_padding_info['cu_seqlens_k'] = cu_seqlens_k + flash_attn_padding_info['max_seqlen_q'] = max_seqlen_q + flash_attn_padding_info['max_seqlen_k'] = max_seqlen_k + return flash_attn_padding_info + + def apply_sequence_id(attn_bias: torch.Tensor, sequence_id: torch.LongTensor, max_seq_len: int) -> torch.Tensor: seq_len = sequence_id.shape[-1] @@ -516,10 +562,12 @@ def forward( raise ValueError( 'You cannot specify both input_ids and inputs_embeds.') elif input_ids is not None: + bsz = input_ids.size(0) S = input_ids.size(1) x = self.wte(input_ids) input_device = input_ids.device elif inputs_embeds is not None: + bsz = inputs_embeds.size(0) S = inputs_embeds.size(1) x = inputs_embeds input_device = inputs_embeds.device @@ -531,22 +579,23 @@ def forward( ), f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}' rotary_emb_w_meta_info = None - if self.learned_pos_emb or self.rope: - past_position = 0 - if past_key_values is not None: - if len(past_key_values) != self.config.n_layers: - raise ValueError( - f'past_key_values must provide a past_key_value for each attention ' - + - f'layer in the network ({len(past_key_values)=}; {self.config.n_layers=}).' - ) - # For attn_impl: triton and flash the past key tensor spec is (batch, seq, dim). - # For attn_impl: torch the past key tensor spec is (batch, heads, head_dim, seq). - # Here we shift position embedding using the `seq` dim of the past key - past_position = past_key_values[0][0].size(1) - if self.attn_impl == 'torch': - past_position = past_key_values[0][0].size(3) + past_position = 0 + if past_key_values is not None: + if len(past_key_values) != self.config.n_layers: + raise ValueError( + f'past_key_values must provide a past_key_value for each attention ' + + + f'layer in the network ({len(past_key_values)=}; {self.config.n_layers=}).' + ) + # For attn_impl: triton and flash the past key tensor spec is (batch, seq, dim). + # For attn_impl: torch the past key tensor spec is (batch, heads, head_dim, seq). + # Here we shift position embedding using the `seq` dim of the past key + past_position = past_key_values[0][0].size(1) + if self.attn_impl == 'torch': + past_position = past_key_values[0][0].size(3) + + if self.learned_pos_emb or self.rope: if self.learned_pos_emb and (S + past_position > self.config.max_seq_len): raise ValueError( @@ -624,6 +673,12 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None + flash_attn_padding_info = {} + if self.attn_impl == 'flash': + flash_attn_padding_info = gen_flash_attn_padding_info( + bsz, S, past_position, x.device, attention_mask_in_length, + attention_mask) + for b_idx, block in enumerate(self.blocks): if output_hidden_states: assert all_hidden_states is not None # pyright @@ -638,8 +693,8 @@ def forward( attention_mask=attention_mask, is_causal=self.is_causal, output_attentions=bool(output_attentions), - attention_mask_in_length=attention_mask_in_length, alibi_slopes=alibi_slopes, + flash_attn_padding_info=flash_attn_padding_info, ) if presents is not None: presents += (present,) diff --git a/scripts/data_prep/convert_delta_to_json.py b/scripts/data_prep/convert_delta_to_json.py index 326b8e912f..9d73cb0e9f 100644 --- a/scripts/data_prep/convert_delta_to_json.py +++ b/scripts/data_prep/convert_delta_to_json.py @@ -407,7 +407,7 @@ def validate_and_get_cluster_info(cluster_id: str, stripped_runtime = re.sub( r'[a-zA-Z]', '', res.spark_version.split('-scala')[0].replace('x-snapshot', '')) - runtime_version = re.sub(r'.-+$', '', stripped_runtime) + runtime_version = re.sub(r'[.-]*$', '', stripped_runtime) if version.parse(runtime_version) < version.parse( MINIMUM_SQ_CONNECT_DBR_VERSION): raise ValueError( diff --git a/tests/models/layers/test_flash_attn.py b/tests/models/layers/test_flash_attn.py index 3e1ec37b2e..9471cdac6a 100644 --- a/tests/models/layers/test_flash_attn.py +++ b/tests/models/layers/test_flash_attn.py @@ -12,6 +12,7 @@ flash_attn_fn, gen_slopes, is_flash_v2_installed, triton_flash_attn_fn) +from llmfoundry.models.mpt.modeling_mpt import gen_flash_attn_padding_info @pytest.mark.gpu @@ -35,22 +36,24 @@ def test_gqa_kv_repetition(kv_n_heads: int): kv_n_heads * d).to(torch.bfloat16).cuda() value_1.requires_grad = True - output_1, _, _ = flash_attn_fn(query=query_1, - key=key_1, - value=value_1, - n_heads=n_heads, - kv_n_heads=kv_n_heads, - past_key_value=None, - softmax_scale=1 / math.sqrt(d), - attn_bias=None, - key_padding_mask=None, - is_causal=True, - dropout_p=0.0, - training=False, - needs_weights=False, - multiquery=False, - attention_mask_in_length=None, - should_repeat_kv_for_gqa=True) + output_1, _, _ = flash_attn_fn( + query=query_1, + key=key_1, + value=value_1, + n_heads=n_heads, + kv_n_heads=kv_n_heads, + past_key_value=None, + softmax_scale=1 / math.sqrt(d), + attn_bias=None, + key_padding_mask=None, + is_causal=True, + dropout_p=0.0, + training=False, + needs_weights=False, + multiquery=False, + flash_attn_padding_info=gen_flash_attn_padding_info( + bsz, seqlen_1, 0, query_1.device, None, None), + should_repeat_kv_for_gqa=True) output_1.sum().backward() @@ -61,22 +64,24 @@ def test_gqa_kv_repetition(kv_n_heads: int): value_2 = value_1.detach().clone() value_2.requires_grad = True - output_2, _, _ = flash_attn_fn(query=query_2, - key=key_2, - value=value_2, - n_heads=n_heads, - kv_n_heads=kv_n_heads, - past_key_value=None, - softmax_scale=1 / math.sqrt(d), - attn_bias=None, - key_padding_mask=None, - is_causal=True, - dropout_p=0.0, - training=False, - needs_weights=False, - multiquery=False, - attention_mask_in_length=None, - should_repeat_kv_for_gqa=False) + output_2, _, _ = flash_attn_fn( + query=query_2, + key=key_2, + value=value_2, + n_heads=n_heads, + kv_n_heads=kv_n_heads, + past_key_value=None, + softmax_scale=1 / math.sqrt(d), + attn_bias=None, + key_padding_mask=None, + is_causal=True, + dropout_p=0.0, + training=False, + needs_weights=False, + multiquery=False, + flash_attn_padding_info=gen_flash_attn_padding_info( + bsz, seqlen_1, 0, query_2.device, None, None), + should_repeat_kv_for_gqa=False) output_2.sum().backward() assert torch.allclose(output_1, output_2) @@ -114,6 +119,9 @@ def test_seq_id_masking_FA_v2(): [3, 2, 1, 0, 0, 0]]).to(torch.int64).cuda() + flash_attn_padding_info_1 = gen_flash_attn_padding_info( + bsz, seqlen_1, 0, query_1.device, attention_mask_in_length_1, None) + output_1, _, _ = flash_attn_fn( query=query_1, key=key_1, @@ -129,7 +137,7 @@ def test_seq_id_masking_FA_v2(): training=False, needs_weights=False, multiquery=False, - attention_mask_in_length=attention_mask_in_length_1) + flash_attn_padding_info=flash_attn_padding_info_1) output_1.sum().backward() @@ -141,21 +149,25 @@ def test_seq_id_masking_FA_v2(): value_2 = value_1.detach().clone()[:, seq_range[0]:seq_range[1], :] value_2.requires_grad = True - output_2, _, _ = flash_attn_fn(query=query_2, - key=key_2, - value=value_2, - n_heads=n_heads, - kv_n_heads=kv_n_heads, - past_key_value=None, - softmax_scale=1 / math.sqrt(d), - attn_bias=None, - key_padding_mask=None, - is_causal=True, - dropout_p=0.0, - training=False, - needs_weights=False, - multiquery=False, - attention_mask_in_length=None) + flash_attn_padding_info_2 = gen_flash_attn_padding_info( + bsz, seq_range[1] - seq_range[0], 0, query_2.device, None, None) + + output_2, _, _ = flash_attn_fn( + query=query_2, + key=key_2, + value=value_2, + n_heads=n_heads, + kv_n_heads=kv_n_heads, + past_key_value=None, + softmax_scale=1 / math.sqrt(d), + attn_bias=None, + key_padding_mask=None, + is_causal=True, + dropout_p=0.0, + training=False, + needs_weights=False, + multiquery=False, + flash_attn_padding_info=flash_attn_padding_info_2) output_2.sum().backward() assert torch.allclose(output_1[:, seq_range[0]:seq_range[1], :], @@ -196,23 +208,25 @@ def test_sliding_window(sliding_window_size: int): device=device) value_1.requires_grad = True - output_1, _, _ = flash_attn_fn(query=query_1, - key=key_1, - value=value_1, - n_heads=n_heads, - kv_n_heads=n_heads, - past_key_value=None, - softmax_scale=1 / math.sqrt(d), - attn_bias=None, - key_padding_mask=None, - is_causal=True, - dropout_p=0.0, - training=False, - needs_weights=False, - multiquery=False, - attention_mask_in_length=None, - should_repeat_kv_for_gqa=True, - sliding_window_size=sliding_window_size) + output_1, _, _ = flash_attn_fn( + query=query_1, + key=key_1, + value=value_1, + n_heads=n_heads, + kv_n_heads=n_heads, + past_key_value=None, + softmax_scale=1 / math.sqrt(d), + attn_bias=None, + key_padding_mask=None, + is_causal=True, + dropout_p=0.0, + training=False, + needs_weights=False, + multiquery=False, + flash_attn_padding_info=gen_flash_attn_padding_info( + bsz, seqlen_1, 0, query_1.device, None, None), + should_repeat_kv_for_gqa=True, + sliding_window_size=sliding_window_size) output_1.sum().backward() @@ -284,23 +298,25 @@ def test_alibi_bias(n_heads: int): alibi_bias_max=8, device=torch.device(device), return_1d=True) - output_1, _, _ = flash_attn_fn(query=query_1, - key=key_1, - value=value_1, - n_heads=n_heads, - kv_n_heads=n_heads, - past_key_value=None, - softmax_scale=1 / math.sqrt(d), - attn_bias=None, - key_padding_mask=None, - is_causal=True, - dropout_p=0.0, - training=False, - needs_weights=False, - multiquery=False, - attention_mask_in_length=None, - should_repeat_kv_for_gqa=True, - alibi_slopes=alibi_slopes_1) + output_1, _, _ = flash_attn_fn( + query=query_1, + key=key_1, + value=value_1, + n_heads=n_heads, + kv_n_heads=n_heads, + past_key_value=None, + softmax_scale=1 / math.sqrt(d), + attn_bias=None, + key_padding_mask=None, + is_causal=True, + dropout_p=0.0, + training=False, + needs_weights=False, + multiquery=False, + flash_attn_padding_info=gen_flash_attn_padding_info( + bsz, seqlen_1, 0, query_1.device, None, None), + should_repeat_kv_for_gqa=True, + alibi_slopes=alibi_slopes_1) output_1.sum().backward() diff --git a/tests/models/layers/test_flash_triton_torch.py b/tests/models/layers/test_flash_triton_torch.py index 4ca5c7b668..2f992cd92f 100644 --- a/tests/models/layers/test_flash_triton_torch.py +++ b/tests/models/layers/test_flash_triton_torch.py @@ -10,6 +10,7 @@ is_flash_v2_installed) from llmfoundry.models.mpt.modeling_mpt import (apply_sequence_id, gen_attention_mask_in_length, + gen_flash_attn_padding_info, gen_rotary_embedding) @@ -164,6 +165,13 @@ def gen_bias(attn_impl: str): attn_uses_sequence_id=attn_uses_sequence_id, attn_impl=attn_impl_0, attention_mask=attention_mask) + + flash_attn_padding_info_0 = {} + if attn_impl_0 == 'flash': + flash_attn_padding_info_0 = gen_flash_attn_padding_info( + n, s, 0, torch.device(device), attention_mask_in_length_0, + attention_mask) + attention_mask_in_length_1 = gen_attention_mask_in_length( sequence_id=sequence_id, S=s, @@ -171,6 +179,12 @@ def gen_bias(attn_impl: str): attn_impl=attn_impl_1, attention_mask=attention_mask) + flash_attn_padding_info_1 = {} + if attn_impl_1 == 'flash': + flash_attn_padding_info_1 = gen_flash_attn_padding_info( + n, s, 0, torch.device(device), attention_mask_in_length_1, + attention_mask) + x0 = torch.randn(n, s, f).to(device) x1 = x0.clone().detach() x0.requires_grad = True @@ -216,7 +230,7 @@ def gen_bias(attn_impl: str): attention_mask=attention_mask, rotary_emb_w_meta_info=rotary_emb_w_meta_info, is_causal=True, - attention_mask_in_length=attention_mask_in_length_0, + flash_attn_padding_info=flash_attn_padding_info_0, alibi_slopes=alibi_slopes_0) attn_bias_1 = gen_bias(attn_impl_1) alibi_slopes_1 = None @@ -231,7 +245,7 @@ def gen_bias(attn_impl: str): attention_mask=attention_mask, rotary_emb_w_meta_info=rotary_emb_w_meta_info, is_causal=True, - attention_mask_in_length=attention_mask_in_length_1, + flash_attn_padding_info=flash_attn_padding_info_1, alibi_slopes=alibi_slopes_1) y0 *= attention_mask.unsqueeze(-1) y1 *= attention_mask.unsqueeze(-1) @@ -313,11 +327,16 @@ def gen_tca_mask(): x1.requires_grad = True with torch.autocast(x0.device.type): + flash_attn_padding_info = None + if attn_impl == 'flash': + flash_attn_padding_info = gen_flash_attn_padding_info( + n, s, 0, torch.device(device), None, attention_mask) y0, _, _ = mmhsa(x0, past_key_value=None, attn_bias=None, attention_mask=attention_mask, - is_causal=True) + is_causal=True, + flash_attn_padding_info=flash_attn_padding_info) y1, _ = tmhsa(x1, x1, x1, @@ -387,11 +406,16 @@ def test_grouped_attention_heads(attn_impl: str, x0.requires_grad = True with torch.autocast(x0.device.type): + flash_attn_padding_info = None + if attn_impl == 'flash': + flash_attn_padding_info = gen_flash_attn_padding_info( + n, s, 0, torch.device(device), None, attention_mask) y0, _, _ = mmhsa(x0, past_key_value=None, attn_bias=None, attention_mask=attention_mask, - is_causal=True) + is_causal=True, + flash_attn_padding_info=flash_attn_padding_info) y0 *= attention_mask.unsqueeze(-1) loss0 = y0.sum() diff --git a/tests/models/test_rope_dail_vs_hf.py b/tests/models/test_rope_dail_vs_hf.py index 70a00470f9..33c3d3c052 100644 --- a/tests/models/test_rope_dail_vs_hf.py +++ b/tests/models/test_rope_dail_vs_hf.py @@ -7,7 +7,8 @@ from omegaconf import OmegaConf as om from llmfoundry.models.layers.attention import is_flash_v2_installed -from llmfoundry.models.mpt.modeling_mpt import gen_rotary_embedding +from llmfoundry.models.mpt.modeling_mpt import (gen_flash_attn_padding_info, + gen_rotary_embedding) @pytest.mark.gpu @@ -104,14 +105,20 @@ def test_rope_dail_vs_hf(attn_type: str, seq_len: int, device: str = 'cuda'): attn_bias=None, attention_mask=attention_mask, rotary_emb_w_meta_info=dail_rope_w_meta_info, - is_causal=True) + is_causal=True, + flash_attn_padding_info=gen_flash_attn_padding_info( + batch_size, seq_len, 0, torch.device(device), None, + attention_mask)) y1, _, _ = attn1(x1, past_key_value=None, attn_bias=None, attention_mask=attention_mask, rotary_emb_w_meta_info=hf_rope_w_meta_info, - is_causal=True) + is_causal=True, + flash_attn_padding_info=gen_flash_attn_padding_info( + batch_size, seq_len, 0, torch.device(device), None, + attention_mask)) y0 *= attention_mask.unsqueeze(-1) y1 *= attention_mask.unsqueeze(-1)