diff --git a/src/stream/src/executor/top_n/group_top_n.rs b/src/stream/src/executor/top_n/group_top_n.rs index 204d93a2558ae..61a0de5c0a7f6 100644 --- a/src/stream/src/executor/top_n/group_top_n.rs +++ b/src/stream/src/executor/top_n/group_top_n.rs @@ -165,7 +165,7 @@ where let mut res_ops = Vec::with_capacity(self.limit); let mut res_rows = Vec::with_capacity(self.limit); let keys = K::build(&self.group_by, chunk.data_chunk())?; - let table_id_str = self.managed_state.state_table.table_id().to_string(); + let table_id_str = self.managed_state.table().table_id().to_string(); let actor_id_str = self.ctx.id.to_string(); let fragment_id_str = self.ctx.fragment_id.to_string(); for (r, group_cache_key) in chunk.rows_with_holes().zip_eq_debug(keys.iter()) { @@ -243,11 +243,7 @@ where } fn update_vnode_bitmap(&mut self, vnode_bitmap: Arc) { - let (_previous_vnode_bitmap, cache_may_stale) = self - .managed_state - .state_table - .update_vnode_bitmap(vnode_bitmap); - + let cache_may_stale = self.managed_state.update_vnode_bitmap(vnode_bitmap); if cache_may_stale { self.caches.clear(); } @@ -258,14 +254,13 @@ where } async fn init(&mut self, epoch: EpochPair) -> StreamExecutorResult<()> { - self.managed_state.state_table.init_epoch(epoch); + self.managed_state.init_epoch(epoch); Ok(()) } async fn handle_watermark(&mut self, watermark: Watermark) -> Option { if watermark.col_idx == self.group_by[0] { self.managed_state - .state_table .update_watermark(watermark.val.clone(), false); Some(watermark) } else { diff --git a/src/stream/src/executor/top_n/group_top_n_appendonly.rs b/src/stream/src/executor/top_n/group_top_n_appendonly.rs index bf8cbdb0a6134..3f185a581ef53 100644 --- a/src/stream/src/executor/top_n/group_top_n_appendonly.rs +++ b/src/stream/src/executor/top_n/group_top_n_appendonly.rs @@ -163,7 +163,7 @@ where let data_types = self.info().schema.data_types(); let row_deserializer = RowDeserializer::new(data_types.clone()); - let table_id_str = self.managed_state.state_table.table_id().to_string(); + let table_id_str = self.managed_state.table().table_id().to_string(); let actor_id_str = self.ctx.id.to_string(); let fragment_id_str = self.ctx.fragment_id.to_string(); for (r, group_cache_key) in chunk.rows_with_holes().zip_eq_debug(keys.iter()) { @@ -223,11 +223,7 @@ where } fn update_vnode_bitmap(&mut self, vnode_bitmap: Arc) { - let (_previous_vnode_bitmap, cache_may_stale) = self - .managed_state - .state_table - .update_vnode_bitmap(vnode_bitmap); - + let cache_may_stale = self.managed_state.update_vnode_bitmap(vnode_bitmap); if cache_may_stale { self.caches.clear(); } @@ -242,14 +238,13 @@ where } async fn init(&mut self, epoch: EpochPair) -> StreamExecutorResult<()> { - self.managed_state.state_table.init_epoch(epoch); + self.managed_state.init_epoch(epoch); Ok(()) } async fn handle_watermark(&mut self, watermark: Watermark) -> Option { if watermark.col_idx == self.group_by[0] { self.managed_state - .state_table .update_watermark(watermark.val.clone(), false); Some(watermark) } else { diff --git a/src/stream/src/executor/top_n/top_n_appendonly.rs b/src/stream/src/executor/top_n/top_n_appendonly.rs index 6392b0ac491fe..4aff28bb30bd2 100644 --- a/src/stream/src/executor/top_n/top_n_appendonly.rs +++ b/src/stream/src/executor/top_n/top_n_appendonly.rs @@ -140,7 +140,7 @@ where } async fn init(&mut self, epoch: EpochPair) -> StreamExecutorResult<()> { - self.managed_state.state_table.init_epoch(epoch); + self.managed_state.init_epoch(epoch); self.managed_state .init_topn_cache(NO_GROUP_KEY, &mut self.cache) .await diff --git a/src/stream/src/executor/top_n/top_n_cache.rs b/src/stream/src/executor/top_n/top_n_cache.rs index aed23760c332f..01253469ece34 100644 --- a/src/stream/src/executor/top_n/top_n_cache.rs +++ b/src/stream/src/executor/top_n/top_n_cache.rs @@ -28,6 +28,7 @@ use super::{CacheKey, GroupKey, ManagedTopNState}; use crate::executor::error::StreamExecutorResult; const TOPN_CACHE_HIGH_CAPACITY_FACTOR: usize = 2; +const TOPN_CACHE_MIN_CAPACITY: usize = 10; /// Cache for [`ManagedTopNState`]. /// @@ -58,6 +59,11 @@ pub struct TopNCache { /// Assumption: `limit != 0` pub limit: usize, + /// Number of rows corresponding to the current group. + /// This is a nice-to-have information. `None` means we don't know the row count, + /// but it doesn't prevent us from working correctly. + table_row_count: Option, + /// Data types for the full row. /// /// For debug formatting only. @@ -166,9 +172,11 @@ impl TopNCache { high_capacity: offset .checked_add(limit) .and_then(|v| v.checked_mul(TOPN_CACHE_HIGH_CAPACITY_FACTOR)) - .unwrap_or(usize::MAX), + .unwrap_or(usize::MAX) + .max(TOPN_CACHE_MIN_CAPACITY), offset, limit, + table_row_count: None, data_types, } } @@ -181,6 +189,21 @@ impl TopNCache { self.high.clear(); } + /// Get total count of entries in the cache. + pub fn len(&self) -> usize { + self.low.len() + self.middle.len() + self.high.len() + } + + pub(super) fn update_table_row_count(&mut self, table_row_count: usize) { + self.table_row_count = Some(table_row_count) + } + + fn table_row_count_matched(&self) -> bool { + self.table_row_count + .map(|n| n == self.len()) + .unwrap_or(false) + } + pub fn is_low_cache_full(&self) -> bool { assert!(self.low.len() <= self.offset); let full = self.low.len() == self.offset; @@ -219,6 +242,11 @@ impl TopNCache { self.high.len() >= self.high_capacity } + fn last_cache_key_before_high(&self) -> Option<&CacheKey> { + let middle_last_key = self.middle.last_key_value().map(|(k, _)| k); + middle_last_key.or_else(|| self.low.last_key_value().map(|(k, _)| k)) + } + /// Use this method instead of `self.high.insert` directly when possible. /// /// It only inserts into high cache if the key is smaller than the largest key in the high @@ -256,6 +284,10 @@ impl TopNCacheTrait for TopNCache { res_ops: &mut Vec, res_rows: &mut Vec, ) { + if let Some(row_count) = self.table_row_count.as_mut() { + *row_count += 1; + } + if !self.is_low_cache_full() { self.low.insert(cache_key, (&row).into()); return; @@ -318,6 +350,10 @@ impl TopNCacheTrait for TopNCache { res_ops: &mut Vec, res_rows: &mut Vec, ) -> StreamExecutorResult<()> { + if let Some(row_count) = self.table_row_count.as_mut() { + *row_count -= 1; + } + if self.is_middle_cache_full() && cache_key > *self.middle.last_key_value().unwrap().0 { // The row is in high self.high.remove(&cache_key); @@ -325,22 +361,22 @@ impl TopNCacheTrait for TopNCache { && (self.offset == 0 || cache_key > *self.low.last_key_value().unwrap().0) { // The row is in mid + self.middle.remove(&cache_key); + res_ops.push(Op::Delete); + res_rows.push((&row).into()); + // Try to fill the high cache if it is empty - if self.high.is_empty() { + if self.high.is_empty() && !self.table_row_count_matched() { managed_state .fill_high_cache( group_key, self, - Some(self.middle.last_key_value().unwrap().0.clone()), + self.last_cache_key_before_high().cloned(), self.high_capacity, ) .await?; } - self.middle.remove(&cache_key); - res_ops.push(Op::Delete); - res_rows.push((&row).into()); - // Bring one element, if any, from high cache to middle cache if !self.high.is_empty() { let high_first = self.high.pop_first().unwrap(); @@ -348,6 +384,8 @@ impl TopNCacheTrait for TopNCache { res_rows.push(high_first.1.clone()); self.middle.insert(high_first.0, high_first.1); } + + assert!(self.high.is_empty() || self.middle.len() == self.limit); } else { // The row is in low self.low.remove(&cache_key); @@ -360,12 +398,12 @@ impl TopNCacheTrait for TopNCache { self.low.insert(middle_first.0, middle_first.1); // Try to fill the high cache if it is empty - if self.high.is_empty() { + if self.high.is_empty() && !self.table_row_count_matched() { managed_state .fill_high_cache( group_key, self, - Some(self.middle.last_key_value().unwrap().0.clone()), + self.last_cache_key_before_high().cloned(), self.high_capacity, ) .await?; @@ -393,6 +431,10 @@ impl TopNCacheTrait for TopNCache { res_ops: &mut Vec, res_rows: &mut Vec, ) { + if let Some(row_count) = self.table_row_count.as_mut() { + *row_count += 1; + } + assert!( self.low.is_empty(), "Offset is not supported yet for WITH TIES, so low cache should be empty" @@ -482,8 +524,11 @@ impl TopNCacheTrait for TopNCache { res_ops: &mut Vec, res_rows: &mut Vec, ) -> StreamExecutorResult<()> { - // Since low cache is always empty for WITH_TIES, this unwrap is safe. + if let Some(row_count) = self.table_row_count.as_mut() { + *row_count -= 1; + } + // Since low cache is always empty for WITH_TIES, this unwrap is safe. let middle_last = self.middle.last_key_value().unwrap(); let middle_last_order_by = middle_last.0 .0.clone(); @@ -502,14 +547,12 @@ impl TopNCacheTrait for TopNCache { } // Try to fill the high cache if it is empty - if self.high.is_empty() { + if self.high.is_empty() && !self.table_row_count_matched() { managed_state .fill_high_cache( group_key, self, - self.middle - .last_key_value() - .map(|(key, _value)| key.clone()), + self.last_cache_key_before_high().cloned(), self.high_capacity, ) .await?; diff --git a/src/stream/src/executor/top_n/top_n_plain.rs b/src/stream/src/executor/top_n/top_n_plain.rs index fe1a078e2d5a1..01fa6722c02df 100644 --- a/src/stream/src/executor/top_n/top_n_plain.rs +++ b/src/stream/src/executor/top_n/top_n_plain.rs @@ -174,7 +174,7 @@ where } async fn init(&mut self, epoch: EpochPair) -> StreamExecutorResult<()> { - self.managed_state.state_table.init_epoch(epoch); + self.managed_state.init_epoch(epoch); self.managed_state .init_topn_cache(NO_GROUP_KEY, &mut self.cache) .await diff --git a/src/stream/src/executor/top_n/top_n_state.rs b/src/stream/src/executor/top_n/top_n_state.rs index 6885701e39179..ad51ed19b3fb5 100644 --- a/src/stream/src/executor/top_n/top_n_state.rs +++ b/src/stream/src/executor/top_n/top_n_state.rs @@ -13,9 +13,12 @@ // limitations under the License. use std::ops::Bound; +use std::sync::Arc; use futures::{pin_mut, StreamExt}; +use risingwave_common::buffer::Bitmap; use risingwave_common::row::{OwnedRow, Row, RowExt}; +use risingwave_common::types::ScalarImpl; use risingwave_common::util::epoch::EpochPair; use risingwave_storage::store::PrefetchOptions; use risingwave_storage::StateStore; @@ -31,7 +34,7 @@ use crate::executor::error::StreamExecutorResult; /// `group_key` is not included. pub struct ManagedTopNState { /// Relational table. - pub(crate) state_table: StateTable, + state_table: StateTable, /// Used for serializing pk into CacheKey. cache_key_serde: CacheKeySerde, @@ -57,6 +60,26 @@ impl ManagedTopNState { } } + /// Get the immutable reference of managed state table. + pub fn table(&self) -> &StateTable { + &self.state_table + } + + /// Init epoch for the managed state table. + pub fn init_epoch(&mut self, epoch: EpochPair) { + self.state_table.init_epoch(epoch) + } + + /// Update vnode bitmap of state table, returning `cache_may_stale`. + pub fn update_vnode_bitmap(&mut self, new_vnodes: Arc) -> bool { + self.state_table.update_vnode_bitmap(new_vnodes).1 + } + + /// Update watermark for the managed state table. + pub fn update_watermark(&mut self, watermark: ScalarImpl, eager_cleaning: bool) { + self.state_table.update_watermark(watermark, eager_cleaning) + } + pub fn insert(&mut self, value: impl Row) { self.state_table.insert(value); } @@ -121,6 +144,7 @@ impl ManagedTopNState { cache_size_limit: usize, ) -> StreamExecutorResult<()> { let cache = &mut topn_cache.high; + let sub_range: &(Bound, Bound) = &(Bound::Unbounded, Bound::Unbounded); let state_table_iter = self .state_table @@ -133,7 +157,12 @@ impl ManagedTopNState { ) .await?; pin_mut!(state_table_iter); + + let mut group_row_count = 0; + while let Some(item) = state_table_iter.next().await { + group_row_count += 1; + // Note(bugen): should first compare with start key before constructing TopNStateRow. let topn_row = self.get_topn_row(item?.into_owned_row(), group_key.len()); if let Some(start_key) = start_key.as_ref() @@ -141,15 +170,17 @@ impl ManagedTopNState { { continue; } - // let row= &topn_row.row; cache.insert(topn_row.cache_key, (&topn_row.row).into()); if cache.len() == cache_size_limit { break; } } + if WITH_TIES && topn_cache.is_high_cache_full() { let high_last_sort_key = topn_cache.high.last_key_value().unwrap().0 .0.clone(); while let Some(item) = state_table_iter.next().await { + group_row_count += 1; + let topn_row = self.get_topn_row(item?.into_owned_row(), group_key.len()); if topn_row.cache_key.0 == high_last_sort_key { topn_cache @@ -161,6 +192,11 @@ impl ManagedTopNState { } } + if state_table_iter.next().await.is_none() { + // We can only update the row count when we have seen all rows of the group in the table. + topn_cache.update_table_row_count(group_row_count); + } + Ok(()) } @@ -184,8 +220,12 @@ impl ManagedTopNState { ) .await?; pin_mut!(state_table_iter); + + let mut group_row_count = 0; + if topn_cache.offset > 0 { while let Some(item) = state_table_iter.next().await { + group_row_count += 1; let topn_row = self.get_topn_row(item?.into_owned_row(), group_key.len()); topn_cache .low @@ -198,6 +238,7 @@ impl ManagedTopNState { assert!(topn_cache.limit > 0, "topn cache limit should always > 0"); while let Some(item) = state_table_iter.next().await { + group_row_count += 1; let topn_row = self.get_topn_row(item?.into_owned_row(), group_key.len()); topn_cache .middle @@ -209,6 +250,7 @@ impl ManagedTopNState { if WITH_TIES && topn_cache.is_middle_cache_full() { let middle_last_sort_key = topn_cache.middle.last_key_value().unwrap().0 .0.clone(); while let Some(item) = state_table_iter.next().await { + group_row_count += 1; let topn_row = self.get_topn_row(item?.into_owned_row(), group_key.len()); if topn_row.cache_key.0 == middle_last_sort_key { topn_cache @@ -230,6 +272,7 @@ impl ManagedTopNState { while !topn_cache.is_high_cache_full() && let Some(item) = state_table_iter.next().await { + group_row_count += 1; let topn_row = self.get_topn_row(item?.into_owned_row(), group_key.len()); topn_cache .high @@ -238,6 +281,7 @@ impl ManagedTopNState { if WITH_TIES && topn_cache.is_high_cache_full() { let high_last_sort_key = topn_cache.high.last_key_value().unwrap().0 .0.clone(); while let Some(item) = state_table_iter.next().await { + group_row_count += 1; let topn_row = self.get_topn_row(item?.into_owned_row(), group_key.len()); if topn_row.cache_key.0 == high_last_sort_key { topn_cache @@ -249,6 +293,12 @@ impl ManagedTopNState { } } + if state_table_iter.next().await.is_none() { + // After trying to initially fill in the cache, all table entries are in the cache, + // we then get the precise table row count. + topn_cache.update_table_row_count(group_row_count); + } + Ok(()) }