From 918bf0fbe5fe2e98ed1becf50835cbdbfe8bf88c Mon Sep 17 00:00:00 2001 From: Tamir Duberstein Date: Fri, 6 Oct 2023 14:36:58 -0400 Subject: [PATCH] GroupMap: add fold_with This is a generalization of `fold` which takes a function rather than a value, which removes the need for a `Clone` bound. `fold` is implemented in terms of `fold_with`. --- src/grouping_map.rs | 51 ++++++++++++++++++++++++++++++++++++++++----- tests/quick.rs | 29 ++++++++++++++++++++++++++ 2 files changed, 75 insertions(+), 5 deletions(-) diff --git a/src/grouping_map.rs b/src/grouping_map.rs index affcb2121..8799c6ecb 100644 --- a/src/grouping_map.rs +++ b/src/grouping_map.rs @@ -115,6 +115,50 @@ where destination_map } + /// Groups elements from the `GroupingMap` source by key and applies `operation` to the elements + /// of each group sequentially, passing the previously accumulated value, a reference to the key + /// and the current element as arguments, and stores the results in a new map. + /// + /// `init` is called to obtain the initial value of each accumulator. + /// + /// `operation` is a function that is invoked on each element with the following parameters: + /// - the current value of the accumulator of the group; + /// - a reference to the key of the group this element belongs to; + /// - the element from the source being accumulated. + /// + /// Return a `HashMap` associating the key of each group with the result of folding that group's elements. + /// + /// ``` + /// use itertools::Itertools; + /// + /// #[derive(Debug, Default)] + /// struct Accumulator { + /// acc: usize, + /// } + /// + /// let lookup = (1..=7) + /// .into_grouping_map_by(|&n| n % 3) + /// .fold_with(|_key| Default::default(), |Accumulator { acc }, _key, val| { + /// let acc = acc + val; + /// Accumulator { acc } + /// }); + /// + /// assert_eq!(lookup[&0].acc, 3 + 6); + /// assert_eq!(lookup[&1].acc, 1 + 4 + 7); + /// assert_eq!(lookup[&2].acc, 2 + 5); + /// assert_eq!(lookup.len(), 3); + /// ``` + pub fn fold_with(self, mut init: FI, mut operation: FO) -> HashMap + where + FI: FnMut(&K) -> R, + FO: FnMut(R, &K, V) -> R, + { + self.aggregate(|acc, key, val| { + let acc = acc.unwrap_or_else(|| init(key)); + Some(operation(acc, key, val)) + }) + } + /// Groups elements from the `GroupingMap` source by key and applies `operation` to the elements /// of each group sequentially, passing the previously accumulated value, a reference to the key /// and the current element as arguments, and stores the results in a new map. @@ -140,15 +184,12 @@ where /// assert_eq!(lookup[&2], 2 + 5); /// assert_eq!(lookup.len(), 3); /// ``` - pub fn fold(self, init: R, mut operation: FO) -> HashMap + pub fn fold(self, init: R, operation: FO) -> HashMap where R: Clone, FO: FnMut(R, &K, V) -> R, { - self.aggregate(|acc, key, val| { - let acc = acc.unwrap_or_else(|| init.clone()); - Some(operation(acc, key, val)) - }) + self.fold_with(|_: &K| init.clone(), operation) } /// Groups elements from the `GroupingMap` source by key and applies `operation` to the elements diff --git a/tests/quick.rs b/tests/quick.rs index f0c8f2d9e..d65e274f3 100644 --- a/tests/quick.rs +++ b/tests/quick.rs @@ -1472,6 +1472,35 @@ quickcheck! { } } + fn correct_grouping_map_by_fold_with_modulo_key(a: Vec, modulo: u8) -> () { + #[derive(Debug, Default, PartialEq)] + struct Accumulator { + acc: u64, + } + + let modulo = if modulo == 0 { 1 } else { modulo } as u64; // Avoid `% 0` + let lookup = a.iter().map(|&b| b as u64) // Avoid overflows + .into_grouping_map_by(|i| i % modulo) + .fold_with(|_key| Default::default(), |Accumulator { acc }, &key, val| { + assert!(val % modulo == key); + let acc = acc + val; + Accumulator { acc } + }); + + let group_map_lookup = a.iter() + .map(|&b| b as u64) + .map(|i| (i % modulo, i)) + .into_group_map() + .into_iter() + .map(|(key, vals)| (key, vals.into_iter().sum())).map(|(key, acc)| (key,Accumulator { acc })) + .collect::>(); + assert_eq!(lookup, group_map_lookup); + + for (&key, &Accumulator { acc: sum }) in lookup.iter() { + assert_eq!(sum, a.iter().map(|&b| b as u64).filter(|&val| val % modulo == key).sum::()); + } + } + fn correct_grouping_map_by_fold_modulo_key(a: Vec, modulo: u8) -> () { let modulo = if modulo == 0 { 1 } else { modulo } as u64; // Avoid `% 0` let lookup = a.iter().map(|&b| b as u64) // Avoid overflows