Skip to content

Commit

Permalink
Adding support for integer indexing [0, :2, -1].
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil committed Feb 14, 2024
1 parent 08db340 commit ed5db3c
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 36 deletions.
2 changes: 1 addition & 1 deletion bindings/python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@ memmap2 = "0.5"
serde_json = "1.0"

[dependencies.safetensors]
version = "0.4.2-dev.0"
version = "0.4.3-dev.0"
path = "../../safetensors"
81 changes: 53 additions & 28 deletions bindings/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,24 +167,41 @@ fn deserialize(py: Python, bytes: &[u8]) -> PyResult<Vec<(String, HashMap<String
Ok(items)
}

fn slice_to_indexer(slice: &PySlice) -> Result<TensorIndexer, PyErr> {
let py_start = slice.getattr(intern!(slice.py(), "start"))?;
let start: Option<usize> = py_start.extract()?;
let start = if let Some(start) = start {
Bound::Included(start)
} else {
Bound::Unbounded
};

let py_stop = slice.getattr(intern!(slice.py(), "stop"))?;
let stop: Option<usize> = py_stop.extract()?;
let stop = if let Some(stop) = stop {
Bound::Excluded(stop)
} else {
Bound::Unbounded
};

Ok(TensorIndexer::Narrow(start, stop))
fn slice_to_indexer(
(dim_idx, (slice_index, dim)): (usize, (SliceIndex, usize)),
) -> Result<TensorIndexer, PyErr> {
match slice_index {
SliceIndex::Slice(slice) => {
let py_start = slice.getattr(intern!(slice.py(), "start"))?;
let start: Option<usize> = py_start.extract()?;
let start = if let Some(start) = start {
Bound::Included(start)
} else {
Bound::Unbounded
};

let py_stop = slice.getattr(intern!(slice.py(), "stop"))?;
let stop: Option<usize> = py_stop.extract()?;
let stop = if let Some(stop) = stop {
Bound::Excluded(stop)
} else {
Bound::Unbounded
};
Ok(TensorIndexer::Narrow(start, stop))
}
SliceIndex::Index(idx) => {
if idx < 0 {
let idx = dim
.checked_add_signed(idx as isize)
.ok_or(SafetensorError::new_err(format!(
"Invalid index {idx} for dimension {dim_idx} of size {dim}"
)))?;
Ok(TensorIndexer::Select(idx))
} else {
Ok(TensorIndexer::Select(idx as usize))
}
}
}
}

#[derive(Debug, Clone, PartialEq, Eq)]
Expand Down Expand Up @@ -730,10 +747,15 @@ struct PySafeSlice {
}

#[derive(FromPyObject)]
enum Slice<'a> {
// Index(usize),
enum SliceIndex<'a> {
Slice(&'a PySlice),
Slices(Vec<&'a PySlice>),
Index(i32),
}

#[derive(FromPyObject)]
enum Slice<'a> {
Slice(SliceIndex<'a>),
Slices(Vec<SliceIndex<'a>>),
}

#[pymethods]
Expand Down Expand Up @@ -780,23 +802,27 @@ impl PySafeSlice {
Ok(dtype)
}

pub fn __getitem__(&self, slices: Slice) -> PyResult<PyObject> {
let slices: Vec<&PySlice> = match slices {
Slice::Slice(slice) => vec![slice],
Slice::Slices(slices) => slices,
};

pub fn __getitem__(&self, slices: &PyAny) -> PyResult<PyObject> {
match &self.storage.as_ref() {
Storage::Mmap(mmap) => {
let slices: Slice = slices.extract()?;
let slices: Vec<SliceIndex> = match slices {
Slice::Slice(slice) => vec![slice],
Slice::Slices(slices) => slices,
};
let data = &mmap[self.info.data_offsets.0 + self.offset
..self.info.data_offsets.1 + self.offset];

let shape = self.info.shape.clone();

let tensor = TensorView::new(self.info.dtype, self.info.shape.clone(), data)
.map_err(|e| {
SafetensorError::new_err(format!("Error preparing tensor view: {e:?}"))
})?;
let slices: Vec<TensorIndexer> = slices
.into_iter()
.zip(shape)
.enumerate()
.map(slice_to_indexer)
.collect::<Result<_, _>>()?;

Expand All @@ -810,7 +836,6 @@ impl PySafeSlice {

let mut offset = 0;
let length = iterator.remaining_byte_len();

Python::with_gil(|py| {
let array: PyObject =
PyByteArray::new_with(py, length, |bytes: &mut [u8]| {
Expand Down
84 changes: 84 additions & 0 deletions bindings/python/tests/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,3 +231,87 @@ def test_exception(self):

with self.assertRaises(SafetensorError):
serialize(flattened)

def test_torch_slice(self):
A = torch.randn((10, 5))
tensors = {
"a": A,
}
save_file_pt(tensors, "./slice.safetensors")

# Now loading
with safe_open("./slice.safetensors", framework="pt", device="cpu") as f:
slice_ = f.get_slice("a")
tensor = slice_[:]
self.assertEqual(list(tensor.shape), [10, 5])
torch.testing.assert_close(tensor, A)

tensor = slice_[:2]
self.assertEqual(list(tensor.shape), [2, 5])
torch.testing.assert_close(tensor, A[:2])

tensor = slice_[:, :2]
self.assertEqual(list(tensor.shape), [10, 2])
torch.testing.assert_close(tensor, A[:, :2])

tensor = slice_[0, :2]
self.assertEqual(list(tensor.shape), [2])
torch.testing.assert_close(tensor, A[0, :2])

tensor = slice_[2:, 0]
self.assertEqual(list(tensor.shape), [8])
torch.testing.assert_close(tensor, A[2:, 0])

tensor = slice_[2:, 1]
self.assertEqual(list(tensor.shape), [8])
torch.testing.assert_close(tensor, A[2:, 1])

tensor = slice_[2:, -1]
self.assertEqual(list(tensor.shape), [8])
torch.testing.assert_close(tensor, A[2:, -1])

def test_numpy_slice(self):
A = np.random.rand(10, 5)
tensors = {
"a": A,
}
save_file(tensors, "./slice.safetensors")

# Now loading
with safe_open("./slice.safetensors", framework="np", device="cpu") as f:
slice_ = f.get_slice("a")
tensor = slice_[:]
self.assertEqual(list(tensor.shape), [10, 5])
self.assertTrue(np.allclose(tensor, A))

tensor = slice_[:2]
self.assertEqual(list(tensor.shape), [2, 5])
self.assertTrue(np.allclose(tensor, A[:2]))

tensor = slice_[:, :2]
self.assertEqual(list(tensor.shape), [10, 2])
self.assertTrue(np.allclose(tensor, A[:, :2]))

tensor = slice_[0, :2]
self.assertEqual(list(tensor.shape), [2])
self.assertTrue(np.allclose(tensor, A[0, :2]))

tensor = slice_[2:, 0]
self.assertEqual(list(tensor.shape), [8])
self.assertTrue(np.allclose(tensor, A[2:, 0]))

tensor = slice_[2:, 1]
self.assertEqual(list(tensor.shape), [8])
self.assertTrue(np.allclose(tensor, A[2:, 1]))

tensor = slice_[2:, -1]
self.assertEqual(list(tensor.shape), [8])
self.assertTrue(np.allclose(tensor, A[2:, -1]))

tensor = slice_[2:, -5]
self.assertEqual(list(tensor.shape), [8])
self.assertTrue(np.allclose(tensor, A[2:, -5]))

with self.assertRaises(SafetensorError) as cm:
tensor = slice_[2:, -6]
self.assertEqual(str(cm.exception), "Invalid index -6 for dimension 1 of size 5")
61 changes: 54 additions & 7 deletions safetensors/src/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,18 @@ pub enum InvalidSlice {
#[derive(Debug, Clone)]
/// Generic structure used to index a slice of the tensor
pub enum TensorIndexer {
//Select(usize),
/// This is selecting an entire dimension
Select(usize),
/// This is a regular slice, purely indexing a chunk of the tensor
Narrow(Bound<usize>, Bound<usize>),
//IndexSelect(Tensor),
}

// impl From<usize> for TensorIndexer {
// fn from(index: usize) -> Self {
// TensorIndexer::Select(index)
// }
// }
impl From<usize> for TensorIndexer {
fn from(index: usize) -> Self {
TensorIndexer::Select(index)
}
}

// impl From<&[usize]> for TensorIndexer {
// fn from(index: &[usize]) -> Self {
Expand Down Expand Up @@ -249,8 +250,11 @@ impl<'data> SliceIterator<'data> {
TensorIndexer::Narrow(Bound::Excluded(s), Bound::Included(stop)) => {
(*s + 1, *stop + 1)
}
TensorIndexer::Select(s) => (*s, *s + 1),
};
newshape.push(stop - start);
if let TensorIndexer::Narrow(..) = slice {
newshape.push(stop - start);
}
if indices.is_empty() {
if start == 0 && stop == shape {
// We haven't started to slice yet, just increase the span
Expand Down Expand Up @@ -487,4 +491,47 @@ mod tests {
assert_eq!(iterator.next(), Some(&data[16..24]));
assert_eq!(iterator.next(), None);
}

#[test]
fn test_slice_select() {
let data: Vec<u8> = vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0]
.into_iter()
.flat_map(|f| f.to_le_bytes())
.collect();

let attn_0 = TensorView::new(Dtype::F32, vec![2, 3], &data).unwrap();

let mut iterator = SliceIterator::new(
&attn_0,
&[
TensorIndexer::Select(1),
TensorIndexer::Narrow(Bound::Included(1), Bound::Excluded(3)),
],
)
.unwrap();
assert_eq!(iterator.next(), Some(&data[16..24]));
assert_eq!(iterator.next(), None);

let mut iterator = SliceIterator::new(
&attn_0,
&[
TensorIndexer::Select(0),
TensorIndexer::Narrow(Bound::Included(1), Bound::Excluded(3)),
],
)
.unwrap();
assert_eq!(iterator.next(), Some(&data[4..12]));
assert_eq!(iterator.next(), None);

let mut iterator = SliceIterator::new(
&attn_0,
&[
TensorIndexer::Narrow(Bound::Included(1), Bound::Excluded(2)),
TensorIndexer::Select(0),
],
)
.unwrap();
assert_eq!(iterator.next(), Some(&data[12..16]));
assert_eq!(iterator.next(), None);
}
}

0 comments on commit ed5db3c

Please sign in to comment.