diff --git a/src/lib.rs b/src/lib.rs index 1b7590da1..6d106dd07 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1608,6 +1608,7 @@ mod impl_raw_views; // Copy-on-write array methods mod impl_cow; +pub mod ma; /// A contiguous array shape of n dimensions. /// diff --git a/src/ma/mod.rs b/src/ma/mod.rs new file mode 100644 index 000000000..677625090 --- /dev/null +++ b/src/ma/mod.rs @@ -0,0 +1,218 @@ +use std::cmp::PartialEq; +use std::marker::PhantomData; +use std::ops::{Add, Index}; +use crate::{ArrayBase, Array1, RawData, Data, DataOwned, Dimension, NdIndex, Array, DataMut}; +use crate::iter::IndexedIter; + +/// Enum that represents a value that can potentially be masked. +/// We could potentially use `Option` for that, but that produces +/// weird `Option>` return types in iterators. +/// This type can be converted to `Option` using `into` method. +/// There is also a `PartialEq` implementation just to be able to +/// use it in `assert_eq!` statements. +#[derive(Clone, Copy, Debug, Eq)] +pub enum Masked { + Value(T), + Empty, +} + +impl Masked<&T> { + fn cloned(&self) -> Masked + where + T: Clone + { + match self { + Masked::Value(v) => Masked::Value((*v).clone()), + Masked::Empty => Masked::Empty, + } + } +} + +impl PartialEq for Masked +where + T: PartialEq +{ + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (Masked::Value(v1), Masked::Value(v2)) => v1.eq(v2), + (Masked::Empty, Masked::Empty) => true, + _ => false, + } + } +} + +impl From> for Option { + fn from(other: Masked) -> Option { + match other { + Masked::Value(v) => Some(v), + Masked::Empty => None, + } + } +} + +/// Every struct that can be used as a mask should implement this trait. +/// It has two generic parameters: +/// A - type of the values to be masked +/// D - dimension of the mask +/// The trait is implemented in such a way so that it could be implemented +/// by different types, not just variations of `ArrayBase`. For example, +/// we can implement a mask as a whitelist/blacklist of indices or as a +/// struct which treats some value or range of values as a mask. +pub trait Mask { + /// Given an index of the element and a reference to it, return masked + /// version of the reference. Accepting a pair allows masking by index, + /// value or both. + fn mask_ref<'a, I: NdIndex>(&self, pair: (I, &'a A)) -> Masked<&'a A>; + + // Probably we will need two more methods to be able to mask by value and + // by mutable reference: + + // fn mask>(&self, pair: (I, A)) -> Masked; + // fn mask_ref_mut<'a, I: NdIndex>(&self, pair: (I, &'a mut A)) -> Masked<&'a mut A>; + + fn mask_iter<'a, 'b: 'a, I>(&'b self, iter: I) -> MaskedIter<'a, A, Self, I, D> + where + I: Iterator, + D: Dimension, + D::Pattern: NdIndex, + { + MaskedIter::new(self, iter) + } +} + +/// Given two masks, generate their intersection. This may be required for any +/// binary operations with two masks. +pub trait JoinMask : Mask +where + M: Mask +{ + type Output: Mask; + + fn join(&self, other: &M) -> Self::Output; +} + +pub struct MaskedIter<'a, A: 'a, M, I, D> +where + I: Iterator, + D: Dimension, + D::Pattern: NdIndex, + M: ?Sized + Mask +{ + mask: &'a M, + iter: I, + _dim: PhantomData, +} + +impl<'a, A, M, I, D> MaskedIter<'a, A, M, I, D> +where + I: Iterator, + D: Dimension, + D::Pattern: NdIndex, + M: ?Sized + Mask +{ + fn new(mask: &'a M, iter: I) -> MaskedIter<'a, A, M, I, D> { + MaskedIter { mask, iter, _dim: PhantomData } + } +} + +impl<'a, A, M, I, D> Iterator for MaskedIter<'a, A, M, I, D> +where + I: Iterator, + D: Dimension, + D::Pattern: NdIndex, + M: Mask +{ + type Item = Masked<&'a A>; + + fn next(&mut self) -> Option { + let nex_val = self.iter.next()?; + Some(self.mask.mask_ref(nex_val)) + } +} + +/// First implementation of the mask as a bool array of the same shape. +impl Mask for ArrayBase +where + D: Dimension, + S: Data, +{ + fn mask_ref<'a, I: NdIndex>(&self, pair: (I, &'a A)) -> Masked<&'a A> { + if *self.index(pair.0) { Masked::Value(pair.1) } else { Masked::Empty } + } +} + +impl JoinMask> for ArrayBase +where + D: Dimension, + S1: Data, + S2: Data, +{ + type Output = Array; + + fn join(&self, other: &ArrayBase) -> Self::Output { + self & other + } +} + +/// Base type for masked array. `S` and `D` types are exactly the ones +/// of `ArrayBase`, `M` is a mask type. +pub struct MaskedArrayBase +where + S: RawData, + M: Mask, +{ + data: ArrayBase, + mask: M, +} + +impl MaskedArrayBase +where + S: RawData, + D: Dimension, + M: Mask, +{ + pub fn compressed(&self) -> Array1 + where + S::Elem: Clone, + S: Data, + D::Pattern: NdIndex, + { + self.iter() + .filter_map(|mv: Masked<&S::Elem>| mv.cloned().into()) + .collect() + } + + pub fn iter(&self) -> MaskedIter<'_, S::Elem, M, IndexedIter<'_, S::Elem, D>, D> + where + S: Data, + D::Pattern: NdIndex, + { + self.mask.mask_iter(self.data.indexed_iter()) + } +} + +impl Add> for MaskedArrayBase +where + A: Clone + Add, + S1: DataOwned + DataMut, + S2: Data, + D: Dimension, + M: Mask + JoinMask, +{ + type Output = MaskedArrayBase>::Output>; + + fn add(self, rhs: MaskedArrayBase) -> Self::Output { + MaskedArrayBase { + data: self.data + rhs.data, + mask: self.mask.join(&rhs.mask), + } + } +} + +pub fn array(data: ArrayBase, mask: M) -> MaskedArrayBase +where + S: RawData, + M: Mask, +{ + MaskedArrayBase { data, mask } +} diff --git a/tests/ma.rs b/tests/ma.rs new file mode 100644 index 000000000..273ad4f62 --- /dev/null +++ b/tests/ma.rs @@ -0,0 +1,44 @@ +use ndarray::{array}; +use ndarray::ma; + +#[cfg(test)] +mod test_array_mask { + use super::*; + + #[test] + fn test_iter() { + let data = array![1, 2, 3, 4]; + let mask = array![true, false, true, false]; + let arr = ma::array(data, mask); + let actual_vec: Vec<_> = arr.iter().collect(); + let expected_vec = vec![ + ma::Masked::Value(&1), + ma::Masked::Empty, + ma::Masked::Value(&3), + ma::Masked::Empty, + ]; + assert_eq!(actual_vec, expected_vec); + } + + #[test] + fn test_compressed() { + let arr = ma::array(array![1, 2, 3, 4], array![true, true, false, false]); + let res = arr.compressed(); + assert_eq!(res, array![1, 2]); + } + + #[test] + fn test_add() { + let arr1 = ma::array(array![1, 2, 3, 4], array![true, false, true, false]); + let arr2 = ma::array(array![4, 3, 2, 1], array![true, false, false, false]); + let res = arr1 + arr2; + let actual_vec: Vec<_> = res.iter().collect(); + let expected_vec = vec![ + ma::Masked::Value(&5), + ma::Masked::Empty, + ma::Masked::Empty, + ma::Masked::Empty, + ]; + assert_eq!(actual_vec, expected_vec); + } +}