-
Notifications
You must be signed in to change notification settings - Fork 100
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce a new trait to represent types that can be used as output from a tensor #62
base: master
Are you sure you want to change the base?
Conversation
8b6f1b5
to
1bfa0ed
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is extracted from a branch adding both flexible output types (for models whose outputs aren't all the same type) and string output types, and it was just getting too big. I think it makes sense to do dynamic output types first, as even my tiny string output sample model has both a string and an i32 output.
@@ -322,154 +322,6 @@ impl Into<sys::GraphOptimizationLevel> for GraphOptimizationLevel { | |||
} | |||
} | |||
|
|||
// FIXME: Use https://docs.rs/bindgen/0.54.1/bindgen/struct.Builder.html#method.rustified_enum |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved as-is into tensor.rs
. I figured with this + the new trait it was making lib.rs
pretty fat, and I think crate::tensor
is a reasonable home for these types. WDYT?
output_tensor_extractor.extract::<TOut>() | ||
.map(|tensor_ptr| { | ||
let dims = unsafe { | ||
call_with_tensor_info(tensor_ptr, |tensor_info_ptr| { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A little pre-cleanup -- in my WIP branch to extract string output, I needed to get multiple things out of the tensor info, so I made this helper to make it hard to forget to clean up tensor_info_ptr
in all error handling cases. Soon this will be getting more than just dims
out of the info ptr.
|
||
// Note: Both tensor and array will point to the same data, nothing is copied. | ||
// As such, there is no need to free the pointer used to create the ArrayView. | ||
assert_ne!(tensor_ptr, std::ptr::null_mut()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
taken from what was OrtOwnedTensorExtractor
, the remaining contents of which now lives mostly in the numeric type impls of TensorDataToType::extract_array
num_dims, | ||
); | ||
status_to_result(status).map_err(OrtError::GetDimensions)?; | ||
call_ort(|ort| { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just converting a few g_ort()'s to call_ort
Ok(node_dims) | ||
} | ||
|
||
unsafe fn extract_data_type( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
extracted a function for this logic as-is as I had need of it in one additional callsite for strings
onnxruntime/src/tensor.rs
Outdated
} | ||
|
||
/// Trait used to map onnxruntime types to Rust types | ||
pub trait TensorDataToType: Sized + fmt::Debug { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is where the new stuff begins in this file
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've been playing around with this and the String output PR. It looks good from what I've seen so far. One thing I did notice is that this trait doesn't include Clone
as mentioned in your other comment here. This seems like it'd be necessary if your want to get an ArrayView<OwnedRepr<T>, Dim<IxDynImpl>>
.
e.g.
The following complains about not being able to move out of the dereference the ArrayView
let ort_owned_tensor = output.try_extract::<f32>().unwrap();
let tensor = ort_owned_tensor.view().into_owned();
but adding in the Clone
trait here allows you to do the following
let ort_owned_tensor = output.try_extract::<f32>().unwrap();
let tensor = ort_owned_tensor.view().clone().into_owned();
There may be some other solution to this that I'm missing, but I thought I'd point it out.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, not that I'm aware of... Having an OwnedRepr
didn't occur to me as a use case but I suppose in practice there's little concern about requiring Clone since the data types would involved are almost certainly already implementing it. I'll add it back in.
/// Trait used to map onnxruntime types to Rust types | ||
pub trait TensorDataToType: Sized + fmt::Debug { | ||
/// The tensor element type that this type can extract from | ||
fn tensor_element_data_type() -> TensorElementDataType; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function isn't needed just yet but it will be shortly so I figured it might as well go in now 🤷
D: ndarray::Dimension, | ||
{ | ||
// Get pointer to output tensor float values | ||
let mut output_array_ptr: *mut T = ptr::null_mut(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
previously this was in the extractor
@@ -25,18 +23,18 @@ use crate::{ | |||
#[derive(Debug)] | |||
pub struct OrtOwnedTensor<'t, 'm, T, D> | |||
where | |||
T: TypeToTensorElementDataType + Debug + Clone, | |||
T: TensorDataToType, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TensorDataToType
includes Debug
, and Clone
didn't seem to be needed any longer so I ditched it
…rom a tensor This is some prep work for string output types and tensor types that vary across the model outputs. For now, the supported types are just the basic numeric types. Since strings have to be copied out of a tensor, it only makes sense to have `String` be an output type, not `str`, hence the new type so that we can have more input types supported than output types.
1bfa0ed
to
0d34d23
Compare
Codecov Report
@@ Coverage Diff @@
## master #62 +/- ##
==========================================
+ Coverage 14.58% 17.94% +3.36%
==========================================
Files 18 19 +1
Lines 960 1003 +43
==========================================
+ Hits 140 180 +40
- Misses 820 823 +3
Continue to review full report at Codecov.
|
This is some prep work for string output types and tensor types that vary across the model outputs. For now, the supported types are just the basic numeric types.
Since strings have to be copied out of a tensor, it only makes sense to have
String
be an output type, notstr
, hence the new type so that we can have more input types supported than output types.