Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add directml support #1153

Merged
merged 7 commits into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ option(SHERPA_ONNX_ENABLE_JNI "Whether to build JNI internface" OFF)
option(SHERPA_ONNX_ENABLE_C_API "Whether to build C API" ON)
option(SHERPA_ONNX_ENABLE_WEBSOCKET "Whether to build webscoket server/client" ON)
option(SHERPA_ONNX_ENABLE_GPU "Enable ONNX Runtime GPU support" OFF)
option(SHERPA_ONNX_ENABLE_DIRECTML "Enable ONNX Runtime DirectML support" OFF)
option(SHERPA_ONNX_ENABLE_WASM "Whether to enable WASM" OFF)
option(SHERPA_ONNX_ENABLE_WASM_TTS "Whether to enable WASM for TTS" OFF)
option(SHERPA_ONNX_ENABLE_WASM_ASR "Whether to enable WASM for ASR" OFF)
Expand Down Expand Up @@ -94,6 +95,19 @@ to install CUDA toolkit if you have not installed it.")
endif()
endif()

if(SHERPA_ONNX_ENABLE_DIRECTML)
message(WARNING "\
Compiling with DirectML enabled. Please make sure Windows 10 SDK
is installed on your system. Otherwise, you will get errors at runtime.
Please refer to
https://onnxruntime.ai/docs/execution-providers/DirectML-ExecutionProvider.html#requirements
to install Windows 10 SDK if you have not installed it.")
if(NOT BUILD_SHARED_LIBS)
message(STATUS "Set BUILD_SHARED_LIBS to ON since SHERPA_ONNX_ENABLE_DIRECTML is ON")
set(BUILD_SHARED_LIBS ON CACHE BOOL "" FORCE)
endif()
endif()

# see https://cmake.org/cmake/help/latest/prop_tgt/MSVC_RUNTIME_LIBRARY.html
# https://stackoverflow.com/questions/14172856/compile-with-mt-instead-of-md-using-cmake
if(MSVC)
Expand Down Expand Up @@ -160,6 +174,14 @@ else()
add_definitions(-DSHERPA_ONNX_ENABLE_TTS=0)
endif()

if(SHERPA_ONNX_ENABLE_DIRECTML)
message(STATUS "DirectML is enabled")
add_definitions(-DSHERPA_ONNX_ENABLE_DIRECTML=1)
else()
message(WARNING "DirectML is disabled")
add_definitions(-DSHERPA_ONNX_ENABLE_DIRECTML=0)
endif()

if(SHERPA_ONNX_ENABLE_WASM_TTS)
if(NOT SHERPA_ONNX_ENABLE_TTS)
message(FATAL_ERROR "Please set SHERPA_ONNX_ENABLE_TTS to ON if you want to build wasm TTS")
Expand Down
161 changes: 161 additions & 0 deletions cmake/onnxruntime-win-x64-directml.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# Copyright (c) 2022-2023 Xiaomi Corporation
message(STATUS "CMAKE_SYSTEM_NAME: ${CMAKE_SYSTEM_NAME}")
message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}")
message(STATUS "CMAKE_VS_PLATFORM_NAME: ${CMAKE_VS_PLATFORM_NAME}")

if(NOT CMAKE_SYSTEM_NAME STREQUAL Windows)
message(FATAL_ERROR "This file is for Windows only. Given: ${CMAKE_SYSTEM_NAME}")
endif()

if(NOT (CMAKE_VS_PLATFORM_NAME STREQUAL X64 OR CMAKE_VS_PLATFORM_NAME STREQUAL x64))
message(FATAL_ERROR "This file is for Windows x64 only. Given: ${CMAKE_VS_PLATFORM_NAME}")
endif()

if(NOT BUILD_SHARED_LIBS)
message(FATAL_ERROR "This file is for building shared libraries. BUILD_SHARED_LIBS: ${BUILD_SHARED_LIBS}")
endif()

if(NOT SHERPA_ONNX_ENABLE_DIRECTML)
message(FATAL_ERROR "This file is for DirectML. Given SHERPA_ONNX_ENABLE_DIRECTML: ${SHERPA_ONNX_ENABLE_DIRECTML}")
endif()

