From 38a36298c4b66bd845029d0fa4b41ac716f7a9dd Mon Sep 17 00:00:00 2001 From: ZC <390810839@qq.com> Date: Tue, 7 Jan 2025 19:48:53 +0800 Subject: [PATCH] support no_std (#556) * support no_std (#544) * Simpler clippy check (no features in safetensors really). --------- Co-authored-by: Nicolas Patry --- .github/workflows/rust.yml | 5 ++++- safetensors/Cargo.toml | 10 +++++++-- safetensors/src/lib.rs | 40 ++++++++++++++++++++++++++++++++- safetensors/src/slice.rs | 10 ++++----- safetensors/src/tensor.rs | 45 ++++++++++++++++++++++---------------- 5 files changed, 82 insertions(+), 28 deletions(-) diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 1c816497..e7ac0684 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -36,11 +36,14 @@ jobs: run: cargo build --all-targets --verbose - name: Lint with Clippy - run: cargo clippy --all-targets --all-features -- -D warnings + run: cargo clippy --all-targets -- -D warnings - name: Run Tests run: cargo test --verbose + - name: Run No-STD Tests + run: cargo test --no-default-features --features alloc --verbose + - name: Run Audit # RUSTSEC-2021-0145 is criterion so only within benchmarks run: cargo audit -D warnings --ignore RUSTSEC-2021-0145 diff --git a/safetensors/Cargo.toml b/safetensors/Cargo.toml index d59a6e18..02bcc2f0 100644 --- a/safetensors/Cargo.toml +++ b/safetensors/Cargo.toml @@ -21,14 +21,20 @@ exclude = [ "rust-toolchain", "target/*", "Cargo.lock"] # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -serde = {version = "1.0", features = ["derive"]} -serde_json = "1.0" +hashbrown = { version = "0.15.2", features = ["serde"], optional = true } +serde = { version = "1.0", default-features = false, features = ["derive"] } +serde_json = { version = "1.0", default-features = false } [dev-dependencies] criterion = "0.5" memmap2 = "0.9" proptest = "1.4" +[features] +default = ["std"] +std = ["serde/default", "serde_json/default"] +alloc = ["serde/alloc", "serde_json/alloc", "hashbrown"] + [[bench]] name = "benchmark" harness = false diff --git a/safetensors/src/lib.rs b/safetensors/src/lib.rs index 48d8d521..4020c9dc 100644 --- a/safetensors/src/lib.rs +++ b/safetensors/src/lib.rs @@ -1,5 +1,43 @@ #![deny(missing_docs)] #![doc = include_str!("../README.md")] +#![cfg_attr(not(feature = "std"), no_std)] pub mod slice; pub mod tensor; -pub use tensor::{serialize, serialize_to_file, Dtype, SafeTensorError, SafeTensors, View}; +/// serialize_to_file only valid in std +#[cfg(feature = "std")] +pub use tensor::serialize_to_file; +pub use tensor::{serialize, Dtype, SafeTensorError, SafeTensors, View}; + +#[cfg(feature = "alloc")] +#[macro_use] +extern crate alloc; + +#[cfg(all(feature = "std", feature = "alloc"))] +compile_error!("must choose either the `std` or `alloc` feature, but not both."); +#[cfg(all(not(feature = "std"), not(feature = "alloc")))] +compile_error!("must choose either the `std` or `alloc` feature"); + +/// A facade around all the types we need from the `std`, `core`, and `alloc` +/// crates. This avoids elaborate import wrangling having to happen in every +/// module. +mod lib { + #[cfg(not(feature = "std"))] + mod no_stds { + pub use alloc::borrow::Cow; + pub use alloc::string::{String, ToString}; + pub use alloc::vec::Vec; + pub use hashbrown::HashMap; + } + #[cfg(feature = "std")] + mod stds { + pub use std::borrow::Cow; + pub use std::collections::HashMap; + pub use std::string::{String, ToString}; + pub use std::vec::Vec; + } + /// choose std or no_std to export by feature flag + #[cfg(not(feature = "std"))] + pub use no_stds::*; + #[cfg(feature = "std")] + pub use stds::*; +} diff --git a/safetensors/src/slice.rs b/safetensors/src/slice.rs index d19b4b59..91087170 100644 --- a/safetensors/src/slice.rs +++ b/safetensors/src/slice.rs @@ -1,7 +1,7 @@ //! Module handling lazy loading via iterating on slices on the original buffer. +use crate::lib::{String, ToString, Vec}; use crate::tensor::TensorView; -use std::fmt; -use std::ops::{ +use core::ops::{ Bound, Range, RangeBounds, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive, }; @@ -40,8 +40,8 @@ fn display_bound(bound: &Bound) -> String { } /// Intended for Python users mostly or at least for its conventions -impl fmt::Display for TensorIndexer { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { +impl core::fmt::Display for TensorIndexer { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { match self { TensorIndexer::Select(n) => { write!(f, "{n}") @@ -77,7 +77,7 @@ macro_rules! impl_from_range { ($range_type:ty) => { impl From<$range_type> for TensorIndexer { fn from(range: $range_type) -> Self { - use std::ops::Bound::*; + use core::ops::Bound::*; let start = match range.start_bound() { Included(idx) => Included(*idx), diff --git a/safetensors/src/tensor.rs b/safetensors/src/tensor.rs index 596fa367..bee71782 100644 --- a/safetensors/src/tensor.rs +++ b/safetensors/src/tensor.rs @@ -1,11 +1,9 @@ //! Module Containing the most important structures +use crate::lib::{Cow, HashMap, String, ToString, Vec}; use crate::slice::{InvalidSlice, SliceIterator, TensorIndexer}; use serde::{ser::SerializeMap, Deserialize, Deserializer, Serialize, Serializer}; -use std::borrow::Cow; -use std::collections::HashMap; -use std::fs::File; -use std::io::{BufWriter, Write}; -use std::path::Path; +#[cfg(feature = "std")] +use std::io::Write; const MAX_HEADER_SIZE: usize = 100_000_000; @@ -32,6 +30,7 @@ pub enum SafeTensorError { /// The offsets declared for tensor with name `String` in the header are invalid InvalidOffset(String), /// IoError + #[cfg(feature = "std")] IoError(std::io::Error), /// JSON error JsonError(serde_json::Error), @@ -46,6 +45,7 @@ pub enum SafeTensorError { ValidationOverflow, } +#[cfg(feature = "std")] impl From for SafeTensorError { fn from(error: std::io::Error) -> SafeTensorError { SafeTensorError::IoError(error) @@ -58,13 +58,13 @@ impl From for SafeTensorError { } } -impl std::fmt::Display for SafeTensorError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl core::fmt::Display for SafeTensorError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { write!(f, "{self:?}") } } -impl std::error::Error for SafeTensorError {} +impl core::error::Error for SafeTensorError {} struct PreparedData { n: u64, @@ -164,7 +164,7 @@ pub trait View { fn data_len(&self) -> usize; } -fn prepare + Ord + std::fmt::Display, V: View, I: IntoIterator>( +fn prepare + Ord + core::fmt::Display, V: View, I: IntoIterator>( data: I, data_info: &Option>, // ) -> Result<(Metadata, Vec<&'hash TensorView<'data>>, usize), SafeTensorError> { @@ -212,7 +212,7 @@ fn prepare + Ord + std::fmt::Display, V: View, I: IntoIterator + Ord + std::fmt::Display, + S: AsRef + Ord + core::fmt::Display, V: View, I: IntoIterator, >( @@ -240,14 +240,15 @@ pub fn serialize< /// Serialize to a regular file the dictionnary of tensors. /// Writing directly to file reduces the need to allocate the whole amount to /// memory. +#[cfg(feature = "std")] pub fn serialize_to_file< - S: AsRef + Ord + std::fmt::Display, + S: AsRef + Ord + core::fmt::Display, V: View, I: IntoIterator, >( data: I, data_info: &Option>, - filename: &Path, + filename: &std::path::Path, ) -> Result<(), SafeTensorError> { let ( PreparedData { @@ -255,7 +256,7 @@ pub fn serialize_to_file< }, tensors, ) = prepare(data, data_info)?; - let mut f = BufWriter::new(File::create(filename)?); + let mut f = std::io::BufWriter::new(std::fs::File::create(filename)?); f.write_all(n.to_le_bytes().as_ref())?; f.write_all(&header_bytes)?; for tensor in tensors { @@ -303,7 +304,7 @@ impl<'data> SafeTensors<'data> { return Err(SafeTensorError::InvalidHeaderLength); } let string = - std::str::from_utf8(&buffer[8..stop]).map_err(|_| SafeTensorError::InvalidHeader)?; + core::str::from_utf8(&buffer[8..stop]).map_err(|_| SafeTensorError::InvalidHeader)?; // Assert the string starts with { // NOTE: Add when we move to 0.4.0 // if !string.starts_with('{') { @@ -719,6 +720,9 @@ mod tests { use super::*; use crate::slice::IndexOp; use proptest::prelude::*; + #[cfg(not(feature = "std"))] + extern crate std; + use std::io::Write; const MAX_DIMENSION: usize = 8; const MAX_SIZE: usize = 8; @@ -1021,10 +1025,13 @@ mod tests { std::fs::remove_file(&filename).unwrap(); // File api - serialize_to_file(&metadata, &None, Path::new(&filename)).unwrap(); - let raw = std::fs::read(&filename).unwrap(); - let _deserialized = SafeTensors::deserialize(&raw).unwrap(); - std::fs::remove_file(&filename).unwrap(); + #[cfg(feature = "std")] + { + serialize_to_file(&metadata, &None, std::path::Path::new(&filename)).unwrap(); + let raw = std::fs::read(&filename).unwrap(); + let _deserialized = SafeTensors::deserialize(&raw).unwrap(); + std::fs::remove_file(&filename).unwrap(); + } } #[test] @@ -1097,7 +1104,7 @@ mod tests { let n = serialized.len(); let filename = "out.safetensors"; - let mut f = BufWriter::new(File::create(filename).unwrap()); + let mut f = std::io::BufWriter::new(std::fs::File::create(filename).unwrap()); f.write_all(n.to_le_bytes().as_ref()).unwrap(); f.write_all(serialized).unwrap(); f.write_all(b"\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0").unwrap();