Skip to content

Commit

Permalink
..
Browse files Browse the repository at this point in the history
  • Loading branch information
Shashank Rajput committed Dec 25, 2023
1 parent fcb59d4 commit e31cb8a
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 4 deletions.
5 changes: 2 additions & 3 deletions tests/models/layers/test_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,7 @@ def test_sliding_window(sliding_window_size: int):
@pytest.mark.gpu
@pytest.mark.skipif(
not is_flash_v2_installed(v2_version='v2.4.0.post1'),
reason=
'ALiBi only supported by Flash Attention after v2.4.0.post1.')
reason='ALiBi only supported by Flash Attention after v2.4.0.post1.')
@pytest.mark.parametrize('n_heads', [1, 6, 8])
def test_alibi_bias(n_heads: int):
# Test that sliding window attention works as expected.
Expand All @@ -282,7 +281,7 @@ def test_alibi_bias(n_heads: int):
value_1.requires_grad = True
attn_bias_1 = gen_alibi_slopes(n_heads=n_heads,
alibi_bias_max=8,
device=device)
device=torch.device(device))
output_1, _, _ = flash_attn_fn(query=query_1,
key=key_1,
value=value_1,
Expand Down
2 changes: 1 addition & 1 deletion tests/models/layers/test_flash_triton_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def gen_bias(attn_impl: str,
if alibi and attn_impl == 'flash':
attn_bias = gen_alibi_slopes(n_heads=cfg.n_heads,
alibi_bias_max=8,
device=device)
device=torch.device(device))

return attn_bias

Expand Down

0 comments on commit e31cb8a

Please sign in to comment.