diff --git a/src/stream/src/executor/top_n/topn_cache_state.rs b/src/stream/src/executor/top_n/topn_cache_state.rs index 064fc83ce2149..34f3719db44be 100644 --- a/src/stream/src/executor/top_n/topn_cache_state.rs +++ b/src/stream/src/executor/top_n/topn_cache_state.rs @@ -108,18 +108,25 @@ impl TopNCacheState { self.inner.range(range) } - pub fn extract_if(&mut self, pred: F) -> ExtractIf<'_, CacheKey, CompactedRow, F, Global> + pub fn extract_if<'a, F1>( + &'a mut self, + mut pred: F1, + ) -> TopNExtractIf<'a, impl FnMut(&CacheKey, &mut CompactedRow) -> bool> where - F: FnMut(&CacheKey, &mut CompactedRow) -> bool, + F1: 'a + FnMut(&CacheKey, &CompactedRow) -> bool, { - self.inner.extract_if(pred) + let pred_immut = move |key: &CacheKey, value: &mut CompactedRow| pred(key, value); + TopNExtractIf { + inner: self.inner.extract_if(pred_immut), + kv_heap_size: &mut self.kv_heap_size, + } } - pub fn retain(&mut self, f: F) + pub fn retain(&mut self, mut f: F) where - F: FnMut(&CacheKey, &mut CompactedRow) -> bool, + F: FnMut(&CacheKey, &CompactedRow) -> bool, { - self.inner.retain(f) + self.extract_if(|k, v| !f(k, v)).for_each(drop); } } @@ -153,3 +160,28 @@ impl fmt::Debug for TopNCacheState { self.inner.fmt(f) } } + +pub struct TopNExtractIf<'a, F> +where + F: FnMut(&CacheKey, &mut CompactedRow) -> bool, +{ + inner: ExtractIf<'a, CacheKey, CompactedRow, F, Global>, + kv_heap_size: &'a mut KvSize, +} + +impl<'a, F> Iterator for TopNExtractIf<'a, F> +where + F: 'a + FnMut(&CacheKey, &mut CompactedRow) -> bool, +{ + type Item = (CacheKey, CompactedRow); + + fn next(&mut self) -> Option { + self.inner + .next() + .inspect(|(k, v)| self.kv_heap_size.sub(k, v)) + } + + fn size_hint(&self) -> (usize, Option) { + self.inner.size_hint() + } +}