diff --git a/tests/models/layers/test_flash_attn.py b/tests/models/layers/test_flash_attn.py index 91d881029b..abab741a94 100644 --- a/tests/models/layers/test_flash_attn.py +++ b/tests/models/layers/test_flash_attn.py @@ -167,7 +167,7 @@ def test_seq_id_masking_FA_v2(attn_impl: str): pytest.skip( 'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.', ) - d = 128 # Compiled FlexAttention works for d=16 with seqlen=6, but not for d=128 with seqlen=6. For seqlen=128, all d's in [16, 32, 64, 128, 256] work. Probably because this is not yet fixed: https://pytorch.org/blog/flexattention/#limitations-and-future-work + d = 128 # TODO: Compiled FlexAttention works for d=16 with seqlen=6, but not for d=128 with seqlen=6. For seqlen=128, all d's in [16, 32, 64, 128, 256] work. Probably because this is not yet fixed: https://pytorch.org/blog/flexattention/#limitations-and-future-work n_heads = 4 kv_n_heads = 4 seqlen_1 = 128 @@ -311,7 +311,7 @@ def test_alibi_bias(attn_impl: str, n_heads: int): dtype = torch.bfloat16 device = 'cuda' d = 128 - seqlen_1 = 6 if attn_impl != 'flex' else 128 # FlexAttention requires seqlen to be a multiple of 128 (to compute gradients I think). More info: https://pytorch.org/blog/flexattention/#limitations-and-future-work + seqlen_1 = 6 if attn_impl != 'flex' else 128 # TODO: FlexAttention requires seqlen to be a multiple of 128 (to compute gradients I think). More info: https://pytorch.org/blog/flexattention/#limitations-and-future-work bsz = 2 query_1 = torch.randn(bsz, seqlen_1,