From 672b9116aa8e9c2db694163ddab0c17af0c43b02 Mon Sep 17 00:00:00 2001 From: Richard Chien Date: Fri, 1 Nov 2024 17:15:57 +0800 Subject: [PATCH] add shortcut for first_value and last_value in over window impl Signed-off-by: Richard Chien --- .../impl/src/window_function/aggregate.rs | 71 ++++++++++++------- src/expr/impl/src/window_function/buffer.rs | 21 ++++++ 2 files changed, 68 insertions(+), 24 deletions(-) diff --git a/src/expr/impl/src/window_function/aggregate.rs b/src/expr/impl/src/window_function/aggregate.rs index c173ae8c82995..213de14878627 100644 --- a/src/expr/impl/src/window_function/aggregate.rs +++ b/src/expr/impl/src/window_function/aggregate.rs @@ -14,6 +14,7 @@ use std::collections::BTreeSet; +use educe::Educe; use futures_util::FutureExt; use risingwave_common::array::{DataChunk, Op, StreamChunk}; use risingwave_common::types::{DataType, Datum}; @@ -22,7 +23,7 @@ use risingwave_common::{bail, must_match}; use risingwave_common_estimate_size::{EstimateSize, KvSize}; use risingwave_expr::aggregate::{ build_append_only, AggCall, AggType, AggregateFunction, AggregateState as AggImplState, - BoxedAggregateFunction, + BoxedAggregateFunction, PbAggKind, }; use risingwave_expr::sig::FUNCTION_REGISTRY; use risingwave_expr::window_function::{ @@ -40,7 +41,6 @@ struct AggregateState where W: WindowImpl, { - agg_func: BoxedAggregateFunction, agg_impl: AggImpl, arg_data_types: Vec, buffer: WindowBuffer, @@ -65,7 +65,9 @@ pub(super) fn new(call: &WindowFuncCall) -> Result { direct_args: vec![], }; - let (agg_func, agg_impl, enable_delta) = match agg_type { + let (agg_impl, enable_delta) = match agg_type { + AggType::Builtin(PbAggKind::FirstValue) => (AggImpl::Shortcut(Shortcut::FirstValue), false), + AggType::Builtin(PbAggKind::LastValue) => (AggImpl::Shortcut(Shortcut::LastValue), false), AggType::Builtin(kind) => { let agg_func_sig = FUNCTION_REGISTRY .get(*kind, &arg_data_types, &call.return_type) @@ -74,28 +76,25 @@ pub(super) fn new(call: &WindowFuncCall) -> Result { let (agg_impl, enable_delta) = if agg_func_sig.is_retractable() && call.frame.exclusion.is_no_others() { let init_state = agg_func.create_state()?; - (AggImpl::Incremental(init_state), true) + (AggImpl::Incremental(agg_func, init_state), true) } else { - (AggImpl::Full, false) + (AggImpl::Full(agg_func), false) }; - (agg_func, agg_impl, enable_delta) + (agg_impl, enable_delta) } AggType::UserDefined(_) => { // TODO(rc): utilize `retract` method of embedded UDAF to do incremental aggregation - let agg_func = build_append_only(&agg_call)?; - (agg_func, AggImpl::Full, false) + (AggImpl::Full(build_append_only(&agg_call)?), false) } AggType::WrapScalar(_) => { // we have to feed the wrapped scalar function with all the rows in the window, // instead of doing incremental aggregation - let agg_func = build_append_only(&agg_call)?; - (agg_func, AggImpl::Full, false) + (AggImpl::Full(build_append_only(&agg_call)?), false) } }; let this = match &call.frame.bounds { FrameBounds::Rows(frame_bounds) => Box::new(AggregateState { - agg_func, agg_impl, arg_data_types, buffer: WindowBuffer::>::new( @@ -106,7 +105,6 @@ pub(super) fn new(call: &WindowFuncCall) -> Result { buffer_heap_size: KvSize::new(), }) as BoxedWindowState, FrameBounds::Range(frame_bounds) => Box::new(AggregateState { - agg_func, agg_impl, arg_data_types, buffer: WindowBuffer::>::new( @@ -117,7 +115,6 @@ pub(super) fn new(call: &WindowFuncCall) -> Result { buffer_heap_size: KvSize::new(), }) as BoxedWindowState, FrameBounds::Session(frame_bounds) => Box::new(AggregateState { - agg_func, agg_impl, arg_data_types, buffer: WindowBuffer::>::new( @@ -181,15 +178,31 @@ where } fn slide(&mut self) -> Result<(Datum, StateEvictHint)> { - let wrapper = AggregatorWrapper { - agg_func: self.agg_func.as_ref(), - arg_data_types: &self.arg_data_types, - }; let output = match self.agg_impl { - AggImpl::Full => wrapper.aggregate(self.buffer.curr_window_values()), - AggImpl::Incremental(ref mut state) => { + AggImpl::Full(ref agg_func) => { + let wrapper = AggregatorWrapper { + agg_func: agg_func.as_ref(), + arg_data_types: &self.arg_data_types, + }; + wrapper.aggregate(self.buffer.curr_window_values()) + } + AggImpl::Incremental(ref agg_func, ref mut state) => { + let wrapper = AggregatorWrapper { + agg_func: agg_func.as_ref(), + arg_data_types: &self.arg_data_types, + }; wrapper.update(state, self.buffer.consume_curr_window_values_delta()) } + AggImpl::Shortcut(shortcut) => match shortcut { + Shortcut::FirstValue => Ok(self + .buffer + .curr_window_first_value() + .and_then(|args| args[0].clone())), + Shortcut::LastValue => Ok(self + .buffer + .curr_window_last_value() + .and_then(|args| args[0].clone())), + }, }?; let evict_hint = self.slide_inner(); Ok((output, evict_hint)) @@ -197,16 +210,17 @@ where fn slide_no_output(&mut self) -> Result { match self.agg_impl { - AggImpl::Full => {} - AggImpl::Incremental(ref mut state) => { + AggImpl::Full(..) => {} + AggImpl::Incremental(ref agg_func, ref mut state) => { // for incremental agg, we need to update the state even if the caller doesn't need // the output let wrapper = AggregatorWrapper { - agg_func: self.agg_func.as_ref(), + agg_func: agg_func.as_ref(), arg_data_types: &self.arg_data_types, }; wrapper.update(state, self.buffer.consume_curr_window_values_delta())?; } + AggImpl::Shortcut(..) => {} }; Ok(self.slide_inner()) } @@ -223,9 +237,18 @@ where } } +#[derive(Educe)] +#[educe(Debug)] enum AggImpl { - Incremental(AggImplState), - Full, + Incremental(#[educe(Debug(ignore))] BoxedAggregateFunction, AggImplState), + Full(#[educe(Debug(ignore))] BoxedAggregateFunction), + Shortcut(Shortcut), +} + +#[derive(Debug, Clone, Copy)] +enum Shortcut { + FirstValue, + LastValue, } struct AggregatorWrapper<'a> { diff --git a/src/expr/impl/src/window_function/buffer.rs b/src/expr/impl/src/window_function/buffer.rs index 57217dda6fd4b..365f3ec4fbc84 100644 --- a/src/expr/impl/src/window_function/buffer.rs +++ b/src/expr/impl/src/window_function/buffer.rs @@ -132,6 +132,27 @@ impl WindowBuffer { .map(|Entry { value, .. }| value) } + /// Get the first value in the current window. Time complexity is O(1). + pub fn curr_window_first_value(&self) -> Option<&W::Value> { + self.curr_window_values().next() + } + + /// Get the last value in the current window. Time complexity is O(1). + pub fn curr_window_last_value(&self) -> Option<&W::Value> { + let (left, right) = self.curr_window_ranges(); + if !right.is_empty() { + self.buffer + .get(right.end - 1) + .map(|Entry { value, .. }| value) + } else if !left.is_empty() { + self.buffer + .get(left.end - 1) + .map(|Entry { value, .. }| value) + } else { + None + } + } + /// Consume the delta of values comparing the current window to the previous window. /// The delta is not guaranteed to be sorted, especially when frame exclusion is not `NoOthers`. pub fn consume_curr_window_values_delta(