From 3351d23ea96a05810dc4b8c769ddd6ea35c7e9bc Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Wed, 17 Jan 2024 22:21:53 +0000 Subject: [PATCH] .. --- tests/models/layers/test_flash_triton_torch.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/tests/models/layers/test_flash_triton_torch.py b/tests/models/layers/test_flash_triton_torch.py index 63d869ee35..2f992cd92f 100644 --- a/tests/models/layers/test_flash_triton_torch.py +++ b/tests/models/layers/test_flash_triton_torch.py @@ -327,11 +327,16 @@ def gen_tca_mask(): x1.requires_grad = True with torch.autocast(x0.device.type): + flash_attn_padding_info = None + if attn_impl == 'flash': + flash_attn_padding_info = gen_flash_attn_padding_info( + n, s, 0, torch.device(device), None, attention_mask) y0, _, _ = mmhsa(x0, past_key_value=None, attn_bias=None, attention_mask=attention_mask, - is_causal=True) + is_causal=True, + flash_attn_padding_info=flash_attn_padding_info) y1, _ = tmhsa(x1, x1, x1, @@ -401,11 +406,16 @@ def test_grouped_attention_heads(attn_impl: str, x0.requires_grad = True with torch.autocast(x0.device.type): + flash_attn_padding_info = None + if attn_impl == 'flash': + flash_attn_padding_info = gen_flash_attn_padding_info( + n, s, 0, torch.device(device), None, attention_mask) y0, _, _ = mmhsa(x0, past_key_value=None, attn_bias=None, attention_mask=attention_mask, - is_causal=True) + is_causal=True, + flash_attn_padding_info=flash_attn_padding_info) y0 *= attention_mask.unsqueeze(-1) loss0 = y0.sum()