Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(agg): fix first_value and last_value to not ignore NULLs #19275

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions src/expr/impl/src/aggregate/first_last_value.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Copyright 2024 RisingWave Labs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use risingwave_expr::aggregate;

/// Note that different from `min` and `max`, `first_value` doesn't ignore `NULL` values.
///
/// ```slt
/// statement ok
/// create table t(v1 int, ts int);
///
/// statement ok
/// insert into t values (null, 1), (2, 2), (null, 3);
///
/// query I
/// select first_value(v1 order by ts) from t;
/// ----
/// NULL
///
/// statement ok
/// drop table t;
/// ```
#[aggregate("first_value(*) -> auto", state = "ref", shortcurcuit_if = "true" /* always shortcurcuit */)]
fn first_value<T>(_: Option<T>, input: Option<T>) -> Option<T> {
input // always shortcurcuit immediately, so the output is always the first value
}

/// Note that different from `min` and `max`, `last_value` doesn't ignore `NULL` values.
///
/// ```slt
/// statement ok
/// create table t(v1 int, ts int);
///
/// statement ok
/// insert into t values (null, 1), (2, 2), (null, 3);
///
/// query I
/// select last_value(v1 order by ts) from t;
/// ----
/// NULL
///
/// statement ok
/// drop table t;
/// ```
#[aggregate("last_value(*) -> auto", state = "ref")] // TODO(rc): `last_value(any) -> any`
fn last_value<T>(_: Option<T>, input: Option<T>) -> Option<T> {
input
}

#[aggregate("internal_last_seen_value(*) -> auto", state = "ref", internal)]
fn internal_last_seen_value<T>(state: T, input: T, retract: bool) -> T {
if retract {
state
} else {
input
}
}
19 changes: 0 additions & 19 deletions src/expr/impl/src/aggregate/general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,25 +124,6 @@ fn max<T: Ord>(state: T, input: T) -> T {
state.max(input)
}

#[aggregate("first_value(*) -> auto", state = "ref")]
fn first_value<T>(state: T, _: T) -> T {
state
}

#[aggregate("last_value(*) -> auto", state = "ref")]
fn last_value<T>(_: T, input: T) -> T {
input
}

#[aggregate("internal_last_seen_value(*) -> auto", state = "ref", internal)]
fn internal_last_seen_value<T>(state: T, input: T, retract: bool) -> T {
if retract {
state
} else {
input
}
}

/// Note the following corner cases:
///
/// ```slt
Expand Down
1 change: 1 addition & 0 deletions src/expr/impl/src/aggregate/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ mod bit_or;
mod bit_xor;
mod bool_and;
mod bool_or;
mod first_last_value;
mod general;
mod jsonb_agg;
mod mode;
Expand Down
87 changes: 62 additions & 25 deletions src/expr/macro/src/gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -826,43 +826,52 @@ impl FunctionAttr {
ReturnTypeKind::ResultOption => quote! { #next_state? },
};
if user_fn.accumulate().args_option.iter().all(|b| !b) {
let first_state = if self.init_state.is_some() {
// if `init_state` is specified, the state will never be None
quote! { unreachable!() }
} else if let Some(s) = &self.state
&& s == "ref"
{
if self.args.is_empty() {
return Err(Error::new(
Span::call_site(),
"`state` cannot be `ref` if there's no argument",
));
}

// for min/max/..., the first state is the first non-NULL value
quote! { Some(v0) }
} else if let AggregateFnOrImpl::Impl(impl_) = user_fn
&& impl_.create_state.is_some()
{
// use user-defined create_state function
quote! {{
let state = self.function.create_state();
#next_state
}}
} else {
quote! {{
let state = #state_type::default();
#next_state
}}
};

match self.args.len() {
0 => {
// the only argument is the state itself, now it's non-Option
next_state = quote! {
match state {
Some(state) => #next_state,
None => state,
None => #first_state, // create first state on first input
}
};
}
1 => {
let first_state = if self.init_state.is_some() {
// for count, the state will never be None
quote! { unreachable!() }
} else if let Some(s) = &self.state
&& s == "ref"
{
// for min/max/first/last, the state is the first value
quote! { Some(v0) }
} else if let AggregateFnOrImpl::Impl(impl_) = user_fn
&& impl_.create_state.is_some()
{
// use user-defined create_state function
quote! {{
let state = self.function.create_state();
#next_state
}}
} else {
quote! {{
let state = #state_type::default();
#next_state
}}
};
next_state = quote! {
match (state, v0) {
(Some(state), Some(v0)) => #next_state,
(None, Some(v0)) => #first_state,
(state, None) => state,
(None, Some(v0)) => #first_state, // for the first non-NULL input, create first state
(state, None) => state, // ignoring NULL input
}
};
}
Expand All @@ -874,6 +883,31 @@ impl FunctionAttr {
} else {
quote! { state = #next_state; }
};
let shortcurcuit = if let Some(cond) = &self.shortcurcuit_if {
let cond: TokenStream2 = cond.parse().unwrap();
if user_fn.accumulate().args_option.iter().all(|b| !b) {
// non-`Option` arguments, unpack the state inner value
quote! {
match state {
Some(state) => {
if #cond {
break; // this will break the loop in `update`/`update_range`
}
}
None => {}
}
}
} else {
// if some arguments are `Option`, we interpret the shortcurcuit condition as it is
quote! {
if #cond {
break; // this will break the loop in `update`/`update_range`
}
}
}
} else {
quote! {}
};
let get_result = if custom_state.is_some() {
quote! { Ok(state.downcast_ref::<#state_type>().into()) }
} else if let AggregateFnOrImpl::Impl(impl_) = user_fn
Expand Down Expand Up @@ -953,6 +987,7 @@ impl FunctionAttr {
let op = unsafe { *input.ops().get_unchecked(row_id) };
#(#let_values)*
#update_state
#shortcurcuit
}
#restore_state
Ok(())
Expand All @@ -967,6 +1002,7 @@ impl FunctionAttr {
let op = unsafe { *input.ops().get_unchecked(row_id) };
#(#let_values)*
#update_state
#shortcurcuit
}
} else {
for row_id in input.visibility().iter_ones() {
Expand All @@ -978,6 +1014,7 @@ impl FunctionAttr {
let op = unsafe { *input.ops().get_unchecked(row_id) };
#(#let_values)*
#update_state
#shortcurcuit
}
}
#restore_state
Expand Down
3 changes: 3 additions & 0 deletions src/expr/macro/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,9 @@ struct FunctionAttr {
/// Initial state value for aggregate function.
/// If not specified, it will be NULL.
init_state: Option<String>,
/// Shortcurcuit condition for aggregate function.
/// If not specified, there won't be any shortcurcuit.
shortcurcuit_if: Option<String>,
/// Prebuild function for arguments.
/// This could be any Rust expression.
prebuild: Option<String>,
Expand Down
2 changes: 2 additions & 0 deletions src/expr/macro/src/parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ impl Parse for FunctionAttr {
parsed.state = Some(get_value()?);
} else if meta.path().is_ident("init_state") {
parsed.init_state = Some(get_value()?);
} else if meta.path().is_ident("shortcurcuit_if") {
parsed.shortcurcuit_if = Some(get_value()?);
} else if meta.path().is_ident("prebuild") {
parsed.prebuild = Some(get_value()?);
} else if meta.path().is_ident("type_infer") {
Expand Down
Loading