diff --git a/src/common/estimate_size/src/collections/btreemap.rs b/src/common/estimate_size/src/collections/btreemap.rs index f48a78715f692..af9ab3471acec 100644 --- a/src/common/estimate_size/src/collections/btreemap.rs +++ b/src/common/estimate_size/src/collections/btreemap.rs @@ -44,7 +44,7 @@ impl EstimatedBTreeMap { self.inner.is_empty() } - pub fn iter(&self) -> impl Iterator { + pub fn iter(&self) -> impl DoubleEndedIterator { self.inner.iter() } diff --git a/src/expr/impl/src/aggregate/approx_percentile.rs b/src/expr/impl/src/aggregate/approx_percentile.rs index 1c5df73bc5c0f..33e2a9969cdc9 100644 --- a/src/expr/impl/src/aggregate/approx_percentile.rs +++ b/src/expr/impl/src/aggregate/approx_percentile.rs @@ -91,9 +91,15 @@ impl ApproxPercentile { } else if non_neg { let count = state.pos_buckets.entry(bucket_id).or_insert(0); *count -= 1; + if *count == 0 { + state.pos_buckets.remove(&bucket_id); + } } else { let count = state.neg_buckets.entry(bucket_id).or_insert(0); *count -= 1; + if *count == 0 { + state.neg_buckets.remove(&bucket_id); + } } state.count -= 1; } diff --git a/src/stream/src/executor/approx_percentile/global.rs b/src/stream/src/executor/approx_percentile/global.rs index 9434ccf05d5a8..2ccff36c47390 100644 --- a/src/stream/src/executor/approx_percentile/global.rs +++ b/src/stream/src/executor/approx_percentile/global.rs @@ -12,13 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use core::ops::Bound; - -use risingwave_common::array::Op; -use risingwave_common::row::RowExt; -use risingwave_common::types::ToOwnedDatum; -use risingwave_storage::store::PrefetchOptions; - +use super::global_state::GlobalApproxPercentileState; use crate::executor::prelude::*; pub struct GlobalApproxPercentileExecutor { @@ -27,10 +21,7 @@ pub struct GlobalApproxPercentileExecutor { pub quantile: f64, pub base: f64, pub chunk_size: usize, - /// Used for the approx percentile buckets. - pub bucket_state_table: StateTable, - /// Used for the approx percentile count. - pub count_state_table: StateTable, + pub state: GlobalApproxPercentileState, } impl GlobalApproxPercentileExecutor { @@ -43,186 +34,45 @@ impl GlobalApproxPercentileExecutor { bucket_state_table: StateTable, count_state_table: StateTable, ) -> Self { + let global_state = + GlobalApproxPercentileState::new(quantile, base, bucket_state_table, count_state_table); Self { _ctx, input, quantile, base, chunk_size, - bucket_state_table, - count_state_table, + state: global_state, } } /// TODO(kwannoel): Include cache later. #[try_stream(ok = Message, error = StreamExecutorError)] async fn execute_inner(self) { - let mut bucket_state_table = self.bucket_state_table; - let mut count_state_table = self.count_state_table; + // Initialize state let mut input_stream = self.input.execute(); - - // Initialize state tables. let first_barrier = expect_first_barrier(&mut input_stream).await?; - bucket_state_table.init_epoch(first_barrier.epoch); - count_state_table.init_epoch(first_barrier.epoch); + let mut state = self.state; + state.init(first_barrier.epoch).await?; yield Message::Barrier(first_barrier); // Get row count state, and row_count. - let mut row_count_state = count_state_table.get_row(&[Datum::None; 0]).await?; - let mut row_count = if let Some(row) = row_count_state.as_ref() { - row.datum_at(0).unwrap().into_int64() - } else { - 0 - }; - - // Get prev output, based on the current state. - let mut prev_output = Self::get_output( - &bucket_state_table, - row_count as u64, - self.quantile, - self.base, - ) - .await?; - #[for_await] for message in input_stream { match message? { Message::Chunk(chunk) => { - for (_, row) in chunk.rows() { - // Decoding - let sign_datum = row.datum_at(0); - let bucket_id_datum = row.datum_at(1); - let delta_datum = row.datum_at(2); - let delta: i32 = delta_datum.unwrap().into_int32(); - - // Updates - row_count = row_count.checked_add(delta as i64).unwrap(); - - let pk = row.project(&[0, 1]); - let old_row = bucket_state_table.get_row(pk).await?; - let old_bucket_row_count: i64 = if let Some(row) = old_row.as_ref() { - row.datum_at(2).unwrap().into_int64() - } else { - 0 - }; - - let new_value = old_bucket_row_count.checked_add(delta as i64).unwrap(); - let new_value_datum = Datum::from(ScalarImpl::Int64(new_value)); - let new_row = &[ - sign_datum.to_owned_datum(), - bucket_id_datum.map(|d| d.into()), - new_value_datum, - ]; - - if old_row.is_none() { - bucket_state_table.insert(new_row); - } else { - bucket_state_table.update(old_row, new_row); - } - } + state.apply_chunk(chunk)?; } Message::Barrier(barrier) => { - // We maintain an invariant, iff row_count_state is none, - // we haven't pushed any data to downstream. - // Naturally, if row_count_state is some, - // we have pushed data to downstream. - let new_output = Self::get_output( - &bucket_state_table, - row_count as u64, - self.quantile, - self.base, - ) - .await?; - let percentile_chunk = if row_count_state.is_none() { - StreamChunk::from_rows( - &[(Op::Insert, &[new_output.clone()])], - &[DataType::Float64], - ) - } else { - StreamChunk::from_rows( - &[ - (Op::UpdateDelete, &[prev_output.clone()]), - (Op::UpdateInsert, &[new_output.clone()]), - ], - &[DataType::Float64], - ) - }; - prev_output = new_output; - yield Message::Chunk(percentile_chunk); - - let new_row_count_state = &[Datum::from(ScalarImpl::Int64(row_count))]; - if let Some(row_count_state) = row_count_state { - count_state_table.update(row_count_state, new_row_count_state); - } else { - count_state_table.insert(new_row_count_state); - } - row_count_state = Some(new_row_count_state.into_owned_row()); - count_state_table.commit(barrier.epoch).await?; - - bucket_state_table.commit(barrier.epoch).await?; - + let output = state.get_output(); + yield Message::Chunk(output); + state.commit(barrier.epoch).await?; yield Message::Barrier(barrier); } Message::Watermark(_) => {} } } } - - /// We have these scenarios to consider, based on row count state. - /// 1. We have no row count state, this means it's the bootstrap init for this executor. - /// Output NULL as an INSERT. Persist row count state=0. - /// 2. We have row count state. - /// Output UPDATE (`old_state`, `new_state`) to downstream. - async fn get_output( - bucket_state_table: &StateTable, - row_count: u64, - quantile: f64, - base: f64, - ) -> StreamExecutorResult { - let quantile_count = (row_count as f64 * quantile).floor() as u64; - let mut acc_count = 0; - let neg_bounds: (Bound, Bound) = ( - Bound::Unbounded, - Bound::Excluded([Datum::from(ScalarImpl::Int16(0))].to_owned_row()), - ); - let non_neg_bounds: (Bound, Bound) = ( - Bound::Included([Datum::from(ScalarImpl::Int16(0))].to_owned_row()), - Bound::Unbounded, - ); - // Just iterate over the singleton vnode. - // TODO(kwannoel): Should we just use separate state tables for - // positive and negative counts? - // Reverse iterator is not as efficient. - #[for_await] - for keyed_row in bucket_state_table - .rev_iter_with_prefix(&[Datum::None; 0], &neg_bounds, PrefetchOptions::default()) - .await? - .chain( - bucket_state_table - .iter_with_prefix( - &[Datum::None; 0], - &non_neg_bounds, - PrefetchOptions::default(), - ) - .await?, - ) - { - let row = keyed_row?.into_owned_row(); - let count = row.datum_at(2).unwrap().into_int64(); - acc_count += count as u64; - if acc_count > quantile_count { - let sign = row.datum_at(0).unwrap().into_int16(); - if sign == 0 { - return Ok(Datum::from(ScalarImpl::Float64(0.0.into()))); - } - let bucket_id = row.datum_at(1).unwrap().into_int32(); - let percentile_value = sign as f64 * 2.0 * base.powi(bucket_id) / (base + 1.0); - let percentile_datum = Datum::from(ScalarImpl::Float64(percentile_value.into())); - return Ok(percentile_datum); - } - } - Ok(Datum::None) - } } impl Execute for GlobalApproxPercentileExecutor { diff --git a/src/stream/src/executor/approx_percentile/global_state.rs b/src/stream/src/executor/approx_percentile/global_state.rs new file mode 100644 index 0000000000000..790d89699e781 --- /dev/null +++ b/src/stream/src/executor/approx_percentile/global_state.rs @@ -0,0 +1,328 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::{BTreeMap, Bound}; +use std::mem; + +use risingwave_common::array::Op; +use risingwave_common::bail; +use risingwave_common::row::Row; +use risingwave_common::types::{Datum, ToOwnedDatum}; +use risingwave_common::util::epoch::EpochPair; +use risingwave_storage::store::PrefetchOptions; +use risingwave_storage::StateStore; + +use crate::executor::prelude::*; +use crate::executor::StreamExecutorResult; + +/// The global approx percentile state. +pub struct GlobalApproxPercentileState { + quantile: f64, + base: f64, + row_count: i64, + bucket_state_table: StateTable, + count_state_table: StateTable, + cache: BucketTableCache, + last_output: Option, + output_changed: bool, +} + +// Initialization +impl GlobalApproxPercentileState { + pub fn new( + quantile: f64, + base: f64, + bucket_state_table: StateTable, + count_state_table: StateTable, + ) -> Self { + Self { + quantile, + base, + row_count: 0, + bucket_state_table, + count_state_table, + cache: BucketTableCache::new(), + last_output: None, + output_changed: false, + } + } + + pub async fn init(&mut self, init_epoch: EpochPair) -> StreamExecutorResult<()> { + // Init state tables. + self.count_state_table.init_epoch(init_epoch); + self.bucket_state_table.init_epoch(init_epoch); + + // Refill row_count + let row_count_state = self.get_row_count_state().await?; + let row_count = Self::decode_row_count(&row_count_state)?; + self.row_count = row_count; + tracing::debug!(?row_count, "recovered row_count"); + + // Refill cache + self.refill_cache().await?; + + // Update the last output downstream + let last_output = if row_count_state.is_none() { + None + } else { + Some(self.cache.get_output(row_count, self.quantile, self.base)) + }; + tracing::debug!(?last_output, "recovered last_output"); + self.last_output = last_output; + Ok(()) + } + + async fn refill_cache(&mut self) -> StreamExecutorResult<()> { + let bounds: (Bound, Bound) = (Bound::Unbounded, Bound::Unbounded); + #[for_await] + for keyed_row in self + .bucket_state_table + .iter_with_prefix(&[Datum::None; 0], &bounds, PrefetchOptions::default()) + .await? + { + let row = keyed_row?.into_owned_row(); + let sign = row.datum_at(0).unwrap().into_int16(); + let bucket_id = row.datum_at(1).unwrap().into_int32(); + let count = row.datum_at(2).unwrap().into_int64(); + match sign { + -1 => { + self.cache.neg_buckets.insert(bucket_id, count as i64); + } + 0 => { + self.cache.zeros = count as i64; + } + 1 => { + self.cache.pos_buckets.insert(bucket_id, count as i64); + } + _ => { + bail!("Invalid sign: {}", sign); + } + } + } + Ok(()) + } + + async fn get_row_count_state(&self) -> StreamExecutorResult> { + self.count_state_table.get_row(&[Datum::None; 0]).await + } + + fn decode_row_count(row_count_state: &Option) -> StreamExecutorResult { + if let Some(row) = row_count_state.as_ref() { + let Some(datum) = row.datum_at(0) else { + bail!("Invalid row count state: {:?}", row) + }; + Ok(datum.into_int64()) + } else { + Ok(0) + } + } +} + +// Update +impl GlobalApproxPercentileState { + pub fn apply_chunk(&mut self, chunk: StreamChunk) -> StreamExecutorResult<()> { + // Op is ignored here, because we only check the `delta` column inside the row. + // The sign of the `delta` column will tell us if we need to decrease or increase the + // count of the bucket. + for (_op, row) in chunk.rows() { + debug_assert_eq!(_op, Op::Insert); + self.apply_row(row)?; + } + Ok(()) + } + + pub fn apply_row(&mut self, row: impl Row) -> StreamExecutorResult<()> { + // Decoding + let sign_datum = row.datum_at(0); + let sign = sign_datum.unwrap().into_int16(); + let sign_datum = sign_datum.to_owned_datum(); + let bucket_id_datum = row.datum_at(1); + let bucket_id = bucket_id_datum.unwrap().into_int32(); + let bucket_id_datum = bucket_id_datum.to_owned_datum(); + let delta_datum = row.datum_at(2); + let delta: i32 = delta_datum.unwrap().into_int32(); + + if delta == 0 { + return Ok(()); + } + + self.output_changed = true; + + // Updates + self.row_count = self.row_count.checked_add(delta as i64).unwrap(); + tracing::debug!("updated row_count: {}", self.row_count); + + let (is_new_entry, old_count, new_count) = match sign { + -1 => { + let count_entry = self.cache.neg_buckets.get(&bucket_id).copied(); + let old_count = count_entry.unwrap_or(0); + let new_count = old_count.checked_add(delta as i64).unwrap(); + let is_new_entry = count_entry.is_none(); + if new_count != 0 { + self.cache.neg_buckets.insert(bucket_id, new_count); + } else { + self.cache.neg_buckets.remove(&bucket_id); + } + (is_new_entry, old_count, new_count) + } + 0 => { + let old_count = self.cache.zeros; + let new_count = old_count.checked_add(delta as i64).unwrap(); + let is_new_entry = old_count == 0; + if new_count != 0 { + self.cache.zeros = new_count; + } + (is_new_entry, old_count, new_count) + } + 1 => { + let count_entry = self.cache.pos_buckets.get(&bucket_id).copied(); + let old_count = count_entry.unwrap_or(0); + let new_count = old_count.checked_add(delta as i64).unwrap(); + let is_new_entry = count_entry.is_none(); + if new_count != 0 { + self.cache.pos_buckets.insert(bucket_id, new_count); + } else { + self.cache.pos_buckets.remove(&bucket_id); + } + (is_new_entry, old_count, new_count) + } + _ => bail!("Invalid sign: {}", sign), + }; + + let old_row = &[ + sign_datum.clone(), + bucket_id_datum.clone(), + Datum::from(ScalarImpl::Int64(old_count)), + ]; + if new_count == 0 && !is_new_entry { + self.bucket_state_table.delete(old_row); + } else if new_count > 0 { + let new_row = &[ + sign_datum, + bucket_id_datum, + Datum::from(ScalarImpl::Int64(new_count)), + ]; + if is_new_entry { + self.bucket_state_table.insert(new_row); + } else { + self.bucket_state_table.update(old_row, new_row); + } + } else { + bail!("invalid state, new_count = 0 and is_new_entry is true") + } + + Ok(()) + } + + pub async fn commit(&mut self, epoch: EpochPair) -> StreamExecutorResult<()> { + // Commit row count state. + let row_count_datum = Datum::from(ScalarImpl::Int64(self.row_count)); + let row_count_row = &[row_count_datum]; + let last_row_count_state = self.count_state_table.get_row(&[Datum::None; 0]).await?; + match last_row_count_state { + None => self.count_state_table.insert(row_count_row), + Some(last_row_count_state) => self + .count_state_table + .update(last_row_count_state, row_count_row), + } + self.count_state_table.commit(epoch).await?; + self.bucket_state_table.commit(epoch).await?; + Ok(()) + } +} + +// Read +impl GlobalApproxPercentileState { + pub fn get_output(&mut self) -> StreamChunk { + let last_output = mem::take(&mut self.last_output); + let new_output = if !self.output_changed { + tracing::debug!("last_output: {:#?}", last_output); + last_output.clone().flatten() + } else { + self.cache + .get_output(self.row_count, self.quantile, self.base) + }; + self.last_output = Some(new_output.clone()); + let output_chunk = match last_output { + None => StreamChunk::from_rows(&[(Op::Insert, &[new_output])], &[DataType::Float64]), + Some(last_output) if !self.output_changed => StreamChunk::from_rows( + &[ + (Op::UpdateDelete, &[last_output.clone()]), + (Op::UpdateInsert, &[last_output]), + ], + &[DataType::Float64], + ), + Some(last_output) => StreamChunk::from_rows( + &[ + (Op::UpdateDelete, &[last_output.clone()]), + (Op::UpdateInsert, &[new_output.clone()]), + ], + &[DataType::Float64], + ), + }; + tracing::debug!("get_output: {:#?}", output_chunk,); + self.output_changed = false; + output_chunk + } +} + +type Count = i64; +type BucketId = i32; + +type BucketMap = BTreeMap; + +/// Keeps the entire bucket state table contents in-memory. +struct BucketTableCache { + neg_buckets: BucketMap, + zeros: Count, // If Count is 0, it means this bucket has not be inserted into before. + pos_buckets: BucketMap, +} + +impl BucketTableCache { + pub fn new() -> Self { + Self { + neg_buckets: BucketMap::new(), + zeros: 0, + pos_buckets: BucketMap::new(), + } + } + + pub fn get_output(&self, row_count: i64, quantile: f64, base: f64) -> Datum { + let quantile_count = (row_count as f64 * quantile).floor() as i64; + let mut acc_count = 0; + for (bucket_id, count) in self.neg_buckets.iter().rev() { + acc_count += count; + if acc_count > quantile_count { + // approx value = -2 * y^i / (y + 1) + let approx_percentile = -2.0 * base.powi(*bucket_id) / (base + 1.0); + let approx_percentile = ScalarImpl::Float64(approx_percentile.into()); + return Datum::from(approx_percentile); + } + } + acc_count += self.zeros; + if acc_count > quantile_count { + return Datum::from(ScalarImpl::Float64(0.0.into())); + } + for (bucket_id, count) in &self.pos_buckets { + acc_count += count; + if acc_count > quantile_count { + // approx value = 2 * y^i / (y + 1) + let approx_percentile = 2.0 * base.powi(*bucket_id) / (base + 1.0); + let approx_percentile = ScalarImpl::Float64(approx_percentile.into()); + return Datum::from(approx_percentile); + } + } + Datum::None + } +} diff --git a/src/stream/src/executor/approx_percentile/mod.rs b/src/stream/src/executor/approx_percentile/mod.rs index 29910d9032e15..8d2c5bdcf4544 100644 --- a/src/stream/src/executor/approx_percentile/mod.rs +++ b/src/stream/src/executor/approx_percentile/mod.rs @@ -13,4 +13,5 @@ // limitations under the License. pub mod global; +mod global_state; pub mod local; diff --git a/stest-73cc5911-e5e7-4de0-aff9-30c82b2d40c6.sqlite b/stest-73cc5911-e5e7-4de0-aff9-30c82b2d40c6.sqlite new file mode 100644 index 0000000000000..3b4d6d0e9c94b Binary files /dev/null and b/stest-73cc5911-e5e7-4de0-aff9-30c82b2d40c6.sqlite differ