Skip to content

Commit

Permalink
fix lifetimes, add tensors iterator (#518)
Browse files Browse the repository at this point in the history
* fix lifetimes, add tensors iterator

* add test for lifetimes
  • Loading branch information
gvilums authored Sep 5, 2024
1 parent cafcd3d commit ba2e397
Showing 1 changed file with 33 additions and 4 deletions.
37 changes: 33 additions & 4 deletions safetensors/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -345,10 +345,10 @@ impl<'data> SafeTensors<'data> {
Ok(Self { metadata, data })
}

/// Allow the user to iterate over tensors within the SafeTensors.
/// Returns the tensors contained within the SafeTensors.
/// The tensors returned are merely views and the data is not owned by this
/// structure.
pub fn tensors(&self) -> Vec<(String, TensorView<'_>)> {
pub fn tensors(&self) -> Vec<(String, TensorView<'data>)> {
let mut tensors = Vec::with_capacity(self.metadata.index_map.len());
for (name, &index) in &self.metadata.index_map {
let info = &self.metadata.tensors[index];
Expand All @@ -362,10 +362,24 @@ impl<'data> SafeTensors<'data> {
tensors
}

/// Returns an iterator over the tensors contained within the SafeTensors.
/// The tensors returned are merely views and the data is not owned by this
/// structure.
pub fn iter<'a>(&'a self) -> impl Iterator<Item = (&'a str, TensorView<'data>)> {
self.metadata.index_map.iter().map(|(name, &idx)| {
let info = &self.metadata.tensors[idx];
(name.as_str(), TensorView {
dtype: info.dtype,
shape: info.shape.clone(),
data: &self.data[info.data_offsets.0..info.data_offsets.1],
})
})
}

/// Allow the user to get a specific tensor within the SafeTensors.
/// The tensor returned is merely a view and the data is not owned by this
/// structure.
pub fn tensor(&self, tensor_name: &str) -> Result<TensorView<'_>, SafeTensorError> {
pub fn tensor(&self, tensor_name: &str) -> Result<TensorView<'data>, SafeTensorError> {
if let Some(index) = &self.metadata.index_map.get(tensor_name) {
if let Some(info) = &self.metadata.tensors.get(**index) {
Ok(TensorView {
Expand Down Expand Up @@ -541,7 +555,7 @@ impl Metadata {
/// A view of a Tensor within the file.
/// Contains references to data within the full byte-buffer
/// And is thus a readable view of a single tensor
#[derive(Debug, PartialEq, Eq)]
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct TensorView<'data> {
dtype: Dtype,
shape: Vec<usize>,
Expand Down Expand Up @@ -1038,6 +1052,21 @@ mod tests {
assert_eq!(tensor.data(), b"\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0");
}

#[test]
fn test_lifetimes() {
let serialized = b"<\x00\x00\x00\x00\x00\x00\x00{\"test\":{\"dtype\":\"I32\",\"shape\":[2,2],\"data_offsets\":[0,16]}}\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00";

let tensor = {
let loaded = SafeTensors::deserialize(serialized).unwrap();
loaded.tensor("test").unwrap()
};

assert_eq!(tensor.shape(), vec![2, 2]);
assert_eq!(tensor.dtype(), Dtype::I32);
// 16 bytes
assert_eq!(tensor.data(), b"\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0");
}

#[test]
fn test_json_attack() {
let mut tensors = HashMap::new();
Expand Down

0 comments on commit ba2e397

Please sign in to comment.