From c7e48053aab09ad11efa2ad12513e9ab56f29563 Mon Sep 17 00:00:00 2001 From: Fanli Lin Date: Wed, 18 Dec 2024 17:14:22 +0800 Subject: [PATCH] [tests] make cuda-only tests device-agnostic (#35222) fix cuda-only tests --- tests/models/rag/test_modeling_rag.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/models/rag/test_modeling_rag.py b/tests/models/rag/test_modeling_rag.py index 3e3f7b9c457589..b219d5c74edff0 100644 --- a/tests/models/rag/test_modeling_rag.py +++ b/tests/models/rag/test_modeling_rag.py @@ -33,7 +33,7 @@ require_sentencepiece, require_tokenizers, require_torch, - require_torch_non_multi_gpu, + require_torch_non_multi_accelerator, slow, torch_device, ) @@ -678,7 +678,7 @@ def config_and_inputs(self): @require_retrieval @require_sentencepiece @require_tokenizers -@require_torch_non_multi_gpu +@require_torch_non_multi_accelerator class RagModelIntegrationTests(unittest.TestCase): def tearDown(self): super().tearDown() @@ -1002,7 +1002,7 @@ def test_rag_token_generate_batch(self): torch_device ) - if torch_device == "cuda": + if torch_device != "cpu": rag_token.half() input_dict = tokenizer(