From 1871703c50051d4ca4a53343e9b807537ecfb924 Mon Sep 17 00:00:00 2001 From: Marshall Pierce Date: Mon, 8 Mar 2021 11:54:50 -0700 Subject: [PATCH] Support the onnx string type in output tensors This approach allocates owned Strings for each element, which works, but stresses the allocator, and incurs unnecessary copying. Part of the complication stems from the limitation that in Rust, a field can't be a reference to another field in the same struct. This means that having a Vec of copied data, referred to by a Vec<&str>, which is then referred to by an ArrayView, requires a sequence of 3 structs to express. Building a Vec gets rid of the references, but also loses the efficiency of 1 allocation with strs pointing into it. --- onnxruntime/examples/sample.rs | 4 +- onnxruntime/src/error.rs | 14 ++- onnxruntime/src/session.rs | 23 +++- onnxruntime/src/tensor.rs | 119 +++++++++++++++++++-- onnxruntime/src/tensor/ort_owned_tensor.rs | 105 +++++++++++------- onnxruntime/tests/string_type.rs | 21 ++-- 6 files changed, 229 insertions(+), 57 deletions(-) diff --git a/onnxruntime/examples/sample.rs b/onnxruntime/examples/sample.rs index 3fbc2670..5563ab4b 100644 --- a/onnxruntime/examples/sample.rs +++ b/onnxruntime/examples/sample.rs @@ -66,9 +66,9 @@ fn run() -> Result<(), Error> { let outputs: Vec> = session.run(input_tensor_values)?; let output: OrtOwnedTensor = outputs[0].try_extract().unwrap(); - assert_eq!(output.shape(), output0_shape.as_slice()); + assert_eq!(output.view().shape(), output0_shape.as_slice()); for i in 0..5 { - println!("Score for class [{}] = {}", i, output[[0, i, 0, 0]]); + println!("Score for class [{}] = {}", i, output.view()[[0, i, 0, 0]]); } Ok(()) diff --git a/onnxruntime/src/error.rs b/onnxruntime/src/error.rs index f49613fe..86280f74 100644 --- a/onnxruntime/src/error.rs +++ b/onnxruntime/src/error.rs @@ -1,6 +1,6 @@ //! Module containing error definitions. -use std::{io, path::PathBuf}; +use std::{io, path::PathBuf, string}; use thiserror::Error; @@ -53,6 +53,12 @@ pub enum OrtError { /// Error occurred when getting ONNX dimensions #[error("Failed to get dimensions: {0}")] GetDimensions(OrtApiError), + /// Error occurred when getting string length + #[error("Failed to get string tensor length: {0}")] + GetStringTensorDataLength(OrtApiError), + /// Error occurred when getting tensor element count + #[error("Failed to get tensor element count: {0}")] + GetTensorShapeElementCount(OrtApiError), /// Error occurred when creating CPU memory information #[error("Failed to get dimensions: {0}")] CreateCpuMemoryInfo(OrtApiError), @@ -77,6 +83,12 @@ pub enum OrtError { /// Error occurred when extracting data from an ONNX tensor into an C array to be used as an `ndarray::ArrayView` #[error("Failed to get tensor data: {0}")] GetTensorMutableData(OrtApiError), + /// Error occurred when extracting string data from an ONNX tensor + #[error("Failed to get tensor string data: {0}")] + GetStringTensorContent(OrtApiError), + /// Error occurred when converting data to a String + #[error("Data was not UTF-8: {0}")] + StringFromUtf8Error(#[from] string::FromUtf8Error), /// Error occurred when downloading a pre-trained ONNX model from the [ONNX Model Zoo](https://github.com/onnx/models) #[error("Failed to download ONNX model: {0}")] diff --git a/onnxruntime/src/session.rs b/onnxruntime/src/session.rs index 232d188d..98099fd6 100644 --- a/onnxruntime/src/session.rs +++ b/onnxruntime/src/session.rs @@ -1,6 +1,6 @@ //! Module containing session types -use std::{ffi::CString, fmt::Debug, path::Path}; +use std::{convert::TryInto as _, ffi::CString, fmt::Debug, path::Path}; #[cfg(not(target_family = "windows"))] use std::os::unix::ffi::OsStrExt; @@ -436,7 +436,7 @@ impl<'a> Session<'a> { output_tensor_ptrs .into_iter() .map(|tensor_ptr| { - let (dims, data_type) = unsafe { + let (dims, data_type, len) = unsafe { call_with_tensor_info(tensor_ptr, |tensor_info_ptr| { get_tensor_dimensions(tensor_info_ptr) .map(|dims| dims.iter().map(|&n| n as usize).collect::>()) @@ -444,6 +444,24 @@ impl<'a> Session<'a> { extract_data_type(tensor_info_ptr) .map(|data_type| (dims, data_type)) }) + .and_then(|(dims, data_type)| { + let mut len = 0_u64; + + call_ort(|ort| { + ort.GetTensorShapeElementCount.unwrap()( + tensor_info_ptr, + &mut len, + ) + }) + .map_err(OrtError::GetTensorShapeElementCount)?; + + Ok(( + dims, + data_type, + len.try_into() + .expect("u64 length could not fit into usize"), + )) + }) }) }?; @@ -451,6 +469,7 @@ impl<'a> Session<'a> { tensor_ptr, memory_info_ref, ndarray::IxDyn(&dims), + len, data_type, )) }) diff --git a/onnxruntime/src/tensor.rs b/onnxruntime/src/tensor.rs index df85e1ed..74e8329c 100644 --- a/onnxruntime/src/tensor.rs +++ b/onnxruntime/src/tensor.rs @@ -30,9 +30,10 @@ pub mod ort_tensor; pub use ort_owned_tensor::{DynOrtTensor, OrtOwnedTensor}; pub use ort_tensor::OrtTensor; -use crate::{OrtError, Result}; +use crate::tensor::ort_owned_tensor::TensorPointerHolder; +use crate::{error::call_ort, OrtError, Result}; use onnxruntime_sys::{self as sys, OnnxEnumInt}; -use std::{fmt, ptr}; +use std::{convert::TryInto as _, ffi, fmt, ptr, rc, result, string}; // FIXME: Use https://docs.rs/bindgen/0.54.1/bindgen/struct.Builder.html#method.rustified_enum // FIXME: Add tests to cover the commented out types @@ -188,14 +189,41 @@ pub trait TensorDataToType: Sized + fmt::Debug { fn tensor_element_data_type() -> TensorElementDataType; /// Extract an `ArrayView` from the ort-owned tensor. - fn extract_array<'t, D>( + fn extract_data<'t, D>( shape: D, - tensor: *mut sys::OrtValue, - ) -> Result> + tensor_element_len: usize, + tensor_ptr: rc::Rc, + ) -> Result> where D: ndarray::Dimension; } +/// Represents the possible ways tensor data can be accessed. +/// +/// This should only be used internally. +#[derive(Debug)] +pub enum TensorData<'t, T, D> +where + D: ndarray::Dimension, +{ + /// Data resides in ort's tensor, in which case the 't lifetime is what makes this valid. + /// This is used for data types whose in-memory form from ort is compatible with Rust's, like + /// primitive numeric types. + TensorPtr { + /// The pointer ort produced. Kept alive so that `array_view` is valid. + ptr: rc::Rc, + /// A view into `ptr` + array_view: ndarray::ArrayView<'t, T, D>, + }, + /// String data is output differently by ort, and of course is also variable size, so it cannot + /// use the same simple pointer representation. + // Since 't outlives this struct, the 't lifetime is more than we need, but no harm done. + Strings { + /// Owned Strings copied out of ort's output + strings: ndarray::Array, + }, +} + /// Implements `OwnedTensorDataToType` for primitives, which can use `GetTensorMutableData` macro_rules! impl_prim_type_from_ort_trait { ($type_:ty, $variant:ident) => { @@ -204,14 +232,20 @@ macro_rules! impl_prim_type_from_ort_trait { TensorElementDataType::$variant } - fn extract_array<'t, D>( + fn extract_data<'t, D>( shape: D, - tensor: *mut sys::OrtValue, - ) -> Result> + _tensor_element_len: usize, + tensor_ptr: rc::Rc, + ) -> Result> where D: ndarray::Dimension, { - extract_primitive_array(shape, tensor) + extract_primitive_array(shape, tensor_ptr.tensor_ptr).map(|v| { + TensorData::TensorPtr { + ptr: tensor_ptr, + array_view: v, + } + }) } } }; @@ -255,3 +289,70 @@ impl_prim_type_from_ort_trait!(i64, Int64); impl_prim_type_from_ort_trait!(f64, Double); impl_prim_type_from_ort_trait!(u32, Uint32); impl_prim_type_from_ort_trait!(u64, Uint64); + +impl TensorDataToType for String { + fn tensor_element_data_type() -> TensorElementDataType { + TensorElementDataType::String + } + + fn extract_data<'t, D: ndarray::Dimension>( + shape: D, + tensor_element_len: usize, + tensor_ptr: rc::Rc, + ) -> Result> { + // Total length of string data, not including \0 suffix + let mut total_length = 0_u64; + unsafe { + call_ort(|ort| { + ort.GetStringTensorDataLength.unwrap()(tensor_ptr.tensor_ptr, &mut total_length) + }) + .map_err(OrtError::GetStringTensorDataLength)? + } + + // In the JNI impl of this, tensor_element_len was included in addition to total_length, + // but that seems contrary to the docs of GetStringTensorDataLength, and those extra bytes + // don't seem to be written to in practice either. + // If the string data actually did go farther, it would panic below when using the offset + // data to get slices for each string. + let mut string_contents = vec![0_u8; total_length as usize]; + // one extra slot so that the total length can go in the last one, making all per-string + // length calculations easy + let mut offsets = vec![0_u64; tensor_element_len as usize + 1]; + + unsafe { + call_ort(|ort| { + ort.GetStringTensorContent.unwrap()( + tensor_ptr.tensor_ptr, + string_contents.as_mut_ptr() as *mut ffi::c_void, + total_length, + offsets.as_mut_ptr(), + tensor_element_len as u64, + ) + }) + .map_err(OrtError::GetStringTensorContent)? + } + + // final offset = overall length so that per-string length calculations work for the last + // string + debug_assert_eq!(0, offsets[tensor_element_len]); + offsets[tensor_element_len] = total_length; + + let strings = offsets + // offsets has 1 extra offset past the end so that all windows work + .windows(2) + .map(|w| { + let start: usize = w[0].try_into().expect("Offset didn't fit into usize"); + let next_start: usize = w[1].try_into().expect("Offset didn't fit into usize"); + + let slice = &string_contents[start..next_start]; + String::from_utf8(slice.into()) + }) + .collect::, string::FromUtf8Error>>() + .map_err(OrtError::StringFromUtf8Error)?; + + let array = ndarray::Array::from_shape_vec(shape, strings) + .expect("Shape extracted from tensor didn't match tensor contents"); + + Ok(TensorData::Strings { strings: array }) + } +} diff --git a/onnxruntime/src/tensor/ort_owned_tensor.rs b/onnxruntime/src/tensor/ort_owned_tensor.rs index 48f48308..f782df1b 100644 --- a/onnxruntime/src/tensor/ort_owned_tensor.rs +++ b/onnxruntime/src/tensor/ort_owned_tensor.rs @@ -2,7 +2,7 @@ use std::{fmt::Debug, ops::Deref, ptr, rc, result}; -use ndarray::{Array, ArrayView}; +use ndarray::ArrayView; use thiserror::Error; use tracing::debug; @@ -12,7 +12,7 @@ use crate::{ error::call_ort, g_ort, memory::MemoryInfo, - tensor::{ndarray_tensor::NdArrayTensor, TensorDataToType, TensorElementDataType}, + tensor::{TensorData, TensorDataToType, TensorElementDataType}, OrtError, }; @@ -46,9 +46,12 @@ pub struct DynOrtTensor<'m, D> where D: ndarray::Dimension, { - tensor_ptr_holder: rc::Rc, + // TODO could this also hold a Vec for strings so that the extracted tensor could then + // hold a Vec<&str>? + tensor_ptr_holder: rc::Rc, memory_info: &'m MemoryInfo, shape: D, + tensor_element_len: usize, data_type: TensorElementDataType, } @@ -60,12 +63,14 @@ where tensor_ptr: *mut sys::OrtValue, memory_info: &'m MemoryInfo, shape: D, + tensor_element_len: usize, data_type: TensorElementDataType, ) -> DynOrtTensor<'m, D> { DynOrtTensor { - tensor_ptr_holder: rc::Rc::from(TensorPointerDropper { tensor_ptr }), + tensor_ptr_holder: rc::Rc::from(TensorPointerHolder { tensor_ptr }), memory_info, shape, + tensor_element_len, data_type, } } @@ -87,6 +92,8 @@ where where T: TensorDataToType + Clone + Debug, 'm: 't, // mem info outlives tensor + D: 't, // not clear why this is needed since we clone the shape, but it doesn't make + // a difference in practice since the shape is extracted from the tensor { if self.data_type != T::tensor_element_data_type() { Err(TensorExtractError::DataTypeMismatch { @@ -107,13 +114,13 @@ where .map_err(OrtError::IsTensor)?; assert_eq!(is_tensor, 1); - let array_view = - T::extract_array(self.shape.clone(), self.tensor_ptr_holder.tensor_ptr)?; + let data = T::extract_data( + self.shape.clone(), + self.tensor_element_len, + rc::Rc::clone(&self.tensor_ptr_holder), + )?; - Ok(OrtOwnedTensor::new( - self.tensor_ptr_holder.clone(), - array_view, - )) + Ok(OrtOwnedTensor { data }) } } } @@ -134,45 +141,69 @@ where T: TensorDataToType, D: ndarray::Dimension, { - /// Keep the pointer alive - tensor_ptr_holder: rc::Rc, - array_view: ArrayView<'t, T, D>, + data: TensorData<'t, T, D>, } -impl<'t, T, D> Deref for OrtOwnedTensor<'t, T, D> +impl<'t, T, D> OrtOwnedTensor<'t, T, D> where T: TensorDataToType, - D: ndarray::Dimension, + D: ndarray::Dimension + 't, { - type Target = ArrayView<'t, T, D>; - - fn deref(&self) -> &Self::Target { - &self.array_view + /// Produce a [ViewHolder] for the underlying data, which + pub fn view<'s>(&'s self) -> ViewHolder<'s, T, D> + where + 't: 's, // tensor ptr can outlive the TensorData + { + ViewHolder::new(&self.data) } } -impl<'t, T, D> OrtOwnedTensor<'t, T, D> +/// An intermediate step on the way to an ArrayView. +// Since Deref has to produce a reference, and the referent can't be a local in deref(), it must +// be a field in a struct. This struct exists only to hold that field. +// Its lifetime 's is bound to the TensorData its view was created around, not the underlying tensor +// pointer, since in the case of strings the data is the Array in the TensorData, not the pointer. +pub struct ViewHolder<'s, T, D> where T: TensorDataToType, D: ndarray::Dimension, { - pub(crate) fn new( - tensor_ptr_holder: rc::Rc, - array_view: ArrayView<'t, T, D>, - ) -> OrtOwnedTensor<'t, T, D> { - OrtOwnedTensor { - tensor_ptr_holder, - array_view, - } - } + array_view: ndarray::ArrayView<'s, T, D>, +} - /// Apply a softmax on the specified axis - pub fn softmax(&self, axis: ndarray::Axis) -> Array +impl<'s, T, D> ViewHolder<'s, T, D> +where + T: TensorDataToType, + D: ndarray::Dimension, +{ + fn new<'t>(data: &'s TensorData<'t, T, D>) -> ViewHolder<'s, T, D> where - D: ndarray::RemoveAxis, - T: ndarray::NdFloat + std::ops::SubAssign + std::ops::DivAssign, + 't: 's, // underlying tensor ptr lives at least as long as TensorData { - self.array_view.softmax(axis) + match data { + TensorData::TensorPtr { array_view, .. } => ViewHolder { + // we already have a view, but creating a view from a view is cheap + array_view: array_view.view(), + }, + TensorData::Strings { strings } => ViewHolder { + // This view creation has to happen here, not at new()'s callsite, because + // a field can't be a reference to another field in the same struct. Thus, we have + // this separate struct to hold the view that refers to the `Array`. + array_view: strings.view(), + }, + } + } +} + +impl<'t, T, D> Deref for ViewHolder<'t, T, D> +where + T: TensorDataToType, + D: ndarray::Dimension, +{ + type Target = ArrayView<'t, T, D>; + + fn deref(&self) -> &Self::Target { + &self.array_view } } @@ -183,11 +214,11 @@ where /// It also avoids needing `OrtOwnedTensor` to keep a reference to `DynOrtTensor`, which would be /// inconvenient. #[derive(Debug)] -pub(crate) struct TensorPointerDropper { - tensor_ptr: *mut sys::OrtValue, +pub struct TensorPointerHolder { + pub(crate) tensor_ptr: *mut sys::OrtValue, } -impl Drop for TensorPointerDropper { +impl Drop for TensorPointerHolder { #[tracing::instrument] fn drop(&mut self) { debug!("Dropping OrtOwnedTensor."); diff --git a/onnxruntime/tests/string_type.rs b/onnxruntime/tests/string_type.rs index fe4c0da9..e07fe5f6 100644 --- a/onnxruntime/tests/string_type.rs +++ b/onnxruntime/tests/string_type.rs @@ -5,7 +5,7 @@ use onnxruntime::tensor::{OrtOwnedTensor, TensorElementDataType}; use onnxruntime::{environment::Environment, tensor::DynOrtTensor, LoggingLevel}; #[test] -fn run_model_with_string_input_output() -> Result<(), Box> { +fn run_model_with_string_1d_input_output() -> Result<(), Box> { let environment = Environment::builder() .with_name("test") .with_log_level(LoggingLevel::Verbose) @@ -30,7 +30,7 @@ fn run_model_with_string_input_output() -> Result<(), Box> { // type = String // dimensions = [None] - let array = ndarray::Array::from(vec!["foo", "bar", "foo", "foo"]); + let array = ndarray::Array::from(vec!["foo", "bar", "foo", "foo", "baz"]); let input_tensor_values = vec![array]; let outputs: Vec> = session.run(input_tensor_values)?; @@ -39,10 +39,19 @@ fn run_model_with_string_input_output() -> Result<(), Box> { assert_eq!(TensorElementDataType::String, outputs[1].data_type()); let int_output: OrtOwnedTensor = outputs[0].try_extract()?; - - assert_eq!(&[0, 1, 0, 0], int_output.as_slice().unwrap()); - - // TODO get the string output once string extraction is implemented + let string_output: OrtOwnedTensor = outputs[1].try_extract()?; + + assert_eq!(&[5], int_output.view().shape()); + assert_eq!(&[3], string_output.view().shape()); + + assert_eq!(&[0, 1, 0, 0, 2], int_output.view().as_slice().unwrap()); + assert_eq!( + vec!["foo", "bar", "baz"] + .into_iter() + .map(|s| s.to_owned()) + .collect::>(), + string_output.view().as_slice().unwrap() + ); Ok(()) }