diff --git a/onnxruntime/examples/sample.rs b/onnxruntime/examples/sample.rs index 3fbc2670..5563ab4b 100644 --- a/onnxruntime/examples/sample.rs +++ b/onnxruntime/examples/sample.rs @@ -66,9 +66,9 @@ fn run() -> Result<(), Error> { let outputs: Vec> = session.run(input_tensor_values)?; let output: OrtOwnedTensor = outputs[0].try_extract().unwrap(); - assert_eq!(output.shape(), output0_shape.as_slice()); + assert_eq!(output.view().shape(), output0_shape.as_slice()); for i in 0..5 { - println!("Score for class [{}] = {}", i, output[[0, i, 0, 0]]); + println!("Score for class [{}] = {}", i, output.view()[[0, i, 0, 0]]); } Ok(()) diff --git a/onnxruntime/src/error.rs b/onnxruntime/src/error.rs index f49613fe..86280f74 100644 --- a/onnxruntime/src/error.rs +++ b/onnxruntime/src/error.rs @@ -1,6 +1,6 @@ //! Module containing error definitions. -use std::{io, path::PathBuf}; +use std::{io, path::PathBuf, string}; use thiserror::Error; @@ -53,6 +53,12 @@ pub enum OrtError { /// Error occurred when getting ONNX dimensions #[error("Failed to get dimensions: {0}")] GetDimensions(OrtApiError), + /// Error occurred when getting string length + #[error("Failed to get string tensor length: {0}")] + GetStringTensorDataLength(OrtApiError), + /// Error occurred when getting tensor element count + #[error("Failed to get tensor element count: {0}")] + GetTensorShapeElementCount(OrtApiError), /// Error occurred when creating CPU memory information #[error("Failed to get dimensions: {0}")] CreateCpuMemoryInfo(OrtApiError), @@ -77,6 +83,12 @@ pub enum OrtError { /// Error occurred when extracting data from an ONNX tensor into an C array to be used as an `ndarray::ArrayView` #[error("Failed to get tensor data: {0}")] GetTensorMutableData(OrtApiError), + /// Error occurred when extracting string data from an ONNX tensor + #[error("Failed to get tensor string data: {0}")] + GetStringTensorContent(OrtApiError), + /// Error occurred when converting data to a String + #[error("Data was not UTF-8: {0}")] + StringFromUtf8Error(#[from] string::FromUtf8Error), /// Error occurred when downloading a pre-trained ONNX model from the [ONNX Model Zoo](https://github.com/onnx/models) #[error("Failed to download ONNX model: {0}")] diff --git a/onnxruntime/src/session.rs b/onnxruntime/src/session.rs index 232d188d..98099fd6 100644 --- a/onnxruntime/src/session.rs +++ b/onnxruntime/src/session.rs @@ -1,6 +1,6 @@ //! Module containing session types -use std::{ffi::CString, fmt::Debug, path::Path}; +use std::{convert::TryInto as _, ffi::CString, fmt::Debug, path::Path}; #[cfg(not(target_family = "windows"))] use std::os::unix::ffi::OsStrExt; @@ -436,7 +436,7 @@ impl<'a> Session<'a> { output_tensor_ptrs .into_iter() .map(|tensor_ptr| { - let (dims, data_type) = unsafe { + let (dims, data_type, len) = unsafe { call_with_tensor_info(tensor_ptr, |tensor_info_ptr| { get_tensor_dimensions(tensor_info_ptr) .map(|dims| dims.iter().map(|&n| n as usize).collect::>()) @@ -444,6 +444,24 @@ impl<'a> Session<'a> { extract_data_type(tensor_info_ptr) .map(|data_type| (dims, data_type)) }) + .and_then(|(dims, data_type)| { + let mut len = 0_u64; + + call_ort(|ort| { + ort.GetTensorShapeElementCount.unwrap()( + tensor_info_ptr, + &mut len, + ) + }) + .map_err(OrtError::GetTensorShapeElementCount)?; + + Ok(( + dims, + data_type, + len.try_into() + .expect("u64 length could not fit into usize"), + )) + }) }) }?; @@ -451,6 +469,7 @@ impl<'a> Session<'a> { tensor_ptr, memory_info_ref, ndarray::IxDyn(&dims), + len, data_type, )) }) diff --git a/onnxruntime/src/tensor.rs b/onnxruntime/src/tensor.rs index df85e1ed..74e8329c 100644 --- a/onnxruntime/src/tensor.rs +++ b/onnxruntime/src/tensor.rs @@ -30,9 +30,10 @@ pub mod ort_tensor; pub use ort_owned_tensor::{DynOrtTensor, OrtOwnedTensor}; pub use ort_tensor::OrtTensor; -use crate::{OrtError, Result}; +use crate::tensor::ort_owned_tensor::TensorPointerHolder; +use crate::{error::call_ort, OrtError, Result}; use onnxruntime_sys::{self as sys, OnnxEnumInt}; -use std::{fmt, ptr}; +use std::{convert::TryInto as _, ffi, fmt, ptr, rc, result, string}; // FIXME: Use https://docs.rs/bindgen/0.54.1/bindgen/struct.Builder.html#method.rustified_enum // FIXME: Add tests to cover the commented out types @@ -188,14 +189,41 @@ pub trait TensorDataToType: Sized + fmt::Debug { fn tensor_element_data_type() -> TensorElementDataType; /// Extract an `ArrayView` from the ort-owned tensor. - fn extract_array<'t, D>( + fn extract_data<'t, D>( shape: D, - tensor: *mut sys::OrtValue, - ) -> Result> + tensor_element_len: usize, + tensor_ptr: rc::Rc, + ) -> Result> where D: ndarray::Dimension; } +/// Represents the possible ways tensor data can be accessed. +/// +/// This should only be used internally. +#[derive(Debug)] +pub enum TensorData<'t, T, D> +where + D: ndarray::Dimension, +{ + /// Data resides in ort's tensor, in which case the 't lifetime is what makes this valid. + /// This is used for data types whose in-memory form from ort is compatible with Rust's, like + /// primitive numeric types. + TensorPtr { + /// The pointer ort produced. Kept alive so that `array_view` is valid. + ptr: rc::Rc, + /// A view into `ptr` + array_view: ndarray::ArrayView<'t, T, D>, + }, + /// String data is output differently by ort, and of course is also variable size, so it cannot + /// use the same simple pointer representation. + // Since 't outlives this struct, the 't lifetime is more than we need, but no harm done. + Strings { + /// Owned Strings copied out of ort's output + strings: ndarray::Array, + }, +} + /// Implements `OwnedTensorDataToType` for primitives, which can use `GetTensorMutableData` macro_rules! impl_prim_type_from_ort_trait { ($type_:ty, $variant:ident) => { @@ -204,14 +232,20 @@ macro_rules! impl_prim_type_from_ort_trait { TensorElementDataType::$variant } - fn extract_array<'t, D>( + fn extract_data<'t, D>( shape: D, - tensor: *mut sys::OrtValue, - ) -> Result> + _tensor_element_len: usize, + tensor_ptr: rc::Rc, + ) -> Result> where D: ndarray::Dimension, { - extract_primitive_array(shape, tensor) + extract_primitive_array(shape, tensor_ptr.tensor_ptr).map(|v| { + TensorData::TensorPtr { + ptr: tensor_ptr, + array_view: v, + } + }) } } }; @@ -255,3 +289,70 @@ impl_prim_type_from_ort_trait!(i64, Int64); impl_prim_type_from_ort_trait!(f64, Double); impl_prim_type_from_ort_trait!(u32, Uint32); impl_prim_type_from_ort_trait!(u64, Uint64); + +impl TensorDataToType for String { + fn tensor_element_data_type() -> TensorElementDataType { + TensorElementDataType::String + } + + fn extract_data<'t, D: ndarray::Dimension>( + shape: D, + tensor_element_len: usize, + tensor_ptr: rc::Rc, + ) -> Result> { + // Total length of string data, not including \0 suffix + let mut total_length = 0_u64; + unsafe { + call_ort(|ort| { + ort.GetStringTensorDataLength.unwrap()(tensor_ptr.tensor_ptr, &mut total_length) + }) + .map_err(OrtError::GetStringTensorDataLength)? + } + + // In the JNI impl of this, tensor_element_len was included in addition to total_length, + // but that seems contrary to the docs of GetStringTensorDataLength, and those extra bytes + // don't seem to be written to in practice either. + // If the string data actually did go farther, it would panic below when using the offset + // data to get slices for each string. + let mut string_contents = vec![0_u8; total_length as usize]; + // one extra slot so that the total length can go in the last one, making all per-string + // length calculations easy + let mut offsets = vec![0_u64; tensor_element_len as usize + 1]; + + unsafe { + call_ort(|ort| { + ort.GetStringTensorContent.unwrap()( + tensor_ptr.tensor_ptr, + string_contents.as_mut_ptr() as *mut ffi::c_void, + total_length, + offsets.as_mut_ptr(), + tensor_element_len as u64, + ) + }) + .map_err(OrtError::GetStringTensorContent)? + } + + // final offset = overall length so that per-string length calculations work for the last + // string + debug_assert_eq!(0, offsets[tensor_element_len]); + offsets[tensor_element_len] = total_length; + + let strings = offsets + // offsets has 1 extra offset past the end so that all windows work + .windows(2) + .map(|w| { + let start: usize = w[0].try_into().expect("Offset didn't fit into usize"); + let next_start: usize = w[1].try_into().expect("Offset didn't fit into usize"); + + let slice = &string_contents[start..next_start]; + String::from_utf8(slice.into()) + }) + .collect::, string::FromUtf8Error>>() + .map_err(OrtError::StringFromUtf8Error)?; + + let array = ndarray::Array::from_shape_vec(shape, strings) + .expect("Shape extracted from tensor didn't match tensor contents"); + + Ok(TensorData::Strings { strings: array }) + } +} diff --git a/onnxruntime/src/tensor/ort_owned_tensor.rs b/onnxruntime/src/tensor/ort_owned_tensor.rs index 48f48308..f782df1b 100644 --- a/onnxruntime/src/tensor/ort_owned_tensor.rs +++ b/onnxruntime/src/tensor/ort_owned_tensor.rs @@ -2,7 +2,7 @@ use std::{fmt::Debug, ops::Deref, ptr, rc, result}; -use ndarray::{Array, ArrayView}; +use ndarray::ArrayView; use thiserror::Error; use tracing::debug; @@ -12,7 +12,7 @@ use crate::{ error::call_ort, g_ort, memory::MemoryInfo, - tensor::{ndarray_tensor::NdArrayTensor, TensorDataToType, TensorElementDataType}, + tensor::{TensorData, TensorDataToType, TensorElementDataType}, OrtError, }; @@ -46,9 +46,12 @@ pub struct DynOrtTensor<'m, D> where D: ndarray::Dimension, { - tensor_ptr_holder: rc::Rc, + // TODO could this also hold a Vec for strings so that the extracted tensor could then + // hold a Vec<&str>? + tensor_ptr_holder: rc::Rc, memory_info: &'m MemoryInfo, shape: D, + tensor_element_len: usize, data_type: TensorElementDataType, } @@ -60,12 +63,14 @@ where tensor_ptr: *mut sys::OrtValue, memory_info: &'m MemoryInfo, shape: D, + tensor_element_len: usize, data_type: TensorElementDataType, ) -> DynOrtTensor<'m, D> { DynOrtTensor { - tensor_ptr_holder: rc::Rc::from(TensorPointerDropper { tensor_ptr }), + tensor_ptr_holder: rc::Rc::from(TensorPointerHolder { tensor_ptr }), memory_info, shape, + tensor_element_len, data_type, } } @@ -87,6 +92,8 @@ where where T: TensorDataToType + Clone + Debug, 'm: 't, // mem info outlives tensor + D: 't, // not clear why this is needed since we clone the shape, but it doesn't make + // a difference in practice since the shape is extracted from the tensor { if self.data_type != T::tensor_element_data_type() { Err(TensorExtractError::DataTypeMismatch { @@ -107,13 +114,13 @@ where .map_err(OrtError::IsTensor)?; assert_eq!(is_tensor, 1); - let array_view = - T::extract_array(self.shape.clone(), self.tensor_ptr_holder.tensor_ptr)?; + let data = T::extract_data( + self.shape.clone(), + self.tensor_element_len, + rc::Rc::clone(&self.tensor_ptr_holder), + )?; - Ok(OrtOwnedTensor::new( - self.tensor_ptr_holder.clone(), - array_view, - )) + Ok(OrtOwnedTensor { data }) } } } @@ -134,45 +141,69 @@ where T: TensorDataToType, D: ndarray::Dimension, { - /// Keep the pointer alive - tensor_ptr_holder: rc::Rc, - array_view: ArrayView<'t, T, D>, + data: TensorData<'t, T, D>, } -impl<'t, T, D> Deref for OrtOwnedTensor<'t, T, D> +impl<'t, T, D> OrtOwnedTensor<'t, T, D> where T: TensorDataToType, - D: ndarray::Dimension, + D: ndarray::Dimension + 't, { - type Target = ArrayView<'t, T, D>; - - fn deref(&self) -> &Self::Target { - &self.array_view + /// Produce a [ViewHolder] for the underlying data, which + pub fn view<'s>(&'s self) -> ViewHolder<'s, T, D> + where + 't: 's, // tensor ptr can outlive the TensorData + { + ViewHolder::new(&self.data) } } -impl<'t, T, D> OrtOwnedTensor<'t, T, D> +/// An intermediate step on the way to an ArrayView. +// Since Deref has to produce a reference, and the referent can't be a local in deref(), it must +// be a field in a struct. This struct exists only to hold that field. +// Its lifetime 's is bound to the TensorData its view was created around, not the underlying tensor +// pointer, since in the case of strings the data is the Array in the TensorData, not the pointer. +pub struct ViewHolder<'s, T, D> where T: TensorDataToType, D: ndarray::Dimension, { - pub(crate) fn new( - tensor_ptr_holder: rc::Rc, - array_view: ArrayView<'t, T, D>, - ) -> OrtOwnedTensor<'t, T, D> { - OrtOwnedTensor { - tensor_ptr_holder, - array_view, - } - } + array_view: ndarray::ArrayView<'s, T, D>, +} - /// Apply a softmax on the specified axis - pub fn softmax(&self, axis: ndarray::Axis) -> Array +impl<'s, T, D> ViewHolder<'s, T, D> +where + T: TensorDataToType, + D: ndarray::Dimension, +{ + fn new<'t>(data: &'s TensorData<'t, T, D>) -> ViewHolder<'s, T, D> where - D: ndarray::RemoveAxis, - T: ndarray::NdFloat + std::ops::SubAssign + std::ops::DivAssign, + 't: 's, // underlying tensor ptr lives at least as long as TensorData { - self.array_view.softmax(axis) + match data { + TensorData::TensorPtr { array_view, .. } => ViewHolder { + // we already have a view, but creating a view from a view is cheap + array_view: array_view.view(), + }, + TensorData::Strings { strings } => ViewHolder { + // This view creation has to happen here, not at new()'s callsite, because + // a field can't be a reference to another field in the same struct. Thus, we have + // this separate struct to hold the view that refers to the `Array`. + array_view: strings.view(), + }, + } + } +} + +impl<'t, T, D> Deref for ViewHolder<'t, T, D> +where + T: TensorDataToType, + D: ndarray::Dimension, +{ + type Target = ArrayView<'t, T, D>; + + fn deref(&self) -> &Self::Target { + &self.array_view } } @@ -183,11 +214,11 @@ where /// It also avoids needing `OrtOwnedTensor` to keep a reference to `DynOrtTensor`, which would be /// inconvenient. #[derive(Debug)] -pub(crate) struct TensorPointerDropper { - tensor_ptr: *mut sys::OrtValue, +pub struct TensorPointerHolder { + pub(crate) tensor_ptr: *mut sys::OrtValue, } -impl Drop for TensorPointerDropper { +impl Drop for TensorPointerHolder { #[tracing::instrument] fn drop(&mut self) { debug!("Dropping OrtOwnedTensor."); diff --git a/onnxruntime/tests/integration_tests.rs b/onnxruntime/tests/integration_tests.rs index 2a2ea164..a2ccd49b 100644 --- a/onnxruntime/tests/integration_tests.rs +++ b/onnxruntime/tests/integration_tests.rs @@ -18,6 +18,7 @@ mod download { tensor::{DynOrtTensor, OrtOwnedTensor}, GraphOptimizationLevel, LoggingLevel, }; + use onnxruntime::tensor::ndarray_tensor::NdArrayTensor; #[test] fn squeezenet_mushroom() { @@ -63,7 +64,7 @@ mod download { input0_shape[3] as u32, FilterType::Nearest, ) - .to_rgb(); + .to_rgb8(); // Python: // # image[y, x, RGB] @@ -101,6 +102,7 @@ mod download { // and iterate on resulting probabilities, creating an index to later access labels. let output: OrtOwnedTensor<_, _> = outputs[0].try_extract().unwrap(); let mut probabilities: Vec<(usize, f32)> = output + .view() .softmax(ndarray::Axis(1)) .into_iter() .copied() @@ -171,7 +173,7 @@ mod download { input0_shape[3] as u32, FilterType::Nearest, ) - .to_luma(); + .to_luma8(); let array = ndarray::Array::from_shape_fn((1, 1, 28, 28), |(_, c, j, i)| { let pixel = image_buffer.get_pixel(i as u32, j as u32); @@ -190,6 +192,7 @@ mod download { let output: OrtOwnedTensor<_, _> = outputs[0].try_extract().unwrap(); let mut probabilities: Vec<(usize, f32)> = output + .view() .softmax(ndarray::Axis(1)) .into_iter() .copied() @@ -269,7 +272,7 @@ mod download { .join(IMAGE_TO_LOAD), ) .unwrap() - .to_rgb(); + .to_rgb8(); let array = ndarray::Array::from_shape_fn((1, 224, 224, 3), |(_, j, i, c)| { let pixel = image_buffer.get_pixel(i as u32, j as u32); @@ -291,7 +294,7 @@ mod download { outputs[0].try_extract().unwrap(); // The image should have doubled in size - assert_eq!(output.shape(), [1, 448, 448, 3]); + assert_eq!(output.view().shape(), [1, 448, 448, 3]); } } diff --git a/onnxruntime/tests/string_type.rs b/onnxruntime/tests/string_type.rs index fe4c0da9..e07fe5f6 100644 --- a/onnxruntime/tests/string_type.rs +++ b/onnxruntime/tests/string_type.rs @@ -5,7 +5,7 @@ use onnxruntime::tensor::{OrtOwnedTensor, TensorElementDataType}; use onnxruntime::{environment::Environment, tensor::DynOrtTensor, LoggingLevel}; #[test] -fn run_model_with_string_input_output() -> Result<(), Box> { +fn run_model_with_string_1d_input_output() -> Result<(), Box> { let environment = Environment::builder() .with_name("test") .with_log_level(LoggingLevel::Verbose) @@ -30,7 +30,7 @@ fn run_model_with_string_input_output() -> Result<(), Box> { // type = String // dimensions = [None] - let array = ndarray::Array::from(vec!["foo", "bar", "foo", "foo"]); + let array = ndarray::Array::from(vec!["foo", "bar", "foo", "foo", "baz"]); let input_tensor_values = vec![array]; let outputs: Vec> = session.run(input_tensor_values)?; @@ -39,10 +39,19 @@ fn run_model_with_string_input_output() -> Result<(), Box> { assert_eq!(TensorElementDataType::String, outputs[1].data_type()); let int_output: OrtOwnedTensor = outputs[0].try_extract()?; - - assert_eq!(&[0, 1, 0, 0], int_output.as_slice().unwrap()); - - // TODO get the string output once string extraction is implemented + let string_output: OrtOwnedTensor = outputs[1].try_extract()?; + + assert_eq!(&[5], int_output.view().shape()); + assert_eq!(&[3], string_output.view().shape()); + + assert_eq!(&[0, 1, 0, 0, 2], int_output.view().as_slice().unwrap()); + assert_eq!( + vec!["foo", "bar", "baz"] + .into_iter() + .map(|s| s.to_owned()) + .collect::>(), + string_output.view().as_slice().unwrap() + ); Ok(()) }