Skip to content

Commit

Permalink
Fix provider availability check on ORT 1.16.0 release (#1403)
Browse files Browse the repository at this point in the history
fix provider availablity on ORT 1.16.0 release
  • Loading branch information
fxmarty authored and echarlaix committed Sep 21, 2023
1 parent 1446e53 commit 4a1e438
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions optimum/onnxruntime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import numpy as np
import torch
from packaging import version
from transformers.utils import logging

import onnxruntime as ort
Expand Down Expand Up @@ -218,8 +219,13 @@ def validate_provider_availability(provider: str):
Args:
provider (str): Name of an ONNX Runtime execution provider.
"""
# disable on Windows as reported in https://github.com/huggingface/optimum/issues/769
if os.name != "nt" and provider in ["CUDAExecutionProvider", "TensorrtExecutionProvider"]:
# Disable on Windows as reported in https://github.com/huggingface/optimum/issues/769.
# Disable as well for ORT 1.16.0 that has changed changed the way _ld_preload.py is filled: https://github.com/huggingface/optimum/issues/1402.
if (
version.parse(ort.__version__) < version.parse("1.16.0")
and os.name != "nt"
and provider in ["CUDAExecutionProvider", "TensorrtExecutionProvider"]
):
path_cuda_lib = os.path.join(ort.__path__[0], "capi", "libonnxruntime_providers_cuda.so")
path_trt_lib = os.path.join(ort.__path__[0], "capi", "libonnxruntime_providers_tensorrt.so")
path_dependecy_loading = os.path.join(ort.__path__[0], "capi", "_ld_preload.py")
Expand Down

0 comments on commit 4a1e438

Please sign in to comment.