Skip to content

Commit

Permalink
add shortcut for first_value and last_value in over window impl
Browse files Browse the repository at this point in the history
Signed-off-by: Richard Chien <[email protected]>
  • Loading branch information
stdrc committed Nov 1, 2024
1 parent 3b8b913 commit 672b911
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 24 deletions.
71 changes: 47 additions & 24 deletions src/expr/impl/src/window_function/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

use std::collections::BTreeSet;

use educe::Educe;
use futures_util::FutureExt;
use risingwave_common::array::{DataChunk, Op, StreamChunk};
use risingwave_common::types::{DataType, Datum};
Expand All @@ -22,7 +23,7 @@ use risingwave_common::{bail, must_match};
use risingwave_common_estimate_size::{EstimateSize, KvSize};
use risingwave_expr::aggregate::{
build_append_only, AggCall, AggType, AggregateFunction, AggregateState as AggImplState,
BoxedAggregateFunction,
BoxedAggregateFunction, PbAggKind,
};
use risingwave_expr::sig::FUNCTION_REGISTRY;
use risingwave_expr::window_function::{
Expand All @@ -40,7 +41,6 @@ struct AggregateState<W>
where
W: WindowImpl<Key = StateKey, Value = StateValue>,
{
agg_func: BoxedAggregateFunction,
agg_impl: AggImpl,
arg_data_types: Vec<DataType>,
buffer: WindowBuffer<W>,
Expand All @@ -65,7 +65,9 @@ pub(super) fn new(call: &WindowFuncCall) -> Result<BoxedWindowState> {
direct_args: vec![],
};

let (agg_func, agg_impl, enable_delta) = match agg_type {
let (agg_impl, enable_delta) = match agg_type {
AggType::Builtin(PbAggKind::FirstValue) => (AggImpl::Shortcut(Shortcut::FirstValue), false),
AggType::Builtin(PbAggKind::LastValue) => (AggImpl::Shortcut(Shortcut::LastValue), false),
AggType::Builtin(kind) => {
let agg_func_sig = FUNCTION_REGISTRY
.get(*kind, &arg_data_types, &call.return_type)
Expand All @@ -74,28 +76,25 @@ pub(super) fn new(call: &WindowFuncCall) -> Result<BoxedWindowState> {
let (agg_impl, enable_delta) =
if agg_func_sig.is_retractable() && call.frame.exclusion.is_no_others() {
let init_state = agg_func.create_state()?;
(AggImpl::Incremental(init_state), true)
(AggImpl::Incremental(agg_func, init_state), true)
} else {
(AggImpl::Full, false)
(AggImpl::Full(agg_func), false)
};
(agg_func, agg_impl, enable_delta)
(agg_impl, enable_delta)
}
AggType::UserDefined(_) => {
// TODO(rc): utilize `retract` method of embedded UDAF to do incremental aggregation
let agg_func = build_append_only(&agg_call)?;
(agg_func, AggImpl::Full, false)
(AggImpl::Full(build_append_only(&agg_call)?), false)
}
AggType::WrapScalar(_) => {
// we have to feed the wrapped scalar function with all the rows in the window,
// instead of doing incremental aggregation
let agg_func = build_append_only(&agg_call)?;
(agg_func, AggImpl::Full, false)
(AggImpl::Full(build_append_only(&agg_call)?), false)
}
};

let this = match &call.frame.bounds {
FrameBounds::Rows(frame_bounds) => Box::new(AggregateState {
agg_func,
agg_impl,
arg_data_types,
buffer: WindowBuffer::<RowsWindow<StateKey, StateValue>>::new(
Expand All @@ -106,7 +105,6 @@ pub(super) fn new(call: &WindowFuncCall) -> Result<BoxedWindowState> {
buffer_heap_size: KvSize::new(),
}) as BoxedWindowState,
FrameBounds::Range(frame_bounds) => Box::new(AggregateState {
agg_func,
agg_impl,
arg_data_types,
buffer: WindowBuffer::<RangeWindow<StateValue>>::new(
Expand All @@ -117,7 +115,6 @@ pub(super) fn new(call: &WindowFuncCall) -> Result<BoxedWindowState> {
buffer_heap_size: KvSize::new(),
}) as BoxedWindowState,
FrameBounds::Session(frame_bounds) => Box::new(AggregateState {
agg_func,
agg_impl,
arg_data_types,
buffer: WindowBuffer::<SessionWindow<StateValue>>::new(
Expand Down Expand Up @@ -181,32 +178,49 @@ where
}

fn slide(&mut self) -> Result<(Datum, StateEvictHint)> {
let wrapper = AggregatorWrapper {
agg_func: self.agg_func.as_ref(),
arg_data_types: &self.arg_data_types,
};
let output = match self.agg_impl {
AggImpl::Full => wrapper.aggregate(self.buffer.curr_window_values()),
AggImpl::Incremental(ref mut state) => {
AggImpl::Full(ref agg_func) => {
let wrapper = AggregatorWrapper {
agg_func: agg_func.as_ref(),
arg_data_types: &self.arg_data_types,
};
wrapper.aggregate(self.buffer.curr_window_values())
}
AggImpl::Incremental(ref agg_func, ref mut state) => {
let wrapper = AggregatorWrapper {
agg_func: agg_func.as_ref(),
arg_data_types: &self.arg_data_types,
};
wrapper.update(state, self.buffer.consume_curr_window_values_delta())
}
AggImpl::Shortcut(shortcut) => match shortcut {
Shortcut::FirstValue => Ok(self
.buffer
.curr_window_first_value()
.and_then(|args| args[0].clone())),
Shortcut::LastValue => Ok(self
.buffer
.curr_window_last_value()
.and_then(|args| args[0].clone())),
},
}?;
let evict_hint = self.slide_inner();
Ok((output, evict_hint))
}

fn slide_no_output(&mut self) -> Result<StateEvictHint> {
match self.agg_impl {
AggImpl::Full => {}
AggImpl::Incremental(ref mut state) => {
AggImpl::Full(..) => {}
AggImpl::Incremental(ref agg_func, ref mut state) => {
// for incremental agg, we need to update the state even if the caller doesn't need
// the output
let wrapper = AggregatorWrapper {
agg_func: self.agg_func.as_ref(),
agg_func: agg_func.as_ref(),
arg_data_types: &self.arg_data_types,
};
wrapper.update(state, self.buffer.consume_curr_window_values_delta())?;
}
AggImpl::Shortcut(..) => {}
};
Ok(self.slide_inner())
}
Expand All @@ -223,9 +237,18 @@ where
}
}

#[derive(Educe)]
#[educe(Debug)]
enum AggImpl {
Incremental(AggImplState),
Full,
Incremental(#[educe(Debug(ignore))] BoxedAggregateFunction, AggImplState),
Full(#[educe(Debug(ignore))] BoxedAggregateFunction),
Shortcut(Shortcut),
}

#[derive(Debug, Clone, Copy)]
enum Shortcut {
FirstValue,
LastValue,
}

struct AggregatorWrapper<'a> {
Expand Down
21 changes: 21 additions & 0 deletions src/expr/impl/src/window_function/buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,27 @@ impl<W: WindowImpl> WindowBuffer<W> {
.map(|Entry { value, .. }| value)
}

/// Get the first value in the current window. Time complexity is O(1).
pub fn curr_window_first_value(&self) -> Option<&W::Value> {
self.curr_window_values().next()
}

/// Get the last value in the current window. Time complexity is O(1).
pub fn curr_window_last_value(&self) -> Option<&W::Value> {
let (left, right) = self.curr_window_ranges();
if !right.is_empty() {
self.buffer
.get(right.end - 1)
.map(|Entry { value, .. }| value)
} else if !left.is_empty() {
self.buffer
.get(left.end - 1)
.map(|Entry { value, .. }| value)
} else {
None
}
}

/// Consume the delta of values comparing the current window to the previous window.
/// The delta is not guaranteed to be sorted, especially when frame exclusion is not `NoOthers`.
pub fn consume_curr_window_values_delta(
Expand Down

0 comments on commit 672b911

Please sign in to comment.