diff --git a/src/matrix/base/mod.rs b/src/matrix/base/mod.rs index 6c19fb8..591ecae 100644 --- a/src/matrix/base/mod.rs +++ b/src/matrix/base/mod.rs @@ -75,11 +75,40 @@ pub trait BaseMatrix: Sized { } } - /// Get a reference to a point in the matrix without bounds checking. + /// Get a reference to an element in the matrix without bounds checking. unsafe fn get_unchecked(&self, index: [usize; 2]) -> &T { &*(self.as_ptr().offset((index[0] * self.row_stride() + index[1]) as isize)) } + /// Get a reference to an element in the matrix. + /// + /// # Examples + /// + /// ``` + /// # #[macro_use] extern crate rulinalg; fn main() { + /// use rulinalg::matrix::{Matrix, BaseMatrix}; + /// + /// let mat = matrix![0, 1; + /// 3, 4; + /// 6, 7]; + /// + /// assert_eq!(mat.get([0, 2]), None); + /// assert_eq!(mat.get([3, 0]), None); + /// + /// assert_eq!( *mat.get([0, 0]).unwrap(), 0) + /// # } + /// ``` + fn get(&self, index: [usize; 2]) -> Option<&T> { + let row_ind = index[0]; + let col_ind = index[1]; + + if row_ind >= self.rows() || col_ind >= self.cols() { + None + } else { + unsafe { Some(self.get_unchecked(index)) } + } + } + /// Returns the column of a matrix at the given index. /// `None` if the index is out of bounds. /// @@ -1139,11 +1168,42 @@ pub trait BaseMatrixMut: BaseMatrix { } } - /// Get a mutable reference to a point in the matrix without bounds checks. + /// Get a mutable reference to an element in the matrix without bounds checks. unsafe fn get_unchecked_mut(&mut self, index: [usize; 2]) -> &mut T { &mut *(self.as_mut_ptr().offset((index[0] * self.row_stride() + index[1]) as isize)) } + /// Get a mutable reference to an element in the matrix. + /// + /// # Examples + /// + /// ``` + /// # #[macro_use] extern crate rulinalg; fn main() { + /// use rulinalg::matrix::{Matrix, BaseMatrix, BaseMatrixMut}; + /// + /// let mut mat = matrix![0, 1; + /// 3, 4; + /// 6, 7]; + /// + /// assert_eq!(mat.get_mut([0, 2]), None); + /// assert_eq!(mat.get_mut([3, 0]), None); + /// + /// assert_eq!(*mat.get_mut([0, 0]).unwrap(), 0); + /// *mat.get_mut([0,0]).unwrap() = 2; + /// assert_eq!(*mat.get_mut([0, 0]).unwrap(), 2); + /// # } + /// ``` + fn get_mut(&mut self, index: [usize; 2]) -> Option<&mut T> { + let row_ind = index[0]; + let col_ind = index[1]; + + if row_ind >= self.rows() || col_ind >= self.cols() { + None + } else { + unsafe { Some(self.get_unchecked_mut(index)) } + } + } + /// Returns a mutable iterator over the matrix. /// /// # Examples