Skip to content

Commit

Permalink
refactor(over window): generalize WindowBuffer and window function …
Browse files Browse the repository at this point in the history
…`AggregateState` (#14647)

Signed-off-by: Richard Chien <[email protected]>
  • Loading branch information
stdrc authored Feb 1, 2024
1 parent 752b5e0 commit 8481ea7
Show file tree
Hide file tree
Showing 6 changed files with 332 additions and 201 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions src/expr/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ chrono = { version = "0.4", default-features = false, features = [
] }
downcast-rs = "1.2"
easy-ext = "1"
educe = "0.5"
either = "1"
enum-as-inner = "0.6"
futures-async-stream = { workspace = true }
Expand Down
14 changes: 10 additions & 4 deletions src/expr/core/src/window_function/call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,10 @@ impl FrameBounds {
}
}

pub trait FrameBoundsImpl {
fn validate(&self) -> Result<()>;
}

#[derive(Display, Debug, Clone, Eq, PartialEq, Hash)]
#[display("ROWS BETWEEN {start} AND {end}")]
pub struct RowsFrameBounds {
Expand All @@ -151,10 +155,6 @@ pub struct RowsFrameBounds {
}

impl RowsFrameBounds {
fn validate(&self) -> Result<()> {
FrameBound::validate_bounds(&self.start, &self.end)
}

/// Check if the `ROWS` frame is canonical.
///
/// A canonical `ROWS` frame is defined as:
Expand Down Expand Up @@ -190,6 +190,12 @@ impl RowsFrameBounds {
}
}

impl FrameBoundsImpl for RowsFrameBounds {
fn validate(&self) -> Result<()> {
FrameBound::validate_bounds(&self.start, &self.end)
}
}

#[derive(Display, Debug, Clone, Eq, PartialEq, Hash, EnumAsInner)]
#[display(style = "TITLE CASE")]
pub enum FrameBound<T> {
Expand Down
110 changes: 66 additions & 44 deletions src/expr/core/src/window_function/state/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,66 +22,82 @@ use risingwave_common::util::iter_util::ZipEqFast;
use risingwave_common::{bail, must_match};
use smallvec::SmallVec;

use super::buffer::WindowBuffer;
use super::{StateEvictHint, StateKey, StatePos, WindowState};
use super::buffer::{RowsWindow, WindowBuffer, WindowImpl};
use super::{BoxedWindowState, StateEvictHint, StateKey, StatePos, WindowState};
use crate::aggregate::{
AggArgs, AggCall, AggregateFunction, AggregateState as AggImplState, BoxedAggregateFunction,
};
use crate::sig::FUNCTION_REGISTRY;
use crate::window_function::{WindowFuncCall, WindowFuncKind};
use crate::window_function::{FrameBounds, WindowFuncCall, WindowFuncKind};
use crate::Result;

pub struct AggregateState {
type StateValue = SmallVec<[Datum; 2]>;

struct AggregateState<W>
where
W: WindowImpl<Key = StateKey, Value = StateValue>,
{
agg_func: BoxedAggregateFunction,
agg_impl: AggImpl,
arg_data_types: Vec<DataType>,
buffer: WindowBuffer<StateKey, SmallVec<[Datum; 2]>>,
buffer: WindowBuffer<W>,
buffer_heap_size: KvSize,
}

impl AggregateState {
pub fn new(call: &WindowFuncCall) -> Result<Self> {
if call.frame.bounds.validate().is_err() {
bail!("the window frame must be valid");
}
let agg_kind = must_match!(call.kind, WindowFuncKind::Aggregate(agg_kind) => agg_kind);
let arg_data_types = call.args.arg_types().to_vec();
let agg_call = AggCall {
kind: agg_kind,
args: match &call.args {
// convert args to [0] or [0, 1]
AggArgs::None => AggArgs::None,
AggArgs::Unary(data_type, _) => AggArgs::Unary(data_type.to_owned(), 0),
AggArgs::Binary(data_types, _) => AggArgs::Binary(data_types.to_owned(), [0, 1]),
},
return_type: call.return_type.clone(),
column_orders: Vec::new(), // the input is already sorted
// TODO(rc): support filter on window function call
filter: None,
// TODO(rc): support distinct on window function call? PG doesn't support it either.
distinct: false,
direct_args: vec![],
pub(super) fn new(call: &WindowFuncCall) -> Result<BoxedWindowState> {
if call.frame.bounds.validate().is_err() {
bail!("the window frame must be valid");
}
let agg_kind = must_match!(call.kind, WindowFuncKind::Aggregate(agg_kind) => agg_kind);
let arg_data_types = call.args.arg_types().to_vec();
let agg_call = AggCall {
kind: agg_kind,
args: match &call.args {
// convert args to [0] or [0, 1]
AggArgs::None => AggArgs::None,
AggArgs::Unary(data_type, _) => AggArgs::Unary(data_type.to_owned(), 0),
AggArgs::Binary(data_types, _) => AggArgs::Binary(data_types.to_owned(), [0, 1]),
},
return_type: call.return_type.clone(),
column_orders: Vec::new(), // the input is already sorted
// TODO(rc): support filter on window function call
filter: None,
// TODO(rc): support distinct on window function call? PG doesn't support it either.
distinct: false,
direct_args: vec![],
};
let agg_func_sig = FUNCTION_REGISTRY
.get(agg_kind, &arg_data_types, &call.return_type)
.expect("the agg func must exist");
let agg_func = agg_func_sig.build_aggregate(&agg_call)?;
let (agg_impl, enable_delta) =
if agg_func_sig.is_retractable() && call.frame.exclusion.is_no_others() {
let init_state = agg_func.create_state();
(AggImpl::Incremental(init_state), true)
} else {
(AggImpl::Full, false)
};
let agg_func_sig = FUNCTION_REGISTRY
.get(agg_kind, &arg_data_types, &call.return_type)
.expect("the agg func must exist");
let agg_func = agg_func_sig.build_aggregate(&agg_call)?;
let (agg_impl, enable_delta) =
if agg_func_sig.is_retractable() && call.frame.exclusion.is_no_others() {
let init_state = agg_func.create_state();
(AggImpl::Incremental(init_state), true)
} else {
(AggImpl::Full, false)
};
Ok(Self {

let this = match &call.frame.bounds {
FrameBounds::Rows(frame_bounds) => Box::new(AggregateState {
agg_func,
agg_impl,
arg_data_types,
buffer: WindowBuffer::new(call.frame.clone(), enable_delta),
buffer: WindowBuffer::<RowsWindow<StateKey, StateValue>>::new(
RowsWindow::new(frame_bounds.clone()),
call.frame.exclusion,
enable_delta,
),
buffer_heap_size: KvSize::new(),
})
}
}) as BoxedWindowState,
};
Ok(this)
}

impl<W> AggregateState<W>
where
W: WindowImpl<Key = StateKey, Value = StateValue>,
{
fn slide_inner(&mut self) -> StateEvictHint {
let removed_keys: BTreeSet<_> = self
.buffer
Expand All @@ -107,7 +123,10 @@ impl AggregateState {
}
}

impl WindowState for AggregateState {
impl<W> WindowState for AggregateState<W>
where
W: WindowImpl<Key = StateKey, Value = StateValue>,
{
fn append(&mut self, key: StateKey, args: SmallVec<[Datum; 2]>) {
args.iter().for_each(|arg| {
self.buffer_heap_size.add_val(arg);
Expand Down Expand Up @@ -156,7 +175,10 @@ impl WindowState for AggregateState {
}
}

impl EstimateSize for AggregateState {
impl<W> EstimateSize for AggregateState<W>
where
W: WindowImpl<Key = StateKey, Value = StateValue>,
{
fn estimated_heap_size(&self) -> usize {
// estimate `VecDeque` of `StreamWindowBuffer` internal size
// https://github.com/risingwavelabs/risingwave/issues/9713
Expand Down
Loading

0 comments on commit 8481ea7

Please sign in to comment.