From 8b12d7a4fc64b9bd45a98ce7e6a2a4f4c0b3c646 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 18 Nov 2024 09:41:27 +0100 Subject: [PATCH] revert unnecessary changes --- .github/workflows/test_onnxruntime.yml | 2 +- tests/onnxruntime/test_modeling.py | 22 ---------------- tests/onnxruntime/test_utils.py | 35 ++++++++++++++++++++++++++ 3 files changed, 36 insertions(+), 23 deletions(-) create mode 100644 tests/onnxruntime/test_utils.py diff --git a/.github/workflows/test_onnxruntime.yml b/.github/workflows/test_onnxruntime.yml index 998021668ff..6df6407541f 100644 --- a/.github/workflows/test_onnxruntime.yml +++ b/.github/workflows/test_onnxruntime.yml @@ -17,8 +17,8 @@ jobs: strategy: fail-fast: false matrix: - os: [ubuntu-20.04, windows-2019, macos-15] transformers-version: ["latest"] + os: [ubuntu-20.04, windows-2019, macos-13] include: - transformers-version: "4.45.*" os: ubuntu-20.04 diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index 0270e6bc8b5..fc9c2d76665 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -24,7 +24,6 @@ import numpy as np import onnx import onnxruntime -import onnxruntime as ort import pytest import requests import timm @@ -99,7 +98,6 @@ ) from optimum.onnxruntime.base import ORTDecoderForSeq2Seq, ORTEncoder from optimum.onnxruntime.modeling_ort import ORTModel -from optimum.onnxruntime.utils import get_device_for_provider, get_provider_for_device from optimum.pipelines import pipeline from optimum.utils import ( CONFIG_NAME, @@ -5722,23 +5720,3 @@ def test_find_untested_architectures(self, task: str, test_class): f"For the task `{task}`, the ONNX export supports {supported_export_models}, but only {tested_architectures} are tested.\n" f" Missing {untested_architectures}." ) - - -class ProviderAndDeviceGettersTest(unittest.TestCase): - def test_get_device_for_provider(self): - self.assertEqual( - get_device_for_provider("CPUExecutionProvider", provider_options={}), - torch.device("cpu"), - ) - self.assertEqual( - get_device_for_provider("CUDAExecutionProvider", provider_options={"device_id": 1}), - torch.device("cuda:1"), - ) - - def test_get_provider_for_device(self): - self.assertEqual(get_provider_for_device(torch.device("cpu")), "CPUExecutionProvider") - - if "ROCMExecutionProvider" in ort.get_available_providers(): - self.assertEqual(get_provider_for_device(torch.device("cuda")), "ROCMExecutionProvider") - else: - self.assertEqual(get_provider_for_device(torch.device("cuda")), "CUDAExecutionProvider") diff --git a/tests/onnxruntime/test_utils.py b/tests/onnxruntime/test_utils.py new file mode 100644 index 00000000000..2e30851618c --- /dev/null +++ b/tests/onnxruntime/test_utils.py @@ -0,0 +1,35 @@ +import tempfile +import unittest + +import onnxruntime as ort +import torch + +from optimum.onnxruntime.configuration import AutoQuantizationConfig, OptimizationConfig, ORTConfig +from optimum.onnxruntime.utils import get_device_for_provider, get_provider_for_device + + +class ProviderAndDeviceGettersTest(unittest.TestCase): + def test_get_device_for_provider(self): + self.assertEqual(get_device_for_provider("CPUExecutionProvider", provider_options={}), torch.device("cpu")) + self.assertEqual( + get_device_for_provider("CUDAExecutionProvider", provider_options={"device_id": 1}), torch.device("cuda:1") + ) + + def test_get_provider_for_device(self): + self.assertEqual(get_provider_for_device(torch.device("cpu")), "CPUExecutionProvider") + + if "ROCMExecutionProvider" in ort.get_available_providers(): + self.assertEqual(get_provider_for_device(torch.device("cuda")), "ROCMExecutionProvider") + else: + self.assertEqual(get_provider_for_device(torch.device("cuda")), "CUDAExecutionProvider") + + +class ORTConfigTest(unittest.TestCase): + def test_save_and_load(self): + with tempfile.TemporaryDirectory() as tmp_dir: + quantization_config = AutoQuantizationConfig.arm64(is_static=False, per_channel=False) + optimization_config = OptimizationConfig(optimization_level=2) + ort_config = ORTConfig(opset=11, quantization=quantization_config, optimization=optimization_config) + ort_config.save_pretrained(tmp_dir) + loaded_ort_config = ORTConfig.from_pretrained(tmp_dir) + self.assertEqual(ort_config.to_dict(), loaded_ort_config.to_dict())