Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Attempt at progressively feeding the Session to bypass type checking in #69

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions onnxruntime/examples/issue22.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,12 @@ fn main() {
let input_ids = Array2::<i64>::from_shape_vec((1, 3), vec![1, 2, 3]).unwrap();
let attention_mask = Array2::<i64>::from_shape_vec((1, 3), vec![1, 1, 1]).unwrap();

let outputs: Vec<OrtOwnedTensor<f32, _>> =
session.run(vec![input_ids, attention_mask]).unwrap();
let outputs: Vec<OrtOwnedTensor<f32, _>> = session
.run(vec![input_ids, attention_mask])
.unwrap()
.into_iter()
.map(|dyn_tensor| dyn_tensor.try_extract())
.collect::<Result<_, _>>()
.unwrap();
print!("outputs: {:#?}", outputs);
}
3 changes: 1 addition & 2 deletions onnxruntime/examples/print_structure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@ use std::error::Error;
fn main() -> Result<(), Box<dyn Error>> {
// provide path to .onnx model on disk
let path = std::env::args()
.skip(1)
.next()
.nth(1)
.expect("Must provide an .onnx file as the first arg");

let environment = environment::Environment::builder()
Expand Down
13 changes: 8 additions & 5 deletions onnxruntime/examples/sample.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
#![forbid(unsafe_code)]

use onnxruntime::{
environment::Environment, ndarray::Array, tensor::OrtOwnedTensor, GraphOptimizationLevel,
LoggingLevel,
environment::Environment,
ndarray::Array,
tensor::{DynOrtTensor, OrtOwnedTensor},
GraphOptimizationLevel, LoggingLevel,
};
use tracing::Level;
use tracing_subscriber::FmtSubscriber;
Expand Down Expand Up @@ -61,11 +63,12 @@ fn run() -> Result<(), Error> {
.unwrap();
let input_tensor_values = vec![array];

let outputs: Vec<OrtOwnedTensor<f32, _>> = session.run(input_tensor_values)?;
let outputs: Vec<DynOrtTensor<_>> = session.run(input_tensor_values)?;

assert_eq!(outputs[0].shape(), output0_shape.as_slice());
let output: OrtOwnedTensor<f32, _> = outputs[0].try_extract().unwrap();
assert_eq!(output.view().shape(), output0_shape.as_slice());
for i in 0..5 {
println!("Score for class [{}] = {}", i, outputs[0][[0, i, 0, 0]]);
println!("Score for class [{}] = {}", i, output.view()[[0, i, 0, 0]]);
}

Ok(())
Expand Down
24 changes: 18 additions & 6 deletions onnxruntime/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Module containing error definitions.

use std::{io, path::PathBuf};
use std::{io, path::PathBuf, string};

use thiserror::Error;

Expand Down Expand Up @@ -53,6 +53,12 @@ pub enum OrtError {
/// Error occurred when getting ONNX dimensions
#[error("Failed to get dimensions: {0}")]
GetDimensions(OrtApiError),
/// Error occurred when getting string length
#[error("Failed to get string tensor length: {0}")]
GetStringTensorDataLength(OrtApiError),
/// Error occurred when getting tensor element count
#[error("Failed to get tensor element count: {0}")]
GetTensorShapeElementCount(OrtApiError),
/// Error occurred when creating CPU memory information
#[error("Failed to get dimensions: {0}")]
CreateCpuMemoryInfo(OrtApiError),
Expand All @@ -77,6 +83,12 @@ pub enum OrtError {
/// Error occurred when extracting data from an ONNX tensor into an C array to be used as an `ndarray::ArrayView`
#[error("Failed to get tensor data: {0}")]
GetTensorMutableData(OrtApiError),
/// Error occurred when extracting string data from an ONNX tensor
#[error("Failed to get tensor string data: {0}")]
GetStringTensorContent(OrtApiError),
/// Error occurred when converting data to a String
#[error("Data was not UTF-8: {0}")]
StringFromUtf8Error(#[from] string::FromUtf8Error),

/// Error occurred when downloading a pre-trained ONNX model from the [ONNX Model Zoo](https://github.com/onnx/models)
#[error("Failed to download ONNX model: {0}")]
Expand Down Expand Up @@ -108,16 +120,16 @@ pub enum OrtError {
#[derive(Error, Debug)]
pub enum NonMatchingDimensionsError {
/// Number of inputs from model does not match number of inputs from inference call
#[error("Non-matching number of inputs: {inference_input_count:?} for input vs {model_input_count:?} for model (inputs: {inference_input:?}, model: {model_input:?})")]
#[error("Non-matching number of inputs: {inference_input_count:?} for input vs {model_input_count:?}")]
InputsCount {
/// Number of input dimensions used by inference call
inference_input_count: usize,
/// Number of input dimensions defined in model
model_input_count: usize,
/// Input dimensions used by inference call
inference_input: Vec<Vec<usize>>,
/// Input dimensions defined in model
model_input: Vec<Vec<Option<u32>>>,
// Input dimensions used by inference call
// inference_input: Vec<Vec<usize>>,
// Input dimensions defined in model
// model_input: Vec<Vec<Option<u32>>>,
},
}

Expand Down
178 changes: 18 additions & 160 deletions onnxruntime/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,10 @@ to download.
//! let array = ndarray::Array::linspace(0.0_f32, 1.0, 100);
//! // Multiple inputs and outputs are possible
//! let input_tensor = vec![array];
//! let outputs: Vec<OrtOwnedTensor<f32,_>> = session.run(input_tensor)?;
//! let outputs: Vec<OrtOwnedTensor<f32, _>> = session.run(input_tensor)?
//! .into_iter()
//! .map(|dyn_tensor| dyn_tensor.try_extract())
//! .collect::<Result<_, _>>()?;
//! # Ok(())
//! # }
//! ```
Expand All @@ -115,7 +118,10 @@ to download.
//! See the [`sample.rs`](https://github.com/nbigaouette/onnxruntime-rs/blob/master/onnxruntime/examples/sample.rs)
//! example for more details.

use std::sync::{atomic::AtomicPtr, Arc, Mutex};
use std::{
ffi, ptr,
sync::{atomic::AtomicPtr, Arc, Mutex},
};

use lazy_static::lazy_static;

Expand All @@ -142,7 +148,7 @@ lazy_static! {
// } as *mut sys::OrtApi)));
static ref G_ORT_API: Arc<Mutex<AtomicPtr<sys::OrtApi>>> = {
let base: *const sys::OrtApiBase = unsafe { sys::OrtGetApiBase() };
assert_ne!(base, std::ptr::null());
assert_ne!(base, ptr::null());
let get_api: unsafe extern "C" fn(u32) -> *const onnxruntime_sys::OrtApi =
unsafe { (*base).GetApi.unwrap() };
let api: *const sys::OrtApi = unsafe { get_api(sys::ORT_API_VERSION) };
Expand All @@ -157,13 +163,13 @@ fn g_ort() -> sys::OrtApi {
let api_ref_mut: &mut *mut sys::OrtApi = api_ref.get_mut();
let api_ptr_mut: *mut sys::OrtApi = *api_ref_mut;

assert_ne!(api_ptr_mut, std::ptr::null_mut());
assert_ne!(api_ptr_mut, ptr::null_mut());

unsafe { *api_ptr_mut }
}

fn char_p_to_string(raw: *const i8) -> Result<String> {
let c_string = unsafe { std::ffi::CStr::from_ptr(raw as *mut i8).to_owned() };
let c_string = unsafe { ffi::CStr::from_ptr(raw as *mut i8).to_owned() };

match c_string.into_string() {
Ok(string) => Ok(string),
Expand All @@ -176,7 +182,7 @@ mod onnxruntime {
//! Module containing a custom logger, used to catch the runtime's own logging and send it
//! to Rust's tracing logging instead.

use std::ffi::CStr;
use std::{ffi, ffi::CStr, ptr};
use tracing::{debug, error, info, span, trace, warn, Level};

use onnxruntime_sys as sys;
Expand Down Expand Up @@ -212,7 +218,7 @@ mod onnxruntime {

/// Callback from C that will handle the logging, forwarding the runtime's logs to the tracing crate.
pub(crate) extern "C" fn custom_logger(
_params: *mut std::ffi::c_void,
_params: *mut ffi::c_void,
severity: sys::OrtLoggingLevel,
category: *const i8,
logid: *const i8,
Expand All @@ -227,16 +233,16 @@ mod onnxruntime {
sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_FATAL => Level::ERROR,
};

assert_ne!(category, std::ptr::null());
assert_ne!(category, ptr::null());
let category = unsafe { CStr::from_ptr(category) };
assert_ne!(code_location, std::ptr::null());
assert_ne!(code_location, ptr::null());
let code_location = unsafe { CStr::from_ptr(code_location) }
.to_str()
.unwrap_or("unknown");
assert_ne!(message, std::ptr::null());
assert_ne!(message, ptr::null());
let message = unsafe { CStr::from_ptr(message) };

assert_ne!(logid, std::ptr::null());
assert_ne!(logid, ptr::null());
let logid = unsafe { CStr::from_ptr(logid) };

// Parse the code location
Expand Down Expand Up @@ -322,154 +328,6 @@ impl Into<sys::GraphOptimizationLevel> for GraphOptimizationLevel {
}
}

// FIXME: Use https://docs.rs/bindgen/0.54.1/bindgen/struct.Builder.html#method.rustified_enum
// FIXME: Add tests to cover the commented out types
/// Enum mapping ONNX Runtime's supported tensor types
#[derive(Debug)]
#[cfg_attr(not(windows), repr(u32))]
#[cfg_attr(windows, repr(i32))]
pub enum TensorElementDataType {
/// 32-bit floating point, equivalent to Rust's `f32`
Float = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT as OnnxEnumInt,
/// Unsigned 8-bit int, equivalent to Rust's `u8`
Uint8 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8 as OnnxEnumInt,
/// Signed 8-bit int, equivalent to Rust's `i8`
Int8 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 as OnnxEnumInt,
/// Unsigned 16-bit int, equivalent to Rust's `u16`
Uint16 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16 as OnnxEnumInt,
/// Signed 16-bit int, equivalent to Rust's `i16`
Int16 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16 as OnnxEnumInt,
/// Signed 32-bit int, equivalent to Rust's `i32`
Int32 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 as OnnxEnumInt,
/// Signed 64-bit int, equivalent to Rust's `i64`
Int64 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 as OnnxEnumInt,
/// String, equivalent to Rust's `String`
String = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING as OnnxEnumInt,
// /// Boolean, equivalent to Rust's `bool`
// Bool = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL as OnnxEnumInt,
// /// 16-bit floating point, equivalent to Rust's `f16`
// Float16 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 as OnnxEnumInt,
/// 64-bit floating point, equivalent to Rust's `f64`
Double = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE as OnnxEnumInt,
/// Unsigned 32-bit int, equivalent to Rust's `u32`
Uint32 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32 as OnnxEnumInt,
/// Unsigned 64-bit int, equivalent to Rust's `u64`
Uint64 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64 as OnnxEnumInt,
// /// Complex 64-bit floating point, equivalent to Rust's `???`
// Complex64 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64 as OnnxEnumInt,
// /// Complex 128-bit floating point, equivalent to Rust's `???`
// Complex128 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128 as OnnxEnumInt,
// /// Brain 16-bit floating point
// Bfloat16 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 as OnnxEnumInt,
}

impl Into<sys::ONNXTensorElementDataType> for TensorElementDataType {
fn into(self) -> sys::ONNXTensorElementDataType {
use TensorElementDataType::*;
match self {
Float => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,
Uint8 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8,
Int8 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8,
Uint16 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16,
Int16 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16,
Int32 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32,
Int64 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64,
String => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING,
// Bool => {
// sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL
// }
// Float16 => {
// sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16
// }
Double => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE,
Uint32 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32,
Uint64 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64,
// Complex64 => {
// sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64
// }
// Complex128 => {
// sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128
// }
// Bfloat16 => {
// sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16
// }
}
}
}

/// Trait used to map Rust types (for example `f32`) to ONNX types (for example `Float`)
pub trait TypeToTensorElementDataType {
/// Return the ONNX type for a Rust type
fn tensor_element_data_type() -> TensorElementDataType;

/// If the type is `String`, returns `Some` with utf8 contents, else `None`.
fn try_utf8_bytes(&self) -> Option<&[u8]>;
}

macro_rules! impl_type_trait {
($type_:ty, $variant:ident) => {
impl TypeToTensorElementDataType for $type_ {
fn tensor_element_data_type() -> TensorElementDataType {
// unsafe { std::mem::transmute(TensorElementDataType::$variant) }
TensorElementDataType::$variant
}

fn try_utf8_bytes(&self) -> Option<&[u8]> {
None
}
}
};
}

impl_type_trait!(f32, Float);
impl_type_trait!(u8, Uint8);
impl_type_trait!(i8, Int8);
impl_type_trait!(u16, Uint16);
impl_type_trait!(i16, Int16);
impl_type_trait!(i32, Int32);
impl_type_trait!(i64, Int64);
// impl_type_trait!(bool, Bool);
// impl_type_trait!(f16, Float16);
impl_type_trait!(f64, Double);
impl_type_trait!(u32, Uint32);
impl_type_trait!(u64, Uint64);
// impl_type_trait!(, Complex64);
// impl_type_trait!(, Complex128);
// impl_type_trait!(, Bfloat16);

/// Adapter for common Rust string types to Onnx strings.
///
/// It should be easy to use both `String` and `&str` as [TensorElementDataType::String] data, but
/// we can't define an automatic implementation for anything that implements `AsRef<str>` as it
/// would conflict with the implementations of [TypeToTensorElementDataType] for primitive numeric
/// types (which might implement `AsRef<str>` at some point in the future).
pub trait Utf8Data {
/// Returns the utf8 contents.
fn utf8_bytes(&self) -> &[u8];
}

impl Utf8Data for String {
fn utf8_bytes(&self) -> &[u8] {
self.as_bytes()
}
}

impl<'a> Utf8Data for &'a str {
fn utf8_bytes(&self) -> &[u8] {
self.as_bytes()
}
}

impl<T: Utf8Data> TypeToTensorElementDataType for T {
fn tensor_element_data_type() -> TensorElementDataType {
TensorElementDataType::String
}

fn try_utf8_bytes(&self) -> Option<&[u8]> {
Some(self.utf8_bytes())
}
}

/// Allocator type
#[derive(Debug, Clone)]
#[repr(i32)]
Expand Down Expand Up @@ -524,7 +382,7 @@ mod test {

#[test]
fn test_char_p_to_string() {
let s = std::ffi::CString::new("foo").unwrap();
let s = ffi::CString::new("foo").unwrap();
let ptr = s.as_c_str().as_ptr();
assert_eq!("foo", char_p_to_string(ptr).unwrap());
}
Expand Down
Loading