diff --git a/tests/models/test_rope_dail_vs_hf.py b/tests/models/test_rope_dail_vs_hf.py index 70a00470f9..33c3d3c052 100644 --- a/tests/models/test_rope_dail_vs_hf.py +++ b/tests/models/test_rope_dail_vs_hf.py @@ -7,7 +7,8 @@ from omegaconf import OmegaConf as om from llmfoundry.models.layers.attention import is_flash_v2_installed -from llmfoundry.models.mpt.modeling_mpt import gen_rotary_embedding +from llmfoundry.models.mpt.modeling_mpt import (gen_flash_attn_padding_info, + gen_rotary_embedding) @pytest.mark.gpu @@ -104,14 +105,20 @@ def test_rope_dail_vs_hf(attn_type: str, seq_len: int, device: str = 'cuda'): attn_bias=None, attention_mask=attention_mask, rotary_emb_w_meta_info=dail_rope_w_meta_info, - is_causal=True) + is_causal=True, + flash_attn_padding_info=gen_flash_attn_padding_info( + batch_size, seq_len, 0, torch.device(device), None, + attention_mask)) y1, _, _ = attn1(x1, past_key_value=None, attn_bias=None, attention_mask=attention_mask, rotary_emb_w_meta_info=hf_rope_w_meta_info, - is_causal=True) + is_causal=True, + flash_attn_padding_info=gen_flash_attn_padding_info( + batch_size, seq_len, 0, torch.device(device), None, + attention_mask)) y0 *= attention_mask.unsqueeze(-1) y1 *= attention_mask.unsqueeze(-1)