set(onnxruntime_URL "https://globalcdn.nuget.org/packages/microsoft.ml.onnxruntime.directml.1.14.1.nupkg")
set(onnxruntime_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/microsoft.ml.onnxruntime.directml.1.14.1.nupkg")
set(onnxruntime_HASH "SHA256=c8ae7623385b19cd5de968d0df5383e13b97d1b3a6771c9177eac15b56013a5a")

# If you don't have access to the Internet,
# please download onnxruntime to one of the following locations.
# You can add more if you want.
set(possible_file_locations
$ENV{HOME}/Downloads/microsoft.ml.onnxruntime.directml.1.14.1.nupkg
${PROJECT_SOURCE_DIR}/microsoft.ml.onnxruntime.directml.1.14.1.nupkg
${PROJECT_BINARY_DIR}/microsoft.ml.onnxruntime.directml.1.14.1.nupkg
/tmp/microsoft.ml.onnxruntime.directml.1.14.1.nupkg
)

foreach(f IN LISTS possible_file_locations)
if(EXISTS ${f})
set(onnxruntime_URL "${f}")
file(TO_CMAKE_PATH "${onnxruntime_URL}" onnxruntime_URL)
message(STATUS "Found local downloaded onnxruntime: ${onnxruntime_URL}")
set(onnxruntime_URL2)
break()
endif()
endforeach()

FetchContent_Declare(onnxruntime
URL
${onnxruntime_URL}
${onnxruntime_URL2}
URL_HASH ${onnxruntime_HASH}
)

FetchContent_GetProperties(onnxruntime)
if(NOT onnxruntime_POPULATED)
message(STATUS "Downloading onnxruntime from ${onnxruntime_URL}")
FetchContent_Populate(onnxruntime)
endif()
message(STATUS "onnxruntime is downloaded to ${onnxruntime_SOURCE_DIR}")

find_library(location_onnxruntime onnxruntime
PATHS
"${onnxruntime_SOURCE_DIR}/runtimes/win-x64/native"
NO_CMAKE_SYSTEM_PATH
)

message(STATUS "location_onnxruntime: ${location_onnxruntime}")

add_library(onnxruntime SHARED IMPORTED)

set_target_properties(onnxruntime PROPERTIES
IMPORTED_LOCATION ${location_onnxruntime}
INTERFACE_INCLUDE_DIRECTORIES "${onnxruntime_SOURCE_DIR}/build/native/include"
)

set_property(TARGET onnxruntime
PROPERTY
IMPORTED_IMPLIB "${onnxruntime_SOURCE_DIR}/runtimes/win-x64/native/onnxruntime.lib"
)

file(COPY ${onnxruntime_SOURCE_DIR}/runtimes/win-x64/native/onnxruntime.dll
DESTINATION
${CMAKE_BINARY_DIR}/bin/${CMAKE_BUILD_TYPE}
)

file(GLOB onnxruntime_lib_files "${onnxruntime_SOURCE_DIR}/runtimes/win-x64/native/onnxruntime.*")

message(STATUS "onnxruntime lib files: ${onnxruntime_lib_files}")

if(SHERPA_ONNX_ENABLE_PYTHON)
install(FILES ${onnxruntime_lib_files} DESTINATION ..)
else()
install(FILES ${onnxruntime_lib_files} DESTINATION lib)
endif()

install(FILES ${onnxruntime_lib_files} DESTINATION bin)

# Setup DirectML

set(directml_URL "https://www.nuget.org/api/v2/package/Microsoft.AI.DirectML/1.15.0")
set(directml_HASH "SHA256=10d175f8e97447712b3680e3ac020bbb8eafdf651332b48f09ffee2eec801c23")

set(possible_directml_file_locations
$ENV{HOME}/Downloads/Microsoft.AI.DirectML.1.15.0.nupkg
${PROJECT_SOURCE_DIR}/Microsoft.AI.DirectML.1.15.0.nupkg
${PROJECT_BINARY_DIR}/Microsoft.AI.DirectML.1.15.0.nupkg
/tmp/Microsoft.AI.DirectML.1.15.0.nupkg
)

foreach(f IN LISTS possible_directml_file_locations)
if(EXISTS ${f})
set(directml_URL "${f}")
file(TO_CMAKE_PATH "${directml_URL}" directml_URL)
message(STATUS "Found local downloaded DirectML: ${directml_URL}")
break()
endif()
endforeach()

FetchContent_Declare(directml
URL
${directml_URL}
URL_HASH ${directml_HASH}
)

FetchContent_GetProperties(directml)
if(NOT directml_POPULATED)
message(STATUS "Downloading DirectML from ${directml_URL}")
FetchContent_Populate(directml)
endif()
message(STATUS "DirectML is downloaded to ${directml_SOURCE_DIR}")

find_library(location_directml DirectML
PATHS
"${directml_SOURCE_DIR}/bin/x64-win"
NO_CMAKE_SYSTEM_PATH
)

message(STATUS "location_directml: ${location_directml}")

add_library(directml SHARED IMPORTED)

set_target_properties(directml PROPERTIES
IMPORTED_LOCATION ${location_directml}
INTERFACE_INCLUDE_DIRECTORIES "${directml_SOURCE_DIR}/bin/x64-win"
)

set_property(TARGET directml
PROPERTY
IMPORTED_IMPLIB "${directml_SOURCE_DIR}/bin/x64-win/DirectML.lib"
)

file(COPY ${directml_SOURCE_DIR}/bin/x64-win/DirectML.dll
DESTINATION
${CMAKE_BINARY_DIR}/bin/${CMAKE_BUILD_TYPE}
)

file(GLOB directml_lib_files "${directml_SOURCE_DIR}/bin/x64-win/DirectML.*")

message(STATUS "DirectML lib files: ${directml_lib_files}")

install(FILES ${directml_lib_files} DESTINATION lib)
install(FILES ${directml_lib_files} DESTINATION bin)
5 changes: 4 additions & 1 deletion cmake/onnxruntime.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,10 @@ function(download_onnxruntime)
include(onnxruntime-win-arm64)
else()
# for 64-bit windows (x64)
if(BUILD_SHARED_LIBS)
if(SHERPA_ONNX_ENABLE_DIRECTML)
message(STATUS "Use DirectML")
include(onnxruntime-win-x64-directml)
elseif(BUILD_SHARED_LIBS)
message(STATUS "Use dynamic onnxruntime libraries")
if(SHERPA_ONNX_ENABLE_GPU)
include(onnxruntime-win-x64-gpu)
Expand Down
2 changes: 2 additions & 0 deletions sherpa-onnx/csrc/provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ Provider StringToProvider(std::string s) {
return Provider::kNNAPI;
} else if (s == "trt") {
return Provider::kTRT;
} else if (s == "directml") {
return Provider::kDirectML;
} else {
SHERPA_ONNX_LOGE("Unsupported string: %s. Fallback to cpu", s.c_str());
return Provider::kCPU;
Expand Down
13 changes: 7 additions & 6 deletions sherpa-onnx/csrc/provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@ namespace sherpa_onnx {
// https://github.com/microsoft/onnxruntime/blob/main/java/src/main/java/ai/onnxruntime/OrtProvider.java
// for a list of available providers
enum class Provider {
kCPU = 0, // CPUExecutionProvider
kCUDA = 1, // CUDAExecutionProvider
kCoreML = 2, // CoreMLExecutionProvider
kXnnpack = 3, // XnnpackExecutionProvider
kNNAPI = 4, // NnapiExecutionProvider
kTRT = 5, // TensorRTExecutionProvider
kCPU = 0, // CPUExecutionProvider
kCUDA = 1, // CUDAExecutionProvider
kCoreML = 2, // CoreMLExecutionProvider
kXnnpack = 3, // XnnpackExecutionProvider
kNNAPI = 4, // NnapiExecutionProvider
kTRT = 5, // TensorRTExecutionProvider
kDirectML = 6, // DmlExecutionProvider
};

/**
Expand Down
22 changes: 22 additions & 0 deletions sherpa-onnx/csrc/session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
#include "nnapi_provider_factory.h" // NOLINT
#endif

#if defined(_WIN32) && SHERPA_ONNX_ENABLE_DIRECTML == 1
#include "dml_provider_factory.h" // NOLINT
#endif

namespace sherpa_onnx {

static void OrtStatusFailure(OrtStatus *status, const char *s) {
Expand Down Expand Up @@ -167,6 +171,24 @@ static Ort::SessionOptions GetSessionOptionsImpl(
}
break;
}
case Provider::kDirectML: {
#if defined(_WIN32) && SHERPA_ONNX_ENABLE_DIRECTML == 1
sess_opts.DisableMemPattern();
sess_opts.SetExecutionMode(ORT_SEQUENTIAL);
int32_t device_id = 0;
OrtStatus *status =
OrtSessionOptionsAppendExecutionProvider_DML(sess_opts, device_id);
if (status) {
const auto &api = Ort::GetApi();
const char *msg = api.GetErrorMessage(status);
SHERPA_ONNX_LOGE("Failed to enable DirectML: %s. Fallback to cpu", msg);
api.ReleaseStatus(status);
}
#else
SHERPA_ONNX_LOGE("DirectML is for Windows only. Fallback to cpu!");
#endif
break;
}
case Provider::kCoreML: {
#if defined(__APPLE__)
uint32_t coreml_flags = 0;
Expand Down
Loading