From dd6124acc2c42c8052324e364844694d841988e3 Mon Sep 17 00:00:00 2001 From: andrei-papou Date: Fri, 13 Nov 2020 11:53:09 +0300 Subject: [PATCH 1/2] Draft of potential masked array implementation. --- src/lib.rs | 1 + src/ma/mod.rs | 219 ++++++++++++++++++++++++++++++++++++++++++++++++++ tests/ma.rs | 44 ++++++++++ 3 files changed, 264 insertions(+) create mode 100644 src/ma/mod.rs create mode 100644 tests/ma.rs diff --git a/src/lib.rs b/src/lib.rs index 1b7590da1..6d106dd07 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1608,6 +1608,7 @@ mod impl_raw_views; // Copy-on-write array methods mod impl_cow; +pub mod ma; /// A contiguous array shape of n dimensions. /// diff --git a/src/ma/mod.rs b/src/ma/mod.rs new file mode 100644 index 000000000..31e2ae143 --- /dev/null +++ b/src/ma/mod.rs @@ -0,0 +1,219 @@ +use std::cmp::{PartialEq}; +use std::ops::{Add, Index}; +use crate::{ArrayBase, Array1, Iter, RawData, Data, DataOwned, Dimension, NdIndex, Array, DataMut}; + +/// Enum that represents a value that can potentially be masked. +/// We could potentially use `Option` for that, but that produces +/// weird `Option>` return types in iterators. +/// This type can be converted to `Option` using `into` method. +/// There is also a `PartialEq` implementation just to be able to +/// use it in `assert_eq!` statements. +#[derive(Clone, Copy, Debug, Eq)] +pub enum Masked { + Value(T), + Empty, +} + +impl Masked<&T> { + fn cloned(&self) -> Masked + where + T: Clone + { + match self { + Masked::Value(v) => Masked::Value((*v).clone()), + Masked::Empty => Masked::Empty, + } + } +} + +impl PartialEq for Masked +where + T: PartialEq +{ + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (Masked::Value(v1), Masked::Value(v2)) => v1.eq(v2), + (Masked::Empty, Masked::Empty) => true, + _ => false, + } + } +} + +impl From> for Option { + fn from(other: Masked) -> Option { + match other { + Masked::Value(v) => Some(v), + Masked::Empty => None, + } + } +} + +/// Every struct that can be used as a mask should implement this trait. +/// It has two generic parameters: +/// A - type of the values to be masked +/// D - dimension of the mask +/// The trait is implemented in such a way so that it could be implemented +/// by different types, not just variations of `ArrayBase`. For example, +/// we can implement a mask as a whitelist/blacklist of indices or as a +/// struct which treats some value or range of values as a mask. +pub trait Mask { + /// Return the dimension of the mask, used only by iterators so far. + fn get_dim(&self) -> &D; + + /// Given an index of the element and a reference to it, return masked + /// version of the reference. Accepting a pair allows masking by index, + /// value or both. + fn mask_ref<'a, I: NdIndex>(&self, pair: (I, &'a A)) -> Masked<&'a A>; + + // Probably we will need two more methods to be able to mask by value and + // by mutable reference: + + // fn mask>(&self, pair: (I, A)) -> Masked; + // fn mask_ref_mut<'a, I: NdIndex>(&self, pair: (I, &'a mut A)) -> Masked<&'a mut A>; + + fn mask_iter<'a, 'b: 'a, I>(&'b self, iter: I) -> MaskedIter<'a, A, Self, I, D> + where + I: Iterator, + D: Dimension, + { + MaskedIter::new(self, iter, self.get_dim().first_index()) + } +} + +/// Given two masks, generate their intersection. This may be required for any +/// binary operations with two masks. +pub trait JoinMask : Mask +where + M: Mask +{ + type Output: Mask; + + fn join(&self, other: &M) -> Self::Output; +} + +pub struct MaskedIter<'a, A: 'a, M, I, D> +where + I: Iterator, + D: Dimension, + M: ?Sized + Mask +{ + mask: &'a M, + iter: I, + idx: Option, +} + +impl<'a, A, M, I, D> MaskedIter<'a, A, M, I, D> +where + I: Iterator, + D: Dimension, + M: ?Sized + Mask +{ + fn new(mask: &'a M, iter: I, start_idx: Option) -> MaskedIter<'a, A, M, I, D> { + MaskedIter { mask, iter, idx: start_idx } + } +} + +impl<'a, A, M, I, D> Iterator for MaskedIter<'a, A, M, I, D> +where + I: Iterator, + D: Dimension, + M: Mask +{ + type Item = Masked; + + fn next(&mut self) -> Option { + let nex_val = self.iter.next()?; + let elem = Some(self.mask.mask_ref((self.idx.clone()?, nex_val))); + self.idx = self.mask.get_dim().next_for(self.idx.clone()?); + elem + } +} + +/// First implementation of the mask as a bool array of the same shape. +impl Mask for ArrayBase +where + D: Dimension, + S: Data, +{ + fn get_dim(&self) -> &D { + &self.dim + } + + fn mask_ref<'a, I: NdIndex>(&self, pair: (I, &'a A)) -> Masked<&'a A> { + if *self.index(pair.0) { Masked::Value(pair.1) } else { Masked::Empty } + } +} + +impl JoinMask> for ArrayBase +where + D: Dimension, + S1: Data, + S2: Data, +{ + type Output = Array; + + fn join(&self, other: &ArrayBase) -> Self::Output { + self & other + } +} + +/// Base type for masked array. `S` and `D` types are exactly the ones +/// of `ArrayBase`, `M` is a mask type. +pub struct MaskedArrayBase +where + S: RawData, + M: Mask, +{ + data: ArrayBase, + mask: M, +} + +impl MaskedArrayBase +where + S: RawData, + D: Dimension, + M: Mask, +{ + pub fn compressed(&self) -> Array1 + where + S::Elem: Clone, + S: Data, + { + self.iter() + .filter_map(|mv: Masked<&S::Elem>| mv.cloned().into()) + .collect() + } + + pub fn iter(&self) -> MaskedIter<'_, S::Elem, M, Iter<'_, S::Elem, D>, D> + where + S: Data + { + self.mask.mask_iter(self.data.iter()) + } +} + +impl Add> for MaskedArrayBase +where + A: Clone + Add, + S1: DataOwned + DataMut, + S2: Data, + D: Dimension, + M: Mask + JoinMask, +{ + type Output = MaskedArrayBase>::Output>; + + fn add(self, rhs: MaskedArrayBase) -> Self::Output { + MaskedArrayBase { + data: self.data + rhs.data, + mask: self.mask.join(&rhs.mask), + } + } +} + +pub fn array(data: ArrayBase, mask: M) -> MaskedArrayBase +where + S: RawData, + M: Mask, +{ + MaskedArrayBase { data, mask } +} diff --git a/tests/ma.rs b/tests/ma.rs new file mode 100644 index 000000000..273ad4f62 --- /dev/null +++ b/tests/ma.rs @@ -0,0 +1,44 @@ +use ndarray::{array}; +use ndarray::ma; + +#[cfg(test)] +mod test_array_mask { + use super::*; + + #[test] + fn test_iter() { + let data = array![1, 2, 3, 4]; + let mask = array![true, false, true, false]; + let arr = ma::array(data, mask); + let actual_vec: Vec<_> = arr.iter().collect(); + let expected_vec = vec![ + ma::Masked::Value(&1), + ma::Masked::Empty, + ma::Masked::Value(&3), + ma::Masked::Empty, + ]; + assert_eq!(actual_vec, expected_vec); + } + + #[test] + fn test_compressed() { + let arr = ma::array(array![1, 2, 3, 4], array![true, true, false, false]); + let res = arr.compressed(); + assert_eq!(res, array![1, 2]); + } + + #[test] + fn test_add() { + let arr1 = ma::array(array![1, 2, 3, 4], array![true, false, true, false]); + let arr2 = ma::array(array![4, 3, 2, 1], array![true, false, false, false]); + let res = arr1 + arr2; + let actual_vec: Vec<_> = res.iter().collect(); + let expected_vec = vec![ + ma::Masked::Value(&5), + ma::Masked::Empty, + ma::Masked::Empty, + ma::Masked::Empty, + ]; + assert_eq!(actual_vec, expected_vec); + } +} From 2eeb7b5bd954e7a20016490253ab5952663e6553 Mon Sep 17 00:00:00 2001 From: andrei-papou Date: Fri, 13 Nov 2020 17:11:25 +0300 Subject: [PATCH 2/2] Simplified Mask trait by getting rid of `get_dim` --- src/ma/mod.rs | 47 +++++++++++++++++++++++------------------------ 1 file changed, 23 insertions(+), 24 deletions(-) diff --git a/src/ma/mod.rs b/src/ma/mod.rs index 31e2ae143..677625090 100644 --- a/src/ma/mod.rs +++ b/src/ma/mod.rs @@ -1,6 +1,8 @@ -use std::cmp::{PartialEq}; +use std::cmp::PartialEq; +use std::marker::PhantomData; use std::ops::{Add, Index}; -use crate::{ArrayBase, Array1, Iter, RawData, Data, DataOwned, Dimension, NdIndex, Array, DataMut}; +use crate::{ArrayBase, Array1, RawData, Data, DataOwned, Dimension, NdIndex, Array, DataMut}; +use crate::iter::IndexedIter; /// Enum that represents a value that can potentially be masked. /// We could potentially use `Option` for that, but that produces @@ -57,9 +59,6 @@ impl From> for Option { /// we can implement a mask as a whitelist/blacklist of indices or as a /// struct which treats some value or range of values as a mask. pub trait Mask { - /// Return the dimension of the mask, used only by iterators so far. - fn get_dim(&self) -> &D; - /// Given an index of the element and a reference to it, return masked /// version of the reference. Accepting a pair allows masking by index, /// value or both. @@ -73,10 +72,11 @@ pub trait Mask { fn mask_iter<'a, 'b: 'a, I>(&'b self, iter: I) -> MaskedIter<'a, A, Self, I, D> where - I: Iterator, + I: Iterator, D: Dimension, + D::Pattern: NdIndex, { - MaskedIter::new(self, iter, self.get_dim().first_index()) + MaskedIter::new(self, iter) } } @@ -93,39 +93,40 @@ where pub struct MaskedIter<'a, A: 'a, M, I, D> where - I: Iterator, + I: Iterator, D: Dimension, + D::Pattern: NdIndex, M: ?Sized + Mask { mask: &'a M, iter: I, - idx: Option, + _dim: PhantomData, } impl<'a, A, M, I, D> MaskedIter<'a, A, M, I, D> where - I: Iterator, + I: Iterator, D: Dimension, + D::Pattern: NdIndex, M: ?Sized + Mask { - fn new(mask: &'a M, iter: I, start_idx: Option) -> MaskedIter<'a, A, M, I, D> { - MaskedIter { mask, iter, idx: start_idx } + fn new(mask: &'a M, iter: I) -> MaskedIter<'a, A, M, I, D> { + MaskedIter { mask, iter, _dim: PhantomData } } } impl<'a, A, M, I, D> Iterator for MaskedIter<'a, A, M, I, D> where - I: Iterator, + I: Iterator, D: Dimension, + D::Pattern: NdIndex, M: Mask { - type Item = Masked; + type Item = Masked<&'a A>; fn next(&mut self) -> Option { let nex_val = self.iter.next()?; - let elem = Some(self.mask.mask_ref((self.idx.clone()?, nex_val))); - self.idx = self.mask.get_dim().next_for(self.idx.clone()?); - elem + Some(self.mask.mask_ref(nex_val)) } } @@ -135,10 +136,6 @@ where D: Dimension, S: Data, { - fn get_dim(&self) -> &D { - &self.dim - } - fn mask_ref<'a, I: NdIndex>(&self, pair: (I, &'a A)) -> Masked<&'a A> { if *self.index(pair.0) { Masked::Value(pair.1) } else { Masked::Empty } } @@ -178,17 +175,19 @@ where where S::Elem: Clone, S: Data, + D::Pattern: NdIndex, { self.iter() .filter_map(|mv: Masked<&S::Elem>| mv.cloned().into()) .collect() } - pub fn iter(&self) -> MaskedIter<'_, S::Elem, M, Iter<'_, S::Elem, D>, D> + pub fn iter(&self) -> MaskedIter<'_, S::Elem, M, IndexedIter<'_, S::Elem, D>, D> where - S: Data + S: Data, + D::Pattern: NdIndex, { - self.mask.mask_iter(self.data.iter()) + self.mask.mask_iter(self.data.indexed_iter()) } }