diff --git a/safetensors/src/tensor.rs b/safetensors/src/tensor.rs index e5b8b4b8..a288663e 100644 --- a/safetensors/src/tensor.rs +++ b/safetensors/src/tensor.rs @@ -345,10 +345,10 @@ impl<'data> SafeTensors<'data> { Ok(Self { metadata, data }) } - /// Allow the user to iterate over tensors within the SafeTensors. + /// Returns the tensors contained within the SafeTensors. /// The tensors returned are merely views and the data is not owned by this /// structure. - pub fn tensors(&self) -> Vec<(String, TensorView<'_>)> { + pub fn tensors(&self) -> Vec<(String, TensorView<'data>)> { let mut tensors = Vec::with_capacity(self.metadata.index_map.len()); for (name, &index) in &self.metadata.index_map { let info = &self.metadata.tensors[index]; @@ -362,10 +362,24 @@ impl<'data> SafeTensors<'data> { tensors } + /// Returns an iterator over the tensors contained within the SafeTensors. + /// The tensors returned are merely views and the data is not owned by this + /// structure. + pub fn iter<'a>(&'a self) -> impl Iterator)> { + self.metadata.index_map.iter().map(|(name, &idx)| { + let info = &self.metadata.tensors[idx]; + (name.as_str(), TensorView { + dtype: info.dtype, + shape: info.shape.clone(), + data: &self.data[info.data_offsets.0..info.data_offsets.1], + }) + }) + } + /// Allow the user to get a specific tensor within the SafeTensors. /// The tensor returned is merely a view and the data is not owned by this /// structure. - pub fn tensor(&self, tensor_name: &str) -> Result, SafeTensorError> { + pub fn tensor(&self, tensor_name: &str) -> Result, SafeTensorError> { if let Some(index) = &self.metadata.index_map.get(tensor_name) { if let Some(info) = &self.metadata.tensors.get(**index) { Ok(TensorView { @@ -541,7 +555,7 @@ impl Metadata { /// A view of a Tensor within the file. /// Contains references to data within the full byte-buffer /// And is thus a readable view of a single tensor -#[derive(Debug, PartialEq, Eq)] +#[derive(Debug, PartialEq, Eq, Clone)] pub struct TensorView<'data> { dtype: Dtype, shape: Vec, @@ -1038,6 +1052,21 @@ mod tests { assert_eq!(tensor.data(), b"\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0"); } + #[test] + fn test_lifetimes() { + let serialized = b"<\x00\x00\x00\x00\x00\x00\x00{\"test\":{\"dtype\":\"I32\",\"shape\":[2,2],\"data_offsets\":[0,16]}}\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"; + + let tensor = { + let loaded = SafeTensors::deserialize(serialized).unwrap(); + loaded.tensor("test").unwrap() + }; + + assert_eq!(tensor.shape(), vec![2, 2]); + assert_eq!(tensor.dtype(), Dtype::I32); + // 16 bytes + assert_eq!(tensor.data(), b"\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0"); + } + #[test] fn test_json_attack() { let mut tensors = HashMap::new();