Skip to content

Commit

Permalink
Support the onnx string type in output tensors
Browse files Browse the repository at this point in the history
This approach allocates owned Strings for each element, which works, but stresses the allocator, and incurs unnecessary copying.

Part of the complication stems from the limitation that in Rust, a field can't be a reference to another field in the same struct. This means that having a Vec<u8> of copied data, referred to by a Vec<&str>, which is then referred to by an ArrayView, requires a sequence of 3 structs to express. Building a Vec<String> gets rid of the references, but also loses the efficiency of 1 allocation with strs pointing into it.
  • Loading branch information
marshallpierce committed Mar 9, 2021
1 parent 555bec7 commit 80b68d1
Show file tree
Hide file tree
Showing 7 changed files with 236 additions and 61 deletions.
4 changes: 2 additions & 2 deletions onnxruntime/examples/sample.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ fn run() -> Result<(), Error> {
let outputs: Vec<DynOrtTensor<_>> = session.run(input_tensor_values)?;

let output: OrtOwnedTensor<f32, _> = outputs[0].try_extract().unwrap();
assert_eq!(output.shape(), output0_shape.as_slice());
assert_eq!(output.view().shape(), output0_shape.as_slice());
for i in 0..5 {
println!("Score for class [{}] = {}", i, output[[0, i, 0, 0]]);
println!("Score for class [{}] = {}", i, output.view()[[0, i, 0, 0]]);
}

Ok(())
Expand Down
14 changes: 13 additions & 1 deletion onnxruntime/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Module containing error definitions.
use std::{io, path::PathBuf};
use std::{io, path::PathBuf, string};

use thiserror::Error;

Expand Down Expand Up @@ -53,6 +53,12 @@ pub enum OrtError {
/// Error occurred when getting ONNX dimensions
#[error("Failed to get dimensions: {0}")]
GetDimensions(OrtApiError),
/// Error occurred when getting string length
#[error("Failed to get string tensor length: {0}")]
GetStringTensorDataLength(OrtApiError),
/// Error occurred when getting tensor element count
#[error("Failed to get tensor element count: {0}")]
GetTensorShapeElementCount(OrtApiError),
/// Error occurred when creating CPU memory information
#[error("Failed to get dimensions: {0}")]
CreateCpuMemoryInfo(OrtApiError),
Expand All @@ -77,6 +83,12 @@ pub enum OrtError {
/// Error occurred when extracting data from an ONNX tensor into an C array to be used as an `ndarray::ArrayView`
#[error("Failed to get tensor data: {0}")]
GetTensorMutableData(OrtApiError),
/// Error occurred when extracting string data from an ONNX tensor
#[error("Failed to get tensor string data: {0}")]
GetStringTensorContent(OrtApiError),
/// Error occurred when converting data to a String
#[error("Data was not UTF-8: {0}")]
StringFromUtf8Error(#[from] string::FromUtf8Error),

/// Error occurred when downloading a pre-trained ONNX model from the [ONNX Model Zoo](https://github.com/onnx/models)
#[error("Failed to download ONNX model: {0}")]
Expand Down
23 changes: 21 additions & 2 deletions onnxruntime/src/session.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Module containing session types
use std::{ffi::CString, fmt::Debug, path::Path};
use std::{convert::TryInto as _, ffi::CString, fmt::Debug, path::Path};

#[cfg(not(target_family = "windows"))]
use std::os::unix::ffi::OsStrExt;
Expand Down Expand Up @@ -436,21 +436,40 @@ impl<'a> Session<'a> {
output_tensor_ptrs
.into_iter()
.map(|tensor_ptr| {
let (dims, data_type) = unsafe {
let (dims, data_type, len) = unsafe {
call_with_tensor_info(tensor_ptr, |tensor_info_ptr| {
get_tensor_dimensions(tensor_info_ptr)
.map(|dims| dims.iter().map(|&n| n as usize).collect::<Vec<_>>())
.and_then(|dims| {
extract_data_type(tensor_info_ptr)
.map(|data_type| (dims, data_type))
})
.and_then(|(dims, data_type)| {
let mut len = 0_u64;

call_ort(|ort| {
ort.GetTensorShapeElementCount.unwrap()(
tensor_info_ptr,
&mut len,
)
})
.map_err(OrtError::GetTensorShapeElementCount)?;

Ok((
dims,
data_type,
len.try_into()
.expect("u64 length could not fit into usize"),
))
})
})
}?;

Ok(DynOrtTensor::new(
tensor_ptr,
memory_info_ref,
ndarray::IxDyn(&dims),
len,
data_type,
))
})
Expand Down
119 changes: 110 additions & 9 deletions onnxruntime/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@ pub mod ort_tensor;
pub use ort_owned_tensor::{DynOrtTensor, OrtOwnedTensor};
pub use ort_tensor::OrtTensor;

use crate::{OrtError, Result};
use crate::tensor::ort_owned_tensor::TensorPointerHolder;
use crate::{error::call_ort, OrtError, Result};
use onnxruntime_sys::{self as sys, OnnxEnumInt};
use std::{fmt, ptr};
use std::{convert::TryInto as _, ffi, fmt, ptr, rc, result, string};

// FIXME: Use https://docs.rs/bindgen/0.54.1/bindgen/struct.Builder.html#method.rustified_enum
// FIXME: Add tests to cover the commented out types
Expand Down Expand Up @@ -188,14 +189,41 @@ pub trait TensorDataToType: Sized + fmt::Debug {
fn tensor_element_data_type() -> TensorElementDataType;

/// Extract an `ArrayView` from the ort-owned tensor.
fn extract_array<'t, D>(
fn extract_data<'t, D>(
shape: D,
tensor: *mut sys::OrtValue,
) -> Result<ndarray::ArrayView<'t, Self, D>>
tensor_element_len: usize,
tensor_ptr: rc::Rc<TensorPointerHolder>,
) -> Result<TensorData<'t, Self, D>>
where
D: ndarray::Dimension;
}

/// Represents the possible ways tensor data can be accessed.
///
/// This should only be used internally.
#[derive(Debug)]
pub enum TensorData<'t, T, D>
where
D: ndarray::Dimension,
{
/// Data resides in ort's tensor, in which case the 't lifetime is what makes this valid.
/// This is used for data types whose in-memory form from ort is compatible with Rust's, like
/// primitive numeric types.
TensorPtr {
/// The pointer ort produced. Kept alive so that `array_view` is valid.
ptr: rc::Rc<TensorPointerHolder>,
/// A view into `ptr`
array_view: ndarray::ArrayView<'t, T, D>,
},
/// String data is output differently by ort, and of course is also variable size, so it cannot
/// use the same simple pointer representation.
// Since 't outlives this struct, the 't lifetime is more than we need, but no harm done.
Strings {
/// Owned Strings copied out of ort's output
strings: ndarray::Array<T, D>,
},
}

/// Implements `OwnedTensorDataToType` for primitives, which can use `GetTensorMutableData`
macro_rules! impl_prim_type_from_ort_trait {
($type_:ty, $variant:ident) => {
Expand All @@ -204,14 +232,20 @@ macro_rules! impl_prim_type_from_ort_trait {
TensorElementDataType::$variant
}

fn extract_array<'t, D>(
fn extract_data<'t, D>(
shape: D,
tensor: *mut sys::OrtValue,
) -> Result<ndarray::ArrayView<'t, Self, D>>
_tensor_element_len: usize,
tensor_ptr: rc::Rc<TensorPointerHolder>,
) -> Result<TensorData<'t, Self, D>>
where
D: ndarray::Dimension,
{
extract_primitive_array(shape, tensor)
extract_primitive_array(shape, tensor_ptr.tensor_ptr).map(|v| {
TensorData::TensorPtr {
ptr: tensor_ptr,
array_view: v,
}
})
}
}
};
Expand Down Expand Up @@ -255,3 +289,70 @@ impl_prim_type_from_ort_trait!(i64, Int64);
impl_prim_type_from_ort_trait!(f64, Double);
impl_prim_type_from_ort_trait!(u32, Uint32);
impl_prim_type_from_ort_trait!(u64, Uint64);

impl TensorDataToType for String {
fn tensor_element_data_type() -> TensorElementDataType {
TensorElementDataType::String
}

fn extract_data<'t, D: ndarray::Dimension>(
shape: D,
tensor_element_len: usize,
tensor_ptr: rc::Rc<TensorPointerHolder>,
) -> Result<TensorData<'t, Self, D>> {
// Total length of string data, not including \0 suffix
let mut total_length = 0_u64;
unsafe {
call_ort(|ort| {
ort.GetStringTensorDataLength.unwrap()(tensor_ptr.tensor_ptr, &mut total_length)
})
.map_err(OrtError::GetStringTensorDataLength)?
}

// In the JNI impl of this, tensor_element_len was included in addition to total_length,
// but that seems contrary to the docs of GetStringTensorDataLength, and those extra bytes
// don't seem to be written to in practice either.
// If the string data actually did go farther, it would panic below when using the offset
// data to get slices for each string.
let mut string_contents = vec![0_u8; total_length as usize];
// one extra slot so that the total length can go in the last one, making all per-string
// length calculations easy
let mut offsets = vec![0_u64; tensor_element_len as usize + 1];

unsafe {
call_ort(|ort| {
ort.GetStringTensorContent.unwrap()(
tensor_ptr.tensor_ptr,
string_contents.as_mut_ptr() as *mut ffi::c_void,
total_length,
offsets.as_mut_ptr(),
tensor_element_len as u64,
)
})
.map_err(OrtError::GetStringTensorContent)?
}

// final offset = overall length so that per-string length calculations work for the last
// string
debug_assert_eq!(0, offsets[tensor_element_len]);
offsets[tensor_element_len] = total_length;

let strings = offsets
// offsets has 1 extra offset past the end so that all windows work
.windows(2)
.map(|w| {
let start: usize = w[0].try_into().expect("Offset didn't fit into usize");
let next_start: usize = w[1].try_into().expect("Offset didn't fit into usize");

let slice = &string_contents[start..next_start];
String::from_utf8(slice.into())
})
.collect::<result::Result<Vec<String>, string::FromUtf8Error>>()
.map_err(OrtError::StringFromUtf8Error)?;

let array = ndarray::Array::from_shape_vec(shape, strings)
.expect("Shape extracted from tensor didn't match tensor contents");

Ok(TensorData::Strings { strings: array })
}
}
Loading

0 comments on commit 80b68d1

Please sign in to comment.