From 340cb0b0bc3600b741bc265105db128eed3a7648 Mon Sep 17 00:00:00 2001 From: Timo Betcke Date: Wed, 11 Dec 2024 09:26:32 +0000 Subject: [PATCH] Added new references function --- src/dense/array.rs | 1 + src/dense/array/reference.rs | 333 +++++++++++++++++++++++++++++++++++ src/dense/array/views.rs | 2 + 3 files changed, 336 insertions(+) create mode 100644 src/dense/array/reference.rs diff --git a/src/dense/array.rs b/src/dense/array.rs index e4935f1..74964d0 100644 --- a/src/dense/array.rs +++ b/src/dense/array.rs @@ -24,6 +24,7 @@ pub mod operations; pub mod operators; pub mod random; pub mod rank1_array; +pub mod reference; pub mod slice; pub mod views; diff --git a/src/dense/array/reference.rs b/src/dense/array/reference.rs new file mode 100644 index 0000000..dc69e40 --- /dev/null +++ b/src/dense/array/reference.rs @@ -0,0 +1,333 @@ +//! Reference to an array. +//! +//! A reference is an owned struct that holds a reference to an array. It is used to +//! pass arrays to functions without transferring ownership. + +use crate::dense::types::RlstBase; + +use crate::dense::array::Array; + +use crate::dense::traits::{ + ChunkedAccess, RawAccess, RawAccessMut, ResizeInPlace, Shape, Stride, UnsafeRandomAccessByRef, + UnsafeRandomAccessByValue, UnsafeRandomAccessMut, +}; + +/// Basic structure for a `View` +pub struct ArrayRef< + 'a, + Item: RlstBase, + ArrayImpl: UnsafeRandomAccessByValue + Shape, + const NDIM: usize, +> { + arr: &'a Array, +} + +impl< + 'a, + Item: RlstBase, + ArrayImpl: UnsafeRandomAccessByValue + Shape, + const NDIM: usize, + > ArrayRef<'a, Item, ArrayImpl, NDIM> +{ + /// Create new view + pub fn new(arr: &'a Array) -> Self { + Self { arr } + } +} + +impl< + Item: RlstBase, + ArrayImpl: UnsafeRandomAccessByValue + Shape, + const NDIM: usize, + > Shape for ArrayRef<'_, Item, ArrayImpl, NDIM> +{ + fn shape(&self) -> [usize; NDIM] { + self.arr.shape() + } +} + +impl< + Item: RlstBase, + ArrayImpl: UnsafeRandomAccessByValue + Shape + Stride, + const NDIM: usize, + > Stride for ArrayRef<'_, Item, ArrayImpl, NDIM> +{ + fn stride(&self) -> [usize; NDIM] { + self.arr.stride() + } +} + +impl< + Item: RlstBase, + ArrayImpl: UnsafeRandomAccessByValue + + Shape + + RawAccess + + Stride, + const NDIM: usize, + > RawAccess for ArrayRef<'_, Item, ArrayImpl, NDIM> +{ + type Item = Item; + + fn data(&self) -> &[Self::Item] { + self.arr.data() + } + + fn buff_ptr(&self) -> *const Self::Item { + self.arr.buff_ptr() + } + + fn offset(&self) -> usize { + self.arr.offset() + } +} + +impl< + Item: RlstBase, + ArrayImpl: UnsafeRandomAccessByValue + Shape, + const NDIM: usize, + > UnsafeRandomAccessByValue for ArrayRef<'_, Item, ArrayImpl, NDIM> +{ + type Item = Item; + #[inline] + unsafe fn get_value_unchecked(&self, multi_index: [usize; NDIM]) -> Self::Item { + self.arr.get_value_unchecked(multi_index) + } +} + +impl< + Item: RlstBase, + ArrayImpl: UnsafeRandomAccessByValue + + Shape + + UnsafeRandomAccessByRef, + const NDIM: usize, + > UnsafeRandomAccessByRef for ArrayRef<'_, Item, ArrayImpl, NDIM> +{ + type Item = Item; + #[inline] + unsafe fn get_unchecked(&self, multi_index: [usize; NDIM]) -> &Self::Item { + self.arr.get_unchecked(multi_index) + } +} + +impl< + Item: RlstBase, + ArrayImpl: UnsafeRandomAccessByValue + Shape + ChunkedAccess, + const NDIM: usize, + const N: usize, + > ChunkedAccess for ArrayRef<'_, Item, ArrayImpl, NDIM> +{ + type Item = Item; + fn get_chunk( + &self, + chunk_index: usize, + ) -> Option> { + self.arr.get_chunk(chunk_index) + } +} + +/////////// ArrayRefMut + +/// Mutable array view +pub struct ArrayRefMut< + 'a, + Item: RlstBase, + ArrayImpl: UnsafeRandomAccessByValue + + Shape + + UnsafeRandomAccessMut, + const NDIM: usize, +> { + arr: &'a mut Array, +} + +impl< + 'a, + Item: RlstBase, + ArrayImpl: UnsafeRandomAccessByValue + + Shape + + UnsafeRandomAccessMut, + const NDIM: usize, + > ArrayRefMut<'a, Item, ArrayImpl, NDIM> +{ + /// Create new mutable view + pub fn new(arr: &'a mut Array) -> Self { + Self { arr } + } +} + +impl< + Item: RlstBase, + ArrayImpl: UnsafeRandomAccessByValue + + Shape + + UnsafeRandomAccessMut, + const NDIM: usize, + > Shape for ArrayRefMut<'_, Item, ArrayImpl, NDIM> +{ + fn shape(&self) -> [usize; NDIM] { + self.arr.shape() + } +} + +impl< + Item: RlstBase, + ArrayImpl: UnsafeRandomAccessByValue + + Shape + + Stride + + UnsafeRandomAccessMut, + const NDIM: usize, + > Stride for ArrayRefMut<'_, Item, ArrayImpl, NDIM> +{ + fn stride(&self) -> [usize; NDIM] { + self.arr.stride() + } +} + +impl< + Item: RlstBase, + ArrayImpl: UnsafeRandomAccessByValue + + Shape + + RawAccess + + Stride + + UnsafeRandomAccessMut, + const NDIM: usize, + > RawAccess for ArrayRefMut<'_, Item, ArrayImpl, NDIM> +{ + type Item = Item; + + fn data(&self) -> &[Self::Item] { + self.arr.data() + } + + fn buff_ptr(&self) -> *const Self::Item { + self.arr.buff_ptr() + } + + fn offset(&self) -> usize { + self.arr.offset() + } +} + +impl< + Item: RlstBase, + ArrayImpl: UnsafeRandomAccessByValue + + Shape + + RawAccess + + Stride + + UnsafeRandomAccessMut + + RawAccessMut, + const NDIM: usize, + > RawAccessMut for ArrayRefMut<'_, Item, ArrayImpl, NDIM> +{ + fn data_mut(&mut self) -> &mut [Self::Item] { + self.arr.data_mut() + } + + fn buff_ptr_mut(&mut self) -> *mut Self::Item { + self.arr.buff_ptr_mut() + } +} + +impl< + Item: RlstBase, + ArrayImpl: UnsafeRandomAccessByValue + + Shape + + UnsafeRandomAccessMut, + const NDIM: usize, + > UnsafeRandomAccessByValue for ArrayRefMut<'_, Item, ArrayImpl, NDIM> +{ + type Item = Item; + #[inline] + unsafe fn get_value_unchecked(&self, multi_index: [usize; NDIM]) -> Self::Item { + self.arr.get_value_unchecked(multi_index) + } +} + +impl< + Item: RlstBase, + ArrayImpl: UnsafeRandomAccessByValue + + Shape + + UnsafeRandomAccessByRef + + UnsafeRandomAccessMut, + const NDIM: usize, + > UnsafeRandomAccessByRef for ArrayRefMut<'_, Item, ArrayImpl, NDIM> +{ + type Item = Item; + #[inline] + unsafe fn get_unchecked(&self, multi_index: [usize; NDIM]) -> &Self::Item { + self.arr.get_unchecked(multi_index) + } +} + +impl< + Item: RlstBase, + ArrayImpl: UnsafeRandomAccessByValue + + Shape + + ChunkedAccess + + UnsafeRandomAccessMut, + const NDIM: usize, + const N: usize, + > ChunkedAccess for ArrayRefMut<'_, Item, ArrayImpl, NDIM> +{ + type Item = Item; + fn get_chunk( + &self, + chunk_index: usize, + ) -> Option> { + self.arr.get_chunk(chunk_index) + } +} + +impl< + Item: RlstBase, + ArrayImpl: UnsafeRandomAccessByValue + + Shape + + UnsafeRandomAccessMut, + const NDIM: usize, + > UnsafeRandomAccessMut for ArrayRefMut<'_, Item, ArrayImpl, NDIM> +{ + type Item = Item; + #[inline] + unsafe fn get_unchecked_mut(&mut self, multi_index: [usize; NDIM]) -> &mut Self::Item { + self.arr.get_unchecked_mut(multi_index) + } +} + +impl< + Item: RlstBase, + ArrayImpl: UnsafeRandomAccessByValue + + Shape + + UnsafeRandomAccessByRef + + UnsafeRandomAccessMut + + ResizeInPlace, + const NDIM: usize, + > ResizeInPlace for ArrayRefMut<'_, Item, ArrayImpl, NDIM> +{ + #[inline] + fn resize_in_place(&mut self, shape: [usize; NDIM]) { + self.arr.resize_in_place(shape) + } +} + +impl< + Item: RlstBase, + ArrayImpl: UnsafeRandomAccessByValue + Shape, + const NDIM: usize, + > Array +{ + /// Return a reference to an array. + pub fn r(&self) -> Array, NDIM> { + Array::new(ArrayRef::new(self)) + } +} +impl< + Item: RlstBase, + ArrayImpl: UnsafeRandomAccessByValue + + Shape + + UnsafeRandomAccessMut, + const NDIM: usize, + > Array +{ + /// Return a mutable view onto the array. + pub fn r_mut(&mut self) -> Array, NDIM> { + Array::new(ArrayRefMut::new(self)) + } +} diff --git a/src/dense/array/views.rs b/src/dense/array/views.rs index 1954e80..bcc2c68 100644 --- a/src/dense/array/views.rs +++ b/src/dense/array/views.rs @@ -41,6 +41,7 @@ impl< > Array { /// Return a view onto the array. + #[deprecated(note = "Please use arr.r() instead.")] pub fn view(&self) -> Array, NDIM> { Array::new(ArrayView::new(self)) } @@ -60,6 +61,7 @@ impl< > Array { /// Return a mutable view onto the array. + #[deprecated(note = "Please use arr.r_mut() instead.")] pub fn view_mut(&mut self) -> Array, NDIM> { Array::new(ArrayViewMut::new(self)) }