diff --git a/tests/test_mpt_gen.py b/tests/test_mpt_gen.py index 773fe6f65b..3c3f5f314f 100644 --- a/tests/test_mpt_gen.py +++ b/tests/test_mpt_gen.py @@ -145,6 +145,7 @@ def test_mpt_generate_callback(attn_impl: str, use_alibi: bool, @pytest.mark.gpu @pytest.mark.parametrize('attn_impl', ['triton', 'torch']) +@pytest.mark.parametrize('use_alibi', [True, False]) def test_mpt_generate_callback_not_tied( attn_impl: str, build_tiny_mpt: Callable[..., ComposerMPTCausalLM], tiny_ft_dataloader: DataLoader):