diff --git a/src/matrix/base/mod.rs b/src/matrix/base/mod.rs index 721fab8..6c19fb8 100644 --- a/src/matrix/base/mod.rs +++ b/src/matrix/base/mod.rs @@ -20,7 +20,7 @@ //! ``` use matrix::{Matrix, MatrixSlice, MatrixSliceMut}; -use matrix::{Row, RowMut, Column, ColumnMut, Rows, RowsMut, Axes}; +use matrix::{Cols, ColsMut, Row, RowMut, Column, ColumnMut, Rows, RowsMut, Axes}; use matrix::{DiagOffset, Diagonal, DiagonalMut}; use matrix::{back_substitution, forward_substitution}; use matrix::{SliceIter, SliceIterMut}; @@ -211,19 +211,54 @@ pub trait BaseMatrix: Sized { } } + /// Iterate over the columns of the matrix. + /// + /// # Examples + /// + /// ``` + /// # #[macro_use] extern crate rulinalg; fn main() { + /// use rulinalg::matrix::{Matrix, BaseMatrix}; + /// + /// let a = matrix![0, 1; + /// 2, 3; + /// 4, 5]; + /// + /// let mut iter = a.col_iter(); + /// + /// assert_matrix_eq!(*iter.next().unwrap(), matrix![ 0; 2; 4 ]); + /// assert_matrix_eq!(*iter.next().unwrap(), matrix![ 1; 3; 5 ]); + /// assert!(iter.next().is_none()); + /// # } + /// ``` + fn col_iter(&self) -> Cols { + Cols { + _marker: PhantomData::<&T>, + col_pos: 0, + row_stride: self.row_stride() as isize, + slice_cols: self.cols(), + slice_rows: self.rows(), + slice_start: self.as_ptr(), + } + } + /// Iterate over the rows of the matrix. /// /// # Examples /// /// ``` + /// # #[macro_use] extern crate rulinalg; fn main() { /// use rulinalg::matrix::{Matrix, BaseMatrix}; + /// let a = matrix![0, 1, 2; + /// 3, 4, 5; + /// 6, 7, 8]; /// - /// let a = Matrix::new(3, 2, (0..6).collect::>()); + /// let mut iter = a.row_iter(); /// - /// // Prints "2" three times. - /// for row in a.row_iter() { - /// println!("{}", row.cols()); - /// } + /// assert_matrix_eq!(*iter.next().unwrap(), matrix![ 0, 1, 2 ]); + /// assert_matrix_eq!(*iter.next().unwrap(), matrix![ 3, 4, 5 ]); + /// assert_matrix_eq!(*iter.next().unwrap(), matrix![ 6, 7, 8 ]); + /// assert!(iter.next().is_none()); + /// # } /// ``` fn row_iter(&self) -> Rows { Rows { @@ -393,7 +428,8 @@ pub trait BaseMatrix: Sized { /// # } /// ``` fn metric<'a, 'b, B, M>(&'a self, mat: &'b B, metric: M) -> T - where B: 'b + BaseMatrix, M: MatrixMetric<'a, 'b, T, Self, B> + where B: 'b + BaseMatrix, + M: MatrixMetric<'a, 'b, T, Self, B> { metric.metric(self, mat) } @@ -439,24 +475,26 @@ pub trait BaseMatrix: Sized { /// assert_eq!(rmin, vector![1.0, 2.0]); /// # } /// ``` - fn min(&self, axis: Axes) -> Vector where T: Copy + PartialOrd + fn min(&self, axis: Axes) -> Vector + where T: Copy + PartialOrd { match axis { Axes::Col => { let mut mins: Vec = Vec::with_capacity(self.rows()); for row in self.row_iter() { let min = row.iter() - .skip(1) - .fold(row[0], |m, &v| if v < m { v } else { m } ); + .skip(1) + .fold(row[0], |m, &v| if v < m { v } else { m }); mins.push(min); } Vector::new(mins) - }, + } Axes::Row => { let mut mins: Vec = self.row(0).raw_slice().into(); for row in self.row_iter().skip(1) { - utils::in_place_vec_bin_op(&mut mins, row.raw_slice(), - |min, &r| if r < *min { *min = r; }); + utils::in_place_vec_bin_op(&mut mins, row.raw_slice(), |min, &r| if r < *min { + *min = r; + }); } Vector::new(mins) } @@ -481,24 +519,26 @@ pub trait BaseMatrix: Sized { /// assert_eq!(rmax, vector![3.0, 4.0]); /// # } /// ``` - fn max(&self, axis: Axes) -> Vector where T: Copy + PartialOrd + fn max(&self, axis: Axes) -> Vector + where T: Copy + PartialOrd { match axis { Axes::Col => { let mut maxs: Vec = Vec::with_capacity(self.rows()); for row in self.row_iter() { let max = row.iter() - .skip(1) - .fold(row[0], |m, &v| if v > m { v } else { m } ); + .skip(1) + .fold(row[0], |m, &v| if v > m { v } else { m }); maxs.push(max); } Vector::new(maxs) - }, + } Axes::Row => { let mut maxs: Vec = self.row(0).raw_slice().into(); for row in self.row_iter().skip(1) { - utils::in_place_vec_bin_op(&mut maxs, row.raw_slice(), - |max, &r| if r > *max { *max = r; }); + utils::in_place_vec_bin_op(&mut maxs, row.raw_slice(), |max, &r| if r > *max { + *max = r; + }); } Vector::new(maxs) } @@ -1019,9 +1059,12 @@ pub trait BaseMatrix: Sized { mid, self.cols(), self.row_stride()); - slice_2 = MatrixSlice::from_raw_parts( - self.as_ptr().offset((mid * self.row_stride()) as isize), - self.rows() - mid, self.cols(), self.row_stride()); + slice_2 = MatrixSlice::from_raw_parts(self.as_ptr() + .offset((mid * self.row_stride()) as + isize), + self.rows() - mid, + self.cols(), + self.row_stride()); } } Axes::Col => { @@ -1062,8 +1105,12 @@ pub trait BaseMatrix: Sized { "View dimensions exceed matrix dimensions."); unsafe { - MatrixSlice::from_raw_parts(self.as_ptr().offset((start[0] * self.row_stride() + start[1]) as isize), - rows, cols, self.row_stride()) + MatrixSlice::from_raw_parts(self.as_ptr() + .offset((start[0] * self.row_stride() + start[1]) as + isize), + rows, + cols, + self.row_stride()) } } } @@ -1292,14 +1339,14 @@ pub trait BaseMatrixMut: BaseMatrix { if a != b { unsafe { - let row_a = - slice::from_raw_parts_mut(self.as_mut_ptr() - .offset((self.row_stride() * a) as isize), - self.cols()); - let row_b = - slice::from_raw_parts_mut(self.as_mut_ptr() - .offset((self.row_stride() * b) as isize), - self.cols()); + let row_a = slice::from_raw_parts_mut(self.as_mut_ptr() + .offset((self.row_stride() * a) as + isize), + self.cols()); + let row_b = slice::from_raw_parts_mut(self.as_mut_ptr() + .offset((self.row_stride() * b) as + isize), + self.cols()); for (x, y) in row_a.into_iter().zip(row_b.into_iter()) { mem::swap(x, y); @@ -1354,14 +1401,48 @@ pub trait BaseMatrixMut: BaseMatrix { } + /// Iterate over the mutable columns of the matrix. + /// + /// # Examples + /// + /// ``` + /// # #[macro_use] extern crate rulinalg; fn main() { + /// use rulinalg::matrix::{Matrix, BaseMatrixMut}; + /// + /// let mut a = matrix![0, 1; + /// 2, 3; + /// 4, 5]; + /// + /// for mut col in a.col_iter_mut() { + /// *col += 1; + /// } + /// + /// // Now contains the range 1..7 + /// println!("{}", a); + /// # } + /// ``` + fn col_iter_mut(&mut self) -> ColsMut { + ColsMut { + _marker: PhantomData::<&mut T>, + col_pos: 0, + row_stride: self.row_stride() as isize, + slice_cols: self.cols(), + slice_rows: self.rows(), + slice_start: self.as_mut_ptr(), + } + } + /// Iterate over the mutable rows of the matrix. /// /// # Examples /// /// ``` + /// # #[macro_use] extern crate rulinalg; fn main() { /// use rulinalg::matrix::{Matrix, BaseMatrixMut}; /// - /// let mut a = Matrix::new(3, 2, (0..6).collect::>()); + /// let mut a = matrix![0, 1; + /// 2, 3; + /// 4, 5]; /// /// for mut row in a.row_iter_mut() { /// *row += 1; @@ -1369,6 +1450,7 @@ pub trait BaseMatrixMut: BaseMatrix { /// /// // Now contains the range 1..7 /// println!("{}", a); + /// # } /// ``` fn row_iter_mut(&mut self) -> RowsMut { RowsMut { @@ -1526,9 +1608,13 @@ pub trait BaseMatrixMut: BaseMatrix { mid, self.cols(), self.row_stride()); - slice_2 = MatrixSliceMut::from_raw_parts( - self.as_mut_ptr().offset((mid * self.row_stride()) as isize), - self.rows() - mid, self.cols(), self.row_stride()); + slice_2 = MatrixSliceMut::from_raw_parts(self.as_mut_ptr() + .offset((mid * + self.row_stride()) as + isize), + self.rows() - mid, + self.cols(), + self.row_stride()); } } Axes::Col => { @@ -1574,8 +1660,12 @@ pub trait BaseMatrixMut: BaseMatrix { "View dimensions exceed matrix dimensions."); unsafe { - MatrixSliceMut::from_raw_parts(self.as_mut_ptr().offset((start[0] * self.row_stride() + start[1]) as isize), - rows, cols, self.row_stride()) + MatrixSliceMut::from_raw_parts(self.as_mut_ptr() + .offset((start[0] * self.row_stride() + start[1]) as + isize), + rows, + cols, + self.row_stride()) } } } diff --git a/src/matrix/iter.rs b/src/matrix/iter.rs index 9a02272..9e569ba 100644 --- a/src/matrix/iter.rs +++ b/src/matrix/iter.rs @@ -1,9 +1,8 @@ use std::iter::{ExactSizeIterator, FromIterator}; use std::mem; -// use std::slice; use super::{Matrix, MatrixSlice, MatrixSliceMut}; -use super::{Row, RowMut, Rows, RowsMut, Diagonal, DiagonalMut}; +use super::{Column, ColumnMut, Cols, ColsMut, Row, RowMut, Rows, RowsMut, Diagonal, DiagonalMut}; use super::{BaseMatrix, BaseMatrixMut, SliceIter, SliceIterMut}; macro_rules! impl_slice_iter ( @@ -107,6 +106,75 @@ impl<'a, T, M: $diag_base> ExactSizeIterator for $diag<'a, T, M> {} impl_diag_iter!(Diagonal, BaseMatrix, &'a T, as_ptr); impl_diag_iter!(DiagonalMut, BaseMatrixMut, &'a mut T, as_mut_ptr); +macro_rules! impl_col_iter ( + ($cols:ident, $col_type:ty, $col_base:ident, $slice_base:ident) => ( + +/// Iterates over the columns in the matrix. +impl<'a, T> Iterator for $cols<'a, T> { + type Item = $col_type; + + fn next(&mut self) -> Option { + if self.col_pos >= self.slice_cols { + return None; + } + + let column: $col_type; + unsafe { + let ptr = self.slice_start.offset(self.col_pos as isize); + column = $col_base { + col: $slice_base::from_raw_parts(ptr, self.slice_rows, 1, self.row_stride as usize) + }; + } + self.col_pos += 1; + Some(column) + } + + fn last(self) -> Option { + if self.col_pos >= self.slice_cols { + return None; + } + + unsafe { + let ptr = self.slice_start.offset((self.slice_cols - 1) as isize); + Some($col_base { + col: $slice_base::from_raw_parts(ptr, self.slice_rows, 1, self.row_stride as usize) + }) + } + } + + fn nth(&mut self, n: usize) -> Option { + if self.col_pos + n >= self.slice_cols { + return None; + } + + let column: $col_type; + unsafe { + let ptr = self.slice_start.offset((self.col_pos + n) as isize); + column = $col_base { + col: $slice_base::from_raw_parts(ptr, self.slice_rows, 1, self.row_stride as usize) + } + } + self.col_pos += n + 1; + Some(column) + } + + fn count(self) -> usize { + self.slice_cols - self.col_pos + } + + fn size_hint(&self) -> (usize, Option) { + (self.slice_cols - self.col_pos, Some(self.slice_cols - self.col_pos)) + } +} + ); +); + +impl_col_iter!(Cols, Column<'a, T>, Column, MatrixSlice); +impl_col_iter!(ColsMut, ColumnMut<'a, T>, ColumnMut, MatrixSliceMut); + +impl<'a, T> ExactSizeIterator for Cols<'a, T> {} +impl<'a, T> ExactSizeIterator for ColsMut<'a, T> {} + macro_rules! impl_row_iter ( ($rows:ident, $row_type:ty, $row_base:ident, $slice_base:ident) => ( @@ -580,12 +648,182 @@ mod tests { assert_eq!((0, Some(0)), diags_iter.size_hint()); } + #[test] + fn test_matrix_cols() { + let mut a = matrix![0, 1, 2, 3; + 4, 5, 6, 7; + 8, 9, 10, 11]; + let data = [[0, 4, 8], [1, 5, 9], [2, 6, 10], [3, 7, 11]]; + + for (i, col) in a.col_iter().enumerate() { + for (j, value) in col.iter().enumerate() { + assert_eq!(data[i][j], *value); + } + } + + for (i, mut col) in a.col_iter_mut().enumerate() { + for (j, value) in col.iter_mut().enumerate() { + assert_eq!(data[i][j], *value); + } + } + + for mut col in a.col_iter_mut() { + for r in col.iter_mut() { + *r = 0; + } + } + + assert_eq!(a.into_vec(), vec![0; 12]); + } + + #[test] + fn test_matrix_slice_cols() { + let a = matrix![0, 1, 2, 3; + 4, 5, 6, 7; + 8, 9, 10, 11]; + + let b = MatrixSlice::from_matrix(&a, [0, 0], 3, 2); + + let data = [[0, 4, 8], [1, 5, 9]]; + + for (i, col) in b.col_iter().enumerate() { + for (j, value) in col.iter().enumerate() { + assert_eq!(data[i][j], *value); + } + } + } + + #[test] + fn test_matrix_slice_mut_cols() { + let mut a = matrix![0, 1, 2, 3; + 4, 5, 6, 7; + 8, 9, 10, 11]; + + { + let mut b = MatrixSliceMut::from_matrix(&mut a, [0, 0], 3, 2); + + let data = [[0, 4, 8], [1, 5, 9]]; + + for (i, col) in b.col_iter().enumerate() { + for (j, value) in col.iter().enumerate() { + assert_eq!(data[i][j], *value); + } + } + + for (i, mut col) in b.col_iter_mut().enumerate() { + for (j, value) in col.iter_mut().enumerate() { + assert_eq!(data[i][j], *value); + } + } + + for mut col in b.col_iter_mut() { + for r in col.iter_mut() { + *r = 0; + } + } + } + + assert_eq!(a.into_vec(), vec![0, 0, 2, 3, 0, 0, 6, 7, 0, 0, 10, 11]); + } + + #[test] + fn test_matrix_cols_nth() { + let a = matrix![0, 1, 2, 3; + 4, 5, 6, 7; + 8, 9, 10, 11]; + + let mut col_iter = a.col_iter(); + + let mut nth0 = col_iter.nth(0).unwrap().into_iter(); + + assert_eq!(0, *nth0.next().unwrap()); + assert_eq!(4, *nth0.next().unwrap()); + assert_eq!(8, *nth0.next().unwrap()); + + let mut nth1 = col_iter.nth(2).unwrap().into_iter(); + + assert_eq!(3, *nth1.next().unwrap()); + assert_eq!(7, *nth1.next().unwrap()); + assert_eq!(11, *nth1.next().unwrap()); + + assert!(col_iter.next().is_none()); + } + + #[test] + fn test_matrix_cols_last() { + let a = matrix![0, 1, 2, 3; + 4, 5, 6, 7; + 8, 9, 10, 11]; + + let mut col_iter = a.col_iter().last().unwrap().into_iter(); + + assert_eq!(3, *col_iter.next().unwrap()); + assert_eq!(7, *col_iter.next().unwrap()); + assert_eq!(11, *col_iter.next().unwrap()); + + let mut col_iter = a.col_iter(); + + col_iter.next(); + + let mut last_col_iter = col_iter.last().unwrap().into_iter(); + + assert_eq!(3, *last_col_iter.next().unwrap()); + assert_eq!(7, *last_col_iter.next().unwrap()); + assert_eq!(11, *last_col_iter.next().unwrap()); + + let mut col_iter = a.col_iter(); + + col_iter.next(); + col_iter.next(); + col_iter.next(); + col_iter.next(); + + assert!(col_iter.last().is_none()); + } + + #[test] + fn test_matrix_cols_count() { + let a = matrix![0, 1, 2; + 3, 4, 5; + 6, 7, 8]; + + let col_iter = a.col_iter(); + + assert_eq!(3, col_iter.count()); + + let mut col_iter_2 = a.col_iter(); + col_iter_2.next(); + assert_eq!(2, col_iter_2.count()); + } + + #[test] + fn test_matrix_cols_size_hint() { + let a = matrix![0, 1, 2; + 3, 4, 5; + 6, 7, 8]; + + let mut col_iter = a.col_iter(); + + assert_eq!((3, Some(3)), col_iter.size_hint()); + + col_iter.next(); + + assert_eq!((2, Some(2)), col_iter.size_hint()); + col_iter.next(); + col_iter.next(); + + assert_eq!((0, Some(0)), col_iter.size_hint()); + + assert!(col_iter.next().is_none()); + assert_eq!((0, Some(0)), col_iter.size_hint()); + } #[test] fn test_matrix_rows() { let mut a = matrix![0, 1, 2; 3, 4, 5; 6, 7, 8]; + let data = [[0, 1, 2], [3, 4, 5], [6, 7, 8]]; for (i, row) in a.row_iter().enumerate() { diff --git a/src/matrix/mod.rs b/src/matrix/mod.rs index fa08e05..e446584 100644 --- a/src/matrix/mod.rs +++ b/src/matrix/mod.rs @@ -15,14 +15,14 @@ use error::{Error, ErrorKind}; use utils; use vector::Vector; +mod base; mod decomposition; -mod impl_ops; +mod deref; mod impl_mat; -mod mat_mul; +mod impl_ops; mod iter; -mod deref; +mod mat_mul; mod slice; -mod base; mod permutation_matrix; mod impl_permutation_mul; @@ -133,29 +133,6 @@ pub struct RowMut<'a, T: 'a> { row: MatrixSliceMut<'a, T>, } - -// MAYBE WE SHOULD MOVE SOME OF THIS STUFF OUT -// - -impl<'a, T: 'a> Row<'a, T> { - /// Returns the row as a slice. - pub fn raw_slice(&self) -> &'a [T] { - unsafe { std::slice::from_raw_parts(self.row.as_ptr(), self.row.cols()) } - } -} - -impl<'a, T: 'a> RowMut<'a, T> { - /// Returns the row as a slice. - pub fn raw_slice(&self) -> &'a [T] { - unsafe { std::slice::from_raw_parts(self.row.as_ptr(), self.row.cols()) } - } - - /// Returns the row as a slice. - pub fn raw_slice_mut(&mut self) -> &'a mut [T] { - unsafe { std::slice::from_raw_parts_mut(self.row.as_mut_ptr(), self.row.cols()) } - } -} - /// Row iterator. #[derive(Debug)] pub struct Rows<'a, T: 'a> { @@ -178,6 +155,27 @@ pub struct RowsMut<'a, T: 'a> { _marker: PhantomData<&'a mut T>, } +// MAYBE WE SHOULD MOVE SOME OF THIS STUFF OUT + +impl<'a, T: 'a> Row<'a, T> { + /// Returns the row as a slice. + pub fn raw_slice(&self) -> &'a [T] { + unsafe { std::slice::from_raw_parts(self.row.as_ptr(), self.row.cols()) } + } +} + +impl<'a, T: 'a> RowMut<'a, T> { + /// Returns the row as a slice. + pub fn raw_slice(&self) -> &'a [T] { + unsafe { std::slice::from_raw_parts(self.row.as_ptr(), self.row.cols()) } + } + + /// Returns the row as a slice. + pub fn raw_slice_mut(&mut self) -> &'a mut [T] { + unsafe { std::slice::from_raw_parts_mut(self.row.as_mut_ptr(), self.row.cols()) } + } +} + /// Column of a matrix. /// /// This struct points to a `MatrixSlice` @@ -232,6 +230,28 @@ pub struct ColumnMut<'a, T: 'a> { col: MatrixSliceMut<'a, T>, } +/// Column iterator. +#[derive(Debug)] +pub struct Cols<'a, T: 'a> { + _marker: PhantomData<&'a T>, + col_pos: usize, + row_stride: isize, + slice_cols: usize, + slice_rows: usize, + slice_start: *const T, +} + +/// Mutable column iterator. +#[derive(Debug)] +pub struct ColsMut<'a, T: 'a> { + _marker: PhantomData<&'a mut T>, + col_pos: usize, + row_stride: isize, + slice_cols: usize, + slice_rows: usize, + slice_start: *mut T, +} + /// Diagonal offset (used by Diagonal iterator). #[derive(Debug, PartialEq)] pub enum DiagOffset {