diff --git a/src/expr/impl/src/aggregate/first_last_value.rs b/src/expr/impl/src/aggregate/first_last_value.rs new file mode 100644 index 0000000000000..841442148f722 --- /dev/null +++ b/src/expr/impl/src/aggregate/first_last_value.rs @@ -0,0 +1,88 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use risingwave_common::types::{Datum, ScalarRefImpl}; +use risingwave_common_estimate_size::EstimateSize; +use risingwave_expr::aggregate; +use risingwave_expr::aggregate::AggStateDyn; + +/// Note that different from `min` and `max`, `first_value` doesn't ignore `NULL` values. +/// +/// ```slt +/// statement ok +/// create table t(v1 int, ts int); +/// +/// statement ok +/// insert into t values (null, 1), (2, 2), (null, 3); +/// +/// query I +/// select first_value(v1 order by ts) from t; +/// ---- +/// NULL +/// +/// statement ok +/// drop table t; +/// ``` +#[aggregate("first_value(any) -> any")] +fn first_value(state: &mut FirstValueState, input: Option>) { + if state.0.is_none() { + state.0 = Some(input.map(|x| x.into_scalar_impl())); + } +} + +#[derive(Debug, Clone, Default, EstimateSize)] +struct FirstValueState(Option); + +impl AggStateDyn for FirstValueState {} + +impl From<&FirstValueState> for Datum { + fn from(state: &FirstValueState) -> Self { + if let Some(state) = &state.0 { + state.clone() + } else { + None + } + } +} + +/// Note that different from `min` and `max`, `last_value` doesn't ignore `NULL` values. +/// +/// ```slt +/// statement ok +/// create table t(v1 int, ts int); +/// +/// statement ok +/// insert into t values (null, 1), (2, 2), (null, 3); +/// +/// query I +/// select last_value(v1 order by ts) from t; +/// ---- +/// NULL +/// +/// statement ok +/// drop table t; +/// ``` +#[aggregate("last_value(*) -> auto", state = "ref")] // TODO(rc): `last_value(any) -> any` +fn last_value(_: Option, input: Option) -> Option { + input +} + +#[aggregate("internal_last_seen_value(*) -> auto", state = "ref", internal)] +fn internal_last_seen_value(state: T, input: T, retract: bool) -> T { + if retract { + state + } else { + input + } +} diff --git a/src/expr/impl/src/aggregate/general.rs b/src/expr/impl/src/aggregate/general.rs index 0c94312335b4b..daaea5e782fd1 100644 --- a/src/expr/impl/src/aggregate/general.rs +++ b/src/expr/impl/src/aggregate/general.rs @@ -124,25 +124,6 @@ fn max(state: T, input: T) -> T { state.max(input) } -#[aggregate("first_value(*) -> auto", state = "ref")] -fn first_value(state: T, _: T) -> T { - state -} - -#[aggregate("last_value(*) -> auto", state = "ref")] -fn last_value(_: T, input: T) -> T { - input -} - -#[aggregate("internal_last_seen_value(*) -> auto", state = "ref", internal)] -fn internal_last_seen_value(state: T, input: T, retract: bool) -> T { - if retract { - state - } else { - input - } -} - /// Note the following corner cases: /// /// ```slt diff --git a/src/expr/impl/src/aggregate/mod.rs b/src/expr/impl/src/aggregate/mod.rs index 349574018fedf..881465b4cf82f 100644 --- a/src/expr/impl/src/aggregate/mod.rs +++ b/src/expr/impl/src/aggregate/mod.rs @@ -20,6 +20,7 @@ mod bit_or; mod bit_xor; mod bool_and; mod bool_or; +mod first_last_value; mod general; mod jsonb_agg; mod mode; diff --git a/src/tests/sqlsmith/src/sql_gen/scalar.rs b/src/tests/sqlsmith/src/sql_gen/scalar.rs index 62cd7218dcc90..a532f6138c596 100644 --- a/src/tests/sqlsmith/src/sql_gen/scalar.rs +++ b/src/tests/sqlsmith/src/sql_gen/scalar.rs @@ -81,11 +81,15 @@ impl SqlGenerator<'_, R> { data_type: AstDataType::SmallInt, value: self.gen_int(i16::MIN as isize, i16::MAX as isize), })), - T::Varchar => Expr::Value(Value::SingleQuotedString( - (0..10) - .map(|_| self.rng.sample(Alphanumeric) as char) - .collect(), - )), + T::Varchar => Expr::Cast { + // since we are generating random scalar literal, we should cast it to avoid unknown type + expr: Box::new(Expr::Value(Value::SingleQuotedString( + (0..10) + .map(|_| self.rng.sample(Alphanumeric) as char) + .collect(), + ))), + data_type: AstDataType::Varchar, + }, T::Decimal => Expr::Nested(Box::new(Expr::Value(Value::Number(self.gen_float())))), T::Float64 => Expr::Nested(Box::new(Expr::TypedString { data_type: AstDataType::Float(None),