Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow creating matrix iter with an owned view #1315

Merged
merged 1 commit into from
Nov 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions src/base/conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,18 @@ impl<'a, T: Scalar, R: Dim, C: Dim, S: RawStorage<T, R, C>> IntoIterator
}
}

impl<'a, T: Scalar, R: Dim, C: Dim, RStride: Dim, CStride: Dim> IntoIterator
for Matrix<T, R, C, ViewStorage<'a, T, R, C, RStride, CStride>>
{
type Item = &'a T;
type IntoIter = MatrixIter<'a, T, R, C, ViewStorage<'a, T, R, C, RStride, CStride>>;

#[inline]
fn into_iter(self) -> Self::IntoIter {
MatrixIter::new_owned(self.data)
}
}

impl<'a, T: Scalar, R: Dim, C: Dim, S: RawStorageMut<T, R, C>> IntoIterator
for &'a mut Matrix<T, R, C, S>
{
Expand All @@ -110,6 +122,18 @@ impl<'a, T: Scalar, R: Dim, C: Dim, S: RawStorageMut<T, R, C>> IntoIterator
}
}

impl<'a, T: Scalar, R: Dim, C: Dim, RStride: Dim, CStride: Dim> IntoIterator
for Matrix<T, R, C, ViewStorageMut<'a, T, R, C, RStride, CStride>>
{
type Item = &'a mut T;
type IntoIter = MatrixIterMut<'a, T, R, C, ViewStorageMut<'a, T, R, C, RStride, CStride>>;

#[inline]
fn into_iter(self) -> Self::IntoIter {
MatrixIterMut::new_owned_mut(self.data)
}
}

