diff --git a/tests/test_flash_triton_torch.py b/tests/test_flash_triton_torch.py index cdca4e6f7d..e140f678bc 100644 --- a/tests/test_flash_triton_torch.py +++ b/tests/test_flash_triton_torch.py @@ -103,7 +103,7 @@ def test_attn_impl(attn_impl_0: str, if pad_attention_mask: # zero out the last third of the attention mask # to simulate padding - attention_mask[:, (s * 2) // 3:] = 0 + attention_mask[:, -s // 3:] = 0 def gen_bias(attn_impl: str): causal = True