Skip to content

Commit

Permalink
add support for onnxruntime-rocm to the onnx backend (LeelaChessZero#…
Browse files Browse the repository at this point in the history
…1897)

* first attempt at onnx-rocm backend

* support local onnxruntime builds

* rocm needs locking
  • Loading branch information
borg323 authored Jul 30, 2023
1 parent 4e314e5 commit 4e2baee
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 16 deletions.
26 changes: 15 additions & 11 deletions meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -576,19 +576,23 @@ if get_option('build_backends')
## ONNX
## ~~~~~~~~~~
if get_option('onnx_libdir') != '' and get_option('onnx_include') != ''
onnx_lib = cc.find_library('onnxruntime',
dirs: get_option('onnx_libdir'),
required: true)
deps += cc.find_library('onnxruntime', dirs: get_option('onnx_libdir'),
required: true)
includes += include_directories(get_option('onnx_include'), is_system: true)
cc.has_header('onnxruntime_cxx_api.h',
required: true,
cc.has_header('onnxruntime_cxx_api.h', required: true,
args: '-I' + get_option('onnx_include'))
deps += [onnx_lib]

files += [
'src/neural/onnx/network_onnx.cc',
]

if not cc.has_header('cpu_provider_factory.h',
args: '-I' + get_option('onnx_include'))
cc.has_header('../providers/cpu/cpu_provider_factory.h', required: true,
args: '-I' + get_option('onnx_include'))
includes += include_directories(get_option('onnx_include') + '/../providers/cpu',
is_system: true)
endif
files += 'src/neural/onnx/network_onnx.cc'
if cc.find_library('onnxruntime_providers_rocm',
dirs: get_option('onnx_libdir'), required: false).found()
add_project_arguments('-DUSE_ROCM', language : 'cpp')
endif
has_backends = true
endif

Expand Down
30 changes: 25 additions & 5 deletions src/neural/onnx/network_onnx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
namespace lczero {
namespace {

enum class OnnxProvider { CPU, CUDA, DML };
enum class OnnxProvider { CPU, CUDA, DML, ROCM };

class OnnxNetwork;

Expand Down Expand Up @@ -240,20 +240,29 @@ void OnnxComputation<DataType>::ComputeBlocking() {
int batch = batch_size * step;

auto input_tensor = PrepareInputs(i, batch);
if (network_->provider_ == OnnxProvider::DML) network_->lock_.lock();
// The DML onnxruntime execution provider is documented as not supporting
// multi-threaded calls to Run on the same inference session. We found the
// same to be true for the ROCm execution provider (at least for CNNs).
// TODO: This may be a onnxruntime/ROCm bug, check onnxruntime 1.16 release.
if (network_->provider_ == OnnxProvider::DML ||
network_->provider_ == OnnxProvider::ROCM) {
network_->lock_.lock();
}
network_->session_[step - 1].Run(
{}, network_->inputs_cstr_.data(), &input_tensor, 1,
network_->outputs_cstr_.data(), output_tensors_.data(),
output_tensors_.size());
if (network_->provider_ == OnnxProvider::DML) network_->lock_.unlock();
if (network_->provider_ == OnnxProvider::DML ||
network_->provider_ == OnnxProvider::ROCM) {
network_->lock_.unlock();
}
i += batch;
}
}

Ort::SessionOptions GetOptions(OnnxProvider provider, int gpu, int threads,
int batch_size) {
Ort::SessionOptions options;
OrtCUDAProviderOptions cuda_options;
options.SetIntraOpNumThreads(threads);
options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);

Expand All @@ -276,10 +285,18 @@ Ort::SessionOptions GetOptions(OnnxProvider provider, int gpu, int threads,
throw Exception("ONNX backend internal error.");
#endif
break;
case OnnxProvider::CUDA:
case OnnxProvider::ROCM: {
OrtROCMProviderOptions rocm_options;
rocm_options.device_id = gpu;
options.AppendExecutionProvider_ROCM(rocm_options);
break;
}
case OnnxProvider::CUDA: {
OrtCUDAProviderOptions cuda_options;
cuda_options.device_id = gpu;
options.AppendExecutionProvider_CUDA(cuda_options);
break;
}
case OnnxProvider::CPU:
auto status = OrtSessionOptionsAppendExecutionProvider_CPU(options, 0);
if (status) {
Expand Down Expand Up @@ -426,6 +443,9 @@ std::unique_ptr<Network> MakeOnnxNetwork(const std::optional<WeightsFile>& w,
}
}

#ifdef USE_ROCM
REGISTER_NETWORK("onnx-rocm", MakeOnnxNetwork<OnnxProvider::ROCM>, 64)
#endif
#ifdef USE_DML
REGISTER_NETWORK("onnx-dml", MakeOnnxNetwork<OnnxProvider::DML>, 63)
#endif
Expand Down

0 comments on commit 4e2baee

Please sign in to comment.