impl<T: Scalar, const D: usize> From<[T; D]> for SVector<T, D> {
#[inline]
fn from(arr: [T; D]) -> Self {
Expand Down
152 changes: 121 additions & 31 deletions src/base/iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,29 @@ use std::mem;

use crate::base::dimension::{Dim, U1};
use crate::base::storage::{RawStorage, RawStorageMut};
use crate::base::{Matrix, MatrixView, MatrixViewMut, Scalar};
use crate::base::{Matrix, MatrixView, MatrixViewMut, Scalar, ViewStorage, ViewStorageMut};

#[derive(Clone, Debug)]
struct RawIter<Ptr, T, R: Dim, C: Dim, RStride: Dim, CStride: Dim> {
ptr: Ptr,
inner_ptr: Ptr,
inner_end: Ptr,
size: usize,
strides: (RStride, CStride),
_phantoms: PhantomData<(fn() -> T, R, C)>,
}

macro_rules! iterator {
(struct $Name:ident for $Storage:ident.$ptr: ident -> $Ptr:ty, $Ref:ty, $SRef: ty, $($derives:ident),* $(,)?) => {
/// An iterator through a dense matrix with arbitrary strides matrix.
#[derive($($derives),*)]
pub struct $Name<'a, T, R: Dim, C: Dim, S: 'a + $Storage<T, R, C>> {
ptr: $Ptr,
inner_ptr: $Ptr,
inner_end: $Ptr,
size: usize, // We can't use an end pointer here because a stride might be zero.
strides: (S::RStride, S::CStride),
_phantoms: PhantomData<($Ref, R, C, S)>,
}

// TODO: we need to specialize for the case where the matrix storage is owned (in which
// case the iterator is trivial because it does not have any stride).
impl<'a, T, R: Dim, C: Dim, S: 'a + $Storage<T, R, C>> $Name<'a, T, R, C, S> {
impl<T, R: Dim, C: Dim, RStride: Dim, CStride: Dim>
RawIter<$Ptr, T, R, C, RStride, CStride>
{
/// Creates a new iterator for the given matrix storage.
pub fn new(storage: $SRef) -> $Name<'a, T, R, C, S> {
fn new<'a, S: $Storage<T, R, C, RStride = RStride, CStride = CStride>>(
storage: $SRef,
) -> Self {
let shape = storage.shape();
let strides = storage.strides();
let inner_offset = shape.0.value() * strides.0.value();
Expand All @@ -55,7 +58,7 @@ macro_rules! iterator {
unsafe { ptr.add(inner_offset) }
};

$Name {
RawIter {
ptr,
inner_ptr: ptr,
inner_end,
Expand All @@ -66,11 +69,13 @@ macro_rules! iterator {
}
}

impl<'a, T, R: Dim, C: Dim, S: 'a + $Storage<T, R, C>> Iterator for $Name<'a, T, R, C, S> {
type Item = $Ref;
impl<T, R: Dim, C: Dim, RStride: Dim, CStride: Dim> Iterator
for RawIter<$Ptr, T, R, C, RStride, CStride>
{
type Item = $Ptr;

#[inline]
fn next(&mut self) -> Option<$Ref> {
fn next(&mut self) -> Option<Self::Item> {
unsafe {
if self.size == 0 {
None
Expand Down Expand Up @@ -102,10 +107,7 @@ macro_rules! iterator {
self.ptr = self.ptr.add(stride);
}

// We want either `& *last` or `&mut *last` here, depending
// on the mutability of `$Ref`.
#[allow(clippy::transmute_ptr_to_ref)]
Some(mem::transmute(old))
Some(old)
}
}
}
Expand All @@ -121,11 +123,11 @@ macro_rules! iterator {
}
}

impl<'a, T, R: Dim, C: Dim, S: 'a + $Storage<T, R, C>> DoubleEndedIterator
for $Name<'a, T, R, C, S>
impl<T, R: Dim, C: Dim, RStride: Dim, CStride: Dim> DoubleEndedIterator
for RawIter<$Ptr, T, R, C, RStride, CStride>
{
#[inline]
fn next_back(&mut self) -> Option<$Ref> {
fn next_back(&mut self) -> Option<Self::Item> {
unsafe {
if self.size == 0 {
None
Expand All @@ -152,24 +154,88 @@ macro_rules! iterator {
.ptr
.add((outer_remaining * outer_stride + inner_remaining * inner_stride));

// We want either `& *last` or `&mut *last` here, depending
// on the mutability of `$Ref`.
#[allow(clippy::transmute_ptr_to_ref)]
Some(mem::transmute(last))
Some(last)
}
}
}
}

impl<'a, T, R: Dim, C: Dim, S: 'a + $Storage<T, R, C>> ExactSizeIterator
for $Name<'a, T, R, C, S>
impl<T, R: Dim, C: Dim, RStride: Dim, CStride: Dim> ExactSizeIterator
for RawIter<$Ptr, T, R, C, RStride, CStride>
{
#[inline]
fn len(&self) -> usize {
self.size
}
}

impl<T, R: Dim, C: Dim, RStride: Dim, CStride: Dim> FusedIterator
for RawIter<$Ptr, T, R, C, RStride, CStride>
{
}

/// An iterator through a dense matrix with arbitrary strides matrix.
#[derive($($derives),*)]
pub struct $Name<'a, T, R: Dim, C: Dim, S: 'a + $Storage<T, R, C>> {
inner: RawIter<$Ptr, T, R, C, S::RStride, S::CStride>,
_marker: PhantomData<$Ref>,
}

impl<'a, T, R: Dim, C: Dim, S: 'a + $Storage<T, R, C>> $Name<'a, T, R, C, S> {
/// Creates a new iterator for the given matrix storage.
pub fn new(storage: $SRef) -> Self {
Self {
inner: RawIter::<$Ptr, T, R, C, S::RStride, S::CStride>::new(storage),
_marker: PhantomData,
}
}
}

impl<'a, T, R: Dim, C: Dim, S: 'a + $Storage<T, R, C>> Iterator for $Name<'a, T, R, C, S> {
type Item = $Ref;

#[inline(always)]
fn next(&mut self) -> Option<Self::Item> {
// We want either `& *last` or `&mut *last` here, depending
// on the mutability of `$Ref`.
#[allow(clippy::transmute_ptr_to_ref)]
self.inner.next().map(|ptr| unsafe { mem::transmute(ptr) })
}

#[inline(always)]
fn size_hint(&self) -> (usize, Option<usize>) {
self.inner.size_hint()
}

#[inline(always)]
fn count(self) -> usize {
self.inner.count()
}
}

impl<'a, T, R: Dim, C: Dim, S: 'a + $Storage<T, R, C>> DoubleEndedIterator
for $Name<'a, T, R, C, S>
{
#[inline(always)]
fn next_back(&mut self) -> Option<Self::Item> {
// We want either `& *last` or `&mut *last` here, depending
// on the mutability of `$Ref`.
#[allow(clippy::transmute_ptr_to_ref)]
self.inner
.next_back()
.map(|ptr| unsafe { mem::transmute(ptr) })
}
}

impl<'a, T, R: Dim, C: Dim, S: 'a + $Storage<T, R, C>> ExactSizeIterator
for $Name<'a, T, R, C, S>
{
#[inline(always)]
fn len(&self) -> usize {
self.inner.len()
}
}

impl<'a, T, R: Dim, C: Dim, S: 'a + $Storage<T, R, C>> FusedIterator
for $Name<'a, T, R, C, S>
{
Expand All @@ -180,6 +246,30 @@ macro_rules! iterator {
iterator!(struct MatrixIter for RawStorage.ptr -> *const T, &'a T, &'a S, Clone, Debug);
iterator!(struct MatrixIterMut for RawStorageMut.ptr_mut -> *mut T, &'a mut T, &'a mut S, Debug);

impl<'a, T, R: Dim, C: Dim, RStride: Dim, CStride: Dim>
MatrixIter<'a, T, R, C, ViewStorage<'a, T, R, C, RStride, CStride>>
{
/// Creates a new iterator for the given matrix storage view.
pub fn new_owned(storage: ViewStorage<'a, T, R, C, RStride, CStride>) -> Self {
Self {
inner: RawIter::<*const T, T, R, C, RStride, CStride>::new(&storage),
_marker: PhantomData,
}
}
}

impl<'a, T, R: Dim, C: Dim, RStride: Dim, CStride: Dim>
MatrixIterMut<'a, T, R, C, ViewStorageMut<'a, T, R, C, RStride, CStride>>
{
/// Creates a new iterator for the given matrix storage view.
pub fn new_owned_mut(mut storage: ViewStorageMut<'a, T, R, C, RStride, CStride>) -> Self {
Self {
inner: RawIter::<*mut T, T, R, C, RStride, CStride>::new(&mut storage),
_marker: PhantomData,
}
}
}

/*
*
* Row iterators.
Expand Down
Loading
Loading