From 34c1f82eb309209ea782304ecdf09faabaa266c5 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Thu, 12 Dec 2024 11:43:39 +0100 Subject: [PATCH] test with transformers pipeline --- tests/onnxruntime/test_modeling.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index ef93609353..98537a27c5 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -115,6 +115,7 @@ require_hf_token, require_ort_rocm, ) +from transformers import pipeline as transformers_pipeline if is_diffusers_available(): @@ -1747,7 +1748,7 @@ def test_pipeline_ort_model(self, model_arch): @pytest.mark.run_in_series def test_pipeline_model_is_none(self): - pipe = pipeline("text-classification") + pipe = transformers_pipeline("text-classification") text = "My Name is Philipp and i live in Germany." outputs = pipe(text) @@ -2553,7 +2554,7 @@ def test_pipeline_ort_model(self, test_name: str, model_arch: str, use_cache: bo @pytest.mark.run_in_series def test_pipeline_model_is_none(self): - pipe = pipeline("text-generation") + pipe = transformers_pipeline("text-generation") text = "My Name is Philipp and i live" outputs = pipe(text) @@ -3944,20 +3945,20 @@ def test_pipeline_text_generation(self, test_name: str, model_arch: str, use_cac @pytest.mark.run_in_series def test_pipeline_model_is_none(self): # Text2text generation - pipe = pipeline("text2text-generation") + pipe = transformers_pipeline("text2text-generation") text = "This is a test" outputs = pipe(text, min_length=1, max_length=2) # compare model output class self.assertIsInstance(outputs[0]["generated_text"], str) # Summarization - pipe = pipeline("summarization") + pipe = transformers_pipeline("summarization") outputs = pipe(text, min_length=1, max_length=2) # compare model output class self.assertIsInstance(outputs[0]["summary_text"], str) # Translation - pipe = pipeline("translation_en_to_de") + pipe = transformers_pipeline("translation_en_to_de") outputs = pipe(text, min_length=1, max_length=2) # compare model output class self.assertIsInstance(outputs[0]["translation_text"], str)