diff --git a/tests/models/layers/test_attention.py b/tests/models/layers/test_attention.py index c83e0725b8..bd09d3083c 100644 --- a/tests/models/layers/test_attention.py +++ b/tests/models/layers/test_attention.py @@ -191,7 +191,7 @@ def test_sliding_window(sliding_window_size: int, attn_impl: str): device = 'cuda' d = 128 n_heads = 8 - seqlen_1 = 8 + seqlen_1 = 8 if attn_impl != 'flex' else 128 # FlexAttention requires seqlen to be a multiple of 128 (to compute gradients I think). More info: https://pytorch.org/blog/flexattention/#limitations-and-future-work bsz = 2 query_1 = torch.randn(bsz, seqlen_1,