Skip to content

Commit

Permalink
test with transformers pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Dec 12, 2024
1 parent 633fefc commit 34c1f82
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions tests/onnxruntime/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@
require_hf_token,
require_ort_rocm,
)
from transformers import pipeline as transformers_pipeline


if is_diffusers_available():
Expand Down Expand Up @@ -1747,7 +1748,7 @@ def test_pipeline_ort_model(self, model_arch):

@pytest.mark.run_in_series
def test_pipeline_model_is_none(self):
pipe = pipeline("text-classification")
pipe = transformers_pipeline("text-classification")
text = "My Name is Philipp and i live in Germany."
outputs = pipe(text)

Expand Down Expand Up @@ -2553,7 +2554,7 @@ def test_pipeline_ort_model(self, test_name: str, model_arch: str, use_cache: bo

@pytest.mark.run_in_series
def test_pipeline_model_is_none(self):
pipe = pipeline("text-generation")
pipe = transformers_pipeline("text-generation")
text = "My Name is Philipp and i live"
outputs = pipe(text)

Expand Down Expand Up @@ -3944,20 +3945,20 @@ def test_pipeline_text_generation(self, test_name: str, model_arch: str, use_cac
@pytest.mark.run_in_series
def test_pipeline_model_is_none(self):
# Text2text generation
pipe = pipeline("text2text-generation")
pipe = transformers_pipeline("text2text-generation")
text = "This is a test"
outputs = pipe(text, min_length=1, max_length=2)
# compare model output class
self.assertIsInstance(outputs[0]["generated_text"], str)

# Summarization
pipe = pipeline("summarization")
pipe = transformers_pipeline("summarization")
outputs = pipe(text, min_length=1, max_length=2)
# compare model output class
self.assertIsInstance(outputs[0]["summary_text"], str)

# Translation
pipe = pipeline("translation_en_to_de")
pipe = transformers_pipeline("translation_en_to_de")
outputs = pipe(text, min_length=1, max_length=2)
# compare model output class
self.assertIsInstance(outputs[0]["translation_text"], str)
Expand Down

0 comments on commit 34c1f82

Please sign in to comment.