diff --git a/tests/models/layers/test_flash_triton_torch.py b/tests/models/layers/test_flash_triton_torch.py index 4ca5c7b668..2d0e7ad05b 100644 --- a/tests/models/layers/test_flash_triton_torch.py +++ b/tests/models/layers/test_flash_triton_torch.py @@ -28,6 +28,11 @@ def allclose_helper(t0: torch.Tensor, ]) @pytest.mark.parametrize('clip_qkv', [True, False]) @pytest.mark.parametrize('qk_ln', [True, False]) +@pytest.mark.parametrize('qk_ln, qk_gn', [ + (True, False), + (False, True), + (False, False), +]) @pytest.mark.parametrize('pos_emb_config', [{ 'alibi': False, 'rope': False