Skip to content

Commit

Permalink
Add support for registring custom op libraries
Browse files Browse the repository at this point in the history
  • Loading branch information
marshallpierce committed Mar 9, 2021
1 parent 1871703 commit 668a0d3
Show file tree
Hide file tree
Showing 9 changed files with 286 additions and 1 deletion.
9 changes: 9 additions & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
*
!/Cargo.*
!/onnxruntime/Cargo.toml
!/onnxruntime/src
!/onnxruntime/tests
!/onnxruntime-sys/Cargo.toml
!/onnxruntime-sys/build.rs
!/onnxruntime-sys/src
!/test-models/tensorflow/*.onnx
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added

- Add `String` datatype ([#58](https://github.com/nbigaouette/onnxruntime-rs/pull/58))
- Support custom operator libraries

## [0.0.11] - 2021-02-22

Expand Down
118 changes: 118 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# onnxruntime requires execinfo.h to build, which only works on glibc-based systems, so alpine is out...
FROM debian:bullseye-slim as base

RUN apt-get update && apt-get -y dist-upgrade

FROM base AS onnxruntime

RUN apt-get install -y \
git \
bash \
python3 \
cmake \
git \
build-essential \
llvm \
locales

# onnxruntime built in tests need en_US.UTF-8 available
# Uncomment en_US.UTF-8, then generate
RUN sed -i 's/^# *\(en_US.UTF-8\)/\1/' /etc/locale.gen && locale-gen

# build onnxruntime
RUN mkdir -p /opt/onnxruntime/tmp
# onnxruntime build relies on being in a git repo, so can't just get a tarball
# it's a big repo, so fetch shallowly
RUN cd /opt/onnxruntime/tmp && \
git clone --recursive --depth 1 --shallow-submodules https://github.com/Microsoft/onnxruntime

# use version that onnxruntime-sys expects
RUN cd /opt/onnxruntime/tmp/onnxruntime && \
git fetch --depth 1 origin tag v1.6.0 && \
git checkout v1.6.0

RUN /opt/onnxruntime/tmp/onnxruntime/build.sh --config RelWithDebInfo --build_shared_lib --parallel

# Build ort-customops, linked against the onnxruntime built above.
# No tags / releases yet - that commit is from 2021-02-16
RUN mkdir -p /opt/ort-customops/tmp && \
cd /opt/ort-customops/tmp && \
git clone --recursive https://github.com/microsoft/ort-customops.git && \
cd ort-customops && \
git checkout 92f6b51106c9e9143c452e537cb5e41d2dcaa266

RUN cd /opt/ort-customops/tmp/ort-customops && \
./build.sh -D ONNXRUNTIME_LIB_DIR=/opt/onnxruntime/tmp/onnxruntime/build/Linux/RelWithDebInfo


# install rust toolchain
FROM base AS rust-toolchain

ARG RUST_VERSION=1.50.0

RUN apt-get install -y \
curl

# install rust toolchain
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain $RUST_VERSION

ENV PATH $PATH:/root/.cargo/bin


# build onnxruntime-rs
FROM rust-toolchain as onnxruntime-rs
# clang & llvm needed by onnxruntime-sys
RUN apt-get install -y \
build-essential \
llvm-dev \
libclang-dev \
clang

RUN mkdir -p \
/onnxruntime-rs/build/onnxruntime-sys/src/ \
/onnxruntime-rs/build/onnxruntime/src/ \
/onnxruntime-rs/build/onnxruntime/tests/ \
/opt/onnxruntime/lib \
/opt/ort-customops/lib

COPY --from=onnxruntime /opt/onnxruntime/tmp/onnxruntime/build/Linux/RelWithDebInfo/libonnxruntime.so /opt/onnxruntime/lib/
COPY --from=onnxruntime /opt/ort-customops/tmp/ort-customops/out/Linux/libortcustomops.so /opt/ort-customops/lib/

WORKDIR /onnxruntime-rs/build

ENV ORT_STRATEGY=system
# this has /lib/ appended to it and is used as a lib search path in onnxruntime-sys's build.rs
ENV ORT_LIB_LOCATION=/opt/onnxruntime/

ENV ONNXRUNTIME_RS_TEST_ORT_CUSTOMOPS_LIB=/opt/ort-customops/lib/libortcustomops.so

# create enough of an empty project that dependencies can build
COPY /Cargo.lock /Cargo.toml /onnxruntime-rs/build/
COPY /onnxruntime/Cargo.toml /onnxruntime-rs/build/onnxruntime/
COPY /onnxruntime-sys/Cargo.toml /onnxruntime-sys/build.rs /onnxruntime-rs/build/onnxruntime-sys/

CMD cargo test

# build dependencies and clean the bogus contents of our two packages
RUN touch \
onnxruntime/src/lib.rs \
onnxruntime/tests/integration_tests.rs \
onnxruntime-sys/src/lib.rs \
&& cargo build --tests \
&& cargo clean --package onnxruntime-sys \
&& cargo clean --package onnxruntime \
&& rm -rf \
onnxruntime/src/ \
onnxruntime/tests/ \
onnxruntime-sys/src/

# now build the actual source
COPY /test-models test-models
COPY /onnxruntime-sys/src onnxruntime-sys/src
COPY /onnxruntime/src onnxruntime/src
COPY /onnxruntime/tests onnxruntime/tests

RUN ln -s /opt/onnxruntime/lib/libonnxruntime.so /opt/onnxruntime/lib/libonnxruntime.so.1.6.0
ENV LD_LIBRARY_PATH=/opt/onnxruntime/lib

RUN cargo build --tests
6 changes: 6 additions & 0 deletions onnxruntime/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ ndarray = "0.13"
thiserror = "1.0"
tracing = "0.1"

[target.'cfg(unix)'.dependencies]
libc = "0.2.88"

[target.'cfg(windows)'.dependencies]
winapi = { version = "0.3.9", features = ["std"] }

# Enabled with 'model-fetching' feature
ureq = {version = "1.5.1", optional = true}

Expand Down
55 changes: 54 additions & 1 deletion onnxruntime/src/session.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Module containing session types
use std::{convert::TryInto as _, ffi::CString, fmt::Debug, path::Path};
use std::{convert::TryInto as _, ffi, ffi::CString, fmt::Debug, path::Path};

#[cfg(not(target_family = "windows"))]
use std::os::unix::ffi::OsStrExt;
Expand Down Expand Up @@ -64,11 +64,16 @@ pub struct SessionBuilder<'a> {

allocator: AllocatorType,
memory_type: MemType,
custom_runtime_handles: Vec<*mut ::std::os::raw::c_void>,
}

impl<'a> Drop for SessionBuilder<'a> {
#[tracing::instrument]
fn drop(&mut self) {
for &handle in self.custom_runtime_handles.iter() {
close_lib_handle(handle);
}

debug!("Dropping the session options.");
assert_ne!(self.session_options_ptr, std::ptr::null_mut());
unsafe { g_ort().ReleaseSessionOptions.unwrap()(self.session_options_ptr) };
Expand All @@ -89,6 +94,7 @@ impl<'a> SessionBuilder<'a> {
session_options_ptr,
allocator: AllocatorType::Arena,
memory_type: MemType::Default,
custom_runtime_handles: Vec::new(),
})
}

Expand Down Expand Up @@ -136,6 +142,39 @@ impl<'a> SessionBuilder<'a> {
Ok(self)
}

/// Registers a custom ops library with the given library path in the session.
pub fn with_custom_op_lib(mut self, lib_path: &str) -> Result<SessionBuilder<'a>> {
let path_cstr = ffi::CString::new(lib_path)?;

let mut handle: *mut ::std::os::raw::c_void = std::ptr::null_mut();

let status = unsafe {
g_ort().RegisterCustomOpsLibrary.unwrap()(
self.session_options_ptr,
path_cstr.as_ptr(),
&mut handle,
)
};

// per RegisterCustomOpsLibrary docs, release handle if there was an error and the handle
// is non-null
match status_to_result(status).map_err(OrtError::SessionOptions) {
Ok(_) => {}
Err(e) => {
if handle != std::ptr::null_mut() {
// handle was written to, should release it
close_lib_handle(handle);
}

return Err(e);
}
}

self.custom_runtime_handles.push(handle);

Ok(self)
}

/// Download an ONNX pre-trained model from the [ONNX Model Zoo](https://github.com/onnx/models) and commit the session
#[cfg(feature = "model-fetching")]
pub fn with_model_downloaded<M>(self, model: M) -> Result<Session<'a>>
Expand Down Expand Up @@ -619,6 +658,20 @@ where
res
}

#[cfg(unix)]
fn close_lib_handle(handle: *mut ::std::os::raw::c_void) {
unsafe {
libc::dlclose(handle);
}
}

#[cfg(windows)]
fn close_lib_handle(handle: *mut ::std::os::raw::c_void) {
unsafe {
winapi::um::libloaderapi::FreeLibrary(handle as winapi::shared::minwindef::HINSTANCE)
};
}

/// This module contains dangerous functions working on raw pointers.
/// Those functions are only to be used from inside the
/// `SessionBuilder::with_model_from_file()` method.
Expand Down
51 changes: 51 additions & 0 deletions onnxruntime/tests/custom_ops.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
use std::error::Error;

use ndarray;
use onnxruntime::tensor::{DynOrtTensor, OrtOwnedTensor};
use onnxruntime::{environment::Environment, LoggingLevel};

#[test]
fn run_model_with_ort_customops() -> Result<(), Box<dyn Error>> {
let lib_path = match std::env::var("ONNXRUNTIME_RS_TEST_ORT_CUSTOMOPS_LIB") {
Ok(s) => s,
Err(_e) => {
println!("Skipping ort_customops test -- no lib specified");
return Ok(());
}
};

let environment = Environment::builder()
.with_name("test")
.with_log_level(LoggingLevel::Verbose)
.build()?;

let mut session = environment
.new_session_builder()?
.with_custom_op_lib(&lib_path)?
.with_model_from_file("../test-models/tensorflow/regex_model.onnx")?;

//Inputs:
// 0:
// name = input_1:0
// type = String
// dimensions = [None]
// Outputs:
// 0:
// name = Identity:0
// type = String
// dimensions = [None]

let array = ndarray::Array::from(vec![String::from("Hello world!")]);
let input_tensor_values = vec![array];

let outputs: Vec<DynOrtTensor<_>> = session.run(input_tensor_values)?;
let strings: OrtOwnedTensor<String, _> = outputs[0].try_extract()?;

// ' ' replaced with '_'
assert_eq!(
&[String::from("Hello_world!")],
strings.view().as_slice().unwrap()
);

Ok(())
}
9 changes: 9 additions & 0 deletions test-models/tensorflow/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,12 @@ This supports strings, and doesn't require custom operators.
pipenv run python src/unique_model.py
pipenv run python -m tf2onnx.convert --saved-model models/unique_model --output unique_model.onnx --opset 11
```

# Model: Regex (uses `ort_customops`)

A TensorFlow model that applies a regex, which requires the onnxruntime custom ops in `ort-customops`.

```
pipenv run python src/regex_model.py
pipenv run python -m tf2onnx.convert --saved-model models/regex_model --output regex_model.onnx --extra_opset ai.onnx.contrib:1
```
19 changes: 19 additions & 0 deletions test-models/tensorflow/regex_model.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
tf2onnx1.9.0:�

input_1:0

pattern__7

rewrite__8
Identity:0)PartitionedCall/model1/StaticRegexReplace"StringRegexReplace:ai.onnx.contribtf2onnx*2_B
rewrite__8*2 B
pattern__7R!converted from models/regex_modelZ
input_1:0


unk__9b

Identity:0

unk__10B B
ai.onnx.contrib
19 changes: 19 additions & 0 deletions test-models/tensorflow/src/regex_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import tensorflow as tf
import numpy as np
import tf2onnx


class RegexModel(tf.keras.Model):

def __init__(self, name='model1', **kwargs):
super(RegexModel, self).__init__(name=name, **kwargs)

def call(self, inputs):
return tf.strings.regex_replace(inputs, " ", "_", replace_global=True)


model1 = RegexModel()

print(model1(tf.constant(["Hello world!"])))

model1.save("models/regex_model")

0 comments on commit 668a0d3

Please sign in to comment.