diff --git a/src/lib.rs b/src/lib.rs index 10189a41e..25959e011 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1928,9 +1928,9 @@ pub trait Itertools: Iterator { /// /// assert_eq!(Some([1, 2]), iter.next_array()); /// ``` - fn next_array(&mut self) -> Option<[T; N]> + fn next_array(&mut self) -> Option<[Self::Item; N]> where - Self: Sized + Iterator, + Self: Sized, { next_array::next_array(self) } @@ -1952,9 +1952,9 @@ pub trait Itertools: Iterator { /// panic!("Expected two elements") /// } /// ``` - fn collect_array(mut self) -> Option<[T; N]> + fn collect_array(mut self) -> Option<[Self::Item; N]> where - Self: Sized + Iterator, + Self: Sized, { self.next_array().filter(|_| self.next().is_none()) } diff --git a/src/next_array.rs b/src/next_array.rs index e9747e52c..5e3a30928 100644 --- a/src/next_array.rs +++ b/src/next_array.rs @@ -1,5 +1,4 @@ use core::mem::{self, MaybeUninit}; -use core::ptr; /// An array of at most `N` elements. struct ArrayBuilder { @@ -17,7 +16,7 @@ struct ArrayBuilder { impl ArrayBuilder { /// Initializes a new, empty `ArrayBuilder`. pub fn new() -> Self { - // SAFETY: the validity invariant trivially hold for a zero-length array. + // SAFETY: The safety invariant of `arr` trivially holds for `len = 0`. Self { arr: [(); N].map(|_| MaybeUninit::uninit()), len: 0, @@ -28,50 +27,101 @@ impl ArrayBuilder { /// /// # Panics /// - /// This panics if `self.len() >= N`. + /// This panics if `self.len >= N`. + #[inline(always)] pub fn push(&mut self, value: T) { - // SAFETY: we maintain the invariant here that arr[..len] is valid. - // Indexing with self.len also ensures self.len < N, and thus <= N after - // the increment. - self.arr[self.len] = MaybeUninit::new(value); + // PANICS: This will panic if `self.len >= N`. + let place = &mut self.arr[self.len]; + // SAFETY: The safety invariant of `self.arr` applies to elements at + // indices `0..self.len` — not to the element at `self.len`. Writing to + // the element at index `self.len` therefore does not violate the safety + // invariant of `self.arr`. Even if this line panics, we have not + // created any intermediate invalid state. + *place = MaybeUninit::new(value); + // PANICS: This cannot panic, since `self.len < N <= usize::MAX`. + // `0..self.len` are valid. Due to the above write, the element at + // `self.len` is now also valid. Consequently, all elements at indicies + // `0..(self.len + 1)` are valid, and `self.len` can be safely + // incremented without violating `self.arr`'s invariant. It is fine if + // this increment panics, as we have not created any intermediate + // invalid state. self.len += 1; } - /// Consumes the elements in the `ArrayBuilder` and returns them as an array `[T; N]`. + /// Consumes the elements in the `ArrayBuilder` and returns them as an array + /// `[T; N]`. /// /// If `self.len() < N`, this returns `None`. pub fn take(&mut self) -> Option<[T; N]> { if self.len == N { - // Take the array, resetting our length back to zero. + // SAFETY: Decreasing the value of `self.len` cannot violate the + // safety invariant on `self.arr`. self.len = 0; + + // SAFETY: Since `self.len` is 0, `self.arr` may safely contain + // uninitialized elements. let arr = mem::replace(&mut self.arr, [(); N].map(|_| MaybeUninit::uninit())); - // SAFETY: we had len == N, so all elements in arr are valid. - Some(unsafe { arr.map(|v| v.assume_init()) }) + Some(arr.map(|v| { + // SAFETY: We know that all elements of `arr` are valid because + // we checked that `len == N`. + unsafe { v.assume_init() } + })) } else { None } } } +impl AsMut<[T]> for ArrayBuilder { + fn as_mut(&mut self) -> &mut [T] { + let valid = &mut self.arr[..self.len]; + // SAFETY: By invariant on `self.arr`, the elements of `self.arr` at + // indices `0..self.len` are in a valid state. Since `valid` references + // only these elements, the safety precondition of + // `slice_assume_init_mut` is satisfied. + unsafe { slice_assume_init_mut(valid) } + } +} + impl Drop for ArrayBuilder { + // We provide a non-trivial `Drop` impl, because the trivial impl would be a + // no-op; `MaybeUninit` has no innate awareness of its own validity, and + // so it can only forget its contents. By leveraging the safety invariant of + // `self.arr`, we do know which elements of `self.arr` are valid, and can + // selectively run their destructors. fn drop(&mut self) { - unsafe { - // SAFETY: arr[..len] is valid, so must be dropped. First we create - // a pointer to this valid slice, then drop that slice in-place. - // The cast from *mut MaybeUninit to *mut T is always sound by - // the layout guarantees of MaybeUninit. - let ptr_to_first: *mut MaybeUninit = self.arr.as_mut_ptr(); - let ptr_to_slice = ptr::slice_from_raw_parts_mut(ptr_to_first.cast::(), self.len); - ptr::drop_in_place(ptr_to_slice); - } + // SAFETY: + // - by invariant on `&[T]`, `self.as_mut()` is: + // - valid for reads and writes + // - properly aligned + // - non-null + // - the dropped `T` are valid for dropping; they do not have any + // additional library invariants that we've violated + // - no other pointers to `valid` exist (since we're in the context of + // `drop`) + unsafe { core::ptr::drop_in_place(self.as_mut()) } } } +/// Assuming all the elements are initialized, get a mutable slice to them. +/// +/// # Safety +/// +/// The caller guarantees that the elements `T` referenced by `slice` are in a +/// valid state. +unsafe fn slice_assume_init_mut(slice: &mut [MaybeUninit]) -> &mut [T] { + // SAFETY: Casting `&mut [MaybeUninit]` to `&mut [T]` is sound, because + // `MaybeUninit` is guaranteed to have the same size, alignment and ABI + // as `T`, and because the caller has guaranteed that `slice` is in the + // valid state. + unsafe { &mut *(slice as *mut [MaybeUninit] as *mut [T]) } +} + /// Equivalent to `it.next_array()`. -pub fn next_array(it: &mut I) -> Option<[T; N]> +pub(crate) fn next_array(it: &mut I) -> Option<[I::Item; N]> where - I: Iterator, + I: Iterator, { let mut builder = ArrayBuilder::new(); for _ in 0..N { @@ -79,3 +129,126 @@ where } builder.take() } + +#[cfg(test)] +mod test { + use super::ArrayBuilder; + + #[test] + fn zero_len_take() { + let mut builder = ArrayBuilder::<(), 0>::new(); + let taken = builder.take(); + assert_eq!(taken, Some([(); 0])); + } + + #[test] + #[should_panic] + fn zero_len_push() { + let mut builder = ArrayBuilder::<(), 0>::new(); + builder.push(()); + } + + #[test] + fn push_4() { + let mut builder = ArrayBuilder::<(), 4>::new(); + assert_eq!(builder.take(), None); + + builder.push(()); + assert_eq!(builder.take(), None); + + builder.push(()); + assert_eq!(builder.take(), None); + + builder.push(()); + assert_eq!(builder.take(), None); + + builder.push(()); + assert_eq!(builder.take(), Some([(); 4])); + } + + #[test] + fn tracked_drop() { + use std::panic::{catch_unwind, AssertUnwindSafe}; + use std::sync::atomic::{AtomicU16, Ordering}; + + static DROPPED: AtomicU16 = AtomicU16::new(0); + + #[derive(Debug, PartialEq)] + struct TrackedDrop; + + impl Drop for TrackedDrop { + fn drop(&mut self) { + DROPPED.fetch_add(1, Ordering::Relaxed); + } + } + + { + let builder = ArrayBuilder::::new(); + assert_eq!(DROPPED.load(Ordering::Relaxed), 0); + drop(builder); + assert_eq!(DROPPED.load(Ordering::Relaxed), 0); + } + + { + let mut builder = ArrayBuilder::::new(); + builder.push(TrackedDrop); + assert_eq!(builder.take(), None); + assert_eq!(DROPPED.load(Ordering::Relaxed), 0); + drop(builder); + assert_eq!(DROPPED.swap(0, Ordering::Relaxed), 1); + } + + { + let mut builder = ArrayBuilder::::new(); + builder.push(TrackedDrop); + builder.push(TrackedDrop); + assert!(matches!(builder.take(), Some(_))); + assert_eq!(DROPPED.swap(0, Ordering::Relaxed), 2); + drop(builder); + assert_eq!(DROPPED.load(Ordering::Relaxed), 0); + } + + { + let mut builder = ArrayBuilder::::new(); + + builder.push(TrackedDrop); + builder.push(TrackedDrop); + + assert!(catch_unwind(AssertUnwindSafe(|| { + builder.push(TrackedDrop); + })) + .is_err()); + + assert_eq!(DROPPED.load(Ordering::Relaxed), 1); + + drop(builder); + + assert_eq!(DROPPED.swap(0, Ordering::Relaxed), 3); + } + + { + let mut builder = ArrayBuilder::::new(); + + builder.push(TrackedDrop); + builder.push(TrackedDrop); + + assert!(catch_unwind(AssertUnwindSafe(|| { + builder.push(TrackedDrop); + })) + .is_err()); + + assert_eq!(DROPPED.load(Ordering::Relaxed), 1); + + assert!(matches!(builder.take(), Some(_))); + + assert_eq!(DROPPED.load(Ordering::Relaxed), 3); + + builder.push(TrackedDrop); + builder.push(TrackedDrop); + + assert!(matches!(builder.take(), Some(_))); + + assert_eq!(DROPPED.swap(0, Ordering::Relaxed), 5); + } + } +} diff --git a/tests/test_core.rs b/tests/test_core.rs index f98790b71..493616085 100644 --- a/tests/test_core.rs +++ b/tests/test_core.rs @@ -380,7 +380,7 @@ fn next_array() { assert_eq!(iter.next_array(), Some([])); assert_eq!(iter.next_array().map(|[&x, &y]| [x, y]), Some([1, 2])); assert_eq!(iter.next_array().map(|[&x, &y]| [x, y]), Some([3, 4])); - assert_eq!(iter.next_array::<_, 2>(), None); + assert_eq!(iter.next_array::<2>(), None); } #[test] @@ -391,9 +391,9 @@ fn collect_array() { let v = [1]; let iter = v.iter().cloned(); - assert_eq!(iter.collect_array::<_, 2>(), None); + assert_eq!(iter.collect_array::<2>(), None); let v = [1, 2, 3]; let iter = v.iter().cloned(); - assert_eq!(iter.collect_array::<_, 2>(), None); + assert_eq!(iter.collect_array::<2>(), None); }