diff --git a/src/expr/core/src/window_function/call.rs b/src/expr/core/src/window_function/call.rs index 43545cc2a107a..5f0fc5e7328bd 100644 --- a/src/expr/core/src/window_function/call.rs +++ b/src/expr/core/src/window_function/call.rs @@ -17,7 +17,10 @@ use std::fmt::Display; use enum_as_inner::EnumAsInner; use parse_display::Display; use risingwave_common::bail; -use risingwave_common::types::DataType; +use risingwave_common::types::{ + DataType, Datum, ScalarImpl, ScalarRefImpl, Sentinelled, ToDatumRef, ToOwnedDatum, ToText, +}; +use risingwave_common::util::sort_util::{Direction, OrderType}; use risingwave_pb::expr::window_frame::{PbBound, PbExclusion}; use risingwave_pb::expr::{PbWindowFrame, PbWindowFunction}; use FrameBound::{CurrentRow, Following, Preceding, UnboundedFollowing, UnboundedPreceding}; @@ -107,34 +110,40 @@ impl Frame { end: Some(end.to_protobuf()), exclusion, }, + FrameBounds::Range(RangeFrameBounds { .. }) => { + todo!() // TODO() + } } } } -#[derive(Display, Debug, Clone, Eq, PartialEq, Hash)] +#[derive(Display, Debug, Clone, Eq, PartialEq, Hash, EnumAsInner)] #[display("{0}")] pub enum FrameBounds { Rows(RowsFrameBounds), // Groups(GroupsFrameBounds), - // Range(RangeFrameBounds), + Range(RangeFrameBounds), } impl FrameBounds { pub fn validate(&self) -> Result<()> { match self { Self::Rows(bounds) => bounds.validate(), + Self::Range(bounds) => bounds.validate(), } } pub fn start_is_unbounded(&self) -> bool { match self { Self::Rows(RowsFrameBounds { start, .. }) => start.is_unbounded_preceding(), + Self::Range(RangeFrameBounds { start, .. }) => start.is_unbounded_preceding(), } } pub fn end_is_unbounded(&self) -> bool { match self { Self::Rows(RowsFrameBounds { end, .. }) => end.is_unbounded_following(), + Self::Range(RangeFrameBounds { end, .. }) => end.is_unbounded_following(), } } @@ -152,11 +161,150 @@ pub struct RowsFrameBounds { impl RowsFrameBounds { fn validate(&self) -> Result<()> { - FrameBound::validate_bounds(&self.start, &self.end) + FrameBound::validate_bounds(&self.start, &self.end, |_| Ok(())) } } -#[derive(Display, Debug, Clone, Eq, PartialEq, Hash, EnumAsInner)] +#[derive(Debug, Clone, Eq, PartialEq, Hash)] +pub struct RangeFrameBounds { + pub start: FrameBound, + pub end: FrameBound, +} + +impl Display for RangeFrameBounds { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "RANGE BETWEEN {} AND {}", + self.start.for_display(), + self.end.for_display() + )?; + Ok(()) + } +} + +impl RangeFrameBounds { + fn validate(&self) -> Result<()> { + FrameBound::validate_bounds(&self.start, &self.end, |offset| { + match offset.as_scalar_ref_impl() { + // TODO(): use decl macro to merge with the following + ScalarRefImpl::Int16(val) if val < 0 => { + bail!("frame bound offset should be non-negative, but {} is given", val); + } + ScalarRefImpl::Int32(val) if val < 0 => { + bail!("frame bound offset should be non-negative, but {} is given", val); + } + ScalarRefImpl::Int64(val) if val < 0 => { + bail!("frame bound offset should be non-negative, but {} is given", val); + } + // TODO(): datetime types + _ => unreachable!("other order column data types are not supported and should be banned in frontend"), + } + }) + } + + /// Get the frame start for a given order column value. + /// + /// ## Examples + /// + /// For the following frames: + /// + /// ```sql + /// ORDER BY x ASC RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW + /// ORDER BY x DESC RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW + /// ``` + /// + /// For any CURRENT ROW with any order value, the frame start is always the first-most row, which is + /// represented by [`Sentinelled::Smallest`]. + /// + /// For the following frame: + /// + /// ```sql + /// ORDER BY x ASC RANGE BETWEEN 10 PRECEDING AND CURRENT ROW + /// ``` + /// + /// For CURRENT ROW with order value `100`, the frame start is the **FIRST** row with order value `90`. + /// + /// For the following frame: + /// + /// ```sql + /// ORDER BY x DESC RANGE BETWEEN 10 PRECEDING AND CURRENT ROW + /// ``` + /// + /// For CURRENT ROW with order value `100`, the frame start is the **FIRST** row with order value `110`. + pub fn frame_start_of( + &self, + order_value: impl ToDatumRef, + order_type: OrderType, + ) -> Sentinelled { + self.start.as_ref().bound_of(order_value, order_type) + } + + /// Get the frame end for a given order column value. It's very similar to `frame_start_of`, just with + /// everything on the other direction. + pub fn frame_end_of( + &self, + order_value: impl ToDatumRef, + order_type: OrderType, + ) -> Sentinelled { + self.end.as_ref().bound_of(order_value, order_type) + } + + /// Get the order value of the CURRENT ROW of the first-most frame that includes the given order value. + /// + /// ## Examples + /// + /// For the following frames: + /// + /// ```sql + /// ORDER BY x ASC RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING + /// ORDER BY x DESC RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING + /// ``` + /// + /// For any given order value, the first CURRENT ROW is always the first-most row, which is + /// represented by [`Sentinelled::Smallest`]. + /// + /// For the following frame: + /// + /// ```sql + /// ORDER BY x ASC RANGE BETWEEN CURRENT ROW AND 10 FOLLOWING + /// ``` + /// + /// For a given order value `100`, the first CURRENT ROW should have order value `90`. + /// + /// For the following frame: + /// + /// ```sql + /// ORDER BY x DESC RANGE BETWEEN CURRENT ROW AND 10 FOLLOWING + /// ``` + /// + /// For a given order value `100`, the first CURRENT ROW should have order value `110`. + pub fn first_curr_of( + &self, + order_value: impl ToDatumRef, + order_type: OrderType, + ) -> Sentinelled { + self.end + .as_ref() + .reverse() + .bound_of(order_value, order_type) + } + + /// Get the order value of the CURRENT ROW of the last-most frame that includes the given order value. + /// It's very similar to `first_curr_of`, just with everything on the other direction. + pub fn last_curr_of( + &self, + order_value: impl ToDatumRef, + order_type: OrderType, + ) -> Sentinelled { + self.start + .as_ref() + .reverse() + .bound_of(order_value, order_type) + } +} + +#[derive(Display, Debug, Clone, Copy, Eq, PartialEq, Hash, EnumAsInner)] #[display(style = "TITLE CASE")] pub enum FrameBound { UnboundedPreceding, @@ -169,10 +317,23 @@ pub enum FrameBound { } impl FrameBound { - fn validate_bounds(start: &Self, end: &Self) -> Result<()> { + fn offset_value(&self) -> Option<&T> { + match self { + UnboundedPreceding | UnboundedFollowing | CurrentRow => None, + Preceding(offset) | Following(offset) => Some(offset), + } + } + + fn validate_bounds( + start: &Self, + end: &Self, + offset_checker: impl Fn(&T) -> Result<()>, + ) -> Result<()> { match (start, end) { (_, UnboundedPreceding) => bail!("frame end cannot be UNBOUNDED PRECEDING"), - (UnboundedFollowing, _) => bail!("frame start cannot be UNBOUNDED FOLLOWING"), + (UnboundedFollowing, _) => { + bail!("frame start cannot be UNBOUNDED FOLLOWING") + } (Following(_), CurrentRow) | (Following(_), Preceding(_)) => { bail!("frame starting from following row cannot have preceding rows") } @@ -181,10 +342,32 @@ impl FrameBound { } _ => {} } + + for bound in [start, end] { + if let Some(offset) = bound.offset_value() { + offset_checker(offset)?; + } + } + Ok(()) } } +impl FrameBound +where + FrameBound: Copy, +{ + fn reverse(self) -> FrameBound { + match self { + UnboundedPreceding => UnboundedFollowing, + Preceding(offset) => Following(offset), + CurrentRow => CurrentRow, + Following(offset) => Preceding(offset), + UnboundedFollowing => UnboundedPreceding, + } + } +} + impl FrameBound { pub fn from_protobuf(bound: &PbBound) -> Result { use risingwave_pb::expr::window_frame::bound::PbOffset; @@ -245,6 +428,85 @@ impl FrameBound { } } +impl FrameBound { + fn as_ref(&self) -> FrameBound> { + match self { + UnboundedPreceding => UnboundedPreceding, + Preceding(offset) => Preceding(offset.as_scalar_ref_impl()), + CurrentRow => CurrentRow, + Following(offset) => Following(offset.as_scalar_ref_impl()), + UnboundedFollowing => UnboundedFollowing, + } + } + + fn for_display(&self) -> FrameBound { + match self { + UnboundedPreceding => UnboundedPreceding, + Preceding(offset) => Preceding(offset.as_scalar_ref_impl().to_text()), + CurrentRow => CurrentRow, + Following(offset) => Following(offset.as_scalar_ref_impl().to_text()), + UnboundedFollowing => UnboundedFollowing, + } + } +} + +impl FrameBound> { + fn bound_of(self, order_value: impl ToDatumRef, order_type: OrderType) -> Sentinelled { + let order_value = order_value.to_datum_ref(); + match (self, order_type.direction()) { + (UnboundedPreceding, _) => Sentinelled::Smallest, + (UnboundedFollowing, _) => Sentinelled::Largest, + (CurrentRow, _) => Sentinelled::Normal(order_value.to_owned_datum()), + (Preceding(offset), Direction::Ascending) + | (Following(offset), Direction::Descending) => { + // should SUBTRACT the offset + if let Some(value) = order_value { + let res = match (value, offset) { + // TODO(): use decl macro to merge with the following + (ScalarRefImpl::Int16(val), ScalarRefImpl::Int16(off)) => { + ScalarImpl::Int16(val - off) + } + (ScalarRefImpl::Int32(val), ScalarRefImpl::Int32(off)) => { + ScalarImpl::Int32(val - off) + } + (ScalarRefImpl::Int64(val), ScalarRefImpl::Int64(off)) => { + ScalarImpl::Int64(val - off) + } + // TODO(): datetime types + _ => unreachable!("other order column data types are not supported and should be banned in frontend"), + }; + Sentinelled::Normal(Some(res)) + } else { + Sentinelled::Normal(None) + } + } + (Following(offset), Direction::Ascending) + | (Preceding(offset), Direction::Descending) => { + // should ADD the offset + if let Some(value) = order_value { + let res = match (value, offset) { + // TODO(): use decl macro to merge with the following + (ScalarRefImpl::Int16(val), ScalarRefImpl::Int16(off)) => { + ScalarImpl::Int16(val + off) + } + (ScalarRefImpl::Int32(val), ScalarRefImpl::Int32(off)) => { + ScalarImpl::Int32(val + off) + } + (ScalarRefImpl::Int64(val), ScalarRefImpl::Int64(off)) => { + ScalarImpl::Int64(val + off) + } + // TODO(): datetime types + _ => unreachable!("other order column data types are not supported and should be banned in frontend"), + }; + Sentinelled::Normal(Some(res)) + } else { + Sentinelled::Normal(None) + } + } + } + } +} + #[derive(Display, Debug, Copy, Clone, Eq, PartialEq, Hash, Default, EnumAsInner)] #[display("EXCLUDE {}", style = "TITLE CASE")] pub enum FrameExclusion { diff --git a/src/expr/core/src/window_function/state/aggregate_range.rs b/src/expr/core/src/window_function/state/aggregate_range.rs new file mode 100644 index 0000000000000..6f1ed262b09ee --- /dev/null +++ b/src/expr/core/src/window_function/state/aggregate_range.rs @@ -0,0 +1,228 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::BTreeSet; + +use futures_util::FutureExt; +use risingwave_common::array::{DataChunk, Op, StreamChunk}; +use risingwave_common::estimate_size::{EstimateSize, KvSize}; +use risingwave_common::types::{DataType, Datum}; +use risingwave_common::util::iter_util::ZipEqFast; +use risingwave_common::{bail, must_match}; +use smallvec::SmallVec; + +use super::buffer_range::RangeWindowBuffer; +use super::{StateEvictHint, StateKey, StatePos, WindowState}; +use crate::aggregate::{ + AggArgs, AggCall, AggregateFunction, AggregateState as AggImplState, BoxedAggregateFunction, +}; +use crate::sig::FUNCTION_REGISTRY; +use crate::window_function::{WindowFuncCall, WindowFuncKind}; +use crate::Result; + +pub struct RangeAggregateState { + agg_func: BoxedAggregateFunction, + agg_impl: AggImpl, + arg_data_types: Vec, + buffer: RangeWindowBuffer>, + buffer_heap_size: KvSize, +} + +impl RangeAggregateState { + pub fn new(call: &WindowFuncCall) -> Result { + if call.frame.bounds.validate().is_err() { + bail!("the window frame must be valid"); + } + let agg_kind = must_match!(call.kind, WindowFuncKind::Aggregate(agg_kind) => agg_kind); + let arg_data_types = call.args.arg_types().to_vec(); + let agg_call = AggCall { + kind: agg_kind, + args: match &call.args { + // convert args to [0] or [0, 1] + AggArgs::None => AggArgs::None, + AggArgs::Unary(data_type, _) => AggArgs::Unary(data_type.to_owned(), 0), + AggArgs::Binary(data_types, _) => AggArgs::Binary(data_types.to_owned(), [0, 1]), + }, + return_type: call.return_type.clone(), + column_orders: Vec::new(), // the input is already sorted + // TODO(rc): support filter on window function call + filter: None, + // TODO(rc): support distinct on window function call? PG doesn't support it either. + distinct: false, + direct_args: vec![], + }; + let agg_func_sig = FUNCTION_REGISTRY + .get(agg_kind, &arg_data_types, &call.return_type) + .expect("the agg func must exist"); + let agg_func = agg_func_sig.build_aggregate(&agg_call)?; + let (agg_impl, enable_delta) = + if agg_func_sig.is_retractable() && call.frame.exclusion.is_no_others() { + let init_state = agg_func.create_state(); + (AggImpl::Incremental(init_state), true) + } else { + (AggImpl::Full, false) + }; + Ok(Self { + agg_func, + agg_impl, + arg_data_types, + buffer: RangeWindowBuffer::new(call.frame.clone(), enable_delta), + buffer_heap_size: KvSize::new(), + }) + } + + fn slide_inner(&mut self) -> StateEvictHint { + let removed_keys: BTreeSet<_> = self + .buffer + .slide() + .map(|(k, v)| { + v.iter().for_each(|arg| { + self.buffer_heap_size.sub_val(arg); + }); + self.buffer_heap_size.sub_val(&k); + k + }) + .collect(); + if removed_keys.is_empty() { + StateEvictHint::CannotEvict( + self.buffer + .smallest_key() + .expect("sliding without removing, must have some entry in the buffer") + .clone(), + ) + } else { + StateEvictHint::CanEvict(removed_keys) + } + } +} + +impl WindowState for RangeAggregateState { + fn append(&mut self, key: StateKey, args: SmallVec<[Datum; 2]>) { + args.iter().for_each(|arg| { + self.buffer_heap_size.add_val(arg); + }); + self.buffer_heap_size.add_val(&key); + self.buffer.append(key, args); + } + + fn curr_window(&self) -> StatePos<'_> { + let window = self.buffer.curr_window(); + StatePos { + key: window.key, + is_ready: window.following_saturated, + } + } + + fn slide(&mut self) -> Result<(Datum, StateEvictHint)> { + let wrapper = AggregatorWrapper { + agg_func: self.agg_func.as_ref(), + arg_data_types: &self.arg_data_types, + }; + let output = match self.agg_impl { + AggImpl::Full => wrapper.aggregate(self.buffer.curr_window_values()), + AggImpl::Incremental(ref mut state) => { + wrapper.update(state, self.buffer.consume_curr_window_values_delta()) + } + }?; + let evict_hint = self.slide_inner(); + Ok((output, evict_hint)) + } + + fn slide_no_output(&mut self) -> Result { + match self.agg_impl { + AggImpl::Full => {} + AggImpl::Incremental(ref mut state) => { + // for incremental agg, we need to update the state even if the caller doesn't need + // the output + let wrapper = AggregatorWrapper { + agg_func: self.agg_func.as_ref(), + arg_data_types: &self.arg_data_types, + }; + wrapper.update(state, self.buffer.consume_curr_window_values_delta())?; + } + }; + Ok(self.slide_inner()) + } +} + +impl EstimateSize for RangeAggregateState { + fn estimated_heap_size(&self) -> usize { + // estimate `VecDeque` of `StreamWindowBuffer` internal size + // https://github.com/risingwavelabs/risingwave/issues/9713 + self.arg_data_types.estimated_heap_size() + self.buffer_heap_size.size() + } +} + +// TODO(): the following is reusable + +enum AggImpl { + Incremental(AggImplState), + Full, +} + +struct AggregatorWrapper<'a> { + agg_func: &'a dyn AggregateFunction, + arg_data_types: &'a [DataType], +} + +impl AggregatorWrapper<'_> { + fn aggregate(&self, values: impl IntoIterator) -> Result + where + V: AsRef<[Datum]>, + { + let mut state = self.agg_func.create_state(); + self.update( + &mut state, + values.into_iter().map(|args| (Op::Insert, args)), + ) + } + + fn update( + &self, + state: &mut AggImplState, + delta: impl IntoIterator, + ) -> Result + where + V: AsRef<[Datum]>, + { + let mut args_builders = self + .arg_data_types + .iter() + .map(|data_type| data_type.create_array_builder(0 /* bad! */)) + .collect::>(); + let mut ops = Vec::new(); + let mut n_rows = 0; + for (op, value) in delta { + n_rows += 1; + ops.push(op); + for (builder, datum) in args_builders.iter_mut().zip_eq_fast(value.as_ref()) { + builder.append(datum); + } + } + let columns = args_builders + .into_iter() + .map(|builder| builder.finish().into()) + .collect::>(); + let chunk = StreamChunk::from_parts(ops, DataChunk::new(columns, n_rows)); + + self.agg_func + .update(state, &chunk) + .now_or_never() + .expect("we don't support UDAF currently, so the function should return immediately")?; + self.agg_func + .get_result(state) + .now_or_never() + .expect("we don't support UDAF currently, so the function should return immediately") + } +} diff --git a/src/expr/core/src/window_function/state/buffer.rs b/src/expr/core/src/window_function/state/buffer.rs index 3edb6d7adc164..30dfd161d4579 100644 --- a/src/expr/core/src/window_function/state/buffer.rs +++ b/src/expr/core/src/window_function/state/buffer.rs @@ -85,6 +85,9 @@ impl WindowBuffer { false // unbounded frame start, never preceding-saturated } } + FrameBounds::Range(..) => { + todo!() // TODO(): make this only handle ROWS + } } } @@ -110,6 +113,9 @@ impl WindowBuffer { false // unbounded frame end, never following-saturated } } + FrameBounds::Range(..) => { + todo!() // TODO(): make this only handle ROWS + } } } @@ -204,6 +210,9 @@ impl WindowBuffer { self.right_excl_idx = self.buffer.len(); } } + FrameBounds::Range(..) => { + todo!() // TODO(): make this only handle ROWS + } } } diff --git a/src/expr/core/src/window_function/state/buffer_range.rs b/src/expr/core/src/window_function/state/buffer_range.rs new file mode 100644 index 0000000000000..1c56340b82de5 --- /dev/null +++ b/src/expr/core/src/window_function/state/buffer_range.rs @@ -0,0 +1,296 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::VecDeque; +use std::ops::Range; + +use risingwave_common::array::Op; +use risingwave_common::must_match; +use risingwave_common::types::{DataType, Sentinelled}; +use risingwave_common::util::memcmp_encoding; +use risingwave_common::util::sort_util::OrderType; + +use super::range_utils::range_except; +use super::StateKey; +use crate::window_function::state::range_utils::range_diff; +use crate::window_function::{Frame, FrameBounds, FrameExclusion, RangeFrameBounds}; + +// TODO(): seems reusable +struct Entry { + key: K, + value: V, +} + +// TODO() +fn test_order_col() -> (DataType, OrderType) { + (DataType::Int32, OrderType::ascending()) +} + +/// A sliding window buffer implementation for `RANGE` frames. +pub struct RangeWindowBuffer { + frame_bounds: RangeFrameBounds, + frame_exclusion: FrameExclusion, + buffer: VecDeque>, // TODO(): may store other key than StateKey + curr_idx: usize, + left_idx: usize, // inclusive, note this can be > `curr_idx` + right_excl_idx: usize, // exclusive, note this can be <= `curr_idx` + curr_delta: Option>, +} + +// TODO(): seems reusable +/// Note: A window frame can be pure preceding, pure following, or acrossing the _current row_. +pub struct CurrWindow<'a, K> { + pub key: Option<&'a K>, + pub preceding_saturated: bool, + pub following_saturated: bool, +} + +impl RangeWindowBuffer { + pub fn new(frame: Frame, enable_delta: bool) -> Self { + assert!(frame.bounds.validate().is_ok()); + + let frame_bounds = must_match!(frame.bounds, FrameBounds::Range(bounds) => bounds); + let frame_exclusion = frame.exclusion; + + if enable_delta { + // TODO(rc): currently only support `FrameExclusion::NoOthers` for delta + assert!(frame_exclusion.is_no_others()); + } + + Self { + frame_bounds, + frame_exclusion, + buffer: Default::default(), + curr_idx: 0, + left_idx: 0, + right_excl_idx: 0, + curr_delta: if enable_delta { + Some(Default::default()) + } else { + None + }, + } + } + + // TODO(): seems reusable + /// Get the key part of the current row. + pub fn curr_key(&self) -> Option<&StateKey> { + self.buffer.get(self.curr_idx).map(|Entry { key, .. }| key) + } + + // TODO(): seems reusable + /// Get the current window info. + pub fn curr_window(&self) -> CurrWindow<'_, StateKey> { + CurrWindow { + key: self.curr_key(), + preceding_saturated: self.preceding_saturated(), + following_saturated: self.following_saturated(), + } + } + + // TODO(): seems reusable + fn curr_window_outer(&self) -> Range { + self.left_idx..self.right_excl_idx + } + + // TODO(): seems reusable + fn curr_window_exclusion(&self) -> Range { + // TODO(rc): should intersect with `curr_window_outer` to be more accurate + match self.frame_exclusion { + FrameExclusion::CurrentRow => self.curr_idx..self.curr_idx + 1, + FrameExclusion::NoOthers => self.curr_idx..self.curr_idx, + } + } + + // TODO(): seems reusable + fn curr_window_ranges(&self) -> (Range, Range) { + let selection = self.curr_window_outer(); + let exclusion = self.curr_window_exclusion(); + range_except(selection, exclusion) + } + + // TODO(): seems reusable + /// Iterate over values in the current window. + pub fn curr_window_values(&self) -> impl Iterator { + assert!(self.left_idx <= self.right_excl_idx); + assert!(self.right_excl_idx <= self.buffer.len()); + + let (left, right) = self.curr_window_ranges(); + self.buffer + .range(left) + .chain(self.buffer.range(right)) + .map(|Entry { value, .. }| value) + } + + // TODO(): seems reusable + /// Consume the delta of values comparing the current window to the previous window. + /// The delta is not guaranteed to be sorted, especially when frame exclusion is not `NoOthers`. + pub fn consume_curr_window_values_delta(&mut self) -> impl Iterator + '_ { + self.curr_delta + .as_mut() + .expect("delta mode should be enabled") + .drain(..) + } + + // TODO(): seems reusable + fn maintain_delta(&mut self, old_outer: Range, new_outer: Range) { + debug_assert!(self.frame_exclusion.is_no_others()); + + let (outer_removed, outer_added) = range_diff(old_outer.clone(), new_outer.clone()); + let delta = self.curr_delta.as_mut().unwrap(); + for idx in outer_removed.iter().cloned().flatten() { + delta.push((Op::Delete, self.buffer[idx].value.clone())); + } + for idx in outer_added.iter().cloned().flatten() { + delta.push((Op::Insert, self.buffer[idx].value.clone())); + } + } + + // TODO(): seems reusable + /// Append a key-value pair to the buffer. + pub fn append(&mut self, key: StateKey, value: V) { + let old_outer = self.curr_window_outer(); + + self.buffer.push_back(Entry { key, value }); + self.recalculate_left_right(); + + if self.curr_delta.is_some() { + self.maintain_delta(old_outer, self.curr_window_outer()); + } + } + + // TODO(): seems reusable + /// Get the smallest key that is still kept in the buffer. + /// Returns `None` if there's nothing yet. + pub fn smallest_key(&self) -> Option<&StateKey> { + self.buffer.front().map(|Entry { key, .. }| key) + } + + // TODO(): seems reusable + /// Slide the current window forward. + /// Returns the keys that are removed from the buffer. + pub fn slide(&mut self) -> impl Iterator + '_ { + let old_outer = self.curr_window_outer(); + + self.curr_idx += 1; + self.recalculate_left_right(); + + if self.curr_delta.is_some() { + self.maintain_delta(old_outer, self.curr_window_outer()); + } + + let min_needed_idx = std::cmp::min(self.left_idx, self.curr_idx); + self.curr_idx -= min_needed_idx; + self.left_idx -= min_needed_idx; + self.right_excl_idx -= min_needed_idx; + self.buffer + .drain(0..min_needed_idx) + .map(|Entry { key, value }| (key, value)) + } + + fn preceding_saturated(&self) -> bool { + self.curr_key().is_some() && { + // TODO(rc): It seems that preceding saturation is not important, may remove later. + true + } + } + + fn following_saturated(&self) -> bool { + self.curr_key().is_some() + && { + // Left OK? (note that `left_idx` can be greater than `right_idx`) + // The following line checks whether the left value is the last one in the buffer. + // Here we adopt a conservative approach, which means we assume the next future value + // is likely to be the same as the last value in the current window, in which case + // we can't say the current window is saturated. + self.left_idx < self.buffer.len() /* non-zero */ - 1 + } + && { + // Right OK? Ditto. + self.right_excl_idx < self.buffer.len() + } + } + + fn recalculate_left_right(&mut self) { + if self.buffer.is_empty() { + self.left_idx = 0; + self.right_excl_idx = 0; + } + + let Some(curr_key) = self.curr_key() else { + // If the current index has been moved to a future position, we can't touch anything + // because the next coming key may equal to the previous one which means the left and + // right indices will be the same. + return; + }; + + let (data_type, order_type) = test_order_col(); // TODO() + + let curr_order_value = + memcmp_encoding::decode_value(&data_type, &curr_key.order_key, order_type) + .expect("no reason to fail here because we just encoded it in memory"); + println!("[rc] curr_order_value = {:?}", curr_order_value); + + match self + .frame_bounds + .frame_start_of(&curr_order_value, order_type) + { + Sentinelled::Smallest => { + // unbounded frame start + assert_eq!( + self.left_idx, 0, + "for unbounded start, left index should always be 0" + ); + } + Sentinelled::Normal(value) => { + // bounded, find the start position + // TODO(): move memcmp encoding to `frame_start_of` + let value_enc = memcmp_encoding::encode_value(value, order_type) + .expect("no reason to fail here"); + self.left_idx = self + .buffer + .partition_point(|elem| elem.key.order_key < value_enc); + } + Sentinelled::Largest => unreachable!("frame start never be UNBOUNDED FOLLOWING"), + } + + match self + .frame_bounds + .frame_end_of(&curr_order_value, order_type) + { + Sentinelled::Largest => { + // unbounded frame end + self.right_excl_idx = self.buffer.len(); + } + Sentinelled::Normal(value) => { + // bounded, find the end position + let value_enc = memcmp_encoding::encode_value(value, order_type) + .expect("no reason to fail here"); + self.right_excl_idx = self + .buffer + .partition_point(|elem| elem.key.order_key <= value_enc); + } + Sentinelled::Smallest => unreachable!("frame end never be UNBOUNDED PRECEDING"), + } + + println!( + "[rc] buffer: {:?}", + self.buffer.iter().map(|elem| &elem.key).collect::>() + ); + println!( + "[rc] left = {}, right excl = {}, curr = {}", + self.left_idx, self.right_excl_idx, self.curr_idx + ); + } +} diff --git a/src/expr/core/src/window_function/state/mod.rs b/src/expr/core/src/window_function/state/mod.rs index 37ee086ca7ba4..e17aa97ffe30d 100644 --- a/src/expr/core/src/window_function/state/mod.rs +++ b/src/expr/core/src/window_function/state/mod.rs @@ -25,7 +25,9 @@ use super::{WindowFuncCall, WindowFuncKind}; use crate::{ExprError, Result}; mod aggregate; +mod aggregate_range; mod buffer; +mod buffer_range; mod range_utils; mod rank; @@ -122,7 +124,16 @@ pub fn create_window_state(call: &WindowFuncCall) -> Result Box::new(rank::RankState::::new(call)), Rank => Box::new(rank::RankState::::new(call)), DenseRank => Box::new(rank::RankState::::new(call)), - Aggregate(_) => Box::new(aggregate::AggregateState::new(call)?), + Aggregate(_) => { + if call.frame.bounds.is_rows() { + Box::new(aggregate::AggregateState::new(&call)?) + } else if call.frame.bounds.is_range() { + // TODO(): unify RangeAggregateState and AggregateState + Box::new(aggregate_range::RangeAggregateState::new(&call)?) + } else { + unreachable!() + } + } kind => { return Err(ExprError::UnsupportedFunction(format!( "{}({}) -> {}", diff --git a/src/stream/src/executor/over_window/general.rs b/src/stream/src/executor/over_window/general.rs index 0ba4808b93624..7025cef322679 100644 --- a/src/stream/src/executor/over_window/general.rs +++ b/src/stream/src/executor/over_window/general.rs @@ -25,12 +25,12 @@ use risingwave_common::array::stream_record::Record; use risingwave_common::array::{Op, RowRef, StreamChunk}; use risingwave_common::row::{OwnedRow, Row, RowExt}; use risingwave_common::session_config::OverWindowCachePolicy as CachePolicy; -use risingwave_common::types::{DataType, DefaultOrdered}; +use risingwave_common::types::{DataType, DefaultOrdered, ScalarImpl}; use risingwave_common::util::iter_util::ZipEqFast; use risingwave_common::util::memcmp_encoding::{self, MemcmpEncoded}; use risingwave_common::util::sort_util::OrderType; use risingwave_expr::window_function::{ - create_window_state, StateKey, WindowFuncCall, WindowStates, + create_window_state, RangeFrameBounds, RowsFrameBounds, StateKey, WindowFuncCall, WindowStates, }; use risingwave_storage::StateStore; @@ -189,12 +189,73 @@ impl OverWindowExecutor { &input_info.pk_indices, ); + // TODO(): just for test + let calls = args.calls.into_iter().map(|mut call| { + let test_frame_bounds = match call.frame.bounds.clone() { + risingwave_expr::window_function::FrameBounds::Rows(RowsFrameBounds { + start, + end, + }) => { + let start = match start { + risingwave_expr::window_function::FrameBound::UnboundedPreceding => { + risingwave_expr::window_function::FrameBound::UnboundedPreceding + } + risingwave_expr::window_function::FrameBound::Preceding(offset) => { + risingwave_expr::window_function::FrameBound::Preceding( + ScalarImpl::from(offset as i32), + ) + } + risingwave_expr::window_function::FrameBound::CurrentRow => { + risingwave_expr::window_function::FrameBound::CurrentRow + } + risingwave_expr::window_function::FrameBound::Following(offset) => { + risingwave_expr::window_function::FrameBound::Following( + ScalarImpl::from(offset as i32), + ) + } + risingwave_expr::window_function::FrameBound::UnboundedFollowing => { + risingwave_expr::window_function::FrameBound::UnboundedFollowing + } + }; + let end = match end { + risingwave_expr::window_function::FrameBound::UnboundedPreceding => { + risingwave_expr::window_function::FrameBound::UnboundedPreceding + } + risingwave_expr::window_function::FrameBound::Preceding(offset) => { + risingwave_expr::window_function::FrameBound::Preceding( + ScalarImpl::from(offset as i32), + ) + } + risingwave_expr::window_function::FrameBound::CurrentRow => { + risingwave_expr::window_function::FrameBound::CurrentRow + } + risingwave_expr::window_function::FrameBound::Following(offset) => { + risingwave_expr::window_function::FrameBound::Following( + ScalarImpl::from(offset as i32), + ) + } + risingwave_expr::window_function::FrameBound::UnboundedFollowing => { + risingwave_expr::window_function::FrameBound::UnboundedFollowing + } + }; + risingwave_expr::window_function::FrameBounds::Range(RangeFrameBounds { + start, + end, + }) + } + bounds @ risingwave_expr::window_function::FrameBounds::Range(..) => bounds, + }; + call.frame.bounds = test_frame_bounds; + println!("[rc] new frame: {}", call.frame); + call + }); + Self { input: args.input, inner: ExecutorInner { actor_ctx: args.actor_ctx, info: args.info, - calls: args.calls, + calls: calls.collect(), partition_key_indices: args.partition_key_indices, order_key_indices: args.order_key_indices, order_key_data_types, diff --git a/src/stream/src/executor/over_window/over_partition.rs b/src/stream/src/executor/over_window/over_partition.rs index 7a395821f6030..e86c090ffacf2 100644 --- a/src/stream/src/executor/over_window/over_partition.rs +++ b/src/stream/src/executor/over_window/over_partition.rs @@ -25,7 +25,9 @@ use risingwave_common::array::stream_record::Record; use risingwave_common::estimate_size::collections::EstimatedBTreeMap; use risingwave_common::row::{OwnedRow, Row}; use risingwave_common::session_config::OverWindowCachePolicy as CachePolicy; -use risingwave_common::types::Sentinelled; +use risingwave_common::types::{Datum, Sentinelled}; +use risingwave_common::util::memcmp_encoding; +use risingwave_common::util::sort_util::cmp_datum; use risingwave_expr::window_function::{FrameBounds, StateKey, WindowFuncCall}; use risingwave_storage::store::PrefetchOptions; use risingwave_storage::StateStore; @@ -36,6 +38,17 @@ use crate::executor::StreamExecutorResult; pub(super) type CacheKey = Sentinelled; +// TODO() +fn test_order_col() -> ( + risingwave_common::types::DataType, + risingwave_common::util::sort_util::OrderType, +) { + ( + risingwave_common::types::DataType::Int32, + risingwave_common::util::sort_util::OrderType::ascending(), + ) +} + /// Range cache for one over window partition. /// The cache entries can be: /// @@ -396,6 +409,8 @@ impl<'a, S: StateStore> OverPartition<'a, S> { self::find_affected_ranges(self.calls, DeltaBTreeMap::new(cache_inner, delta)); self.stats.lookup_count += 1; + println!("[rc] ranges: {:?}", ranges); + if ranges.is_empty() { // no ranges affected, we're done return Ok((DeltaBTreeMap::new(cache_inner, delta), ranges)); @@ -790,11 +805,6 @@ fn find_affected_ranges<'cache>( &'cache CacheKey, &'cache CacheKey, )> { - // XXX(rc): NOTE FOR DEVS - // Must carefully consider the sentinel keys in the cache when extending this function to - // support `RANGE` and `GROUPS` frames later. May introduce a return value variant to clearly - // tell the caller that there exists at least one affected range that touches the sentinel. - if part_with_delta.first_key().is_none() { // all keys are deleted in the delta return vec![]; @@ -828,15 +838,21 @@ fn find_affected_ranges<'cache>( // `first_curr_key` which is the MINIMUM of all `first_curr_key`s of all frames of all window // function calls. - let first_curr_key = if end_is_unbounded || delta_first_key == first_key { + let ( + first_curr_key, + // By *logical*, it means the order value doesn't necessary to exist in the `part_with_delta`, instead, + // it's just used as the base when calculating *first frame start* of `RANGE` frames later. + logical_first_curr_order, + ) = if end_is_unbounded || delta_first_key == first_key { // If the frame end is unbounded, or, the first key is in delta, then the frame corresponding // to the first key is always affected. - first_key + (first_key, None) } else { let mut min_first_curr_key = &Sentinelled::Largest; + let mut min_logical_first_curr_order: Option = None; for call in calls { - let key = match &call.frame.bounds { + let (key, logical_order_value) = match &call.frame.bounds { FrameBounds::Rows(bounds) => { let mut cursor = part_with_delta.lower_bound(Bound::Included(delta_first_key)); for _ in 0..bounds.end.n_following_rows().unwrap() { @@ -847,9 +863,83 @@ fn find_affected_ranges<'cache>( break; } } - cursor.key().unwrap_or(first_key) + ( + cursor.key().unwrap_or(first_key), + None, // `ROWS` frames don't have *logical* first curr order + ) + } + FrameBounds::Range(bounds) => { + let (data_type, order_type) = test_order_col(); // TODO() + + let delta_first_order_value = memcmp_encoding::decode_value( + &data_type, + &delta_first_key.as_normal_expect().order_key, + order_type, + ) + .expect("no reason to fail here because we just encoded it in memory"); + + let logical_curr_order_value = + bounds.first_curr_of(&delta_first_order_value, order_type); + + match logical_curr_order_value { + Sentinelled::Smallest => (first_key, None), + Sentinelled::Normal(mut logical_curr_order_value) => { + // In this case we can derive the logical order value of the first CURRENT ROW corresponding to the given + // `delta_first_order_value`. + + logical_curr_order_value = std::cmp::min_by( + logical_curr_order_value, + delta_first_order_value, + |x, y| cmp_datum(x, y, order_type), + ); + + // Now we try to search for the logical order value in `part_with_delta`. + let order_value_enc = memcmp_encoding::encode_value( + &logical_curr_order_value, + order_type, + ) + .expect("the data type is simple, should succeed"); + let search_key = Sentinelled::Normal(StateKey { + order_key: order_value_enc, + pk: OwnedRow::empty().into(), // empty row is minimal + }); + + let cursor = part_with_delta.lower_bound(Bound::Included(&search_key)); + let curr_key = if let Some((prev_key, _)) = cursor.peek_prev() + && prev_key.is_smallest() + { + // If the found lower bound of search key is right behind a smallest sentinel, + // we don't know if there's any other rows with the same order key in the state + // table but not in cache. We should conservatively return the sentinel key as + // the first curr key. + prev_key + } else { + // If cursor is in ghost position, it simply means that the search key is larger than any existing keys. + cursor.key().unwrap_or(last_key) + }; + + (curr_key, Some(logical_curr_order_value)) + } + Sentinelled::Largest => { + unreachable!("first curr key can never be the largest, which means UNBOUNDED FOLLOWING") + } + } } }; + + if let Some(logical_order_value) = logical_order_value { + let (_data_type, order_type) = test_order_col(); // TODO() + if let Some(min_logical_first_curr_order) = min_logical_first_curr_order.as_mut() { + *min_logical_first_curr_order = std::cmp::min_by( + min_logical_first_curr_order.take(), + logical_order_value, + |x, y| cmp_datum(x, y, order_type), + ); + } else { + min_logical_first_curr_order = Some(logical_order_value); + } + } + min_first_curr_key = min_first_curr_key.min(key); if min_first_curr_key == first_key { // if we already pushed the affected curr key to the first key, no more pushing is needed @@ -857,7 +947,7 @@ fn find_affected_ranges<'cache>( } } - min_first_curr_key + (min_first_curr_key, min_logical_first_curr_order) }; let first_frame_start = if start_is_unbounded || first_curr_key == first_key { @@ -879,6 +969,59 @@ fn find_affected_ranges<'cache>( } cursor.key().unwrap_or(first_key) } + FrameBounds::Range(bounds) => { + let (_data_type, order_type) = test_order_col(); // TODO() + + let logical_curr_order_value = logical_first_curr_order + .as_ref() + .expect("otherwise should've gone `first_curr_key == first_key` branch") + .clone(); + + let logical_frame_start = + bounds.frame_start_of(&logical_curr_order_value, order_type); + + match logical_frame_start { + Sentinelled::Smallest => first_key, + Sentinelled::Normal(mut logical_frame_start) => { + // In this case we can derive the logical order value of the frame start corresponding to the order + // value of the *logical* CURRENT ROW. + + logical_frame_start = std::cmp::min_by( + logical_frame_start, + logical_curr_order_value, + |x, y| cmp_datum(x, y, order_type), + ); + + // Now we try to search for the logical order value in `part_with_delta`. + let order_value_enc = + memcmp_encoding::encode_value(&logical_frame_start, order_type) + .expect("the data type is simple, should succeed"); + let search_key = Sentinelled::Normal(StateKey { + order_key: order_value_enc, + pk: OwnedRow::empty().into(), // empty row is minimal + }); + + let cursor = part_with_delta.lower_bound(Bound::Included(&search_key)); + let frame_start = if let Some((prev_key, _)) = cursor.peek_prev() + && prev_key.is_smallest() + { + // If the found lower bound of search key is right behind a smallest sentinel, + // we don't know if there's any other rows with the same order key in the state + // table but not in cache. We should conservatively return the sentinel key as + // the first curr key. + prev_key + } else { + // If cursor is in ghost position, it simply means that the search key is larger than any existing keys. + cursor.key().unwrap_or(last_key) + }; + + frame_start + } + Sentinelled::Largest => { + unreachable!("first frame start can never be the largest, which means UNBOUNDED FOLLOWING") + } + } + } }; min_frame_start = min_frame_start.min(key); if min_frame_start == first_key { @@ -890,13 +1033,18 @@ fn find_affected_ranges<'cache>( min_frame_start }; - let last_curr_key = if start_is_unbounded || delta_last_key == last_key { - last_key + let ( + last_curr_key, + // similar to `logical_first_curr_order` + logical_last_curr_order, + ) = if start_is_unbounded || delta_last_key == last_key { + (last_key, None) } else { let mut max_last_curr_key = &Sentinelled::Smallest; + let mut max_logical_last_curr_order: Option = None; for call in calls { - let key = match &call.frame.bounds { + let (key, logical_order_value) = match &call.frame.bounds { FrameBounds::Rows(bounds) => { let mut cursor = part_with_delta.upper_bound(Bound::Included(delta_last_key)); for _ in 0..bounds.start.n_preceding_rows().unwrap() { @@ -905,9 +1053,87 @@ fn find_affected_ranges<'cache>( break; } } - cursor.key().unwrap_or(last_key) + ( + cursor.key().unwrap_or(last_key), + None, // `ROWS` frames don't have *logical* last curr order + ) + } + FrameBounds::Range(bounds) => { + let (data_type, order_type) = test_order_col(); // TODO() + + let delta_last_order_value = memcmp_encoding::decode_value( + &data_type, + &delta_last_key.as_normal_expect().order_key, + order_type, + ) + .expect("no reason to fail here because we just encoded it in memory"); + + let logical_curr_order_value = + bounds.last_curr_of(&delta_last_order_value, order_type); + + match logical_curr_order_value { + Sentinelled::Largest => (last_key, None), + Sentinelled::Normal(mut logical_curr_order_value) => { + // In this case we can derive the logical order value of the last CURRENT ROW corresponding to the given + // `delta_last_order_value`. + + logical_curr_order_value = std::cmp::max_by( + logical_curr_order_value, + delta_last_order_value, + |x, y| cmp_datum(x, y, order_type), + ); + + // Now we try to search for the logical order value in `part_with_delta`. + let order_value_enc = memcmp_encoding::encode_value( + &logical_curr_order_value, + order_type, + ) + .expect("the data type is simple, should succeed"); + let search_key = Sentinelled::Normal(StateKey { + order_key: order_value_enc, + pk: OwnedRow::new(vec![ + None; // all-NULL row is maximal + delta_last_key.as_normal_expect().pk.len() + ]) + .into(), + }); + + let cursor = part_with_delta.upper_bound(Bound::Included(&search_key)); + let curr_key = if let Some((next_key, _)) = cursor.peek_next() + && next_key.is_largest() + { + // If the found upper bound of search key is right before a largest sentinel, + // we don't know if there's any other rows with the same order key in the state + // table but not in cache. We should conservatively return the sentinel key as + // the last curr key. + next_key + } else { + // If cursor is in ghost position, it simply means that the search key is larger than any existing keys. + cursor.key().unwrap_or(first_key) + }; + + (curr_key, Some(logical_curr_order_value)) + } + Sentinelled::Smallest => { + unreachable!("last curr key can never be the smallest, which means UNBOUNDED PRECEDING") + } + } } }; + + if let Some(logical_order_value) = logical_order_value { + let (_data_type, order_type) = test_order_col(); // TODO() + if let Some(max_logical_last_curr_order) = max_logical_last_curr_order.as_mut() { + *max_logical_last_curr_order = std::cmp::max_by( + max_logical_last_curr_order.take(), + logical_order_value, + |x, y| cmp_datum(x, y, order_type), + ); + } else { + max_logical_last_curr_order = Some(logical_order_value); + } + } + max_last_curr_key = max_last_curr_key.max(key); if max_last_curr_key == last_key { // if we already pushed the affected curr key to the last key, no more pushing is needed @@ -915,7 +1141,7 @@ fn find_affected_ranges<'cache>( } } - max_last_curr_key + (max_last_curr_key, max_logical_last_curr_order) }; let last_frame_end = if end_is_unbounded || last_curr_key == last_key { @@ -935,6 +1161,63 @@ fn find_affected_ranges<'cache>( } cursor.key().unwrap_or(last_key) } + FrameBounds::Range(bounds) => { + let (_data_type, order_type) = test_order_col(); // TODO() + + let logical_curr_order_value = logical_last_curr_order + .as_ref() + .expect("otherwise should've gone `last_curr_key == last_key` branch") + .clone(); + + let logical_frame_end = + bounds.frame_end_of(&logical_curr_order_value, order_type); + + match logical_frame_end { + Sentinelled::Largest => last_key, + Sentinelled::Normal(mut logical_frame_end) => { + // In this case we can derive the logical order value of the frame end corresponding to the order + // value of the *logical* CURRENT ROW. + + logical_frame_end = std::cmp::max_by( + logical_frame_end, + logical_curr_order_value, + |x, y| cmp_datum(x, y, order_type), + ); + + // Now we try to search for the logical order value in `part_with_delta`. + let order_value_enc = + memcmp_encoding::encode_value(&logical_frame_end, order_type) + .expect("the data type is simple, should succeed"); + let search_key = Sentinelled::Normal(StateKey { + order_key: order_value_enc, + pk: OwnedRow::new(vec![ + None; // all-NULL row is maximal + delta_last_key.as_normal_expect().pk.len() + ]) + .into(), + }); + + let cursor = part_with_delta.upper_bound(Bound::Included(&search_key)); + let frame_end = if let Some((next_key, _)) = cursor.peek_next() + && next_key.is_largest() + { + // If the found upper bound of search key is right before a largest sentinel, + // we don't know if there's any other rows with the same order key in the state + // table but not in cache. We should conservatively return the sentinel key as + // the first curr key. + next_key + } else { + // If cursor is in ghost position, it simply means that the search key is larger than any existing keys. + cursor.key().unwrap_or(first_key) + }; + + frame_end + } + Sentinelled::Smallest => { + unreachable!("last frame end can never be the smallest, which means UNBOUNDED PRECEDING") + } + } + } }; max_frame_end = max_frame_end.max(key); if max_frame_end == last_key {