Skip to content

Commit

Permalink
support no_std (#556)
Browse files Browse the repository at this point in the history
* support no_std (#544)

* Simpler clippy check (no features in safetensors really).

---------

Co-authored-by: Nicolas Patry <[email protected]>
  • Loading branch information
ivila and Narsil authored Jan 7, 2025
1 parent a481b07 commit 38a3629
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 28 deletions.
5 changes: 4 additions & 1 deletion .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,14 @@ jobs:
run: cargo build --all-targets --verbose

- name: Lint with Clippy
run: cargo clippy --all-targets --all-features -- -D warnings
run: cargo clippy --all-targets -- -D warnings

- name: Run Tests
run: cargo test --verbose

- name: Run No-STD Tests
run: cargo test --no-default-features --features alloc --verbose

- name: Run Audit
# RUSTSEC-2021-0145 is criterion so only within benchmarks
run: cargo audit -D warnings --ignore RUSTSEC-2021-0145
Expand Down
10 changes: 8 additions & 2 deletions safetensors/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,20 @@ exclude = [ "rust-toolchain", "target/*", "Cargo.lock"]
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
serde = {version = "1.0", features = ["derive"]}
serde_json = "1.0"
hashbrown = { version = "0.15.2", features = ["serde"], optional = true }
serde = { version = "1.0", default-features = false, features = ["derive"] }
serde_json = { version = "1.0", default-features = false }

[dev-dependencies]
criterion = "0.5"
memmap2 = "0.9"
proptest = "1.4"

[features]
default = ["std"]
std = ["serde/default", "serde_json/default"]
alloc = ["serde/alloc", "serde_json/alloc", "hashbrown"]

[[bench]]
name = "benchmark"
harness = false
40 changes: 39 additions & 1 deletion safetensors/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,43 @@
#![deny(missing_docs)]
#![doc = include_str!("../README.md")]
#![cfg_attr(not(feature = "std"), no_std)]
pub mod slice;
pub mod tensor;
pub use tensor::{serialize, serialize_to_file, Dtype, SafeTensorError, SafeTensors, View};
/// serialize_to_file only valid in std
#[cfg(feature = "std")]
pub use tensor::serialize_to_file;
pub use tensor::{serialize, Dtype, SafeTensorError, SafeTensors, View};

#[cfg(feature = "alloc")]
#[macro_use]
extern crate alloc;

#[cfg(all(feature = "std", feature = "alloc"))]
compile_error!("must choose either the `std` or `alloc` feature, but not both.");
#[cfg(all(not(feature = "std"), not(feature = "alloc")))]
compile_error!("must choose either the `std` or `alloc` feature");

/// A facade around all the types we need from the `std`, `core`, and `alloc`
/// crates. This avoids elaborate import wrangling having to happen in every
/// module.
mod lib {
#[cfg(not(feature = "std"))]
mod no_stds {
pub use alloc::borrow::Cow;
pub use alloc::string::{String, ToString};
pub use alloc::vec::Vec;
pub use hashbrown::HashMap;
}
#[cfg(feature = "std")]
mod stds {
pub use std::borrow::Cow;
pub use std::collections::HashMap;
pub use std::string::{String, ToString};
pub use std::vec::Vec;
}
/// choose std or no_std to export by feature flag
#[cfg(not(feature = "std"))]
pub use no_stds::*;
#[cfg(feature = "std")]
pub use stds::*;
}
10 changes: 5 additions & 5 deletions safetensors/src/slice.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! Module handling lazy loading via iterating on slices on the original buffer.
use crate::lib::{String, ToString, Vec};
use crate::tensor::TensorView;
use std::fmt;
use std::ops::{
use core::ops::{
Bound, Range, RangeBounds, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive,
};

Expand Down Expand Up @@ -40,8 +40,8 @@ fn display_bound(bound: &Bound<usize>) -> String {
}

/// Intended for Python users mostly or at least for its conventions
impl fmt::Display for TensorIndexer {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
impl core::fmt::Display for TensorIndexer {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
TensorIndexer::Select(n) => {
write!(f, "{n}")
Expand Down Expand Up @@ -77,7 +77,7 @@ macro_rules! impl_from_range {
($range_type:ty) => {
impl From<$range_type> for TensorIndexer {
fn from(range: $range_type) -> Self {
use std::ops::Bound::*;
use core::ops::Bound::*;

let start = match range.start_bound() {
Included(idx) => Included(*idx),
Expand Down
45 changes: 26 additions & 19 deletions safetensors/src/tensor.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
//! Module Containing the most important structures
use crate::lib::{Cow, HashMap, String, ToString, Vec};
use crate::slice::{InvalidSlice, SliceIterator, TensorIndexer};
use serde::{ser::SerializeMap, Deserialize, Deserializer, Serialize, Serializer};
use std::borrow::Cow;
use std::collections::HashMap;
use std::fs::File;
use std::io::{BufWriter, Write};
use std::path::Path;
#[cfg(feature = "std")]
use std::io::Write;

const MAX_HEADER_SIZE: usize = 100_000_000;

Expand All @@ -32,6 +30,7 @@ pub enum SafeTensorError {
/// The offsets declared for tensor with name `String` in the header are invalid
InvalidOffset(String),
/// IoError
#[cfg(feature = "std")]
IoError(std::io::Error),
/// JSON error
JsonError(serde_json::Error),
Expand All @@ -46,6 +45,7 @@ pub enum SafeTensorError {
ValidationOverflow,
}

#[cfg(feature = "std")]
impl From<std::io::Error> for SafeTensorError {
fn from(error: std::io::Error) -> SafeTensorError {
SafeTensorError::IoError(error)
Expand All @@ -58,13 +58,13 @@ impl From<serde_json::Error> for SafeTensorError {
}
}

impl std::fmt::Display for SafeTensorError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
impl core::fmt::Display for SafeTensorError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "{self:?}")
}
}

impl std::error::Error for SafeTensorError {}
impl core::error::Error for SafeTensorError {}

struct PreparedData {
n: u64,
Expand Down Expand Up @@ -164,7 +164,7 @@ pub trait View {
fn data_len(&self) -> usize;
}

fn prepare<S: AsRef<str> + Ord + std::fmt::Display, V: View, I: IntoIterator<Item = (S, V)>>(
fn prepare<S: AsRef<str> + Ord + core::fmt::Display, V: View, I: IntoIterator<Item = (S, V)>>(
data: I,
data_info: &Option<HashMap<String, String>>,
// ) -> Result<(Metadata, Vec<&'hash TensorView<'data>>, usize), SafeTensorError> {
Expand Down Expand Up @@ -212,7 +212,7 @@ fn prepare<S: AsRef<str> + Ord + std::fmt::Display, V: View, I: IntoIterator<Ite

/// Serialize to an owned byte buffer the dictionnary of tensors.
pub fn serialize<
S: AsRef<str> + Ord + std::fmt::Display,
S: AsRef<str> + Ord + core::fmt::Display,
V: View,
I: IntoIterator<Item = (S, V)>,
>(
Expand Down Expand Up @@ -240,22 +240,23 @@ pub fn serialize<
/// Serialize to a regular file the dictionnary of tensors.
/// Writing directly to file reduces the need to allocate the whole amount to
/// memory.
#[cfg(feature = "std")]
pub fn serialize_to_file<
S: AsRef<str> + Ord + std::fmt::Display,
S: AsRef<str> + Ord + core::fmt::Display,
V: View,
I: IntoIterator<Item = (S, V)>,
>(
data: I,
data_info: &Option<HashMap<String, String>>,
filename: &Path,
filename: &std::path::Path,
) -> Result<(), SafeTensorError> {
let (
PreparedData {
n, header_bytes, ..
},
tensors,
) = prepare(data, data_info)?;
let mut f = BufWriter::new(File::create(filename)?);
let mut f = std::io::BufWriter::new(std::fs::File::create(filename)?);
f.write_all(n.to_le_bytes().as_ref())?;
f.write_all(&header_bytes)?;
for tensor in tensors {
Expand Down Expand Up @@ -303,7 +304,7 @@ impl<'data> SafeTensors<'data> {
return Err(SafeTensorError::InvalidHeaderLength);
}
let string =
std::str::from_utf8(&buffer[8..stop]).map_err(|_| SafeTensorError::InvalidHeader)?;
core::str::from_utf8(&buffer[8..stop]).map_err(|_| SafeTensorError::InvalidHeader)?;
// Assert the string starts with {
// NOTE: Add when we move to 0.4.0
// if !string.starts_with('{') {
Expand Down Expand Up @@ -719,6 +720,9 @@ mod tests {
use super::*;
use crate::slice::IndexOp;
use proptest::prelude::*;
#[cfg(not(feature = "std"))]
extern crate std;
use std::io::Write;

const MAX_DIMENSION: usize = 8;
const MAX_SIZE: usize = 8;
Expand Down Expand Up @@ -1021,10 +1025,13 @@ mod tests {
std::fs::remove_file(&filename).unwrap();

// File api
serialize_to_file(&metadata, &None, Path::new(&filename)).unwrap();
let raw = std::fs::read(&filename).unwrap();
let _deserialized = SafeTensors::deserialize(&raw).unwrap();
std::fs::remove_file(&filename).unwrap();
#[cfg(feature = "std")]
{
serialize_to_file(&metadata, &None, std::path::Path::new(&filename)).unwrap();
let raw = std::fs::read(&filename).unwrap();
let _deserialized = SafeTensors::deserialize(&raw).unwrap();
std::fs::remove_file(&filename).unwrap();
}
}

#[test]
Expand Down Expand Up @@ -1097,7 +1104,7 @@ mod tests {
let n = serialized.len();

let filename = "out.safetensors";
let mut f = BufWriter::new(File::create(filename).unwrap());
let mut f = std::io::BufWriter::new(std::fs::File::create(filename).unwrap());
f.write_all(n.to_le_bytes().as_ref()).unwrap();
f.write_all(serialized).unwrap();
f.write_all(b"\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0").unwrap();
Expand Down

0 comments on commit 38a3629

Please sign in to comment.