From 7341f03a9681345431066a334bdb15495ca0336f Mon Sep 17 00:00:00 2001 From: Richard Chien Date: Thu, 18 Jan 2024 13:41:32 +0800 Subject: [PATCH] generalize window function `AggregateState` Signed-off-by: Richard Chien --- .../src/window_function/state/aggregate.rs | 110 +++++++++++------- .../core/src/window_function/state/mod.rs | 6 +- 2 files changed, 70 insertions(+), 46 deletions(-) diff --git a/src/expr/core/src/window_function/state/aggregate.rs b/src/expr/core/src/window_function/state/aggregate.rs index 09555eecf9201..d53b99782e5bf 100644 --- a/src/expr/core/src/window_function/state/aggregate.rs +++ b/src/expr/core/src/window_function/state/aggregate.rs @@ -22,66 +22,82 @@ use risingwave_common::util::iter_util::ZipEqFast; use risingwave_common::{bail, must_match}; use smallvec::SmallVec; -use super::buffer::WindowBuffer; -use super::{StateEvictHint, StateKey, StatePos, WindowState}; +use super::buffer::{RowsWindow, WindowBuffer, WindowImpl}; +use super::{BoxedWindowState, StateEvictHint, StateKey, StatePos, WindowState}; use crate::aggregate::{ AggArgs, AggCall, AggregateFunction, AggregateState as AggImplState, BoxedAggregateFunction, }; use crate::sig::FUNCTION_REGISTRY; -use crate::window_function::{WindowFuncCall, WindowFuncKind}; +use crate::window_function::{FrameBounds, WindowFuncCall, WindowFuncKind}; use crate::Result; -pub struct AggregateState { +type StateValue = SmallVec<[Datum; 2]>; + +struct AggregateState +where + W: WindowImpl, +{ agg_func: BoxedAggregateFunction, agg_impl: AggImpl, arg_data_types: Vec, - buffer: WindowBuffer>, + buffer: WindowBuffer, buffer_heap_size: KvSize, } -impl AggregateState { - pub fn new(call: &WindowFuncCall) -> Result { - if call.frame.bounds.validate().is_err() { - bail!("the window frame must be valid"); - } - let agg_kind = must_match!(call.kind, WindowFuncKind::Aggregate(agg_kind) => agg_kind); - let arg_data_types = call.args.arg_types().to_vec(); - let agg_call = AggCall { - kind: agg_kind, - args: match &call.args { - // convert args to [0] or [0, 1] - AggArgs::None => AggArgs::None, - AggArgs::Unary(data_type, _) => AggArgs::Unary(data_type.to_owned(), 0), - AggArgs::Binary(data_types, _) => AggArgs::Binary(data_types.to_owned(), [0, 1]), - }, - return_type: call.return_type.clone(), - column_orders: Vec::new(), // the input is already sorted - // TODO(rc): support filter on window function call - filter: None, - // TODO(rc): support distinct on window function call? PG doesn't support it either. - distinct: false, - direct_args: vec![], +pub(super) fn new(call: &WindowFuncCall) -> Result { + if call.frame.bounds.validate().is_err() { + bail!("the window frame must be valid"); + } + let agg_kind = must_match!(call.kind, WindowFuncKind::Aggregate(agg_kind) => agg_kind); + let arg_data_types = call.args.arg_types().to_vec(); + let agg_call = AggCall { + kind: agg_kind, + args: match &call.args { + // convert args to [0] or [0, 1] + AggArgs::None => AggArgs::None, + AggArgs::Unary(data_type, _) => AggArgs::Unary(data_type.to_owned(), 0), + AggArgs::Binary(data_types, _) => AggArgs::Binary(data_types.to_owned(), [0, 1]), + }, + return_type: call.return_type.clone(), + column_orders: Vec::new(), // the input is already sorted + // TODO(rc): support filter on window function call + filter: None, + // TODO(rc): support distinct on window function call? PG doesn't support it either. + distinct: false, + direct_args: vec![], + }; + let agg_func_sig = FUNCTION_REGISTRY + .get(agg_kind, &arg_data_types, &call.return_type) + .expect("the agg func must exist"); + let agg_func = agg_func_sig.build_aggregate(&agg_call)?; + let (agg_impl, enable_delta) = + if agg_func_sig.is_retractable() && call.frame.exclusion.is_no_others() { + let init_state = agg_func.create_state(); + (AggImpl::Incremental(init_state), true) + } else { + (AggImpl::Full, false) }; - let agg_func_sig = FUNCTION_REGISTRY - .get(agg_kind, &arg_data_types, &call.return_type) - .expect("the agg func must exist"); - let agg_func = agg_func_sig.build_aggregate(&agg_call)?; - let (agg_impl, enable_delta) = - if agg_func_sig.is_retractable() && call.frame.exclusion.is_no_others() { - let init_state = agg_func.create_state(); - (AggImpl::Incremental(init_state), true) - } else { - (AggImpl::Full, false) - }; - Ok(Self { + + let this = match &call.frame.bounds { + FrameBounds::Rows(frame_bounds) => Box::new(AggregateState { agg_func, agg_impl, arg_data_types, - buffer: WindowBuffer::new(call.frame.clone(), enable_delta), + buffer: WindowBuffer::>::new( + frame_bounds.clone(), + call.frame.exclusion, + enable_delta, + ), buffer_heap_size: KvSize::new(), - }) - } + }) as BoxedWindowState, + }; + Ok(this) +} +impl AggregateState +where + W: WindowImpl, +{ fn slide_inner(&mut self) -> StateEvictHint { let removed_keys: BTreeSet<_> = self .buffer @@ -107,7 +123,10 @@ impl AggregateState { } } -impl WindowState for AggregateState { +impl WindowState for AggregateState +where + W: WindowImpl, +{ fn append(&mut self, key: StateKey, args: SmallVec<[Datum; 2]>) { args.iter().for_each(|arg| { self.buffer_heap_size.add_val(arg); @@ -156,7 +175,10 @@ impl WindowState for AggregateState { } } -impl EstimateSize for AggregateState { +impl EstimateSize for AggregateState +where + W: WindowImpl, +{ fn estimated_heap_size(&self) -> usize { // estimate `VecDeque` of `StreamWindowBuffer` internal size // https://github.com/risingwavelabs/risingwave/issues/9713 diff --git a/src/expr/core/src/window_function/state/mod.rs b/src/expr/core/src/window_function/state/mod.rs index 37ee086ca7ba4..fbaec55a84c38 100644 --- a/src/expr/core/src/window_function/state/mod.rs +++ b/src/expr/core/src/window_function/state/mod.rs @@ -114,7 +114,9 @@ pub trait WindowState: EstimateSize { fn slide_no_output(&mut self) -> Result; } -pub fn create_window_state(call: &WindowFuncCall) -> Result> { +pub type BoxedWindowState = Box; + +pub fn create_window_state(call: &WindowFuncCall) -> Result { assert!(call.frame.bounds.validate().is_ok()); use WindowFuncKind::*; @@ -122,7 +124,7 @@ pub fn create_window_state(call: &WindowFuncCall) -> Result Box::new(rank::RankState::::new(call)), Rank => Box::new(rank::RankState::::new(call)), DenseRank => Box::new(rank::RankState::::new(call)), - Aggregate(_) => Box::new(aggregate::AggregateState::new(call)?), + Aggregate(_) => aggregate::new(call)?, kind => { return Err(ExprError::UnsupportedFunction(format!( "{}({}) -> {}",