diff --git a/e2e_test/batch/functions/to_char.slt.part b/e2e_test/batch/functions/to_char.slt.part index b9322a6003143..b4d10ec34d49c 100644 --- a/e2e_test/batch/functions/to_char.slt.part +++ b/e2e_test/batch/functions/to_char.slt.part @@ -1,3 +1,6 @@ +statement ok +SET RW_IMPLICIT_FLUSH TO true; + query T SELECT to_char(timestamp '2002-04-20 17:31:12.66', 'HH12:MI:SS') ---- @@ -66,3 +69,88 @@ select to_char(tsz, 'YYYY-MM-DD HH24:MI:SS TZH:TZM') from t order by tsz; statement ok drop table t; + + +query T +select to_char('-20459year -256 days -120hours 866seconds'::interval, 'YYYY IYYY YY IY MM DD PM pm HH HH12 HH24 MI SS'); +---- +-20459 -20460 -59 -60 00 -256 AM am -11 -11 -119 -45 -34 + +query T +select to_char('0year -256 days -120hours'::interval, 'YYYY IYYY YY IY MM DD PM pm HH HH12 HH24 MI SS'); +---- +0000 -001 00 -1 00 -256 AM am 012 012 -120 00 00 + +query T +select to_char('0year 0 days 0hours'::interval, 'YYYY IYYY YY IY MM DD PM pm HH12 HH24 MI SS'); +---- +0000 -001 00 -1 00 00 AM am 12 00 00 00 + +query T +select to_char('1year 1month 1day 1hours 1minute 1second'::interval, 'YYYY IYYY YY IY MM DD PM pm HH12 HH24 MI SS MS US'); +---- +0001 0001 01 01 01 01 AM am 01 01 01 01 000 000000 + +query T +select to_char('-1year -1month -1day -1hours -1minute -1second'::interval, 'YYYY IYYY YY IY MM DD PM pm HH12 HH24 MI SS MS US'); +---- +-0001 -0002 -01 -02 -01 -1 AM am -01 -01 -01 -01 000 000000 + +query T +select to_char('23:22:57.124562'::interval, 'HH12 MI SS MS US'); +---- +11 22 57 124 124562 + +query T +select to_char('-23:22:57.124562'::interval, 'HH12 MI SS MS US'); +---- +-11 -22 -57 -124 -124562 + +query error +select to_char('1year 1month 1day 1hours 1minute 1second'::interval, 'IY MM DD AM HH12 MM SS tzhtzm'); +---- +db error: ERROR: Failed to run the query + +Caused by these errors (recent errors listed first): + 1: Expr error + 2: Invalid parameter pattern: invalid format specification for an interval value, HINT: Intervals are not tied to specific calendar dates. + + +query error +select to_char('1year 1month 1day 1hours 1minute 1second'::interval, 'IY MM DD AM HH12 MI SS TZH:TZM'); +---- +db error: ERROR: Failed to run the query + +Caused by these errors (recent errors listed first): + 1: Expr error + 2: Invalid parameter pattern: invalid format specification for an interval value, HINT: Intervals are not tied to specific calendar dates. + + +query error +select to_char('1year 1month 1day 1hours 1minute 1second'::interval, 'IY MM DD AM HH12 MI SS TZH'); +---- +db error: ERROR: Failed to run the query + +Caused by these errors (recent errors listed first): + 1: Expr error + 2: Invalid parameter pattern: invalid format specification for an interval value, HINT: Intervals are not tied to specific calendar dates. + + +query error +select to_char('1year 1month 1day 1hours 1minute 1second'::interval, 'IY MM DD AM HH12 MI SS Month'); +---- +db error: ERROR: Failed to run the query + +Caused by these errors (recent errors listed first): + 1: Expr error + 2: Invalid parameter pattern: invalid format specification for an interval value, HINT: Intervals are not tied to specific calendar dates. + + +query error +select to_char('1year 1month 1day 1hours 1minute 1second'::interval, 'IY MM DD AM HH12 MI SS Mon'); +---- +db error: ERROR: Failed to run the query + +Caused by these errors (recent errors listed first): + 1: Expr error + 2: Invalid parameter pattern: invalid format specification for an interval value, HINT: Intervals are not tied to specific calendar dates. diff --git a/src/expr/impl/src/scalar/to_char.rs b/src/expr/impl/src/scalar/to_char.rs index 4d4edb2d390ba..2fd488d76ae9d 100644 --- a/src/expr/impl/src/scalar/to_char.rs +++ b/src/expr/impl/src/scalar/to_char.rs @@ -16,14 +16,24 @@ use std::fmt::{Debug, Write}; use std::sync::LazyLock; use aho_corasick::{AhoCorasick, AhoCorasickBuilder}; -use chrono::format::StrftimeItems; -use risingwave_common::types::{Timestamp, Timestamptz}; +use chrono::format::{Item, StrftimeItems}; +use chrono::{Datelike, NaiveDate}; +use risingwave_common::types::{Interval, Timestamp, Timestamptz}; use risingwave_expr::{function, ExprError, Result}; use super::timestamptz::time_zone_err; +use crate::scalar::arithmetic_op::timestamp_interval_add; type Pattern<'a> = Vec>; +#[inline(always)] +fn invalid_pattern_err() -> ExprError { + ExprError::InvalidParam { + name: "pattern", + reason: "invalid format specification for an interval value, HINT: Intervals are not tied to specific calendar dates.".into(), + } +} + self_cell::self_cell! { pub struct ChronoPattern { owner: String, @@ -97,10 +107,73 @@ impl ChronoPattern { .expect("failed to build an Aho-Corasick automaton") }); + ChronoPattern::compile_inner(tmpl, PATTERNS, &AC) + } + + pub fn compile_for_interval(tmpl: &str) -> ChronoPattern { + // mapping from pg pattern to chrono pattern + // pg pattern: https://www.postgresql.org/docs/current/functions-formatting.html + // chrono pattern: https://docs.rs/chrono/latest/chrono/format/strftime/index.html + const PATTERNS: &[(&str, &str)] = &[ + ("HH24", "%H"), + ("hh24", "%H"), + ("HH12", "%I"), + ("hh12", "%I"), + ("HH", "%I"), + ("hh", "%I"), + ("AM", "%p"), + ("PM", "%p"), + ("am", "%P"), + ("pm", "%P"), + ("MI", "%M"), + ("mi", "%M"), + ("SS", "%S"), + ("ss", "%S"), + ("YYYY", "%Y"), + ("yyyy", "%Y"), + ("YY", "%y"), + ("yy", "%y"), + ("IYYY", "%G"), + ("iyyy", "%G"), + ("IY", "%g"), + ("iy", "%g"), + ("MM", "%m"), + ("mm", "%m"), + ("Month", "%B"), + ("Mon", "%b"), + ("DD", "%d"), + ("dd", "%d"), + ("US", "%.6f"), /* "%6f" and "%3f" are converted to private data structures in chrono, so we use "%.6f" and "%.3f" instead. */ + ("us", "%.6f"), + ("MS", "%.3f"), + ("ms", "%.3f"), + ("TZH:TZM", "%:z"), + ("tzh:tzm", "%:z"), + ("TZHTZM", "%z"), + ("tzhtzm", "%z"), + ("TZH", "%#z"), + ("tzh", "%#z"), + ]; + // build an Aho-Corasick automaton for fast matching + static AC: LazyLock = LazyLock::new(|| { + AhoCorasickBuilder::new() + .ascii_case_insensitive(false) + .match_kind(aho_corasick::MatchKind::LeftmostLongest) + .build(PATTERNS.iter().map(|(k, _)| k)) + .expect("failed to build an Aho-Corasick automaton") + }); + ChronoPattern::compile_inner(tmpl, PATTERNS, &AC) + } + + fn compile_inner( + tmpl: &str, + patterns: &[(&str, &str)], + ac: &LazyLock, + ) -> ChronoPattern { // replace all pg patterns with chrono patterns let mut chrono_tmpl = String::new(); - AC.replace_all_with(tmpl, &mut chrono_tmpl, |mat, _, dst| { - dst.push_str(PATTERNS[mat.pattern()].1); + ac.replace_all_with(tmpl, &mut chrono_tmpl, |mat, _, dst| { + dst.push_str(patterns[mat.pattern()].1); true }); tracing::debug!(tmpl, chrono_tmpl, "compile_pattern_to_chrono"); @@ -138,3 +211,192 @@ fn timestamptz_to_char3( write!(writer, "{}", format).unwrap(); Ok(()) } + +#[function( + "to_char(interval, varchar) -> varchar", + prebuild = "ChronoPattern::compile_for_interval($1)" +)] +fn interval_to_char( + interval: Interval, + pattern: &ChronoPattern, + writer: &mut impl Write, +) -> Result<()> { + for iter in pattern.borrow_dependent() { + format_inner(writer, interval, iter)?; + } + Ok(()) +} + +fn adjust_to_iso_year(interval: Interval) -> Result { + let start = risingwave_common::types::Timestamp( + NaiveDate::from_ymd_opt(0, 1, 1) + .unwrap() + .and_hms_opt(0, 0, 0) + .unwrap(), + ); + let interval = Interval::from_month_day_usec(interval.months(), interval.days(), 0); + let date = timestamp_interval_add(start, interval)?; + Ok(date.0.iso_week().year()) +} + +fn format_inner(w: &mut impl Write, interval: Interval, item: &Item<'_>) -> Result<()> { + match *item { + Item::Literal(s) | Item::Space(s) => { + w.write_str(s).unwrap(); + Ok(()) + } + Item::OwnedLiteral(ref s) | Item::OwnedSpace(ref s) => { + w.write_str(s).unwrap(); + Ok(()) + } + Item::Numeric(ref spec, _) => { + use chrono::format::Numeric::*; + match *spec { + Year => { + let year = interval.years_field(); + if year < 0 { + write!(w, "{:+05}", year).unwrap(); + } else { + write!(w, "{:04}", year).unwrap(); + } + } + YearMod100 => { + let year = interval.years_field(); + if year % 100 < 0 { + let year = -((-year) % 100); + write!(w, "{:+03}", year).unwrap(); + } else { + let year = year % 100; + write!(w, "{:02}", year).unwrap(); + } + } + IsoYear => { + let iso_year = adjust_to_iso_year(interval)?; + if interval.years_field() < 0 { + write!(w, "{:+05}", iso_year).unwrap(); + } else { + write!(w, "{:04}", iso_year).unwrap(); + } + } + IsoYearMod100 => { + let iso_year = adjust_to_iso_year(interval)?; + if interval.years_field() % 100 < 0 { + let iso_year = -((-iso_year) % 100); + write!(w, "{:+03}", iso_year).unwrap(); + } else { + let iso_year = iso_year % 100; + write!(w, "{:02}", iso_year).unwrap(); + } + } + Month => { + let month = interval.months_field(); + if month < 0 { + write!(w, "{:+03}", month).unwrap(); + } else { + write!(w, "{:02}", month).unwrap(); + } + } + Day => { + let day = interval.days_field(); + if day < 0 { + write!(w, "{:+02}", day).unwrap(); + } else { + write!(w, "{:02}", day).unwrap(); + } + } + Hour => { + let hour = interval.hours_field(); + if hour < 0 { + write!(w, "{:+03}", hour).unwrap(); + } else { + write!(w, "{:02}", hour).unwrap(); + } + } + Hour12 => { + let hour = interval.hours_field(); + if hour < 0 { + // here to align with postgres, we format -0 as 012. + let hour = -(-hour) % 12; + if hour == 0 { + w.write_str("012").unwrap(); + } else { + write!(w, "{:+03}", hour).unwrap(); + } + } else { + let hour = if hour % 12 == 0 { 12 } else { hour % 12 }; + write!(w, "{:02}", hour).unwrap(); + } + } + Minute => { + let minute = interval.usecs() / 1_000_000 / 60; + if minute % 60 < 0 { + let minute = -((-minute) % 60); + write!(w, "{:+03}", minute).unwrap(); + } else { + let minute = minute % 60; + write!(w, "{:02}", minute).unwrap(); + } + } + Second => { + let second = interval.usecs() / 1_000_000; + if second % 60 < 0 { + let second = -((-second) % 60); + write!(w, "{:+03}", second).unwrap(); + } else { + let second = second % 60; + write!(w, "{:02}", second).unwrap(); + } + } + Nanosecond | Ordinal | WeekdayFromMon | NumDaysFromSun | IsoWeek | WeekFromSun + | WeekFromMon | IsoYearDiv100 | Timestamp | YearDiv100 | Internal(_) => { + unreachable!() + } + } + Ok(()) + } + Item::Fixed(ref spec) => { + use chrono::format::Fixed::*; + match *spec { + LowerAmPm => { + if interval.hours_field() % 24 >= 12 { + w.write_str("pm").unwrap(); + } else { + w.write_str("am").unwrap(); + } + Ok(()) + } + UpperAmPm => { + if interval.hours_field() % 24 >= 12 { + w.write_str("PM").unwrap(); + } else { + w.write_str("AM").unwrap(); + } + Ok(()) + } + Nanosecond3 => { + let usec = interval.usecs() % 1_000_000; + write!(w, "{:03}", usec / 1000).unwrap(); + Ok(()) + } + Nanosecond6 => { + let usec = interval.usecs() % 1_000_000; + write!(w, "{:06}", usec).unwrap(); + Ok(()) + } + Internal(_) | ShortMonthName | LongMonthName | TimezoneOffset | TimezoneOffsetZ + | TimezoneOffsetColon => Err(invalid_pattern_err()), + ShortWeekdayName + | LongWeekdayName + | TimezoneName + | TimezoneOffsetDoubleColon + | TimezoneOffsetTripleColon + | TimezoneOffsetColonZ + | Nanosecond + | Nanosecond9 + | RFC2822 + | RFC3339 => unreachable!(), + } + } + Item::Error => Err(invalid_pattern_err()), + } +} diff --git a/src/stream/src/executor/now.rs b/src/stream/src/executor/now.rs index 8acf4806d0f6a..621c0e9706c3d 100644 --- a/src/stream/src/executor/now.rs +++ b/src/stream/src/executor/now.rs @@ -23,6 +23,7 @@ use risingwave_common::row::{self, OwnedRow}; use risingwave_common::types::{DataType, Datum}; use risingwave_storage::StateStore; use tokio::sync::mpsc::UnboundedReceiver; +use tokio_stream::wrappers::UnboundedReceiverStream; use super::{ Barrier, BoxedMessageStream, Executor, ExecutorInfo, Message, Mutation, PkIndicesRef, @@ -55,7 +56,7 @@ impl NowExecutor { #[try_stream(ok = Message, error = StreamExecutorError)] async fn into_stream(self) { let Self { - mut barrier_receiver, + barrier_receiver, mut state_table, info, .. @@ -68,45 +69,60 @@ impl NowExecutor { // Whether the first barrier is handled and `last_timestamp` is initialized. let mut initialized = false; - while let Some(barrier) = barrier_receiver.recv().await { - if !initialized { - // Handle the first barrier. - state_table.init_epoch(barrier.epoch); - let state_row = { - let sub_range: &(Bound, Bound) = &(Unbounded, Unbounded); - let data_iter = state_table - .iter_with_prefix(row::empty(), sub_range, Default::default()) - .await?; - pin_mut!(data_iter); - if let Some(keyed_row) = data_iter.next().await { - Some(keyed_row?) - } else { - None - } - }; - last_timestamp = state_row.and_then(|row| row[0].clone()); - paused = barrier.is_pause_on_startup(); - initialized = true; - } else if paused { - // Assert that no data is updated. - state_table.commit_no_data_expected(barrier.epoch); - } else { - state_table.commit(barrier.epoch).await?; + const MAX_MERGE_BARRIER_SIZE: usize = 64; + + #[for_await] + for barriers in + UnboundedReceiverStream::new(barrier_receiver).ready_chunks(MAX_MERGE_BARRIER_SIZE) + { + let mut timestamp = None; + if barriers.len() > 1 { + warn!( + "handle multiple barriers at once in now executor: {}", + barriers.len() + ); } + for barrier in barriers { + if !initialized { + // Handle the first barrier. + state_table.init_epoch(barrier.epoch); + let state_row = { + let sub_range: &(Bound, Bound) = + &(Unbounded, Unbounded); + let data_iter = state_table + .iter_with_prefix(row::empty(), sub_range, Default::default()) + .await?; + pin_mut!(data_iter); + if let Some(keyed_row) = data_iter.next().await { + Some(keyed_row?) + } else { + None + } + }; + last_timestamp = state_row.and_then(|row| row[0].clone()); + paused = barrier.is_pause_on_startup(); + initialized = true; + } else if paused { + // Assert that no data is updated. + state_table.commit_no_data_expected(barrier.epoch); + } else { + state_table.commit(barrier.epoch).await?; + } - // Extract timestamp from the current epoch. - let timestamp = Some(barrier.get_curr_epoch().as_scalar()); + // Extract timestamp from the current epoch. + timestamp = Some(barrier.get_curr_epoch().as_scalar()); - // Update paused state. - if let Some(mutation) = barrier.mutation.as_deref() { - match mutation { - Mutation::Pause => paused = true, - Mutation::Resume => paused = false, - _ => {} + // Update paused state. + if let Some(mutation) = barrier.mutation.as_deref() { + match mutation { + Mutation::Pause => paused = true, + Mutation::Resume => paused = false, + _ => {} + } } - } - yield Message::Barrier(barrier.clone()); + yield Message::Barrier(barrier); + } // Do not yield any messages if paused. if paused { diff --git a/src/stream/src/executor/top_n/group_top_n.rs b/src/stream/src/executor/top_n/group_top_n.rs index 204d93a2558ae..61a0de5c0a7f6 100644 --- a/src/stream/src/executor/top_n/group_top_n.rs +++ b/src/stream/src/executor/top_n/group_top_n.rs @@ -165,7 +165,7 @@ where let mut res_ops = Vec::with_capacity(self.limit); let mut res_rows = Vec::with_capacity(self.limit); let keys = K::build(&self.group_by, chunk.data_chunk())?; - let table_id_str = self.managed_state.state_table.table_id().to_string(); + let table_id_str = self.managed_state.table().table_id().to_string(); let actor_id_str = self.ctx.id.to_string(); let fragment_id_str = self.ctx.fragment_id.to_string(); for (r, group_cache_key) in chunk.rows_with_holes().zip_eq_debug(keys.iter()) { @@ -243,11 +243,7 @@ where } fn update_vnode_bitmap(&mut self, vnode_bitmap: Arc) { - let (_previous_vnode_bitmap, cache_may_stale) = self - .managed_state - .state_table - .update_vnode_bitmap(vnode_bitmap); - + let cache_may_stale = self.managed_state.update_vnode_bitmap(vnode_bitmap); if cache_may_stale { self.caches.clear(); } @@ -258,14 +254,13 @@ where } async fn init(&mut self, epoch: EpochPair) -> StreamExecutorResult<()> { - self.managed_state.state_table.init_epoch(epoch); + self.managed_state.init_epoch(epoch); Ok(()) } async fn handle_watermark(&mut self, watermark: Watermark) -> Option { if watermark.col_idx == self.group_by[0] { self.managed_state - .state_table .update_watermark(watermark.val.clone(), false); Some(watermark) } else { diff --git a/src/stream/src/executor/top_n/group_top_n_appendonly.rs b/src/stream/src/executor/top_n/group_top_n_appendonly.rs index bf8cbdb0a6134..3f185a581ef53 100644 --- a/src/stream/src/executor/top_n/group_top_n_appendonly.rs +++ b/src/stream/src/executor/top_n/group_top_n_appendonly.rs @@ -163,7 +163,7 @@ where let data_types = self.info().schema.data_types(); let row_deserializer = RowDeserializer::new(data_types.clone()); - let table_id_str = self.managed_state.state_table.table_id().to_string(); + let table_id_str = self.managed_state.table().table_id().to_string(); let actor_id_str = self.ctx.id.to_string(); let fragment_id_str = self.ctx.fragment_id.to_string(); for (r, group_cache_key) in chunk.rows_with_holes().zip_eq_debug(keys.iter()) { @@ -223,11 +223,7 @@ where } fn update_vnode_bitmap(&mut self, vnode_bitmap: Arc) { - let (_previous_vnode_bitmap, cache_may_stale) = self - .managed_state - .state_table - .update_vnode_bitmap(vnode_bitmap); - + let cache_may_stale = self.managed_state.update_vnode_bitmap(vnode_bitmap); if cache_may_stale { self.caches.clear(); } @@ -242,14 +238,13 @@ where } async fn init(&mut self, epoch: EpochPair) -> StreamExecutorResult<()> { - self.managed_state.state_table.init_epoch(epoch); + self.managed_state.init_epoch(epoch); Ok(()) } async fn handle_watermark(&mut self, watermark: Watermark) -> Option { if watermark.col_idx == self.group_by[0] { self.managed_state - .state_table .update_watermark(watermark.val.clone(), false); Some(watermark) } else { diff --git a/src/stream/src/executor/top_n/top_n_appendonly.rs b/src/stream/src/executor/top_n/top_n_appendonly.rs index 6392b0ac491fe..4aff28bb30bd2 100644 --- a/src/stream/src/executor/top_n/top_n_appendonly.rs +++ b/src/stream/src/executor/top_n/top_n_appendonly.rs @@ -140,7 +140,7 @@ where } async fn init(&mut self, epoch: EpochPair) -> StreamExecutorResult<()> { - self.managed_state.state_table.init_epoch(epoch); + self.managed_state.init_epoch(epoch); self.managed_state .init_topn_cache(NO_GROUP_KEY, &mut self.cache) .await diff --git a/src/stream/src/executor/top_n/top_n_cache.rs b/src/stream/src/executor/top_n/top_n_cache.rs index aed23760c332f..01253469ece34 100644 --- a/src/stream/src/executor/top_n/top_n_cache.rs +++ b/src/stream/src/executor/top_n/top_n_cache.rs @@ -28,6 +28,7 @@ use super::{CacheKey, GroupKey, ManagedTopNState}; use crate::executor::error::StreamExecutorResult; const TOPN_CACHE_HIGH_CAPACITY_FACTOR: usize = 2; +const TOPN_CACHE_MIN_CAPACITY: usize = 10; /// Cache for [`ManagedTopNState`]. /// @@ -58,6 +59,11 @@ pub struct TopNCache { /// Assumption: `limit != 0` pub limit: usize, + /// Number of rows corresponding to the current group. + /// This is a nice-to-have information. `None` means we don't know the row count, + /// but it doesn't prevent us from working correctly. + table_row_count: Option, + /// Data types for the full row. /// /// For debug formatting only. @@ -166,9 +172,11 @@ impl TopNCache { high_capacity: offset .checked_add(limit) .and_then(|v| v.checked_mul(TOPN_CACHE_HIGH_CAPACITY_FACTOR)) - .unwrap_or(usize::MAX), + .unwrap_or(usize::MAX) + .max(TOPN_CACHE_MIN_CAPACITY), offset, limit, + table_row_count: None, data_types, } } @@ -181,6 +189,21 @@ impl TopNCache { self.high.clear(); } + /// Get total count of entries in the cache. + pub fn len(&self) -> usize { + self.low.len() + self.middle.len() + self.high.len() + } + + pub(super) fn update_table_row_count(&mut self, table_row_count: usize) { + self.table_row_count = Some(table_row_count) + } + + fn table_row_count_matched(&self) -> bool { + self.table_row_count + .map(|n| n == self.len()) + .unwrap_or(false) + } + pub fn is_low_cache_full(&self) -> bool { assert!(self.low.len() <= self.offset); let full = self.low.len() == self.offset; @@ -219,6 +242,11 @@ impl TopNCache { self.high.len() >= self.high_capacity } + fn last_cache_key_before_high(&self) -> Option<&CacheKey> { + let middle_last_key = self.middle.last_key_value().map(|(k, _)| k); + middle_last_key.or_else(|| self.low.last_key_value().map(|(k, _)| k)) + } + /// Use this method instead of `self.high.insert` directly when possible. /// /// It only inserts into high cache if the key is smaller than the largest key in the high @@ -256,6 +284,10 @@ impl TopNCacheTrait for TopNCache { res_ops: &mut Vec, res_rows: &mut Vec, ) { + if let Some(row_count) = self.table_row_count.as_mut() { + *row_count += 1; + } + if !self.is_low_cache_full() { self.low.insert(cache_key, (&row).into()); return; @@ -318,6 +350,10 @@ impl TopNCacheTrait for TopNCache { res_ops: &mut Vec, res_rows: &mut Vec, ) -> StreamExecutorResult<()> { + if let Some(row_count) = self.table_row_count.as_mut() { + *row_count -= 1; + } + if self.is_middle_cache_full() && cache_key > *self.middle.last_key_value().unwrap().0 { // The row is in high self.high.remove(&cache_key); @@ -325,22 +361,22 @@ impl TopNCacheTrait for TopNCache { && (self.offset == 0 || cache_key > *self.low.last_key_value().unwrap().0) { // The row is in mid + self.middle.remove(&cache_key); + res_ops.push(Op::Delete); + res_rows.push((&row).into()); + // Try to fill the high cache if it is empty - if self.high.is_empty() { + if self.high.is_empty() && !self.table_row_count_matched() { managed_state .fill_high_cache( group_key, self, - Some(self.middle.last_key_value().unwrap().0.clone()), + self.last_cache_key_before_high().cloned(), self.high_capacity, ) .await?; } - self.middle.remove(&cache_key); - res_ops.push(Op::Delete); - res_rows.push((&row).into()); - // Bring one element, if any, from high cache to middle cache if !self.high.is_empty() { let high_first = self.high.pop_first().unwrap(); @@ -348,6 +384,8 @@ impl TopNCacheTrait for TopNCache { res_rows.push(high_first.1.clone()); self.middle.insert(high_first.0, high_first.1); } + + assert!(self.high.is_empty() || self.middle.len() == self.limit); } else { // The row is in low self.low.remove(&cache_key); @@ -360,12 +398,12 @@ impl TopNCacheTrait for TopNCache { self.low.insert(middle_first.0, middle_first.1); // Try to fill the high cache if it is empty - if self.high.is_empty() { + if self.high.is_empty() && !self.table_row_count_matched() { managed_state .fill_high_cache( group_key, self, - Some(self.middle.last_key_value().unwrap().0.clone()), + self.last_cache_key_before_high().cloned(), self.high_capacity, ) .await?; @@ -393,6 +431,10 @@ impl TopNCacheTrait for TopNCache { res_ops: &mut Vec, res_rows: &mut Vec, ) { + if let Some(row_count) = self.table_row_count.as_mut() { + *row_count += 1; + } + assert!( self.low.is_empty(), "Offset is not supported yet for WITH TIES, so low cache should be empty" @@ -482,8 +524,11 @@ impl TopNCacheTrait for TopNCache { res_ops: &mut Vec, res_rows: &mut Vec, ) -> StreamExecutorResult<()> { - // Since low cache is always empty for WITH_TIES, this unwrap is safe. + if let Some(row_count) = self.table_row_count.as_mut() { + *row_count -= 1; + } + // Since low cache is always empty for WITH_TIES, this unwrap is safe. let middle_last = self.middle.last_key_value().unwrap(); let middle_last_order_by = middle_last.0 .0.clone(); @@ -502,14 +547,12 @@ impl TopNCacheTrait for TopNCache { } // Try to fill the high cache if it is empty - if self.high.is_empty() { + if self.high.is_empty() && !self.table_row_count_matched() { managed_state .fill_high_cache( group_key, self, - self.middle - .last_key_value() - .map(|(key, _value)| key.clone()), + self.last_cache_key_before_high().cloned(), self.high_capacity, ) .await?; diff --git a/src/stream/src/executor/top_n/top_n_plain.rs b/src/stream/src/executor/top_n/top_n_plain.rs index fe1a078e2d5a1..01fa6722c02df 100644 --- a/src/stream/src/executor/top_n/top_n_plain.rs +++ b/src/stream/src/executor/top_n/top_n_plain.rs @@ -174,7 +174,7 @@ where } async fn init(&mut self, epoch: EpochPair) -> StreamExecutorResult<()> { - self.managed_state.state_table.init_epoch(epoch); + self.managed_state.init_epoch(epoch); self.managed_state .init_topn_cache(NO_GROUP_KEY, &mut self.cache) .await diff --git a/src/stream/src/executor/top_n/top_n_state.rs b/src/stream/src/executor/top_n/top_n_state.rs index 6885701e39179..ad51ed19b3fb5 100644 --- a/src/stream/src/executor/top_n/top_n_state.rs +++ b/src/stream/src/executor/top_n/top_n_state.rs @@ -13,9 +13,12 @@ // limitations under the License. use std::ops::Bound; +use std::sync::Arc; use futures::{pin_mut, StreamExt}; +use risingwave_common::buffer::Bitmap; use risingwave_common::row::{OwnedRow, Row, RowExt}; +use risingwave_common::types::ScalarImpl; use risingwave_common::util::epoch::EpochPair; use risingwave_storage::store::PrefetchOptions; use risingwave_storage::StateStore; @@ -31,7 +34,7 @@ use crate::executor::error::StreamExecutorResult; /// `group_key` is not included. pub struct ManagedTopNState { /// Relational table. - pub(crate) state_table: StateTable, + state_table: StateTable, /// Used for serializing pk into CacheKey. cache_key_serde: CacheKeySerde, @@ -57,6 +60,26 @@ impl ManagedTopNState { } } + /// Get the immutable reference of managed state table. + pub fn table(&self) -> &StateTable { + &self.state_table + } + + /// Init epoch for the managed state table. + pub fn init_epoch(&mut self, epoch: EpochPair) { + self.state_table.init_epoch(epoch) + } + + /// Update vnode bitmap of state table, returning `cache_may_stale`. + pub fn update_vnode_bitmap(&mut self, new_vnodes: Arc) -> bool { + self.state_table.update_vnode_bitmap(new_vnodes).1 + } + + /// Update watermark for the managed state table. + pub fn update_watermark(&mut self, watermark: ScalarImpl, eager_cleaning: bool) { + self.state_table.update_watermark(watermark, eager_cleaning) + } + pub fn insert(&mut self, value: impl Row) { self.state_table.insert(value); } @@ -121,6 +144,7 @@ impl ManagedTopNState { cache_size_limit: usize, ) -> StreamExecutorResult<()> { let cache = &mut topn_cache.high; + let sub_range: &(Bound, Bound) = &(Bound::Unbounded, Bound::Unbounded); let state_table_iter = self .state_table @@ -133,7 +157,12 @@ impl ManagedTopNState { ) .await?; pin_mut!(state_table_iter); + + let mut group_row_count = 0; + while let Some(item) = state_table_iter.next().await { + group_row_count += 1; + // Note(bugen): should first compare with start key before constructing TopNStateRow. let topn_row = self.get_topn_row(item?.into_owned_row(), group_key.len()); if let Some(start_key) = start_key.as_ref() @@ -141,15 +170,17 @@ impl ManagedTopNState { { continue; } - // let row= &topn_row.row; cache.insert(topn_row.cache_key, (&topn_row.row).into()); if cache.len() == cache_size_limit { break; } } + if WITH_TIES && topn_cache.is_high_cache_full() { let high_last_sort_key = topn_cache.high.last_key_value().unwrap().0 .0.clone(); while let Some(item) = state_table_iter.next().await { + group_row_count += 1; + let topn_row = self.get_topn_row(item?.into_owned_row(), group_key.len()); if topn_row.cache_key.0 == high_last_sort_key { topn_cache @@ -161,6 +192,11 @@ impl ManagedTopNState { } } + if state_table_iter.next().await.is_none() { + // We can only update the row count when we have seen all rows of the group in the table. + topn_cache.update_table_row_count(group_row_count); + } + Ok(()) } @@ -184,8 +220,12 @@ impl ManagedTopNState { ) .await?; pin_mut!(state_table_iter); + + let mut group_row_count = 0; + if topn_cache.offset > 0 { while let Some(item) = state_table_iter.next().await { + group_row_count += 1; let topn_row = self.get_topn_row(item?.into_owned_row(), group_key.len()); topn_cache .low @@ -198,6 +238,7 @@ impl ManagedTopNState { assert!(topn_cache.limit > 0, "topn cache limit should always > 0"); while let Some(item) = state_table_iter.next().await { + group_row_count += 1; let topn_row = self.get_topn_row(item?.into_owned_row(), group_key.len()); topn_cache .middle @@ -209,6 +250,7 @@ impl ManagedTopNState { if WITH_TIES && topn_cache.is_middle_cache_full() { let middle_last_sort_key = topn_cache.middle.last_key_value().unwrap().0 .0.clone(); while let Some(item) = state_table_iter.next().await { + group_row_count += 1; let topn_row = self.get_topn_row(item?.into_owned_row(), group_key.len()); if topn_row.cache_key.0 == middle_last_sort_key { topn_cache @@ -230,6 +272,7 @@ impl ManagedTopNState { while !topn_cache.is_high_cache_full() && let Some(item) = state_table_iter.next().await { + group_row_count += 1; let topn_row = self.get_topn_row(item?.into_owned_row(), group_key.len()); topn_cache .high @@ -238,6 +281,7 @@ impl ManagedTopNState { if WITH_TIES && topn_cache.is_high_cache_full() { let high_last_sort_key = topn_cache.high.last_key_value().unwrap().0 .0.clone(); while let Some(item) = state_table_iter.next().await { + group_row_count += 1; let topn_row = self.get_topn_row(item?.into_owned_row(), group_key.len()); if topn_row.cache_key.0 == high_last_sort_key { topn_cache @@ -249,6 +293,12 @@ impl ManagedTopNState { } } + if state_table_iter.next().await.is_none() { + // After trying to initially fill in the cache, all table entries are in the cache, + // we then get the precise table row count. + topn_cache.update_table_row_count(group_row_count); + } + Ok(()) }