Skip to content

Commit

Permalink
adding todos
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML committed Dec 8, 2024
1 parent dfde51b commit d1d04ce
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions tests/models/layers/test_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def test_seq_id_masking_FA_v2(attn_impl: str):
pytest.skip(
'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.',
)
d = 128 # Compiled FlexAttention works for d=16 with seqlen=6, but not for d=128 with seqlen=6. For seqlen=128, all d's in [16, 32, 64, 128, 256] work. Probably because this is not yet fixed: https://pytorch.org/blog/flexattention/#limitations-and-future-work
d = 128 # TODO: Compiled FlexAttention works for d=16 with seqlen=6, but not for d=128 with seqlen=6. For seqlen=128, all d's in [16, 32, 64, 128, 256] work. Probably because this is not yet fixed: https://pytorch.org/blog/flexattention/#limitations-and-future-work
n_heads = 4
kv_n_heads = 4
seqlen_1 = 128
Expand Down Expand Up @@ -311,7 +311,7 @@ def test_alibi_bias(attn_impl: str, n_heads: int):
dtype = torch.bfloat16
device = 'cuda'
d = 128
seqlen_1 = 6 if attn_impl != 'flex' else 128 # FlexAttention requires seqlen to be a multiple of 128 (to compute gradients I think). More info: https://pytorch.org/blog/flexattention/#limitations-and-future-work
seqlen_1 = 6 if attn_impl != 'flex' else 128 # TODO: FlexAttention requires seqlen to be a multiple of 128 (to compute gradients I think). More info: https://pytorch.org/blog/flexattention/#limitations-and-future-work
bsz = 2

query_1 = torch.randn(bsz, seqlen_1,
Expand Down

0 comments on commit d1d04ce

Please sign in to comment.