From e31cb8a63af8eef4172171537da26edd26b9dbba Mon Sep 17 00:00:00 2001 From: Shashank Rajput Date: Mon, 25 Dec 2023 07:20:50 +0000 Subject: [PATCH] .. --- tests/models/layers/test_flash_attn.py | 5 ++--- tests/models/layers/test_flash_triton_torch.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/models/layers/test_flash_attn.py b/tests/models/layers/test_flash_attn.py index 3376bdeb68..319144577b 100644 --- a/tests/models/layers/test_flash_attn.py +++ b/tests/models/layers/test_flash_attn.py @@ -260,8 +260,7 @@ def test_sliding_window(sliding_window_size: int): @pytest.mark.gpu @pytest.mark.skipif( not is_flash_v2_installed(v2_version='v2.4.0.post1'), - reason= - 'ALiBi only supported by Flash Attention after v2.4.0.post1.') + reason='ALiBi only supported by Flash Attention after v2.4.0.post1.') @pytest.mark.parametrize('n_heads', [1, 6, 8]) def test_alibi_bias(n_heads: int): # Test that sliding window attention works as expected. @@ -282,7 +281,7 @@ def test_alibi_bias(n_heads: int): value_1.requires_grad = True attn_bias_1 = gen_alibi_slopes(n_heads=n_heads, alibi_bias_max=8, - device=device) + device=torch.device(device)) output_1, _, _ = flash_attn_fn(query=query_1, key=key_1, value=value_1, diff --git a/tests/models/layers/test_flash_triton_torch.py b/tests/models/layers/test_flash_triton_torch.py index 833a97a4b2..9f3df7c75d 100644 --- a/tests/models/layers/test_flash_triton_torch.py +++ b/tests/models/layers/test_flash_triton_torch.py @@ -161,7 +161,7 @@ def gen_bias(attn_impl: str, if alibi and attn_impl == 'flash': attn_bias = gen_alibi_slopes(n_heads=cfg.n_heads, alibi_bias_max=8, - device=device) + device=torch.device(device)) return attn_bias