Skip to content

Commit

Permalink
generalize window function AggregateState
Browse files Browse the repository at this point in the history
Signed-off-by: Richard Chien <[email protected]>
  • Loading branch information
stdrc committed Jan 24, 2024
1 parent eac4a04 commit f80da3f
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 46 deletions.
110 changes: 66 additions & 44 deletions src/expr/core/src/window_function/state/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,66 +22,82 @@ use risingwave_common::util::iter_util::ZipEqFast;
use risingwave_common::{bail, must_match};
use smallvec::SmallVec;

use super::buffer::WindowBuffer;
use super::{StateEvictHint, StateKey, StatePos, WindowState};
use super::buffer::{RowsWindow, WindowBuffer, WindowImpl};
use super::{BoxedWindowState, StateEvictHint, StateKey, StatePos, WindowState};
use crate::aggregate::{
AggArgs, AggCall, AggregateFunction, AggregateState as AggImplState, BoxedAggregateFunction,
};
use crate::sig::FUNCTION_REGISTRY;
use crate::window_function::{WindowFuncCall, WindowFuncKind};
use crate::window_function::{FrameBounds, WindowFuncCall, WindowFuncKind};
use crate::Result;

pub struct AggregateState {
type StateValue = SmallVec<[Datum; 2]>;

struct AggregateState<W>
where
W: WindowImpl<Key = StateKey, Value = StateValue>,
{
agg_func: BoxedAggregateFunction,
agg_impl: AggImpl,
arg_data_types: Vec<DataType>,
buffer: WindowBuffer<StateKey, SmallVec<[Datum; 2]>>,
buffer: WindowBuffer<W>,
buffer_heap_size: KvSize,
}

impl AggregateState {
pub fn new(call: &WindowFuncCall) -> Result<Self> {
if call.frame.bounds.validate().is_err() {
bail!("the window frame must be valid");
}
let agg_kind = must_match!(call.kind, WindowFuncKind::Aggregate(agg_kind) => agg_kind);
let arg_data_types = call.args.arg_types().to_vec();
let agg_call = AggCall {
kind: agg_kind,
args: match &call.args {
// convert args to [0] or [0, 1]
AggArgs::None => AggArgs::None,
AggArgs::Unary(data_type, _) => AggArgs::Unary(data_type.to_owned(), 0),
AggArgs::Binary(data_types, _) => AggArgs::Binary(data_types.to_owned(), [0, 1]),
},
return_type: call.return_type.clone(),
column_orders: Vec::new(), // the input is already sorted
// TODO(rc): support filter on window function call
filter: None,
// TODO(rc): support distinct on window function call? PG doesn't support it either.
distinct: false,
direct_args: vec![],
pub(super) fn new(call: &WindowFuncCall) -> Result<BoxedWindowState> {
if call.frame.bounds.validate().is_err() {
bail!("the window frame must be valid");
}
let agg_kind = must_match!(call.kind, WindowFuncKind::Aggregate(agg_kind) => agg_kind);
let arg_data_types = call.args.arg_types().to_vec();
let agg_call = AggCall {
kind: agg_kind,
args: match &call.args {
// convert args to [0] or [0, 1]
AggArgs::None => AggArgs::None,
AggArgs::Unary(data_type, _) => AggArgs::Unary(data_type.to_owned(), 0),
AggArgs::Binary(data_types, _) => AggArgs::Binary(data_types.to_owned(), [0, 1]),
},
return_type: call.return_type.clone(),
column_orders: Vec::new(), // the input is already sorted
// TODO(rc): support filter on window function call
filter: None,
// TODO(rc): support distinct on window function call? PG doesn't support it either.
distinct: false,
direct_args: vec![],
};
let agg_func_sig = FUNCTION_REGISTRY
.get(agg_kind, &arg_data_types, &call.return_type)
.expect("the agg func must exist");
let agg_func = agg_func_sig.build_aggregate(&agg_call)?;
let (agg_impl, enable_delta) =
if agg_func_sig.is_retractable() && call.frame.exclusion.is_no_others() {
let init_state = agg_func.create_state();
(AggImpl::Incremental(init_state), true)
} else {
(AggImpl::Full, false)
};
let agg_func_sig = FUNCTION_REGISTRY
.get(agg_kind, &arg_data_types, &call.return_type)
.expect("the agg func must exist");
let agg_func = agg_func_sig.build_aggregate(&agg_call)?;
let (agg_impl, enable_delta) =
if agg_func_sig.is_retractable() && call.frame.exclusion.is_no_others() {
let init_state = agg_func.create_state();
(AggImpl::Incremental(init_state), true)
} else {
(AggImpl::Full, false)
};
Ok(Self {

let this = match &call.frame.bounds {
FrameBounds::Rows(frame_bounds) => Box::new(AggregateState {
agg_func,
agg_impl,
arg_data_types,
buffer: WindowBuffer::new(call.frame.clone(), enable_delta),
buffer: WindowBuffer::<RowsWindow<StateKey, StateValue>>::new(
frame_bounds.clone(),
call.frame.exclusion,
enable_delta,
),
buffer_heap_size: KvSize::new(),
})
}
}) as BoxedWindowState,
};
Ok(this)
}

impl<W> AggregateState<W>
where
W: WindowImpl<Key = StateKey, Value = StateValue>,
{
fn slide_inner(&mut self) -> StateEvictHint {
let removed_keys: BTreeSet<_> = self
.buffer
Expand All @@ -107,7 +123,10 @@ impl AggregateState {
}
}

impl WindowState for AggregateState {
impl<W> WindowState for AggregateState<W>
where
W: WindowImpl<Key = StateKey, Value = StateValue>,
{
fn append(&mut self, key: StateKey, args: SmallVec<[Datum; 2]>) {
args.iter().for_each(|arg| {
self.buffer_heap_size.add_val(arg);
Expand Down Expand Up @@ -156,7 +175,10 @@ impl WindowState for AggregateState {
}
}

impl EstimateSize for AggregateState {
impl<W> EstimateSize for AggregateState<W>
where
W: WindowImpl<Key = StateKey, Value = StateValue>,
{
fn estimated_heap_size(&self) -> usize {
// estimate `VecDeque` of `StreamWindowBuffer` internal size
// https://github.com/risingwavelabs/risingwave/issues/9713
Expand Down
6 changes: 4 additions & 2 deletions src/expr/core/src/window_function/state/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,15 +114,17 @@ pub trait WindowState: EstimateSize {
fn slide_no_output(&mut self) -> Result<StateEvictHint>;
}

pub fn create_window_state(call: &WindowFuncCall) -> Result<Box<dyn WindowState + Send + Sync>> {
pub type BoxedWindowState = Box<dyn WindowState + Send + Sync>;

pub fn create_window_state(call: &WindowFuncCall) -> Result<BoxedWindowState> {
assert!(call.frame.bounds.validate().is_ok());

use WindowFuncKind::*;
Ok(match call.kind {
RowNumber => Box::new(rank::RankState::<rank::RowNumber>::new(call)),
Rank => Box::new(rank::RankState::<rank::Rank>::new(call)),
DenseRank => Box::new(rank::RankState::<rank::DenseRank>::new(call)),
Aggregate(_) => Box::new(aggregate::AggregateState::new(call)?),
Aggregate(_) => aggregate::new(call)?,
kind => {
return Err(ExprError::UnsupportedFunction(format!(
"{}({}) -> {}",
Expand Down

0 comments on commit f80da3f

Please sign in to comment.