diff --git a/e2e_test/batch/functions/array_sum.slt.part b/e2e_test/batch/functions/array_sum.slt.part index 7fc81e2105b99..11c126ef6a0d1 100644 --- a/e2e_test/batch/functions/array_sum.slt.part +++ b/e2e_test/batch/functions/array_sum.slt.part @@ -3,44 +3,44 @@ select array_sum(array[1, 2, 3]); ---- 6 -# Testing for int16 with positive numbers +# Testing for SMALLINT with positive numbers query I -select array_sum(array[10, 20, 30]); +select array_sum(array[10, 20, 30]::smallint[]); ---- 60 -# Testing for int16 with a mix of positive and negative numbers +# Testing for SMALLINT with a mix of positive and negative numbers query I -select array_sum(array[-10, 20, -30]); +select array_sum(array[-10, 20, -30]::smallint[]); ---- -20 -# Testing for int16 with all zeros +# Testing for SMALLINT with all zeros query I -select array_sum(array[0, 0, 0]); +select array_sum(array[0, 0, 0]::smallint[]); ---- 0 -# Testing for int32 with larger positive numbers +# Testing for INT with larger positive numbers query I -select array_sum(array[1000, 2000, 3000]); +select array_sum(array[1000, 2000, 3000]::int[]); ---- 6000 -# Testing for int32 with a mix of larger positive and negative numbers +# Testing for INT with a mix of larger positive and negative numbers query I -select array_sum(array[-1000, 2000, -3000]); +select array_sum(array[-1000, 2000, -3000]::int[]); ---- -2000 -# Testing for int64 with much larger numbers +# Testing for BIGINT with much larger numbers query I -select array_sum(array[1000000000, 2000000000, 3000000000]); +select array_sum(array[1000000000, 2000000000, 3000000000]::bigint[]); ---- 6000000000 -# Testing for int64 with a mix of much larger positive and negative numbers +# Testing for BIGINT with a mix of much larger positive and negative numbers query I -select array_sum(array[-1000000000, 2000000000, -3000000000]); +select array_sum(array[-1000000000, 2000000000, -3000000000]::bigint[]); ---- --2000000000 \ No newline at end of file +-2000000000 diff --git a/src/common/src/array/list_array.rs b/src/common/src/array/list_array.rs index 2c4a8cf042548..6445ac8a156d3 100644 --- a/src/common/src/array/list_array.rs +++ b/src/common/src/array/list_array.rs @@ -606,6 +606,12 @@ impl ToText for ListRef<'_> { } } +impl<'a> From<&'a ListValue> for ListRef<'a> { + fn from(val: &'a ListValue) -> Self { + ListRef::ValueRef { val } + } +} + #[cfg(test)] mod tests { use more_asserts::{assert_gt, assert_lt}; diff --git a/src/common/src/types/mod.rs b/src/common/src/types/mod.rs index 06b3f3d0bfa49..83d281c5238e6 100644 --- a/src/common/src/types/mod.rs +++ b/src/common/src/types/mod.rs @@ -352,6 +352,14 @@ impl DataType { DataTypeName::from(self).is_scalar() } + pub fn is_array(&self) -> bool { + matches!(self, DataType::List(_)) + } + + pub fn is_struct(&self) -> bool { + matches!(self, DataType::Struct(_)) + } + pub fn is_int(&self) -> bool { matches!(self, DataType::Int16 | DataType::Int32 | DataType::Int64) } @@ -950,7 +958,21 @@ impl ScalarImpl { }; Ok(res) } +} + +impl From> for ScalarImpl { + fn from(scalar_ref: ScalarRefImpl<'_>) -> Self { + scalar_ref.into_scalar_impl() + } +} + +impl<'a> From<&'a ScalarImpl> for ScalarRefImpl<'a> { + fn from(scalar: &'a ScalarImpl) -> Self { + scalar.as_scalar_ref_impl() + } +} +impl ScalarImpl { /// A lite version of casting from string to target type. Used by frontend to handle types that have /// to be created by casting. /// diff --git a/src/expr/core/src/aggregate/mod.rs b/src/expr/core/src/aggregate/mod.rs index d9276f264dfd6..74e3afdb0904c 100644 --- a/src/expr/core/src/aggregate/mod.rs +++ b/src/expr/core/src/aggregate/mod.rs @@ -16,11 +16,11 @@ use std::fmt::Debug; use std::ops::Range; use downcast_rs::{impl_downcast, Downcast}; +use itertools::Itertools; use risingwave_common::array::StreamChunk; use risingwave_common::estimate_size::EstimateSize; -use risingwave_common::types::{DataType, DataTypeName, Datum}; +use risingwave_common::types::{DataType, Datum}; -use crate::sig::FuncSigDebug; use crate::{ExprError, Result}; // aggregate definition @@ -136,25 +136,21 @@ pub fn build_retractable(agg: &AggCall) -> Result { /// NOTE: This function ignores argument indices, `column_orders`, `filter` and `distinct` in /// `AggCall`. Such operations should be done in batch or streaming executors. pub fn build(agg: &AggCall, append_only: bool) -> Result { - let args = (agg.args.arg_types().iter()) - .map(|t| t.into()) - .collect::>(); - let ret_type = (&agg.return_type).into(); - let desc = crate::sig::agg::AGG_FUNC_SIG_MAP - .get(agg.kind, &args, ret_type, append_only) + let desc = crate::sig::FUNCTION_REGISTRY + .get_aggregate( + agg.kind, + agg.args.arg_types(), + &agg.return_type, + append_only, + ) .ok_or_else(|| { ExprError::UnsupportedFunction(format!( - "{:?}", - FuncSigDebug { - func: agg.kind, - inputs_type: &args, - ret_type, - set_returning: false, - deprecated: false, - append_only, - } + "{}({}) -> {}", + agg.kind.to_protobuf().as_str_name().to_ascii_lowercase(), + agg.args.arg_types().iter().format(", "), + agg.return_type, )) })?; - (desc.build)(agg) + desc.build_aggregate(agg) } diff --git a/src/expr/core/src/expr/build.rs b/src/expr/core/src/expr/build.rs index f6f08689aa5f2..1ea03bd36f42a 100644 --- a/src/expr/core/src/expr/build.rs +++ b/src/expr/core/src/expr/build.rs @@ -29,8 +29,7 @@ use super::expr_udf::UdfExpression; use super::expr_vnode::VnodeExpression; use super::wrapper::{Checked, EvalErrorReport, NonStrict}; use crate::expr::{BoxedExpression, Expression, InputRefExpression, LiteralExpression}; -use crate::sig::func::FUNC_SIG_MAP; -use crate::sig::FuncSigDebug; +use crate::sig::FUNCTION_REGISTRY; use crate::{bail, ExprError, Result}; /// Build an expression from protobuf. @@ -195,27 +194,18 @@ pub fn build_func( return Ok(ArrayTransformExpression { array, lambda }.boxed()); } - let args = children - .iter() - .map(|c| c.return_type().into()) - .collect_vec(); - let desc = FUNC_SIG_MAP - .get(func, &args, (&ret_type).into()) + let args = children.iter().map(|c| c.return_type()).collect_vec(); + let desc = FUNCTION_REGISTRY + .get(func, &args, &ret_type) .ok_or_else(|| { ExprError::UnsupportedFunction(format!( - "{:?}", - FuncSigDebug { - func: func.as_str_name(), - inputs_type: &args, - ret_type: (&ret_type).into(), - set_returning: false, - deprecated: false, - append_only: false, - } + "{}({}) -> {}", + func.as_str_name().to_ascii_lowercase(), + args.iter().format(", "), + ret_type, )) })?; - - (desc.build)(ret_type, children) + desc.build_scalar(ret_type, children) } /// Build an expression in `FuncCall` variant in non-strict mode. diff --git a/src/expr/core/src/sig/agg.rs b/src/expr/core/src/sig/agg.rs deleted file mode 100644 index 2df599c61464d..0000000000000 --- a/src/expr/core/src/sig/agg.rs +++ /dev/null @@ -1,122 +0,0 @@ -// Copyright 2023 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::collections::HashMap; -use std::fmt; -use std::sync::LazyLock; - -use risingwave_common::types::DataTypeName; - -use super::FuncSigDebug; -use crate::aggregate::{AggCall, AggKind, BoxedAggregateFunction}; -use crate::Result; - -pub static AGG_FUNC_SIG_MAP: LazyLock = LazyLock::new(|| unsafe { - let mut map = AggFuncSigMap::default(); - tracing::info!("{} aggregations loaded.", AGG_FUNC_SIG_MAP_INIT.len()); - for desc in AGG_FUNC_SIG_MAP_INIT.drain(..) { - map.insert(desc); - } - map -}); - -// Same as FuncSign in func.rs except this is for aggregate function -#[derive(PartialEq, Eq, Hash, Clone)] -pub struct AggFuncSig { - pub func: AggKind, - pub inputs_type: &'static [DataTypeName], - pub state_type: DataTypeName, - pub ret_type: DataTypeName, - pub build: fn(agg: &AggCall) -> Result, - pub append_only: bool, -} - -impl fmt::Debug for AggFuncSig { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - FuncSigDebug { - func: self.func, - inputs_type: self.inputs_type, - ret_type: self.ret_type, - set_returning: false, - deprecated: false, - append_only: self.append_only, - } - .fmt(f) - } -} - -// Same as FuncSigMap in func.rs except this is for aggregate function -#[derive(Default)] -pub struct AggFuncSigMap(HashMap<(AggKind, usize), Vec>); - -impl AggFuncSigMap { - /// Inserts a function signature into the map. - fn insert(&mut self, sig: AggFuncSig) { - let arity = sig.inputs_type.len(); - self.0.entry((sig.func, arity)).or_default().push(sig); - } - - /// Returns a function signature with the given type, argument types, return type. - /// - /// The `append_only` flag only works when both append-only and retractable version exist. - /// Otherwise, return the signature of the only version. - pub fn get( - &self, - ty: AggKind, - args: &[DataTypeName], - ret: DataTypeName, - append_only: bool, - ) -> Option<&AggFuncSig> { - let v = self.0.get(&(ty, args.len()))?; - let mut iter = v - .iter() - .filter(|d| d.inputs_type == args && d.ret_type == ret); - if iter.clone().count() == 2 { - iter.find(|d| d.append_only == append_only) - } else { - iter.next() - } - } - - /// Returns the return type for the given function and arguments. - pub fn get_return_type(&self, ty: AggKind, args: &[DataTypeName]) -> Option { - let v = self.0.get(&(ty, args.len()))?; - v.iter().find(|d| d.inputs_type == args).map(|d| d.ret_type) - } -} - -/// The table of function signatures. -pub fn agg_func_sigs() -> impl Iterator { - AGG_FUNC_SIG_MAP.0.values().flatten() -} - -/// Register a function into global registry. -/// -/// # Safety -/// -/// This function must be called sequentially. -/// -/// It is designed to be used by `#[aggregate]` macro. -/// Users SHOULD NOT call this function. -#[doc(hidden)] -pub unsafe fn _register(desc: AggFuncSig) { - AGG_FUNC_SIG_MAP_INIT.push(desc); -} - -/// The global registry of function signatures on initialization. -/// -/// `#[aggregate]` macro will generate a `#[ctor]` function to register the signature into this -/// vector. The calls are guaranteed to be sequential. The vector will be drained and moved into -/// `AGG_FUNC_SIG_MAP` on the first access of `AGG_FUNC_SIG_MAP`. -static mut AGG_FUNC_SIG_MAP_INIT: Vec = Vec::new(); diff --git a/src/expr/core/src/sig/func.rs b/src/expr/core/src/sig/func.rs deleted file mode 100644 index f6557686c8ca9..0000000000000 --- a/src/expr/core/src/sig/func.rs +++ /dev/null @@ -1,117 +0,0 @@ -// Copyright 2023 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//! Function signatures. - -use std::collections::HashMap; -use std::fmt; -use std::sync::LazyLock; - -use risingwave_common::types::{DataType, DataTypeName}; -use risingwave_pb::expr::expr_node::PbType; - -use super::FuncSigDebug; -use crate::error::Result; -use crate::expr::BoxedExpression; - -pub static FUNC_SIG_MAP: LazyLock = LazyLock::new(|| unsafe { - let mut map = FuncSigMap::default(); - tracing::info!("{} function signatures loaded.", FUNC_SIG_MAP_INIT.len()); - for desc in FUNC_SIG_MAP_INIT.drain(..) { - map.insert(desc); - } - map -}); - -/// The table of function signatures. -pub fn func_sigs() -> impl Iterator { - FUNC_SIG_MAP.0.values().flatten() -} - -#[derive(Default, Clone, Debug)] -pub struct FuncSigMap(HashMap>); - -impl FuncSigMap { - /// Inserts a function signature. - pub fn insert(&mut self, desc: FuncSign) { - self.0.entry(desc.func).or_default().push(desc) - } - - /// Returns a function signature with the same type, argument types and return type. - /// Deprecated functions are included. - pub fn get(&self, ty: PbType, args: &[DataTypeName], ret: DataTypeName) -> Option<&FuncSign> { - let v = self.0.get(&ty)?; - v.iter() - .find(|d| (d.variadic || d.inputs_type == args) && d.ret_type == ret) - } - - /// Returns all function signatures with the same type and number of arguments. - /// Deprecated functions are excluded. - pub fn get_with_arg_nums(&self, ty: PbType, nargs: usize) -> Vec<&FuncSign> { - match self.0.get(&ty) { - Some(v) => v - .iter() - .filter(|d| (d.variadic || d.inputs_type.len() == nargs) && !d.deprecated) - .collect(), - None => vec![], - } - } -} - -/// A function signature. -#[derive(Clone)] -pub struct FuncSign { - pub func: PbType, - pub inputs_type: &'static [DataTypeName], - pub variadic: bool, - pub ret_type: DataTypeName, - pub build: fn(return_type: DataType, children: Vec) -> Result, - /// Whether the function is deprecated and should not be used in the frontend. - /// For backward compatibility, it is still available in the backend. - pub deprecated: bool, -} - -impl fmt::Debug for FuncSign { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - FuncSigDebug { - func: self.func.as_str_name(), - inputs_type: self.inputs_type, - ret_type: self.ret_type, - set_returning: false, - deprecated: self.deprecated, - append_only: false, - } - .fmt(f) - } -} - -/// Register a function into global registry. -/// -/// # Safety -/// -/// This function must be called sequentially. -/// -/// It is designed to be used by `#[function]` macro. -/// Users SHOULD NOT call this function. -#[doc(hidden)] -pub unsafe fn _register(desc: FuncSign) { - FUNC_SIG_MAP_INIT.push(desc) -} - -/// The global registry of function signatures on initialization. -/// -/// `#[function]` macro will generate a `#[ctor]` function to register the signature into this -/// vector. The calls are guaranteed to be sequential. The vector will be drained and moved into -/// `FUNC_SIG_MAP` on the first access of `FUNC_SIG_MAP`. -static mut FUNC_SIG_MAP_INIT: Vec = Vec::new(); diff --git a/src/expr/core/src/sig/mod.rs b/src/expr/core/src/sig/mod.rs index 052b31ad060bc..c2e71b585d49c 100644 --- a/src/expr/core/src/sig/mod.rs +++ b/src/expr/core/src/sig/mod.rs @@ -14,36 +14,174 @@ //! Metadata of expressions. +use std::collections::HashMap; +use std::fmt; +use std::sync::LazyLock; + use itertools::Itertools; -use risingwave_common::types::DataTypeName; +use risingwave_common::types::DataType; +use risingwave_pb::expr::expr_node::PbType as ScalarFunctionType; +use risingwave_pb::expr::table_function::PbType as TableFunctionType; + +use crate::aggregate::{AggCall, AggKind as AggregateFunctionType, BoxedAggregateFunction}; +use crate::error::Result; +use crate::expr::BoxedExpression; +use crate::table_function::BoxedTableFunction; +use crate::ExprError; -pub mod agg; pub mod cast; -pub mod func; -pub mod table_function; - -/// Utility struct for debug printing of function signature. -pub(crate) struct FuncSigDebug<'a, T> { - pub func: T, - pub inputs_type: &'a [DataTypeName], - pub ret_type: DataTypeName, - pub set_returning: bool, + +/// The global registry of all function signatures. +pub static FUNCTION_REGISTRY: LazyLock = LazyLock::new(|| unsafe { + // SAFETY: this function is called after all `#[ctor]` functions are called. + let mut map = FunctionRegistry::default(); + tracing::info!("found {} functions", FUNCTION_REGISTRY_INIT.len()); + for sig in FUNCTION_REGISTRY_INIT.drain(..) { + map.insert(sig); + } + map +}); + +/// A set of function signatures. +#[derive(Default, Clone, Debug)] +pub struct FunctionRegistry(HashMap>); + +impl FunctionRegistry { + /// Inserts a function signature. + pub fn insert(&mut self, sig: FuncSign) { + self.0.entry(sig.name).or_default().push(sig) + } + + /// Returns a function signature with the same type, argument types and return type. + /// Deprecated functions are included. + pub fn get( + &self, + name: impl Into, + args: &[DataType], + ret: &DataType, + ) -> Option<&FuncSign> { + let v = self.0.get(&name.into())?; + v.iter().find(|d| d.match_args_ret(args, ret)) + } + + /// Returns all function signatures with the same type and number of arguments. + /// Deprecated functions are excluded. + pub fn get_with_arg_nums(&self, name: impl Into, nargs: usize) -> Vec<&FuncSign> { + match self.0.get(&name.into()) { + Some(v) => v + .iter() + .filter(|d| d.match_number_of_args(nargs) && !d.deprecated) + .collect(), + None => vec![], + } + } + + /// Returns a function signature with the given type, argument types, return type. + /// + /// The `prefer_append_only` flag only works when both append-only and retractable version exist. + /// Otherwise, return the signature of the only version. + pub fn get_aggregate( + &self, + ty: AggregateFunctionType, + args: &[DataType], + ret: &DataType, + prefer_append_only: bool, + ) -> Option<&FuncSign> { + let v = self.0.get(&ty.into())?; + let mut iter = v.iter().filter(|d| d.match_args_ret(args, ret)); + if iter.clone().count() == 2 { + iter.find(|d| d.append_only == prefer_append_only) + } else { + iter.next() + } + } + + /// Returns the return type for the given function and arguments. + pub fn get_return_type( + &self, + name: impl Into, + args: &[DataType], + ) -> Result { + let name = name.into(); + let v = self + .0 + .get(&name) + .ok_or_else(|| ExprError::UnsupportedFunction(name.to_string()))?; + let sig = v + .iter() + .find(|d| d.match_args(args)) + .ok_or_else(|| ExprError::UnsupportedFunction(name.to_string()))?; + (sig.type_infer)(args) + } + + /// Returns an iterator of all function signatures. + pub fn iter(&self) -> impl Iterator { + self.0.values().flatten() + } + + /// Returns an iterator of all scalar functions. + pub fn iter_scalars(&self) -> impl Iterator { + self.iter().filter(|d| d.is_scalar()) + } + + /// Returns an iterator of all aggregate functions. + pub fn iter_aggregates(&self) -> impl Iterator { + self.iter().filter(|d| d.is_aggregate()) + } +} + +/// A function signature. +#[derive(Clone)] +pub struct FuncSign { + /// The name of the function. + pub name: FuncName, + + /// The argument types. + pub inputs_type: Vec, + + /// Whether the function is variadic. + pub variadic: bool, + + /// The return type. + pub ret_type: SigDataType, + + /// A function to build the expression. + pub build: FuncBuilder, + + /// A function to infer the return type from argument types. + pub type_infer: fn(args: &[DataType]) -> Result, + + /// Whether the function is deprecated and should not be used in the frontend. + /// For backward compatibility, it is still available in the backend. pub deprecated: bool, + + /// The state type of the aggregate function. + /// `None` means equal to the return type. + pub state_type: Option, + + /// Whether the aggregate function is append-only. pub append_only: bool, } -impl<'a, T: std::fmt::Display> std::fmt::Debug for FuncSigDebug<'a, T> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let s = format!( - "{}({:?}) -> {}{:?}", - self.func, +impl fmt::Debug for FuncSign { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "{}({}{}) -> {}{}", + self.name.as_str_name().to_ascii_lowercase(), self.inputs_type.iter().format(", "), - if self.set_returning { "setof " } else { "" }, + if self.variadic { + if self.inputs_type.is_empty() { + "..." + } else { + ", ..." + } + } else { + "" + }, + if self.name.is_table() { "setof " } else { "" }, self.ret_type, - ) - .to_ascii_lowercase(); - - f.write_str(&s)?; + )?; if self.append_only { write!(f, " [append-only]")?; } @@ -53,3 +191,229 @@ impl<'a, T: std::fmt::Display> std::fmt::Debug for FuncSigDebug<'a, T> { Ok(()) } } + +impl FuncSign { + /// Returns true if the argument types match the function signature. + pub fn match_args(&self, args: &[DataType]) -> bool { + if !self.match_number_of_args(args.len()) { + return false; + } + // allow `zip` as the length of `args` may be larger than `inputs_type` + #[allow(clippy::disallowed_methods)] + self.inputs_type + .iter() + .zip(args.iter()) + .all(|(matcher, arg)| matcher.matches(arg)) + } + + /// Returns true if the argument types match the function signature. + fn match_args_ret(&self, args: &[DataType], ret: &DataType) -> bool { + self.match_args(args) && self.ret_type.matches(ret) + } + + /// Returns true if the number of arguments matches the function signature. + fn match_number_of_args(&self, n: usize) -> bool { + if self.variadic { + n >= self.inputs_type.len() + } else { + n == self.inputs_type.len() + } + } + + /// Returns true if the function is a scalar function. + pub const fn is_scalar(&self) -> bool { + matches!(self.name, FuncName::Scalar(_)) + } + + /// Returns true if the function is a table function. + pub const fn is_table_function(&self) -> bool { + matches!(self.name, FuncName::Table(_)) + } + + /// Returns true if the function is a aggregate function. + pub const fn is_aggregate(&self) -> bool { + matches!(self.name, FuncName::Aggregate(_)) + } + + /// Builds the scalar function. + pub fn build_scalar( + &self, + return_type: DataType, + children: Vec, + ) -> Result { + match self.build { + FuncBuilder::Scalar(f) => f(return_type, children), + _ => panic!("Expected a scalar function"), + } + } + + /// Builds the table function. + pub fn build_table( + &self, + return_type: DataType, + chunk_size: usize, + children: Vec, + ) -> Result { + match self.build { + FuncBuilder::Table(f) => f(return_type, chunk_size, children), + _ => panic!("Expected a table function"), + } + } + + /// Builds the aggregate function. + pub fn build_aggregate(&self, agg: &AggCall) -> Result { + match self.build { + FuncBuilder::Aggregate(f) => f(agg), + _ => panic!("Expected an aggregate function"), + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum FuncName { + Scalar(ScalarFunctionType), + Table(TableFunctionType), + Aggregate(AggregateFunctionType), +} + +impl From for FuncName { + fn from(ty: ScalarFunctionType) -> Self { + Self::Scalar(ty) + } +} + +impl From for FuncName { + fn from(ty: TableFunctionType) -> Self { + Self::Table(ty) + } +} + +impl From for FuncName { + fn from(ty: AggregateFunctionType) -> Self { + Self::Aggregate(ty) + } +} + +impl fmt::Display for FuncName { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.as_str_name().to_ascii_lowercase()) + } +} + +impl FuncName { + /// Returns the name of the function in `UPPER_CASE` style. + pub fn as_str_name(&self) -> &'static str { + match self { + Self::Scalar(ty) => ty.as_str_name(), + Self::Table(ty) => ty.as_str_name(), + Self::Aggregate(ty) => ty.to_protobuf().as_str_name(), + } + } + + /// Returns true if the function is a table function. + const fn is_table(&self) -> bool { + matches!(self, Self::Table(_)) + } + + pub fn as_scalar(&self) -> ScalarFunctionType { + match self { + Self::Scalar(ty) => *ty, + _ => panic!("Expected a scalar function"), + } + } + + pub fn as_aggregate(&self) -> AggregateFunctionType { + match self { + Self::Aggregate(ty) => *ty, + _ => panic!("Expected an aggregate function"), + } + } +} + +/// An extended data type that can be used to declare a function's argument or result type. +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum SigDataType { + /// Exact data type + Exact(DataType), + /// Accepts any data type + Any, + /// Accepts any array data type + AnyArray, + /// Accepts any struct data type + AnyStruct, +} + +impl From for SigDataType { + fn from(dt: DataType) -> Self { + SigDataType::Exact(dt) + } +} + +impl std::fmt::Display for SigDataType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Exact(dt) => write!(f, "{}", dt), + Self::Any => write!(f, "any"), + Self::AnyArray => write!(f, "anyarray"), + Self::AnyStruct => write!(f, "anystruct"), + } + } +} + +impl SigDataType { + /// Returns true if the data type matches. + pub fn matches(&self, dt: &DataType) -> bool { + match self { + Self::Exact(ty) => ty == dt, + Self::Any => true, + Self::AnyArray => dt.is_array(), + Self::AnyStruct => dt.is_struct(), + } + } + + /// Returns the exact data type. + pub fn as_exact(&self) -> &DataType { + match self { + Self::Exact(ty) => ty, + t => panic!("expected data type, but got: {t}"), + } + } + + /// Returns true if the data type is exact. + pub fn is_exact(&self) -> bool { + matches!(self, Self::Exact(_)) + } +} + +#[derive(Clone, Copy)] +pub enum FuncBuilder { + Scalar(fn(return_type: DataType, children: Vec) -> Result), + Table( + fn( + return_type: DataType, + chunk_size: usize, + children: Vec, + ) -> Result, + ), + Aggregate(fn(agg: &AggCall) -> Result), +} + +/// Register a function into global registry. +/// +/// # Safety +/// +/// This function must be called sequentially. +/// +/// It is designed to be used by `#[function]` macro. +/// Users SHOULD NOT call this function. +#[doc(hidden)] +pub unsafe fn _register(sig: FuncSign) { + FUNCTION_REGISTRY_INIT.push(sig) +} + +/// The global registry of function signatures on initialization. +/// +/// `#[function]` macro will generate a `#[ctor]` function to register the signature into this +/// vector. The calls are guaranteed to be sequential. The vector will be drained and moved into +/// `FUNCTION_REGISTRY` on the first access of `FUNCTION_REGISTRY`. +static mut FUNCTION_REGISTRY_INIT: Vec = Vec::new(); diff --git a/src/expr/core/src/sig/table_function.rs b/src/expr/core/src/sig/table_function.rs deleted file mode 100644 index f4b2ac6de1e4b..0000000000000 --- a/src/expr/core/src/sig/table_function.rs +++ /dev/null @@ -1,118 +0,0 @@ -// Copyright 2023 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//! Function signatures. - -use std::collections::HashMap; -use std::fmt; -use std::ops::Deref; -use std::sync::LazyLock; - -use risingwave_common::types::{DataType, DataTypeName}; -use risingwave_pb::expr::table_function::PbType; - -use super::FuncSigDebug; -use crate::error::Result; -use crate::expr::BoxedExpression; -use crate::table_function::BoxedTableFunction; - -pub static FUNC_SIG_MAP: LazyLock = LazyLock::new(|| unsafe { - let mut map = FuncSigMap::default(); - tracing::info!( - "{} table function signatures loaded.", - FUNC_SIG_MAP_INIT.len() - ); - for desc in FUNC_SIG_MAP_INIT.drain(..) { - map.insert(desc); - } - map -}); - -/// The table of function signatures. -pub fn func_sigs() -> impl Iterator { - FUNC_SIG_MAP.0.values().flatten() -} - -#[derive(Default, Clone, Debug)] -pub struct FuncSigMap(HashMap<(PbType, usize), Vec>); - -impl FuncSigMap { - /// Inserts a function signature. - pub fn insert(&mut self, desc: FuncSign) { - self.0 - .entry((desc.func, desc.inputs_type.len())) - .or_default() - .push(desc) - } - - /// Returns a function signature with the same type and argument types. - pub fn get(&self, ty: PbType, args: &[DataTypeName]) -> Option<&FuncSign> { - let v = self.0.get(&(ty, args.len()))?; - v.iter().find(|d| d.inputs_type == args) - } - - /// Returns all function signatures with the same type and number of arguments. - pub fn get_with_arg_nums(&self, ty: PbType, nargs: usize) -> &[FuncSign] { - self.0.get(&(ty, nargs)).map_or(&[], Deref::deref) - } -} - -/// A function signature. -#[derive(Clone)] -pub struct FuncSign { - pub func: PbType, - pub inputs_type: &'static [DataTypeName], - pub ret_type: DataTypeName, - pub build: fn( - return_type: DataType, - chunk_size: usize, - children: Vec, - ) -> Result, - /// A function to infer the return type from argument types. - pub type_infer: fn(args: &[DataType]) -> Result, -} - -impl fmt::Debug for FuncSign { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - FuncSigDebug { - func: self.func.as_str_name(), - inputs_type: self.inputs_type, - ret_type: self.ret_type, - set_returning: true, - deprecated: false, - append_only: false, - } - .fmt(f) - } -} - -/// Register a function into global registry. -/// -/// # Safety -/// -/// This function must be called sequentially. -/// -/// It is designed to be used by `#[table_function]` macro. -/// Users SHOULD NOT call this function. -#[doc(hidden)] -pub unsafe fn _register(desc: FuncSign) { - FUNC_SIG_MAP_INIT.push(desc) -} - -/// The global registry of function signatures on initialization. -/// -/// `#[table_function]` macro will generate a `#[ctor]` function to register the signature into this -/// vector. The calls are guaranteed to be sequential. The vector will be drained and moved into -/// `FUNC_SIG_MAP` on the first access of `FUNC_SIG_MAP`. -static mut FUNC_SIG_MAP_INIT: Vec = Vec::new(); diff --git a/src/expr/core/src/table_function/mod.rs b/src/expr/core/src/table_function/mod.rs index daba4e3c5c987..2a3028a72e3b5 100644 --- a/src/expr/core/src/table_function/mod.rs +++ b/src/expr/core/src/table_function/mod.rs @@ -18,14 +18,13 @@ use futures_util::stream::BoxStream; use futures_util::StreamExt; use itertools::Itertools; use risingwave_common::array::{Array, ArrayBuilder, ArrayImpl, ArrayRef, DataChunk}; -use risingwave_common::types::{DataType, DataTypeName, DatumRef}; +use risingwave_common::types::{DataType, DatumRef}; use risingwave_pb::expr::project_set_select_item::SelectItem; use risingwave_pb::expr::table_function::PbType; use risingwave_pb::expr::{PbProjectSetSelectItem, PbTableFunction}; use super::{ExprError, Result}; use crate::expr::{build_from_prost as expr_build_from_prost, BoxedExpression}; -use crate::sig::FuncSigDebug; mod empty; mod repeat; @@ -130,26 +129,18 @@ pub fn build( chunk_size: usize, children: Vec, ) -> Result { - let args = children - .iter() - .map(|t| t.return_type().into()) - .collect::>(); - let desc = crate::sig::table_function::FUNC_SIG_MAP - .get(func, &args) + let args = children.iter().map(|t| t.return_type()).collect_vec(); + let desc = crate::sig::FUNCTION_REGISTRY + .get(func, &args, &return_type) .ok_or_else(|| { ExprError::UnsupportedFunction(format!( - "{:?}", - FuncSigDebug { - func: func.as_str_name(), - inputs_type: &args, - ret_type: (&return_type).into(), - set_returning: true, - deprecated: false, - append_only: false, - } + "{}({}) -> setof {}", + func.as_str_name().to_ascii_lowercase(), + args.iter().format(", "), + return_type, )) })?; - (desc.build)(return_type, chunk_size, children) + desc.build_table(return_type, chunk_size, children) } /// See also [`PbProjectSetSelectItem`] diff --git a/src/expr/core/src/window_function/state/mod.rs b/src/expr/core/src/window_function/state/mod.rs index 977a04b2a7a70..971fb97f66cdc 100644 --- a/src/expr/core/src/window_function/state/mod.rs +++ b/src/expr/core/src/window_function/state/mod.rs @@ -14,6 +14,7 @@ use std::collections::BTreeSet; +use itertools::Itertools; use risingwave_common::estimate_size::EstimateSize; use risingwave_common::row::OwnedRow; use risingwave_common::types::{Datum, DefaultOrdered}; @@ -21,7 +22,6 @@ use risingwave_common::util::memcmp_encoding::MemcmpEncoded; use smallvec::SmallVec; use super::{WindowFuncCall, WindowFuncKind}; -use crate::sig::FuncSigDebug; use crate::{ExprError, Result}; mod buffer; @@ -116,19 +116,11 @@ pub fn create_window_state(call: &WindowFuncCall) -> Result Box::new(row_number::RowNumberState::new(call)), Aggregate(_) => Box::new(aggregate::AggregateState::new(call)?), kind => { - let args = (call.args.arg_types().iter()) - .map(|t| t.into()) - .collect::>(); return Err(ExprError::UnsupportedFunction(format!( - "{:?}", - FuncSigDebug { - func: kind, - inputs_type: &args, - ret_type: call.return_type.clone().into(), - set_returning: false, - deprecated: false, - append_only: false, - } + "{}({}) -> {}", + kind, + call.args.arg_types().iter().format(", "), + &call.return_type, ))); } }) diff --git a/src/expr/impl/benches/expr.rs b/src/expr/impl/benches/expr.rs index 76a00b0974c49..1e84d8d8e4825 100644 --- a/src/expr/impl/benches/expr.rs +++ b/src/expr/impl/benches/expr.rs @@ -28,8 +28,7 @@ use risingwave_common::types::test_utils::IntervalTestExt; use risingwave_common::types::*; use risingwave_expr::aggregate::{build_append_only, AggArgs, AggCall, AggKind}; use risingwave_expr::expr::*; -use risingwave_expr::sig::agg::agg_func_sigs; -use risingwave_expr::sig::func::func_sigs; +use risingwave_expr::sig::FUNCTION_REGISTRY; use risingwave_expr::ExprError; use risingwave_pb::expr::expr_node::PbType; @@ -229,10 +228,10 @@ fn bench_expr(c: &mut Criterion) { InputRefExpression::new(DataType::Jsonb, 26), InputRefExpression::new(DataType::Int256, 27), ]; - let input_index_for_type = |ty: DataType| { + let input_index_for_type = |ty: &DataType| { inputrefs .iter() - .find(|r| r.return_type() == ty) + .find(|r| &r.return_type() == ty) .unwrap_or_else(|| panic!("expression not found for {ty:?}")) .index() }; @@ -267,19 +266,20 @@ fn bench_expr(c: &mut Criterion) { c.bench_function("extract(constant)", |bencher| { let extract = build_from_pretty(format!( "(extract:decimal HOUR:varchar ${}:timestamp)", - input_index_for_type(DataType::Timestamp) + input_index_for_type(&DataType::Timestamp) )); bencher .to_async(FuturesExecutor) .iter(|| extract.eval(&input)) }); - let sigs = func_sigs(); - let sigs = sigs.sorted_by_cached_key(|sig| format!("{sig:?}")); + let sigs = FUNCTION_REGISTRY + .iter_scalars() + .sorted_by_cached_key(|sig| format!("{sig:?}")); 'sig: for sig in sigs { if (sig.inputs_type.iter()) - .chain(&[sig.ret_type]) - .any(|t| matches!(t, DataTypeName::Struct | DataTypeName::List)) + .chain([&sig.ret_type]) + .any(|t| !t.is_exact()) { // TODO: support struct and list println!("todo: {sig:?}"); @@ -302,8 +302,8 @@ fn bench_expr(c: &mut Criterion) { let mut children = vec![]; for (i, t) in sig.inputs_type.iter().enumerate() { - use DataTypeName::*; - let idx = match (sig.func, i) { + use DataType::*; + let idx = match (sig.name.as_scalar(), i) { (PbType::ToTimestamp1, 0) => TIMESTAMP_FORMATTED_STRING, (PbType::ToChar | PbType::ToTimestamp1, 1) => { children.push(string_literal("YYYY/MM/DD HH:MM:SS")); @@ -317,7 +317,7 @@ fn bench_expr(c: &mut Criterion) { children.push(string_literal("VALUE")); continue; } - (PbType::Cast, 0) if *t == DataTypeName::Varchar => match sig.ret_type { + (PbType::Cast, 0) if t.as_exact() == &Varchar => match sig.ret_type.as_exact() { Boolean => BOOL_STRING, Int16 | Int32 | Int64 | Float32 | Float64 | Decimal => NUMBER_STRING, Date => DATE_STRING, @@ -334,7 +334,7 @@ fn bench_expr(c: &mut Criterion) { (PbType::AtTimeZone, 1) => TIMEZONE, (PbType::DateTrunc, 0) => TIME_FIELD, (PbType::DateTrunc, 2) => TIMEZONE, - (PbType::Extract, 0) => match sig.inputs_type[1] { + (PbType::Extract, 0) => match sig.inputs_type[1].as_exact() { Date => EXTRACT_FIELD_DATE, Time => EXTRACT_FIELD_TIME, Timestamp => EXTRACT_FIELD_TIMESTAMP, @@ -342,38 +342,46 @@ fn bench_expr(c: &mut Criterion) { Interval => EXTRACT_FIELD_INTERVAL, t => panic!("unexpected type: {t:?}"), }, - _ => input_index_for_type((*t).into()), + _ => input_index_for_type(t.as_exact()), }; - children.push(InputRefExpression::new(DataType::from(*t), idx).boxed()); + children.push(InputRefExpression::new(t.as_exact().clone(), idx).boxed()); } - let expr = build_func(sig.func, sig.ret_type.into(), children).unwrap(); + let expr = build_func( + sig.name.as_scalar(), + sig.ret_type.as_exact().clone(), + children, + ) + .unwrap(); c.bench_function(&format!("{sig:?}"), |bencher| { bencher.to_async(FuturesExecutor).iter(|| expr.eval(&input)) }); } - let sigs = agg_func_sigs(); - let sigs = sigs.sorted_by_cached_key(|sig| format!("{sig:?}")); + let sigs = FUNCTION_REGISTRY + .iter_aggregates() + .sorted_by_cached_key(|sig| format!("{sig:?}")); for sig in sigs { - if matches!(sig.func, AggKind::PercentileDisc | AggKind::PercentileCont) - || (sig.inputs_type.iter()) - .chain(&[sig.ret_type]) - .any(|t| matches!(t, DataTypeName::Struct | DataTypeName::List)) + if matches!( + sig.name.as_aggregate(), + AggKind::PercentileDisc | AggKind::PercentileCont + ) || (sig.inputs_type.iter()) + .chain([&sig.ret_type]) + .any(|t| !t.is_exact()) { println!("todo: {sig:?}"); continue; } let agg = match build_append_only(&AggCall { - kind: sig.func, - args: match sig.inputs_type { + kind: sig.name.as_aggregate(), + args: match sig.inputs_type.as_slice() { [] => AggArgs::None, - [t] => AggArgs::Unary((*t).into(), input_index_for_type((*t).into())), + [t] => AggArgs::Unary(t.as_exact().clone(), input_index_for_type(t.as_exact())), _ => { println!("todo: {sig:?}"); continue; } }, - return_type: sig.ret_type.into(), + return_type: sig.ret_type.as_exact().clone(), column_orders: vec![], filter: None, distinct: false, diff --git a/src/expr/impl/src/aggregate/array_agg.rs b/src/expr/impl/src/aggregate/array_agg.rs index 21355418810f9..963d56ed08621 100644 --- a/src/expr/impl/src/aggregate/array_agg.rs +++ b/src/expr/impl/src/aggregate/array_agg.rs @@ -13,13 +13,13 @@ // limitations under the License. use risingwave_common::array::ListValue; -use risingwave_common::types::{Datum, ScalarRef}; +use risingwave_common::types::{Datum, ScalarRefImpl, ToOwnedDatum}; use risingwave_expr::aggregate; -#[aggregate("array_agg(*) -> list")] -fn array_agg<'a>(state: Option, value: Option>) -> ListValue { +#[aggregate("array_agg(any) -> anyarray")] +fn array_agg(state: Option, value: Option>) -> ListValue { let mut state: Vec = state.unwrap_or_default().into(); - state.push(value.map(|v| v.to_owned_scalar().into())); + state.push(value.to_owned_datum()); state.into() } diff --git a/src/expr/impl/src/aggregate/jsonb_agg.rs b/src/expr/impl/src/aggregate/jsonb_agg.rs index 3a9aea4016ccd..8385e2c6a060b 100644 --- a/src/expr/impl/src/aggregate/jsonb_agg.rs +++ b/src/expr/impl/src/aggregate/jsonb_agg.rs @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use risingwave_common::bail; use risingwave_common::types::JsonbVal; use risingwave_expr::{aggregate, ExprError, Result}; use serde_json::Value; diff --git a/src/expr/impl/src/aggregate/mode.rs b/src/expr/impl/src/aggregate/mode.rs index cfdad33f98664..927fb5c801f87 100644 --- a/src/expr/impl/src/aggregate/mode.rs +++ b/src/expr/impl/src/aggregate/mode.rs @@ -23,7 +23,7 @@ use risingwave_expr::aggregate::{ }; use risingwave_expr::{build_aggregate, Result}; -#[build_aggregate("mode(*) -> auto")] +#[build_aggregate("mode(any) -> any")] fn build(agg: &AggCall) -> Result { Ok(Box::new(Mode { return_type: agg.return_type.clone(), diff --git a/src/expr/impl/src/aggregate/percentile_disc.rs b/src/expr/impl/src/aggregate/percentile_disc.rs index daac4085967af..c9143dcf8e640 100644 --- a/src/expr/impl/src/aggregate/percentile_disc.rs +++ b/src/expr/impl/src/aggregate/percentile_disc.rs @@ -67,7 +67,7 @@ use risingwave_expr::{build_aggregate, Result}; /// statement ok /// drop table t; /// ``` -#[build_aggregate("percentile_disc(*) -> auto")] +#[build_aggregate("percentile_disc(any) -> any")] fn build(agg: &AggCall) -> Result { let fractions = agg.direct_args[0] .literal() diff --git a/src/expr/impl/src/aggregate/string_agg.rs b/src/expr/impl/src/aggregate/string_agg.rs index d24303fe1f708..6bd9c8e82ee3d 100644 --- a/src/expr/impl/src/aggregate/string_agg.rs +++ b/src/expr/impl/src/aggregate/string_agg.rs @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use risingwave_common::bail; use risingwave_expr::aggregate; #[aggregate("string_agg(varchar, varchar) -> varchar")] diff --git a/src/expr/impl/src/scalar/array.rs b/src/expr/impl/src/scalar/array.rs index a0020b7ea4e26..7df75de8809f3 100644 --- a/src/expr/impl/src/scalar/array.rs +++ b/src/expr/impl/src/scalar/array.rs @@ -17,12 +17,12 @@ use risingwave_common::row::Row; use risingwave_common::types::ToOwnedDatum; use risingwave_expr::function; -#[function("array(...) -> list")] +#[function("array(...) -> anyarray", type_infer = "panic")] fn array(row: impl Row) -> ListValue { ListValue::new(row.iter().map(|d| d.to_owned_datum()).collect()) } -#[function("row(...) -> struct")] +#[function("row(...) -> struct", type_infer = "panic")] fn row_(row: impl Row) -> StructValue { StructValue::new(row.iter().map(|d| d.to_owned_datum()).collect()) } diff --git a/src/expr/impl/src/scalar/array_access.rs b/src/expr/impl/src/scalar/array_access.rs index 925ea496181b4..929eb19b9318b 100644 --- a/src/expr/impl/src/scalar/array_access.rs +++ b/src/expr/impl/src/scalar/array_access.rs @@ -13,22 +13,17 @@ // limitations under the License. use risingwave_common::array::ListRef; -use risingwave_common::types::{Scalar, ToOwnedDatum}; -use risingwave_expr::{function, Result}; +use risingwave_common::types::ScalarRefImpl; +use risingwave_expr::function; -#[function("array_access(list, int4) -> *")] -pub fn array_access(list: ListRef<'_>, index: i32) -> Result> { +#[function("array_access(anyarray, int4) -> any")] +fn array_access(list: ListRef<'_>, index: i32) -> Option> { // index must be greater than 0 following a one-based numbering convention for arrays if index < 1 { - return Ok(None); + return None; } // returns `NULL` if index is out of bounds - let datum_ref = list.elem_at(index as usize - 1).flatten(); - if let Some(scalar) = datum_ref.to_owned_datum() { - Ok(Some(scalar.try_into()?)) - } else { - Ok(None) - } + list.elem_at(index as usize - 1).flatten() } #[cfg(test)] @@ -48,10 +43,10 @@ mod tests { ]); let l1 = ListRef::ValueRef { val: &v1 }; - assert_eq!(array_access::(l1, 1).unwrap(), Some(1)); - assert_eq!(array_access::(l1, -1).unwrap(), None); - assert_eq!(array_access::(l1, 0).unwrap(), None); - assert_eq!(array_access::(l1, 4).unwrap(), None); + assert_eq!(array_access(l1, 1), Some(1.into())); + assert_eq!(array_access(l1, -1), None); + assert_eq!(array_access(l1, 0), None); + assert_eq!(array_access(l1, 4), None); } #[test] @@ -72,18 +67,9 @@ mod tests { let l2 = ListRef::ValueRef { val: &v2 }; let l3 = ListRef::ValueRef { val: &v3 }; - assert_eq!( - array_access::>(l1, 1).unwrap(), - Some("来自".into()) - ); - assert_eq!( - array_access::>(l2, 2).unwrap(), - Some("荷兰".into()) - ); - assert_eq!( - array_access::>(l3, 3).unwrap(), - Some("的爱".into()) - ); + assert_eq!(array_access(l1, 1), Some("来自".into())); + assert_eq!(array_access(l2, 2), Some("荷兰".into())); + assert_eq!(array_access(l3, 3), Some("的爱".into())); } #[test] @@ -100,11 +86,14 @@ mod tests { ]); let l = ListRef::ValueRef { val: &v }; assert_eq!( - array_access::(l, 1).unwrap(), - Some(ListValue::new(vec![ - Some(ScalarImpl::Utf8("foo".into())), - Some(ScalarImpl::Utf8("bar".into())), - ])) + array_access(l, 1), + Some( + ListRef::from(&ListValue::new(vec![ + Some(ScalarImpl::Utf8("foo".into())), + Some(ScalarImpl::Utf8("bar".into())), + ])) + .into() + ) ); } } diff --git a/src/expr/impl/src/scalar/array_concat.rs b/src/expr/impl/src/scalar/array_concat.rs index b63ae05a4fc32..ab63c6f1f7529 100644 --- a/src/expr/impl/src/scalar/array_concat.rs +++ b/src/expr/impl/src/scalar/array_concat.rs @@ -87,7 +87,7 @@ use risingwave_expr::function; /// ---- /// NULL /// ``` -#[function("array_cat(list, list) -> list")] +#[function("array_cat(anyarray, anyarray) -> anyarray")] fn array_cat( left: Option>, right: Option>, @@ -153,15 +153,12 @@ fn array_cat( /// ---- /// {NULL} /// ``` -#[function("array_append(list, *) -> list")] -fn array_append<'a>( - left: Option>, - right: Option>>, -) -> ListValue { +#[function("array_append(anyarray, any) -> anyarray")] +fn array_append(left: Option>, right: Option>) -> ListValue { ListValue::new( left.iter() .flat_map(|list| list.iter()) - .chain(std::iter::once(right.map(Into::into))) + .chain(std::iter::once(right)) .map(|x| x.map(ScalarRefImpl::into_scalar_impl)) .collect(), ) @@ -193,13 +190,10 @@ fn array_append<'a>( /// ---- /// {NULL} /// ``` -#[function("array_prepend(*, list) -> list")] -fn array_prepend<'a>( - left: Option>>, - right: Option>, -) -> ListValue { +#[function("array_prepend(any, anyarray) -> anyarray")] +fn array_prepend(left: Option>, right: Option>) -> ListValue { ListValue::new( - std::iter::once(left.map(Into::into)) + std::iter::once(left) .chain(right.iter().flat_map(|list| list.iter())) .map(|x| x.map(ScalarRefImpl::into_scalar_impl)) .collect(), diff --git a/src/expr/impl/src/scalar/array_distinct.rs b/src/expr/impl/src/scalar/array_distinct.rs index 573648006eb98..a334506631bf8 100644 --- a/src/expr/impl/src/scalar/array_distinct.rs +++ b/src/expr/impl/src/scalar/array_distinct.rs @@ -50,7 +50,7 @@ use risingwave_expr::function; /// select array_distinct(null); /// ``` -#[function("array_distinct(list) -> list")] +#[function("array_distinct(anyarray) -> anyarray")] pub fn array_distinct(list: ListRef<'_>) -> ListValue { ListValue::new(list.iter().unique().map(|x| x.to_owned_datum()).collect()) } diff --git a/src/expr/impl/src/scalar/array_length.rs b/src/expr/impl/src/scalar/array_length.rs index 465ac50b4d4fb..770744cbc0f46 100644 --- a/src/expr/impl/src/scalar/array_length.rs +++ b/src/expr/impl/src/scalar/array_length.rs @@ -15,7 +15,7 @@ use std::fmt::Write; use risingwave_common::array::ListRef; -use risingwave_expr::{function, ExprError}; +use risingwave_expr::{function, ExprError, Result}; /// Returns the length of an array. /// @@ -59,9 +59,9 @@ use risingwave_expr::{function, ExprError}; /// query error Cannot implicitly cast /// select array_length(null); /// ``` -#[function("array_length(list) -> int4")] -#[function("array_length(list) -> int8", deprecated)] -fn array_length>(array: ListRef<'_>) -> Result { +#[function("array_length(anyarray) -> int4")] +#[function("array_length(anyarray) -> int8", deprecated)] +fn array_length>(array: ListRef<'_>) -> Result { array .len() .try_into() @@ -127,8 +127,8 @@ fn array_length>(array: ListRef<'_>) -> Result { /// statement error /// select array_length(array[null, array[2]], 2); /// ``` -#[function("array_length(list, int4) -> int4")] -fn array_length_of_dim(array: ListRef<'_>, d: i32) -> Result, ExprError> { +#[function("array_length(anyarray, int4) -> int4")] +fn array_length_of_dim(array: ListRef<'_>, d: i32) -> Result> { match d { ..=0 => Ok(None), 1 => array_length(array).map(Some), @@ -184,7 +184,7 @@ fn array_length_of_dim(array: ListRef<'_>, d: i32) -> Result, ExprEr /// statement error /// select array_dims(array[array[]::int[]]); -- would be `[1:1][1:0]` after multidimensional support /// ``` -#[function("array_dims(list) -> varchar")] +#[function("array_dims(anyarray) -> varchar")] fn array_dims(array: ListRef<'_>, writer: &mut impl Write) { write!(writer, "[1:{}]", array.len()).unwrap(); } diff --git a/src/expr/impl/src/scalar/array_min_max.rs b/src/expr/impl/src/scalar/array_min_max.rs index 86652b8123970..db54139ec081b 100644 --- a/src/expr/impl/src/scalar/array_min_max.rs +++ b/src/expr/impl/src/scalar/array_min_max.rs @@ -13,45 +13,17 @@ // limitations under the License. use risingwave_common::array::*; -use risingwave_common::types::{DefaultOrdered, Scalar, ToOwnedDatum}; -use risingwave_expr::{function, Result}; +use risingwave_common::types::{DefaultOrdered, ScalarRefImpl}; +use risingwave_expr::function; -/// FIXME: #[`function("array_min(list`) -> any")] supports -/// In this way we could avoid manual macro expansion -#[function("array_min(list) -> *int")] -#[function("array_min(list) -> *float")] -#[function("array_min(list) -> decimal")] -#[function("array_min(list) -> serial")] -#[function("array_min(list) -> int256")] -#[function("array_min(list) -> date")] -#[function("array_min(list) -> time")] -#[function("array_min(list) -> timestamp")] -#[function("array_min(list) -> timestamptz")] -#[function("array_min(list) -> varchar")] -#[function("array_min(list) -> bytea")] -pub fn array_min(list: ListRef<'_>) -> Result> { +#[function("array_min(anyarray) -> any")] +pub fn array_min(list: ListRef<'_>) -> Option> { let min_value = list.iter().flatten().map(DefaultOrdered).min(); - match min_value.map(|v| v.0).to_owned_datum() { - Some(s) => Ok(Some(s.try_into()?)), - None => Ok(None), - } + min_value.map(|v| v.0) } -#[function("array_max(list) -> *int")] -#[function("array_max(list) -> *float")] -#[function("array_max(list) -> decimal")] -#[function("array_max(list) -> serial")] -#[function("array_max(list) -> int256")] -#[function("array_max(list) -> date")] -#[function("array_max(list) -> time")] -#[function("array_max(list) -> timestamp")] -#[function("array_max(list) -> timestamptz")] -#[function("array_max(list) -> varchar")] -#[function("array_max(list) -> bytea")] -pub fn array_max(list: ListRef<'_>) -> Result> { +#[function("array_max(anyarray) -> any")] +pub fn array_max(list: ListRef<'_>) -> Option> { let max_value = list.iter().flatten().map(DefaultOrdered).max(); - match max_value.map(|v| v.0).to_owned_datum() { - Some(s) => Ok(Some(s.try_into()?)), - None => Ok(None), - } + max_value.map(|v| v.0) } diff --git a/src/expr/impl/src/scalar/array_positions.rs b/src/expr/impl/src/scalar/array_positions.rs index bfef1fd8c6f13..6058730344485 100644 --- a/src/expr/impl/src/scalar/array_positions.rs +++ b/src/expr/impl/src/scalar/array_positions.rs @@ -13,7 +13,7 @@ // limitations under the License. use risingwave_common::array::{ListRef, ListValue}; -use risingwave_common::types::{ScalarImpl, ScalarRef}; +use risingwave_common::types::{ScalarImpl, ScalarRefImpl}; use risingwave_expr::{function, ExprError, Result}; /// Returns the subscript of the first occurrence of the second argument in the array, or `NULL` if @@ -65,10 +65,10 @@ use risingwave_expr::{function, ExprError, Result}; /// ---- /// 2 /// ``` -#[function("array_position(list, *) -> int4")] -fn array_position<'a, T: ScalarRef<'a>>( +#[function("array_position(anyarray, any) -> int4")] +fn array_position( array: Option>, - element: Option, + element: Option>, ) -> Result> { array_position_common(array, element, 0) } @@ -96,10 +96,10 @@ fn array_position<'a, T: ScalarRef<'a>>( /// 4 4 /// 5 NULL /// ``` -#[function("array_position(list, *, int4) -> int4")] -fn array_position_start<'a, T: ScalarRef<'a>>( +#[function("array_position(anyarray, any, int4) -> int4")] +fn array_position_start( array: Option>, - element: Option, + element: Option>, start: Option, ) -> Result> { let start = match start { @@ -114,9 +114,9 @@ fn array_position_start<'a, T: ScalarRef<'a>>( array_position_common(array, element, start) } -fn array_position_common<'a, T: ScalarRef<'a>>( +fn array_position_common( array: Option>, - element: Option, + element: Option>, skip: usize, ) -> Result> { let Some(left) = array else { return Ok(None) }; @@ -127,7 +127,7 @@ fn array_position_common<'a, T: ScalarRef<'a>>( Ok(left .iter() .skip(skip) - .position(|item| item == element.map(Into::into)) + .position(|item| item == element) .map(|idx| (idx + 1 + skip) as _)) } @@ -181,25 +181,23 @@ fn array_position_common<'a, T: ScalarRef<'a>>( /// statement error /// select array_positions(ARRAY[array[1],array[2],array[3],array[2],null], array[true]); /// ``` -#[function("array_positions(list, *) -> list")] -fn array_positions<'a, T: ScalarRef<'a>>( +#[function("array_positions(anyarray, any) -> int4[]")] +fn array_positions( array: Option>, - element: Option, + element: Option>, ) -> Result> { - match array { - Some(left) => { - let values = left.iter(); - match TryInto::::try_into(values.len()) { - Ok(_) => Ok(Some(ListValue::new( - values - .enumerate() - .filter(|(_, item)| item == &element.map(|x| x.into())) - .map(|(idx, _)| Some(ScalarImpl::Int32((idx + 1) as _))) - .collect(), - ))), - Err(_) => Err(ExprError::CastOutOfRange("invalid array length")), - } - } - _ => Ok(None), + let Some(array) = array else { + return Ok(None); + }; + let values = array.iter(); + if values.len() - 1 > i32::MAX as usize { + return Err(ExprError::CastOutOfRange("invalid array length")); } + Ok(Some(ListValue::new( + values + .enumerate() + .filter(|(_, item)| item == &element) + .map(|(idx, _)| Some(ScalarImpl::Int32((idx + 1) as _))) + .collect(), + ))) } diff --git a/src/expr/impl/src/scalar/array_range_access.rs b/src/expr/impl/src/scalar/array_range_access.rs index 785160a74c5ce..2782a97a7147f 100644 --- a/src/expr/impl/src/scalar/array_range_access.rs +++ b/src/expr/impl/src/scalar/array_range_access.rs @@ -14,22 +14,22 @@ use risingwave_common::array::{ListRef, ListValue}; use risingwave_common::types::ToOwnedDatum; -use risingwave_expr::{function, Result}; +use risingwave_expr::function; /// If the case is `array[1,2,3][:2]`, then start will be 0 set by the frontend /// If the case is `array[1,2,3][1:]`, then end will be `i32::MAX` set by the frontend -#[function("array_range_access(list, int4, int4) -> list")] -pub fn array_range_access(list: ListRef<'_>, start: i32, end: i32) -> Result> { +#[function("array_range_access(anyarray, int4, int4) -> anyarray")] +pub fn array_range_access(list: ListRef<'_>, start: i32, end: i32) -> Option { let mut data = vec![]; let list_all_values = list.iter(); let start = std::cmp::max(start, 1) as usize; let end = std::cmp::min(std::cmp::max(0, end), list_all_values.len() as i32) as usize; if start > end { - return Ok(Some(ListValue::new(data))); + return Some(ListValue::new(data)); } for datumref in list_all_values.take(end).skip(start - 1) { data.push(datumref.to_owned_datum()); } - Ok(Some(ListValue::new(data))) + Some(ListValue::new(data)) } diff --git a/src/expr/impl/src/scalar/array_remove.rs b/src/expr/impl/src/scalar/array_remove.rs index 9c799fd71dd91..b26be50cc3954 100644 --- a/src/expr/impl/src/scalar/array_remove.rs +++ b/src/expr/impl/src/scalar/array_remove.rs @@ -13,7 +13,7 @@ // limitations under the License. use risingwave_common::array::{ListRef, ListValue}; -use risingwave_common::types::{ScalarRef, ToOwnedDatum}; +use risingwave_common::types::{ScalarRefImpl, ToOwnedDatum}; use risingwave_expr::function; /// Removes all elements equal to the given value from the array. @@ -66,17 +66,13 @@ use risingwave_expr::function; /// statement error /// select array_remove(ARRAY[array[1],array[2],array[3],array[2],null], array[true]); /// ``` -#[function("array_remove(list, *) -> list")] -fn array_remove<'a, T: ScalarRef<'a>>( - arr: Option>, - elem: Option, -) -> Option { - arr.map(|arr| { - ListValue::new( - arr.iter() - .filter(|x| x != &elem.map(Into::into)) - .map(|x| x.to_owned_datum()) - .collect(), - ) - }) +#[function("array_remove(anyarray, any) -> anyarray")] +fn array_remove(array: Option>, elem: Option>) -> Option { + Some(ListValue::new( + array? + .iter() + .filter(|x| x != &elem) + .map(|x| x.to_owned_datum()) + .collect(), + )) } diff --git a/src/expr/impl/src/scalar/array_replace.rs b/src/expr/impl/src/scalar/array_replace.rs index 38760bae2871e..9764a3937fc42 100644 --- a/src/expr/impl/src/scalar/array_replace.rs +++ b/src/expr/impl/src/scalar/array_replace.rs @@ -13,7 +13,7 @@ // limitations under the License. use risingwave_common::array::{ListRef, ListValue}; -use risingwave_common::types::{ScalarRef, ToOwnedDatum}; +use risingwave_common::types::{ScalarRefImpl, ToOwnedDatum}; use risingwave_expr::function; /// Replaces each array element equal to the second argument with the third argument. @@ -54,37 +54,19 @@ use risingwave_expr::function; /// statement error /// select array_replace(array[array[array[0, 1], array[2, 3]], array[array[4, 5], array[6, 7]]], array[4, 5], array[8, 9]); /// ``` -#[function("array_replace(list, boolean, boolean) -> list")] -#[function("array_replace(list, int2, int2) -> list")] -#[function("array_replace(list, int4, int4) -> list")] -#[function("array_replace(list, int8, int8) -> list")] -#[function("array_replace(list, decimal, decimal) -> list")] -#[function("array_replace(list, float4, float4) -> list")] -#[function("array_replace(list, float8, float8) -> list")] -#[function("array_replace(list, varchar, varchar) -> list")] -#[function("array_replace(list, bytea, bytea) -> list")] -#[function("array_replace(list, time, time) -> list")] -#[function("array_replace(list, interval, interval) -> list")] -#[function("array_replace(list, date, date) -> list")] -#[function("array_replace(list, timestamp, timestamp) -> list")] -#[function("array_replace(list, timestamptz, timestamptz) -> list")] -#[function("array_replace(list, list, list) -> list")] -#[function("array_replace(list, struct, struct) -> list")] -#[function("array_replace(list, jsonb, jsonb) -> list")] -#[function("array_replace(list, int256, int256) -> list")] -fn array_replace<'a, T: ScalarRef<'a>>( - arr: Option>, - elem_from: Option, - elem_to: Option, +#[function("array_replace(anyarray, any, any) -> anyarray")] +fn array_replace( + array: Option>, + elem_from: Option>, + elem_to: Option>, ) -> Option { - arr.map(|arr| { - ListValue::new( - arr.iter() - .map(|x| match x == elem_from.map(Into::into) { - true => elem_to.map(Into::into).to_owned_datum(), - false => x.to_owned_datum(), - }) - .collect(), - ) - }) + Some(ListValue::new( + array? + .iter() + .map(|x| match x == elem_from { + true => elem_to.to_owned_datum(), + false => x.to_owned_datum(), + }) + .collect(), + )) } diff --git a/src/expr/impl/src/scalar/array_sort.rs b/src/expr/impl/src/scalar/array_sort.rs index 293954bd6139a..c48fe7608076e 100644 --- a/src/expr/impl/src/scalar/array_sort.rs +++ b/src/expr/impl/src/scalar/array_sort.rs @@ -13,19 +13,15 @@ // limitations under the License. use risingwave_common::array::*; -use risingwave_common::types::{Datum, DatumRef, DefaultOrdered, ToOwnedDatum}; +use risingwave_common::types::{DatumRef, DefaultOrdered, ToOwnedDatum}; use risingwave_expr::function; -#[function("array_sort(list) -> list")] +#[function("array_sort(anyarray) -> anyarray")] pub fn array_sort(list: ListRef<'_>) -> ListValue { let mut v = list .iter() .map(DefaultOrdered) .collect::>>>(); v.sort(); - ListValue::new( - v.into_iter() - .map(|x| x.0.to_owned_datum()) - .collect::>(), - ) + ListValue::new(v.into_iter().map(|x| x.0.to_owned_datum()).collect()) } diff --git a/src/expr/impl/src/scalar/array_sum.rs b/src/expr/impl/src/scalar/array_sum.rs index f6751ad6af8f0..fc45435f71200 100644 --- a/src/expr/impl/src/scalar/array_sum.rs +++ b/src/expr/impl/src/scalar/array_sum.rs @@ -13,70 +13,50 @@ // limitations under the License. use risingwave_common::array::{ArrayError, ListRef}; -use risingwave_common::types::{CheckedAdd, Decimal, Scalar, ScalarImpl, ScalarRefImpl}; +use risingwave_common::types::{CheckedAdd, Decimal, ScalarRefImpl}; use risingwave_expr::{function, ExprError, Result}; -/// `array_sum(int2`[]) -> int8 -/// `array_sum(int4`[]) -> int8 -#[function("array_sum(list) -> int8")] -#[function("array_sum(list) -> float4")] -#[function("array_sum(list) -> float8")] -/// `array_sum(int8`[]) -> decimal -/// `array_sum(decimal`[]) -> decimal -#[function("array_sum(list) -> decimal")] -#[function("array_sum(list) -> interval")] -fn array_sum(list: ListRef<'_>) -> Result> +#[function("array_sum(int2[]) -> int8")] +fn array_sum_int2(list: ListRef<'_>) -> Result> { + array_sum_general::(list) +} + +#[function("array_sum(int4[]) -> int8")] +fn array_sum_int4(list: ListRef<'_>) -> Result> { + array_sum_general::(list) +} + +#[function("array_sum(int8[]) -> decimal")] +fn array_sum_int8(list: ListRef<'_>) -> Result> { + array_sum_general::(list) +} + +#[function("array_sum(float4[]) -> float4")] +#[function("array_sum(float8[]) -> float8")] +#[function("array_sum(decimal[]) -> decimal")] +#[function("array_sum(interval[]) -> interval")] +fn array_sum(list: ListRef<'_>) -> Result> where - T: Default + for<'a> TryFrom, Error = ArrayError> + CheckedAdd, + T: for<'a> TryFrom, Error = ArrayError>, + T: Default + From + CheckedAdd, { - let flag = match list.iter().flatten().next() { - Some(v) => match v { - ScalarRefImpl::Int16(_) | ScalarRefImpl::Int32(_) => 1, - ScalarRefImpl::Int64(_) => 2, - _ => 0, - }, - None => return Ok(None), - }; + array_sum_general::(list) +} - if flag != 0 { - match flag { - 1 => { - let mut sum = 0; - for e in list.iter().flatten() { - sum = sum - .checked_add(match e { - ScalarRefImpl::Int16(v) => v as i64, - ScalarRefImpl::Int32(v) => v as i64, - _ => panic!("Expect ScalarRefImpl::Int16 or ScalarRefImpl::Int32"), - }) - .ok_or_else(|| ExprError::NumericOutOfRange)?; - } - Ok(Some(ScalarImpl::from(sum).try_into()?)) - } - 2 => { - let mut sum = Decimal::Normalized(0.into()); - for e in list.iter().flatten() { - sum = sum - .checked_add(match e { - ScalarRefImpl::Int64(v) => Decimal::Normalized(v.into()), - ScalarRefImpl::Decimal(v) => v, - // FIXME: We can't panic here due to the macro expansion - _ => Decimal::Normalized(0.into()), - }) - .ok_or_else(|| ExprError::NumericOutOfRange)?; - } - Ok(Some(ScalarImpl::from(sum).try_into()?)) - } - _ => Ok(None), - } - } else { - let mut sum = T::default(); - for e in list.iter().flatten() { - let v = e.try_into()?; - sum = sum - .checked_add(v) - .ok_or_else(|| ExprError::NumericOutOfRange)?; - } - Ok(Some(sum)) +fn array_sum_general(list: ListRef<'_>) -> Result> +where + S: for<'a> TryFrom, Error = ArrayError>, + T: Default + From + CheckedAdd, +{ + if list.iter().flatten().next().is_none() { + return Ok(None); + } + let mut sum = T::default(); + for e in list.iter().flatten() { + let v: S = e.try_into()?; + sum = sum + .checked_add(v.into()) + .ok_or_else(|| ExprError::NumericOutOfRange)?; } + Ok(Some(sum)) } diff --git a/src/expr/impl/src/scalar/array_to_string.rs b/src/expr/impl/src/scalar/array_to_string.rs index 6e3ea72c5f4ed..0c98dcba262ea 100644 --- a/src/expr/impl/src/scalar/array_to_string.rs +++ b/src/expr/impl/src/scalar/array_to_string.rs @@ -81,7 +81,7 @@ use risingwave_expr::function; /// ---- /// one,*,three,four /// ``` -#[function("array_to_string(list, varchar) -> varchar")] +#[function("array_to_string(anyarray, varchar) -> varchar")] fn array_to_string(array: ListRef<'_>, delimiter: &str, ctx: &Context, writer: &mut impl Write) { let element_data_type = ctx.arg_types[0].unnest_list(); let mut first = true; @@ -96,7 +96,7 @@ fn array_to_string(array: ListRef<'_>, delimiter: &str, ctx: &Context, writer: & } } -#[function("array_to_string(list, varchar, varchar) -> varchar")] +#[function("array_to_string(anyarray, varchar, varchar) -> varchar")] fn array_to_string_with_null( array: ListRef<'_>, delimiter: &str, diff --git a/src/expr/impl/src/scalar/cardinality.rs b/src/expr/impl/src/scalar/cardinality.rs index c3b539cdeb4ea..a24aee30e8294 100644 --- a/src/expr/impl/src/scalar/cardinality.rs +++ b/src/expr/impl/src/scalar/cardinality.rs @@ -57,8 +57,8 @@ use risingwave_expr::{function, ExprError, Result}; /// query error Cannot implicitly cast /// select cardinality(null); /// ``` -#[function("cardinality(list) -> int4")] -#[function("cardinality(list) -> int8", deprecated)] +#[function("cardinality(anyarray) -> int4")] +#[function("cardinality(anyarray) -> int8", deprecated)] fn cardinality>(array: ListRef<'_>) -> Result { array .flatten() diff --git a/src/expr/impl/src/scalar/cast.rs b/src/expr/impl/src/scalar/cast.rs index 89d643e524eac..889cc43fe6b18 100644 --- a/src/expr/impl/src/scalar/cast.rs +++ b/src/expr/impl/src/scalar/cast.rs @@ -149,7 +149,7 @@ pub fn int_to_bool(input: i32) -> bool { #[function("cast(timestamp) -> varchar")] #[function("cast(jsonb) -> varchar")] #[function("cast(bytea) -> varchar")] -#[function("cast(list) -> varchar")] +#[function("cast(anyarray) -> varchar")] pub fn general_to_text(elem: impl ToText, mut writer: &mut impl Write) { elem.write(&mut writer).unwrap(); } @@ -213,7 +213,7 @@ fn unnest(input: &str) -> Result> { Ok(items) } -#[function("cast(varchar) -> list")] +#[function("cast(varchar) -> anyarray", type_infer = "panic")] fn str_to_list(input: &str, ctx: &Context) -> Result { let cast = build_func( PbType::Cast, @@ -233,7 +233,7 @@ fn str_to_list(input: &str, ctx: &Context) -> Result { } /// Cast array with `source_elem_type` into array with `target_elem_type` by casting each element. -#[function("cast(list) -> list")] +#[function("cast(anyarray) -> anyarray", type_infer = "panic")] fn list_cast(input: ListRef<'_>, ctx: &Context) -> Result { let cast = build_func( PbType::Cast, @@ -254,7 +254,7 @@ fn list_cast(input: ListRef<'_>, ctx: &Context) -> Result { } /// Cast struct of `source_elem_type` to `target_elem_type` by casting each element. -#[function("cast(struct) -> struct")] +#[function("cast(struct) -> struct", type_infer = "panic")] fn struct_cast(input: StructRef<'_>, ctx: &Context) -> Result { let fields = (input.iter_fields_ref()) .zip_eq_fast(ctx.arg_types[0].as_struct().types()) diff --git a/src/expr/impl/src/scalar/cmp.rs b/src/expr/impl/src/scalar/cmp.rs index 2bf0a64d6fe02..b5a38af0d44a6 100644 --- a/src/expr/impl/src/scalar/cmp.rs +++ b/src/expr/impl/src/scalar/cmp.rs @@ -35,7 +35,7 @@ use risingwave_expr::function; #[function("equal(interval, time) -> boolean")] #[function("equal(varchar, varchar) -> boolean")] #[function("equal(bytea, bytea) -> boolean")] -#[function("equal(list, list) -> boolean")] +#[function("equal(anyarray, anyarray) -> boolean")] #[function("equal(struct, struct) -> boolean")] pub fn general_eq(l: T1, r: T2) -> bool where @@ -63,7 +63,7 @@ where #[function("not_equal(interval, time) -> boolean")] #[function("not_equal(varchar, varchar) -> boolean")] #[function("not_equal(bytea, bytea) -> boolean")] -#[function("not_equal(list, list) -> boolean")] +#[function("not_equal(anyarray, anyarray) -> boolean")] #[function("not_equal(struct, struct) -> boolean")] pub fn general_ne(l: T1, r: T2) -> bool where @@ -94,7 +94,7 @@ where #[function("greater_than_or_equal(interval, time) -> boolean")] #[function("greater_than_or_equal(varchar, varchar) -> boolean")] #[function("greater_than_or_equal(bytea, bytea) -> boolean")] -#[function("greater_than_or_equal(list, list) -> boolean")] +#[function("greater_than_or_equal(anyarray, anyarray) -> boolean")] #[function("greater_than_or_equal(struct, struct) -> boolean")] pub fn general_ge(l: T1, r: T2) -> bool where @@ -122,7 +122,7 @@ where #[function("greater_than(interval, time) -> boolean")] #[function("greater_than(varchar, varchar) -> boolean")] #[function("greater_than(bytea, bytea) -> boolean")] -#[function("greater_than(list, list) -> boolean")] +#[function("greater_than(anyarray, anyarray) -> boolean")] #[function("greater_than(struct, struct) -> boolean")] pub fn general_gt(l: T1, r: T2) -> bool where @@ -153,7 +153,7 @@ where #[function("less_than_or_equal(interval, time) -> boolean")] #[function("less_than_or_equal(varchar, varchar) -> boolean")] #[function("less_than_or_equal(bytea, bytea) -> boolean")] -#[function("less_than_or_equal(list, list) -> boolean")] +#[function("less_than_or_equal(anyarray, anyarray) -> boolean")] #[function("less_than_or_equal(struct, struct) -> boolean")] pub fn general_le(l: T1, r: T2) -> bool where @@ -181,7 +181,7 @@ where #[function("less_than(interval, time) -> boolean")] #[function("less_than(varchar, varchar) -> boolean")] #[function("less_than(bytea, bytea) -> boolean")] -#[function("less_than(list, list) -> boolean")] +#[function("less_than(anyarray, anyarray) -> boolean")] #[function("less_than(struct, struct) -> boolean")] pub fn general_lt(l: T1, r: T2) -> bool where @@ -212,7 +212,7 @@ where #[function("is_distinct_from(interval, time) -> boolean")] #[function("is_distinct_from(varchar, varchar) -> boolean")] #[function("is_distinct_from(bytea, bytea) -> boolean")] -#[function("is_distinct_from(list, list) -> boolean")] +#[function("is_distinct_from(anyarray, anyarray) -> boolean")] #[function("is_distinct_from(struct, struct) -> boolean")] pub fn general_is_distinct_from(l: Option, r: Option) -> bool where @@ -243,7 +243,7 @@ where #[function("is_not_distinct_from(interval, time) -> boolean")] #[function("is_not_distinct_from(varchar, varchar) -> boolean")] #[function("is_not_distinct_from(bytea, bytea) -> boolean")] -#[function("is_not_distinct_from(list, list) -> boolean")] +#[function("is_not_distinct_from(anyarray, anyarray) -> boolean")] #[function("is_not_distinct_from(struct, struct) -> boolean")] pub fn general_is_not_distinct_from(l: Option, r: Option) -> bool where diff --git a/src/expr/impl/src/scalar/string_to_array.rs b/src/expr/impl/src/scalar/string_to_array.rs index a638c8ef12bf6..a61cbda3ddfbe 100644 --- a/src/expr/impl/src/scalar/string_to_array.rs +++ b/src/expr/impl/src/scalar/string_to_array.rs @@ -35,39 +35,33 @@ fn string_to_array_inner<'a>( } // Use cases shown in `e2e_test/batch/functions/string_to_array.slt.part` -#[function("string_to_array(varchar, varchar) -> list")] +#[function("string_to_array(varchar, varchar) -> varchar[]")] pub fn string_to_array2(s: Option<&str>, sep: Option<&str>) -> Option { - s.map(|s| { - ListValue::new( - string_to_array_inner(s, sep) - .map(|x| Some(ScalarImpl::Utf8(x.into()))) - .collect_vec(), - ) - }) + Some(ListValue::new( + string_to_array_inner(s?, sep) + .map(|x| Some(ScalarImpl::Utf8(x.into()))) + .collect_vec(), + )) } -#[function("string_to_array(varchar, varchar, varchar) -> list")] +#[function("string_to_array(varchar, varchar, varchar) -> varchar[]")] pub fn string_to_array3( s: Option<&str>, sep: Option<&str>, null: Option<&str>, ) -> Option { - s.map(|s| { - null.map_or_else( - || string_to_array2(Some(s), sep).unwrap(), - |null| { - ListValue::new( - string_to_array_inner(s, sep) - .map(|x| { - if x == null { - None - } else { - Some(ScalarImpl::Utf8(x.into())) - } - }) - .collect_vec(), - ) - }, - ) - }) + let Some(null) = null else { + return string_to_array2(s, sep); + }; + Some(ListValue::new( + string_to_array_inner(s?, sep) + .map(|x| { + if x == null { + None + } else { + Some(ScalarImpl::Utf8(x.into())) + } + }) + .collect_vec(), + )) } diff --git a/src/expr/impl/src/scalar/trim_array.rs b/src/expr/impl/src/scalar/trim_array.rs index 4b7a1d53304ec..3a9bbed9c0562 100644 --- a/src/expr/impl/src/scalar/trim_array.rs +++ b/src/expr/impl/src/scalar/trim_array.rs @@ -70,7 +70,7 @@ use risingwave_expr::{function, ExprError, Result}; /// statement error /// select trim_array(array[1,2,3,4,5,null], true); /// ``` -#[function("trim_array(list, int4) -> list")] +#[function("trim_array(anyarray, int4) -> anyarray")] fn trim_array(array: ListRef<'_>, n: i32) -> Result { let values = array.iter(); let len_to_trim: usize = n.try_into().map_err(|_| ExprError::InvalidParam { diff --git a/src/expr/impl/src/table_function/generate_subscripts.rs b/src/expr/impl/src/table_function/generate_subscripts.rs index 0fe3937aae8c9..53123489d7976 100644 --- a/src/expr/impl/src/table_function/generate_subscripts.rs +++ b/src/expr/impl/src/table_function/generate_subscripts.rs @@ -56,7 +56,7 @@ use risingwave_expr::function; /// ---- /// 1 /// ``` -#[function("generate_subscripts(list, int4, boolean) -> setof int4")] +#[function("generate_subscripts(anyarray, int4, boolean) -> setof int4")] fn generate_subscripts_reverse( array: ListRef<'_>, dim: i32, @@ -104,7 +104,7 @@ fn generate_subscripts_reverse( /// ---- /// 1 /// ``` -#[function("generate_subscripts(list, int4) -> setof int4")] +#[function("generate_subscripts(anyarray, int4) -> setof int4")] fn generate_subscripts(array: ListRef<'_>, dim: i32) -> impl Iterator { generate_subscripts_iterator(array, dim, false) } diff --git a/src/expr/impl/src/table_function/pg_expandarray.rs b/src/expr/impl/src/table_function/pg_expandarray.rs index 4ed8516921654..bf0107b703647 100644 --- a/src/expr/impl/src/table_function/pg_expandarray.rs +++ b/src/expr/impl/src/table_function/pg_expandarray.rs @@ -33,7 +33,7 @@ use risingwave_expr::{function, Result}; /// three 3 /// ``` #[function( - "_pg_expandarray(list) -> setof struct", + "_pg_expandarray(anyarray) -> setof struct", type_infer = "infer_type" )] fn _pg_expandarray(array: ListRef<'_>) -> impl Iterator>, i32)> { diff --git a/src/expr/impl/src/table_function/unnest.rs b/src/expr/impl/src/table_function/unnest.rs index 019f6b08d591e..7534b903565dd 100644 --- a/src/expr/impl/src/table_function/unnest.rs +++ b/src/expr/impl/src/table_function/unnest.rs @@ -17,7 +17,7 @@ use risingwave_common::types::ScalarRefImpl; use risingwave_expr::function; #[function( - "unnest(list) -> setof any", + "unnest(anyarray) -> setof any", type_infer = "|args| Ok(args[0].unnest_list().clone())" )] fn unnest(list: ListRef<'_>) -> impl Iterator>> { diff --git a/src/expr/impl/tests/sig.rs b/src/expr/impl/tests/sig.rs index 34a4bf16beba2..1a227e9472042 100644 --- a/src/expr/impl/tests/sig.rs +++ b/src/expr/impl/tests/sig.rs @@ -12,96 +12,69 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::BTreeMap; +use std::collections::HashMap; risingwave_expr_impl::enable!(); use itertools::Itertools; -use risingwave_common::types::DataTypeName; -use risingwave_expr::sig::func::{func_sigs, FuncSign}; -use risingwave_pb::expr::expr_node::PbType; - +use risingwave_expr::sig::{FuncName, FuncSign, SigDataType, FUNCTION_REGISTRY}; #[test] fn test_func_sig_map() { // convert FUNC_SIG_MAP to a more convenient map for testing - let mut new_map: BTreeMap, Vec>> = BTreeMap::new(); - for sig in func_sigs() { + let mut new_map: HashMap, Vec>> = HashMap::new(); + for sig in FUNCTION_REGISTRY.iter_scalars() { // exclude deprecated functions if sig.deprecated { continue; } new_map - .entry(sig.func) + .entry(sig.name) .or_default() .entry(sig.inputs_type.to_vec()) .or_default() .push(sig.clone()); } - let duplicated: BTreeMap<_, Vec<_>> = new_map - .into_iter() - .filter_map(|(k, funcs_with_same_name)| { - let funcs_with_same_name_type: Vec<_> = funcs_with_same_name - .into_values() - .filter_map(|v| { - if v.len() > 1 { - Some( - format!( - "{:}({:?}) -> {:?}", - v[0].func.as_str_name(), - v[0].inputs_type.iter().format(", "), - v.iter().map(|sig| sig.ret_type).format("/") - ) - .to_ascii_lowercase(), - ) - } else { - None - } - }) - .collect(); - if !funcs_with_same_name_type.is_empty() { - Some((k, funcs_with_same_name_type)) - } else { - None - } + let mut duplicated: Vec<_> = new_map + .into_values() + .flat_map(|funcs_with_same_name| { + funcs_with_same_name.into_values().filter_map(|v| { + if v.len() > 1 { + Some(format!( + "{}({}) -> {}", + v[0].name.as_str_name().to_ascii_lowercase(), + v[0].inputs_type.iter().format(", "), + v.iter().map(|sig| &sig.ret_type).format("/") + )) + } else { + None + } + }) }) .collect(); + duplicated.sort(); // This snapshot shows the function signatures without a unique match. Frontend has to // handle them specially without relying on FuncSigMap. let expected = expect_test::expect![[r#" - { - Cast: [ - "cast(boolean) -> int32/varchar", - "cast(int16) -> int256/decimal/float64/float32/int64/int32/varchar", - "cast(int32) -> int256/int16/decimal/float64/float32/int64/boolean/varchar", - "cast(int64) -> int256/int32/int16/decimal/float64/float32/varchar", - "cast(float32) -> decimal/int64/int32/int16/float64/varchar", - "cast(float64) -> decimal/float32/int64/int32/int16/varchar", - "cast(decimal) -> float64/float32/int64/int32/int16/varchar", - "cast(date) -> timestamp/varchar", - "cast(varchar) -> jsonb/interval/timestamp/time/date/int256/float32/float64/decimal/int16/int32/int64/varchar/boolean/bytea/list", - "cast(time) -> interval/varchar", - "cast(timestamp) -> time/date/varchar", - "cast(interval) -> time/varchar", - "cast(list) -> varchar/list", - "cast(jsonb) -> boolean/float64/float32/decimal/int64/int32/int16/varchar", - "cast(int256) -> float64/varchar", - ], - ArrayAccess: [ - "array_access(list, int32) -> boolean/int16/int32/int64/int256/float32/float64/decimal/serial/date/time/timestamp/timestamptz/interval/varchar/bytea/jsonb/list/struct", - ], - ArrayMin: [ - "array_min(list) -> bytea/varchar/timestamptz/timestamp/time/date/int256/serial/decimal/float32/float64/int16/int32/int64", - ], - ArrayMax: [ - "array_max(list) -> bytea/varchar/timestamptz/timestamp/time/date/int256/serial/decimal/float32/float64/int16/int32/int64", - ], - ArraySum: [ - "array_sum(list) -> interval/decimal/float64/float32/int64", - ], - } - "#]]; + [ + "cast(anyarray) -> character varying/anyarray", + "cast(bigint) -> rw_int256/integer/smallint/numeric/double precision/real/character varying", + "cast(boolean) -> integer/character varying", + "cast(character varying) -> jsonb/interval/timestamp without time zone/time without time zone/date/rw_int256/real/double precision/numeric/smallint/integer/bigint/character varying/boolean/bytea/anyarray", + "cast(date) -> timestamp without time zone/character varying", + "cast(double precision) -> numeric/real/bigint/integer/smallint/character varying", + "cast(integer) -> rw_int256/smallint/numeric/double precision/real/bigint/boolean/character varying", + "cast(interval) -> time without time zone/character varying", + "cast(jsonb) -> boolean/double precision/real/numeric/bigint/integer/smallint/character varying", + "cast(numeric) -> double precision/real/bigint/integer/smallint/character varying", + "cast(real) -> numeric/bigint/integer/smallint/double precision/character varying", + "cast(rw_int256) -> double precision/character varying", + "cast(smallint) -> rw_int256/numeric/double precision/real/bigint/integer/character varying", + "cast(time without time zone) -> interval/character varying", + "cast(timestamp without time zone) -> time without time zone/date/character varying", + ] + "#]]; expected.assert_debug_eq(&duplicated); } diff --git a/src/expr/macro/src/gen.rs b/src/expr/macro/src/gen.rs index 50c276e6685c8..083f184add5e7 100644 --- a/src/expr/macro/src/gen.rs +++ b/src/expr/macro/src/gen.rs @@ -47,6 +47,49 @@ impl FunctionAttr { attrs } + /// Generate the type infer function. + fn generate_type_infer_fn(&self) -> Result { + if let Some(func) = &self.type_infer { + if func == "panic" { + return Ok(quote! { |_| panic!("type inference function is not implemented") }); + } + // use the user defined type inference function + return Ok(func.parse().unwrap()); + } else if self.ret == "any" { + // TODO: if there are multiple "any", they should be the same type + if let Some(i) = self.args.iter().position(|t| t == "any") { + // infer as the type of "any" argument + return Ok(quote! { |args| Ok(args[#i].clone()) }); + } + if let Some(i) = self.args.iter().position(|t| t == "anyarray") { + // infer as the element type of "anyarray" argument + return Ok(quote! { |args| Ok(args[#i].as_list().clone()) }); + } + } else if self.ret == "anyarray" { + if let Some(i) = self.args.iter().position(|t| t == "anyarray") { + // infer as the type of "anyarray" argument + return Ok(quote! { |args| Ok(args[#i].clone()) }); + } + if let Some(i) = self.args.iter().position(|t| t == "any") { + // infer as the array type of "any" argument + return Ok(quote! { |args| Ok(DataType::List(Box::new(args[#i].clone()))) }); + } + } else if self.ret == "struct" { + if let Some(i) = self.args.iter().position(|t| t == "struct") { + // infer as the type of "struct" argument + return Ok(quote! { |args| Ok(args[#i].clone()) }); + } + } else { + // the return type is fixed + let ty = data_type(&self.ret); + return Ok(quote! { |_| Ok(#ty) }); + } + Err(Error::new( + Span::call_site(), + "type inference function is required", + )) + } + /// Generate a descriptor of the scalar or table function. /// /// The types of arguments and return value should not contain wildcard. @@ -65,31 +108,37 @@ impl FunctionAttr { false => &self.args[..], } .iter() - .map(|ty| data_type_name(ty)) + .map(|ty| sig_data_type(ty)) .collect_vec(); - let ret = data_type_name(&self.ret); + let ret = sig_data_type(&self.ret); let pb_type = format_ident!("{}", utils::to_camel_case(&name)); let ctor_name = format_ident!("{}", self.ident_name()); - let descriptor_type = quote! { risingwave_expr::sig::func::FuncSign }; let build_fn = if build_fn { let name = format_ident!("{}", user_fn.name); quote! { #name } } else { self.generate_build_scalar_function(user_fn, true)? }; + let type_infer_fn = self.generate_type_infer_fn()?; let deprecated = self.deprecated; + Ok(quote! { #[risingwave_expr::codegen::ctor] fn #ctor_name() { use risingwave_common::types::{DataType, DataTypeName}; - unsafe { risingwave_expr::sig::func::_register(#descriptor_type { - func: risingwave_pb::expr::expr_node::Type::#pb_type, - inputs_type: &[#(#args),*], + use risingwave_expr::sig::{_register, FuncSign, SigDataType, FuncBuilder}; + + unsafe { _register(FuncSign { + name: risingwave_pb::expr::expr_node::Type::#pb_type.into(), + inputs_type: vec![#(#args),*], variadic: #variadic, ret_type: #ret, - build: #build_fn, + build: FuncBuilder::Scalar(#build_fn), + type_infer: #type_infer_fn, deprecated: #deprecated, + state_type: None, + append_only: false, }) }; } }) @@ -478,12 +527,15 @@ impl FunctionAttr { let mut args = Vec::with_capacity(self.args.len()); for ty in &self.args { - args.push(data_type_name(ty)); + args.push(sig_data_type(ty)); } - let ret = data_type_name(&self.ret); + let ret = sig_data_type(&self.ret); let state_type = match &self.state { - Some(ty) if ty != "ref" => data_type_name(ty), - _ => data_type_name(&self.ret), + Some(ty) if ty != "ref" => { + let ty = data_type(ty); + quote! { Some(#ty) } + } + _ => quote! { None }, }; let append_only = match build_fn { false => !user_fn.has_retract(), @@ -495,24 +547,31 @@ impl FunctionAttr { false => format_ident!("{}", self.ident_name()), true => format_ident!("{}_append_only", self.ident_name()), }; - let descriptor_type = quote! { risingwave_expr::sig::agg::AggFuncSig }; let build_fn = if build_fn { let name = format_ident!("{}", user_fn.as_fn().name); quote! { #name } } else { self.generate_agg_build_fn(user_fn)? }; + let type_infer_fn = self.generate_type_infer_fn()?; + let deprecated = self.deprecated; + Ok(quote! { #[risingwave_expr::codegen::ctor] fn #ctor_name() { use risingwave_common::types::{DataType, DataTypeName}; - unsafe { risingwave_expr::sig::agg::_register(#descriptor_type { - func: risingwave_expr::aggregate::AggKind::#pb_type, - inputs_type: &[#(#args),*], - state_type: #state_type, + use risingwave_expr::sig::{_register, FuncSign, SigDataType, FuncBuilder}; + + unsafe { _register(FuncSign { + name: risingwave_expr::aggregate::AggKind::#pb_type.into(), + inputs_type: vec![#(#args),*], + variadic: false, ret_type: #ret, - build: #build_fn, + build: FuncBuilder::Aggregate(#build_fn), + type_infer: #type_infer_fn, + state_type: #state_type, append_only: #append_only, + deprecated: #deprecated, }) }; } }) @@ -531,11 +590,9 @@ impl FunctionAttr { .enumerate() .map(|(i, arg)| { let array = format_ident!("a{i}"); - let variant: TokenStream2 = types::variant(arg).parse().unwrap(); + let array_type: TokenStream2 = types::array_type(arg).parse().unwrap(); quote! { - let ArrayImpl::#variant(#array) = &**input.column_at(#i) else { - bail!("input type mismatch. expect: {}", stringify!(#variant)); - }; + let #array: &#array_type = input.column_at(#i).as_ref().into(); } }) .collect_vec(); @@ -759,41 +816,37 @@ impl FunctionAttr { let name = self.name.clone(); let mut args = Vec::with_capacity(self.args.len()); for ty in &self.args { - args.push(data_type_name(ty)); + args.push(sig_data_type(ty)); } - let ret = data_type_name(&self.ret); + let ret = sig_data_type(&self.ret); let pb_type = format_ident!("{}", utils::to_camel_case(&name)); let ctor_name = format_ident!("{}", self.ident_name()); - let descriptor_type = quote! { risingwave_expr::sig::table_function::FuncSign }; let build_fn = if build_fn { let name = format_ident!("{}", user_fn.name); quote! { #name } } else { self.generate_build_table_function(user_fn)? }; - let type_infer_fn = if let Some(func) = &self.type_infer { - func.parse().unwrap() - } else { - if matches!(self.ret.as_str(), "any" | "list" | "struct") { - return Err(Error::new( - Span::call_site(), - format!("type inference function is required for {}", self.ret), - )); - } - let ty = data_type(&self.ret); - quote! { |_| Ok(#ty) } - }; + let type_infer_fn = self.generate_type_infer_fn()?; + let deprecated = self.deprecated; + Ok(quote! { #[risingwave_expr::codegen::ctor] fn #ctor_name() { use risingwave_common::types::{DataType, DataTypeName}; - unsafe { risingwave_expr::sig::table_function::_register(#descriptor_type { - func: risingwave_pb::expr::table_function::Type::#pb_type, - inputs_type: &[#(#args),*], + use risingwave_expr::sig::{_register, FuncSign, SigDataType, FuncBuilder}; + + unsafe { _register(FuncSign { + name: risingwave_pb::expr::table_function::Type::#pb_type.into(), + inputs_type: vec![#(#args),*], + variadic: false, ret_type: #ret, - build: #build_fn, + build: FuncBuilder::Table(#build_fn), type_infer: #type_infer_fn, + deprecated: #deprecated, + state_type: None, + append_only: false, }) }; } }) @@ -989,9 +1042,17 @@ impl FunctionAttr { } } -fn data_type_name(ty: &str) -> TokenStream2 { - let variant = format_ident!("{}", types::data_type(ty)); - quote! { DataTypeName::#variant } +fn sig_data_type(ty: &str) -> TokenStream2 { + match ty { + "any" => quote! { SigDataType::Any }, + "anyarray" => quote! { SigDataType::AnyArray }, + "struct" => quote! { SigDataType::AnyStruct }, + _ if ty.starts_with("struct") && ty.contains("any") => quote! { SigDataType::AnyStruct }, + _ => { + let datatype = data_type(ty); + quote! { SigDataType::Exact(#datatype) } + } + } } fn data_type(ty: &str) -> TokenStream2 { diff --git a/src/expr/macro/src/lib.rs b/src/expr/macro/src/lib.rs index 51ccd19f6468d..cb57d0cf75383 100644 --- a/src/expr/macro/src/lib.rs +++ b/src/expr/macro/src/lib.rs @@ -154,7 +154,7 @@ mod utils; /// /// ```ignore /// #[function( -/// "unnest(list) -> setof any", +/// "unnest(anyarray) -> setof any", /// type_infer = "|args| Ok(args[0].unnest_list())" /// )] /// ``` @@ -172,7 +172,7 @@ mod utils; /// For instance: /// /// ```ignore -/// #[function("trim_array(list, int32) -> list")] +/// #[function("trim_array(anyarray, int32) -> anyarray")] /// fn trim_array(array: ListRef<'_>, n: i32) -> ListValue {...} /// ``` /// @@ -182,7 +182,7 @@ mod utils; /// to be considered, the `Option` type can be used: /// /// ```ignore -/// #[function("trim_array(list, int32) -> list")] +/// #[function("trim_array(anyarray, int32) -> anyarray")] /// fn trim_array(array: Option>, n: Option) -> ListValue {...} /// ``` /// @@ -390,7 +390,7 @@ mod utils; /// /// | name | SQL type | owned type | reference type | /// | ---------------------- | -------------------- | ------------- | ------------------ | -/// | list | `any[]` | `ListValue` | `ListRef<'_>` | +/// | anyarray | `any[]` | `ListValue` | `ListRef<'_>` | /// | struct | `record` | `StructValue` | `StructRef<'_>` | /// | T[^1][] | `T[]` | `ListValue` | `ListRef<'_>` | /// | struct | `struct` | `(T, ..)` | `(&T, ..)` | @@ -586,7 +586,7 @@ impl FunctionAttr { /// Return a unique name that can be used as an identifier. fn ident_name(&self) -> String { format!("{}_{}_{}", self.name, self.args.join("_"), self.ret) - .replace("[]", "list") + .replace("[]", "array") .replace("...", "variadic") .replace(['<', '>', ' ', ','], "_") .replace("__", "_") diff --git a/src/expr/macro/src/types.rs b/src/expr/macro/src/types.rs index 6dfa0338cfd4a..9dcd37b401f35 100644 --- a/src/expr/macro/src/types.rs +++ b/src/expr/macro/src/types.rs @@ -14,72 +14,58 @@ //! This module provides utility functions for SQL data type conversion and manipulation. -// name data type variant array type owned type ref type primitive +// name data type array type owned type ref type primitive const TYPE_MATRIX: &str = " - boolean Boolean Bool BoolArray bool bool _ - int2 Int16 Int16 I16Array i16 i16 y - int4 Int32 Int32 I32Array i32 i32 y - int8 Int64 Int64 I64Array i64 i64 y - int256 Int256 Int256 Int256Array Int256 Int256Ref<'_> _ - float4 Float32 Float32 F32Array F32 F32 y - float8 Float64 Float64 F64Array F64 F64 y - decimal Decimal Decimal DecimalArray Decimal Decimal y - serial Serial Serial SerialArray Serial Serial y - date Date Date DateArray Date Date y - time Time Time TimeArray Time Time y - timestamp Timestamp Timestamp TimestampArray Timestamp Timestamp y - timestamptz Timestamptz Timestamptz TimestamptzArray Timestamptz Timestamptz y - interval Interval Interval IntervalArray Interval Interval y - varchar Varchar Utf8 Utf8Array Box &str _ - bytea Bytea Bytea BytesArray Box<[u8]> &[u8] _ - jsonb Jsonb Jsonb JsonbArray JsonbVal JsonbRef<'_> _ - list List List ListArray ListValue ListRef<'_> _ - struct Struct Struct StructArray StructValue StructRef<'_> _ + boolean Boolean BoolArray bool bool _ + int2 Int16 I16Array i16 i16 y + int4 Int32 I32Array i32 i32 y + int8 Int64 I64Array i64 i64 y + int256 Int256 Int256Array Int256 Int256Ref<'_> _ + float4 Float32 F32Array F32 F32 y + float8 Float64 F64Array F64 F64 y + decimal Decimal DecimalArray Decimal Decimal y + serial Serial SerialArray Serial Serial y + date Date DateArray Date Date y + time Time TimeArray Time Time y + timestamp Timestamp TimestampArray Timestamp Timestamp y + timestamptz Timestamptz TimestamptzArray Timestamptz Timestamptz y + interval Interval IntervalArray Interval Interval y + varchar Varchar Utf8Array Box &str _ + bytea Bytea BytesArray Box<[u8]> &[u8] _ + jsonb Jsonb JsonbArray JsonbVal JsonbRef<'_> _ + anyarray List ListArray ListValue ListRef<'_> _ + struct Struct StructArray StructValue StructRef<'_> _ + any ??? ArrayImpl ScalarImpl ScalarRefImpl<'_> _ "; /// Maps a data type to its corresponding data type name. pub fn data_type(ty: &str) -> &str { - // XXX: - // For functions that contain `any` type, or `...` variable arguments, - // there are special handlings in the frontend, and the signature won't be accessed. - // So we simply return a placeholder here. - if ty == "any" || ty == "..." { - return "Int32"; - } lookup_matrix(ty, 1) } -/// Maps a data type to its corresponding variant name. -pub fn variant(ty: &str) -> &str { - lookup_matrix(ty, 2) -} - /// Maps a data type to its corresponding array type name. pub fn array_type(ty: &str) -> &str { - if ty == "any" { - return "ArrayImpl"; - } - lookup_matrix(ty, 3) + lookup_matrix(ty, 2) } /// Maps a data type to its corresponding `Scalar` type name. pub fn owned_type(ty: &str) -> &str { - lookup_matrix(ty, 4) + lookup_matrix(ty, 3) } /// Maps a data type to its corresponding `ScalarRef` type name. pub fn ref_type(ty: &str) -> &str { - lookup_matrix(ty, 5) + lookup_matrix(ty, 4) } /// Checks if a data type is primitive. pub fn is_primitive(ty: &str) -> bool { - lookup_matrix(ty, 6) == "y" + lookup_matrix(ty, 5) == "y" } fn lookup_matrix(mut ty: &str, idx: usize) -> &str { if ty.ends_with("[]") { - ty = "list"; + ty = "anyarray"; } else if ty.starts_with("struct") { ty = "struct"; } else if ty == "void" { @@ -105,6 +91,7 @@ pub fn expand_type_wildcard(ty: &str) -> Vec<&str> { .trim() .lines() .map(|l| l.split_whitespace().next().unwrap()) + .filter(|l| *l != "any") .collect(), "*int" => vec!["int2", "int4", "int8"], "*float" => vec!["float4", "float8"], diff --git a/src/frontend/planner_test/tests/testdata/output/struct_query.yaml b/src/frontend/planner_test/tests/testdata/output/struct_query.yaml index 26faf08c19bb3..bb7f8b95cc9ff 100644 --- a/src/frontend/planner_test/tests/testdata/output/struct_query.yaml +++ b/src/frontend/planner_test/tests/testdata/output/struct_query.yaml @@ -274,7 +274,7 @@ Bind error: failed to bind expression: (country + country) Caused by: - Feature is not yet implemented: Add[Struct, Struct] + Feature is not yet implemented: Add[Struct(StructType { field_names: ["address", "zipcode"], field_types: [Varchar, Varchar] }), Struct(StructType { field_names: ["address", "zipcode"], field_types: [Varchar, Varchar] })] Tracking issue: https://github.com/risingwavelabs/risingwave/issues/112 create_source: format: plain diff --git a/src/frontend/src/expr/agg_call.rs b/src/frontend/src/expr/agg_call.rs index f5633f95ee332..c9fe56b841290 100644 --- a/src/frontend/src/expr/agg_call.rs +++ b/src/frontend/src/expr/agg_call.rs @@ -16,7 +16,7 @@ use itertools::Itertools; use risingwave_common::error::{ErrorCode, Result, RwError}; use risingwave_common::types::DataType; use risingwave_expr::aggregate::AggKind; -use risingwave_expr::sig::agg::AGG_FUNC_SIG_MAP; +use risingwave_expr::sig::FUNCTION_REGISTRY; use super::{Expr, ExprImpl, Literal, OrderBy}; use crate::utils::Condition; @@ -70,11 +70,6 @@ impl AggCall { // min/max allowed for all types except for bool and jsonb (#7981) (AggKind::Min | AggKind::Max, [DataType::Jsonb]) => return Err(err()), - // may return list or struct type - (AggKind::Min | AggKind::Max | AggKind::FirstValue | AggKind::LastValue, [input]) => { - input.clone() - } - (AggKind::ArrayAgg, [input]) => List(Box::new(input.clone())), // functions that are rewritten in the frontend and don't exist in the expr crate (AggKind::Avg, [input]) => match input { Int16 | Int32 | Int64 | Decimal => Decimal, @@ -90,21 +85,9 @@ impl AggCall { Float32 | Float64 | Int256 => Float64, _ => return Err(err()), }, - // Ordered-Set Aggregation - (AggKind::PercentileCont, [input]) => match input { - Float64 => Float64, - _ => return Err(err()), - }, - (AggKind::PercentileDisc | AggKind::Mode, [input]) => input.clone(), (AggKind::Grouping, _) => Int32, // other functions are handled by signature map - _ => { - let args = args.iter().map(|t| t.into()).collect::>(); - return match AGG_FUNC_SIG_MAP.get_return_type(agg_kind, &args) { - Some(t) => Ok(t.into()), - None => Err(err()), - }; - } + _ => FUNCTION_REGISTRY.get_return_type(agg_kind, args)?, }) } diff --git a/src/frontend/src/expr/function_call.rs b/src/frontend/src/expr/function_call.rs index 71efe8063c7a5..f5e618892fc5e 100644 --- a/src/frontend/src/expr/function_call.rs +++ b/src/frontend/src/expr/function_call.rs @@ -211,15 +211,6 @@ impl FunctionCall { match expr_type { ExprType::Some | ExprType::All => { let return_type = infer_some_all(func_types, &mut inputs)?; - - if return_type != DataType::Boolean { - return Err(ErrorCode::BindError(format!( - "op SOME/ANY/ALL (array) requires operator to yield boolean, but got {:?}", - return_type - )) - .into()); - } - Ok(FunctionCall::new_unchecked(expr_type, inputs, return_type).into()) } ExprType::Not | ExprType::IsNotNull | ExprType::IsNull => Ok(FunctionCall::new( diff --git a/src/frontend/src/expr/mod.rs b/src/frontend/src/expr/mod.rs index 0af30ccb364f8..f7dc01e2eef35 100644 --- a/src/frontend/src/expr/mod.rs +++ b/src/frontend/src/expr/mod.rs @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::iter::once; - use enum_as_inner::EnumAsInner; use fixedbitset::FixedBitSet; use futures::FutureExt; @@ -68,8 +66,8 @@ pub use session_timezone::SessionTimezone; pub use subquery::{Subquery, SubqueryKind}; pub use table_function::{TableFunction, TableFunctionType}; pub use type_inference::{ - agg_func_sigs, align_types, cast_map_array, cast_ok, cast_sigs, func_sigs, infer_some_all, - infer_type, least_restrictive, AggFuncSig, CastContext, CastSig, FuncSign, + align_types, cast_map_array, cast_ok, cast_sigs, infer_some_all, infer_type, least_restrictive, + CastContext, CastSig, FuncSign, }; pub use user_defined_function::UserDefinedFunction; pub use utils::*; @@ -202,7 +200,7 @@ impl ExprImpl { /// # Panics /// Panics if `input_ref >= input_col_num`. pub fn collect_input_refs(&self, input_col_num: usize) -> FixedBitSet { - collect_input_refs(input_col_num, once(self)) + collect_input_refs(input_col_num, [self]) } /// Check if the expression has no side effects and output is deterministic @@ -254,6 +252,11 @@ impl ExprImpl { FunctionCall::cast_mut(self, target, CastContext::Implicit) } + /// Shorthand to inplace cast expr to `target` type in explicit context. + pub fn cast_explicit_mut(&mut self, target: DataType) -> Result<(), CastError> { + FunctionCall::cast_mut(self, target, CastContext::Explicit) + } + /// Ensure the return type of this expression is an array of some type. pub fn ensure_array_type(&self) -> Result<(), ErrorCode> { if self.is_untyped() { diff --git a/src/frontend/src/expr/table_function.rs b/src/frontend/src/expr/table_function.rs index 6b01ca2bc98cb..dfb028d605705 100644 --- a/src/frontend/src/expr/table_function.rs +++ b/src/frontend/src/expr/table_function.rs @@ -15,9 +15,8 @@ use std::sync::Arc; use itertools::Itertools; -use risingwave_common::error::ErrorCode; use risingwave_common::types::DataType; -use risingwave_expr::sig::table_function::FUNC_SIG_MAP; +use risingwave_expr::sig::FUNCTION_REGISTRY; pub use risingwave_pb::expr::table_function::PbType as TableFunctionType; use risingwave_pb::expr::{ TableFunction as TableFunctionPb, UserDefinedTableFunction as UserDefinedTableFunctionPb, @@ -44,20 +43,10 @@ impl TableFunction { /// Create a `TableFunction` expr with the return type inferred from `func_type` and types of /// `inputs`. pub fn new(func_type: TableFunctionType, args: Vec) -> RwResult { - let arg_types = args.iter().map(|c| c.return_type()).collect_vec(); - let signature = FUNC_SIG_MAP - .get( - func_type, - &args.iter().map(|c| c.return_type().into()).collect_vec(), - ) - .ok_or_else(|| { - ErrorCode::BindError(format!( - "table function not found: {:?}({})", - func_type, - arg_types.iter().map(|t| format!("{:?}", t)).join(", "), - )) - })?; - let return_type = (signature.type_infer)(&arg_types)?; + let return_type = FUNCTION_REGISTRY.get_return_type( + func_type, + &args.iter().map(|c| c.return_type()).collect_vec(), + )?; Ok(TableFunction { args, return_type, diff --git a/src/frontend/src/expr/type_inference/cast.rs b/src/frontend/src/expr/type_inference/cast.rs index b7e3749236f9f..b941732a2a720 100644 --- a/src/frontend/src/expr/type_inference/cast.rs +++ b/src/frontend/src/expr/type_inference/cast.rs @@ -111,11 +111,13 @@ pub fn align_array_and_element( pub fn cast_ok(source: &DataType, target: &DataType, allows: CastContext) -> bool { cast_ok_struct(source, target, allows) || cast_ok_array(source, target, allows) - || cast_ok_base(source.into(), target.into(), allows) + || cast_ok_base(source, target, allows) } -pub fn cast_ok_base(source: DataTypeName, target: DataTypeName, allows: CastContext) -> bool { - matches!(CAST_MAP.get(&(source, target)), Some(context) if *context <= allows) +/// Checks whether casting from `source` to `target` is ok in `allows` context. +/// Both `source` and `target` must be base types, i.e. not struct or array. +pub fn cast_ok_base(source: &DataType, target: &DataType, allows: CastContext) -> bool { + matches!(CAST_MAP.get(&(source.into(), target.into())), Some(context) if *context <= allows) } fn cast_ok_struct(source: &DataType, target: &DataType, allows: CastContext) -> bool { diff --git a/src/frontend/src/expr/type_inference/func.rs b/src/frontend/src/expr/type_inference/func.rs index 69be8376b46ca..2e7eebf42362f 100644 --- a/src/frontend/src/expr/type_inference/func.rs +++ b/src/frontend/src/expr/type_inference/func.rs @@ -14,10 +14,10 @@ use itertools::Itertools as _; use num_integer::Integer as _; -use risingwave_common::error::{ErrorCode, Result, RwError}; -use risingwave_common::types::{DataType, DataTypeName, StructType}; +use risingwave_common::error::{ErrorCode, Result}; +use risingwave_common::types::{DataType, StructType}; use risingwave_common::util::iter_util::ZipEqFast; -pub use risingwave_expr::sig::func::*; +pub use risingwave_expr::sig::*; use super::{align_types, cast_ok_base, CastContext}; use crate::expr::type_inference::cast::align_array_and_element; @@ -27,7 +27,7 @@ use crate::expr::{cast_ok, is_row_function, Expr as _, ExprImpl, ExprType, Funct /// is not supported on backend. /// /// It also mutates the `inputs` by adding necessary casts. -pub fn infer_type(func_type: ExprType, inputs: &mut Vec) -> Result { +pub fn infer_type(func_type: ExprType, inputs: &mut [ExprImpl]) -> Result { if let Some(res) = infer_type_for_special(func_type, inputs).transpose() { return res; } @@ -36,30 +36,28 @@ pub fn infer_type(func_type: ExprType, inputs: &mut Vec) -> Result None, - false => Some(e.return_type().into()), + false => Some(e.return_type()), }) .collect_vec(); - let sig = infer_type_name(&FUNC_SIG_MAP, func_type, &actuals)?; - let inputs_owned = std::mem::take(inputs); - *inputs = inputs_owned - .into_iter() - .zip_eq_fast(sig.inputs_type) - .map(|(expr, t)| { - if expr.is_untyped() || DataTypeName::from(expr.return_type()) != *t { - if t.is_scalar() { - return expr.cast_implicit((*t).into()).map_err(Into::into); - } else { - return Err(ErrorCode::BindError(format!( - "Cannot implicitly cast '{:?}' to polymorphic type {:?}", - &expr, t - )) - .into()); - } + let sig = infer_type_name(&FUNCTION_REGISTRY, func_type, &actuals)?; + + // add implicit casts to inputs + for (expr, t) in inputs.iter_mut().zip_eq_fast(&sig.inputs_type) { + if expr.is_untyped() || !t.matches(&expr.return_type()) { + if let SigDataType::Exact(t) = t { + expr.cast_implicit_mut(t.clone())?; + } else { + return Err(ErrorCode::BindError(format!( + "Cannot implicitly cast '{expr:?}' to polymorphic type {t:?}", + )) + .into()); } - Ok(expr) - }) - .try_collect::<_, _, RwError>()?; - Ok(sig.ret_type.into()) + } + } + + let input_types = inputs.iter().map(|expr| expr.return_type()).collect_vec(); + let return_type = (sig.type_infer)(&input_types)?; + Ok(return_type) } pub fn infer_some_all( @@ -69,7 +67,7 @@ pub fn infer_some_all( let element_type = if inputs[1].is_untyped() { None } else if let DataType::List(datatype) = inputs[1].return_type() { - Some(DataTypeName::from(*datatype)) + Some(*datatype) } else { return Err(ErrorCode::BindError( "op SOME/ANY/ALL (array) requires array on right side".to_string(), @@ -79,37 +77,38 @@ pub fn infer_some_all( let final_type = func_types.pop().unwrap(); let actuals = vec![ - (!inputs[0].is_untyped()).then_some(inputs[0].return_type().into()), - element_type, + (!inputs[0].is_untyped()).then_some(inputs[0].return_type()), + element_type.clone(), ]; - let sig = infer_type_name(&FUNC_SIG_MAP, final_type, &actuals)?; - if DataTypeName::from(inputs[0].return_type()) != sig.inputs_type[0] { - if matches!( - sig.inputs_type[0], - DataTypeName::List | DataTypeName::Struct - ) { + let sig = infer_type_name(&FUNCTION_REGISTRY, final_type, &actuals)?; + if sig.ret_type != DataType::Boolean.into() { + return Err(ErrorCode::BindError(format!( + "op SOME/ANY/ALL (array) requires operator to yield boolean, but got {}", + sig.ret_type + )) + .into()); + } + if !sig.inputs_type[0].matches(&inputs[0].return_type()) { + let SigDataType::Exact(t) = &sig.inputs_type[0] else { return Err(ErrorCode::BindError( "array of array/struct on right are not supported yet".into(), ) .into()); - } - inputs[0].cast_implicit_mut(sig.inputs_type[0].into())?; + }; + inputs[0].cast_implicit_mut(t.clone())?; } - if element_type != Some(sig.inputs_type[1]) { - if matches!( - sig.inputs_type[1], - DataTypeName::List | DataTypeName::Struct - ) { + if !matches!(&element_type, Some(e) if sig.inputs_type[1].matches(e)) { + let SigDataType::Exact(t) = &sig.inputs_type[1] else { return Err( ErrorCode::BindError("array/struct on left are not supported yet".into()).into(), ); - } - inputs[1].cast_implicit_mut(DataType::List(Box::new(sig.inputs_type[1].into())))?; + }; + inputs[1].cast_implicit_mut(DataType::List(Box::new(t.clone())))?; } let inputs_owned = std::mem::take(inputs); let mut func_call = - FunctionCall::new_unchecked(final_type, inputs_owned, sig.ret_type.into()).into(); + FunctionCall::new_unchecked(final_type, inputs_owned, DataType::Boolean).into(); while let Some(func_type) = func_types.pop() { func_call = FunctionCall::new(func_type, vec![func_call])?.into(); } @@ -273,17 +272,17 @@ fn infer_struct_cast_target_type( (NestedType::Infer(l), NestedType::Infer(r)) => { // Both sides are *unknown*, using the sig_map to infer the return type. let actuals = vec![None, None]; - let sig = infer_type_name(&FUNC_SIG_MAP, func_type, &actuals)?; + let sig = infer_type_name(&FUNCTION_REGISTRY, func_type, &actuals)?; Ok(( sig.ret_type != l.into(), sig.ret_type != r.into(), - sig.ret_type.into(), + sig.ret_type.as_exact().clone(), )) } } } -/// Special exprs that cannot be handled by [`infer_type_name`] and [`FuncSigMap`] are handled here. +/// Special exprs that cannot be handled by [`infer_type_name`] and [`FunctionRegistry`] are handled here. /// These include variadic functions, list and struct type, as well as non-implicit cast. /// /// We should aim for enhancing the general inferring framework and reduce the special cases here. @@ -294,7 +293,7 @@ fn infer_struct_cast_target_type( /// * `Ok(None)` when no special rule matches and it should try general rules later fn infer_type_for_special( func_type: ExprType, - inputs: &mut Vec, + inputs: &mut [ExprImpl], ) -> Result> { match func_type { ExprType::Case => { @@ -321,55 +320,31 @@ fn infer_type_for_special( } ExprType::ConcatWs => { ensure_arity!("concat_ws", 2 <= | inputs |); - let inputs_owned = std::mem::take(inputs); - *inputs = inputs_owned - .into_iter() - .enumerate() - .map(|(i, input)| match i { - // 0-th arg must be string - 0 => input.cast_implicit(DataType::Varchar).map_err(Into::into), - // subsequent can be any type, using the output format - _ => input.cast_output(), - }) - .try_collect()?; + // 0-th arg must be string + inputs[0].cast_implicit_mut(DataType::Varchar)?; + for input in inputs.iter_mut().skip(1) { + // subsequent can be any type, using the output format + let owned = input.take(); + *input = owned.cast_output()?; + } Ok(Some(DataType::Varchar)) } ExprType::ConcatOp => { - let inputs_owned = std::mem::take(inputs); - *inputs = inputs_owned - .into_iter() - .map(|input| input.cast_explicit(DataType::Varchar)) - .try_collect()?; + for input in inputs { + input.cast_explicit_mut(DataType::Varchar)?; + } Ok(Some(DataType::Varchar)) } ExprType::Format => { ensure_arity!("format", 1 <= | inputs |); - let inputs_owned = std::mem::take(inputs); - *inputs = inputs_owned - .into_iter() - .enumerate() - .map(|(i, input)| match i { - // 0-th arg must be string - 0 => input.cast_implicit(DataType::Varchar).map_err(Into::into), - // subsequent can be any type, using the output format - _ => input.cast_output(), - }) - .try_collect()?; - Ok(Some(DataType::Varchar)) - } - ExprType::IsNotNull => { - ensure_arity!("is_not_null", | inputs | == 1); - match inputs[0].return_type() { - DataType::Struct(_) | DataType::List { .. } => Ok(Some(DataType::Boolean)), - _ => Ok(None), - } - } - ExprType::IsNull => { - ensure_arity!("is_null", | inputs | == 1); - match inputs[0].return_type() { - DataType::Struct(_) | DataType::List { .. } => Ok(Some(DataType::Boolean)), - _ => Ok(None), + // 0-th arg must be string + inputs[0].cast_implicit_mut(DataType::Varchar)?; + for input in inputs.iter_mut().skip(1) { + // subsequent can be any type, using the output format + let owned = input.take(); + *input = owned.cast_output()?; } + Ok(Some(DataType::Varchar)) } ExprType::Equal | ExprType::NotEqual @@ -431,15 +406,6 @@ fn infer_type_for_special( .into()) } } - ExprType::RegexpMatch => { - ensure_arity!("regexp_match", 2 <= | inputs | <= 3); - inputs[0].cast_implicit_mut(DataType::Varchar)?; - inputs[1].cast_implicit_mut(DataType::Varchar)?; - if let Some(flags) = inputs.get_mut(2) { - flags.cast_implicit_mut(DataType::Varchar)?; - } - Ok(Some(DataType::List(Box::new(DataType::Varchar)))) - } ExprType::ArrayCat => { ensure_arity!("array_cat", | inputs | == 2); let left_type = (!inputs[0].is_untyped()).then(|| inputs[0].return_type()); @@ -451,11 +417,9 @@ fn infer_type_for_special( // when neither type is available, default to `varchar[]` // when one side is unknown and other side is list, use that list type let t = t.unwrap_or_else(|| DataType::List(DataType::Varchar.into())); - let inputs_owned = std::mem::take(inputs); - *inputs = inputs_owned - .into_iter() - .map(|e| e.cast_implicit(t.clone())) - .try_collect()?; + for input in &mut *inputs { + input.cast_implicit_mut(t.clone())?; + } Some(t) } (Some(DataType::List(_)), Some(DataType::List(_))) => { @@ -557,24 +521,6 @@ fn infer_type_for_special( .into()), } } - ExprType::ArrayDistinct => { - ensure_arity!("array_distinct", | inputs | == 1); - inputs[0].ensure_array_type()?; - - Ok(Some(inputs[0].return_type())) - } - ExprType::ArrayMin => { - ensure_arity!("array_min", | inputs | == 1); - inputs[0].ensure_array_type()?; - - Ok(Some(inputs[0].return_type().as_list().clone())) - } - ExprType::ArraySort => { - ensure_arity!("array_sort", | inputs | == 1); - inputs[0].ensure_array_type()?; - - Ok(Some(inputs[0].return_type())) - } ExprType::ArrayDims => { ensure_arity!("array_dims", | inputs | == 1); inputs[0].ensure_array_type()?; @@ -587,44 +533,6 @@ fn infer_type_for_special( } Ok(Some(DataType::Varchar)) } - ExprType::ArrayMax => { - ensure_arity!("array_max", | inputs | == 1); - inputs[0].ensure_array_type()?; - - Ok(Some(inputs[0].return_type().as_list().clone())) - } - ExprType::ArraySum => { - ensure_arity!("array_sum", | inputs | == 1); - inputs[0].ensure_array_type()?; - - let return_type = match inputs[0].return_type().as_list().clone() { - DataType::Int16 | DataType::Int32 => DataType::Int64, - DataType::Int64 | DataType::Decimal => DataType::Decimal, - DataType::Float32 => DataType::Float32, - DataType::Float64 => DataType::Float64, - DataType::Interval => DataType::Interval, - _ => return Err(ErrorCode::InvalidParameterValue("".to_string()).into()), - }; - - Ok(Some(return_type)) - } - ExprType::StringToArray => { - ensure_arity!("string_to_array", 2 <= | inputs | <= 3); - - if !inputs.iter().all(|e| e.return_type() == DataType::Varchar) { - return Ok(None); - } - - Ok(Some(DataType::List(Box::new(DataType::Varchar)))) - } - ExprType::TrimArray => { - ensure_arity!("trim_array", | inputs | == 2); - inputs[0].ensure_array_type()?; - - inputs[1].cast_implicit_mut(DataType::Int32)?; - - Ok(Some(inputs[0].return_type())) - } ExprType::Vnode => { ensure_arity!("vnode", 1 <= | inputs |); Ok(Some(DataType::Int16)) @@ -653,22 +561,24 @@ fn infer_type_for_special( /// 5. Attempt to narrow down candidates by assuming all arguments are same type. This covers Rule /// 4f in `PostgreSQL`. See [`narrow_same_type`] for details. fn infer_type_name<'a>( - sig_map: &'a FuncSigMap, + sig_map: &'a FunctionRegistry, func_type: ExprType, - inputs: &[Option], + inputs: &[Option], ) -> Result<&'a FuncSign> { let candidates = sig_map.get_with_arg_nums(func_type, inputs.len()); // Binary operators have a special `unknown` handling rule for exact match. We do not // distinguish operators from functions as of now. if inputs.len() == 2 { - let t = match (inputs[0], inputs[1]) { + let t = match (&inputs[0], &inputs[1]) { (None, t) => Ok(t), (t, None) => Ok(t), (Some(_), Some(_)) => Err(()), }; if let Ok(Some(t)) = t { - let exact = candidates.iter().find(|sig| sig.inputs_type == [t, t]); + let exact = candidates + .iter() + .find(|sig| sig.inputs_type[0].matches(t) && sig.inputs_type[1].matches(t)); if let Some(sig) = exact { return Ok(sig); } @@ -710,11 +620,11 @@ fn infer_type_name<'a>( /// Checks if `t` is a preferred type in any type category, as defined by `PostgreSQL`: /// . -fn is_preferred(t: DataTypeName) -> bool { - use DataTypeName as T; +fn is_preferred(t: &SigDataType) -> bool { + use DataType as T; matches!( t, - T::Float64 | T::Boolean | T::Varchar | T::Timestamptz | T::Interval + SigDataType::Exact(T::Float64 | T::Boolean | T::Varchar | T::Timestamptz | T::Interval) ) } @@ -724,8 +634,9 @@ fn is_preferred(t: DataTypeName) -> bool { /// /// Sometimes it is more convenient to include equality when checking whether a formal parameter can /// accept an actual argument. So we introduced `eq_ok` to control this behavior. -fn implicit_ok(source: DataTypeName, target: DataTypeName, eq_ok: bool) -> bool { - eq_ok && source == target || cast_ok_base(source, target, CastContext::Implicit) +fn implicit_ok(source: &DataType, target: &SigDataType, eq_ok: bool) -> bool { + eq_ok && target.matches(source) + || target.is_exact() && cast_ok_base(source, target.as_exact(), CastContext::Implicit) } /// Find the top `candidates` that match `inputs` on most non-null positions. This covers Rule 2, @@ -754,10 +665,7 @@ fn implicit_ok(source: DataTypeName, target: DataTypeName, eq_ok: bool) -> bool /// [rule 4a src]: https://github.com/postgres/postgres/blob/86a4dc1e6f29d1992a2afa3fac1a0b0a6e84568c/src/backend/parser/parse_func.c#L907-L947 /// [rule 4c src]: https://github.com/postgres/postgres/blob/86a4dc1e6f29d1992a2afa3fac1a0b0a6e84568c/src/backend/parser/parse_func.c#L1062-L1104 /// [rule 4d src]: https://github.com/postgres/postgres/blob/86a4dc1e6f29d1992a2afa3fac1a0b0a6e84568c/src/backend/parser/parse_func.c#L1106-L1153 -fn top_matches<'a>( - candidates: &[&'a FuncSign], - inputs: &[Option], -) -> Vec<&'a FuncSign> { +fn top_matches<'a>(candidates: &[&'a FuncSign], inputs: &[Option]) -> Vec<&'a FuncSign> { let mut best_exact = 0; let mut best_preferred = 0; let mut best_candidates = Vec::new(); @@ -768,13 +676,13 @@ fn top_matches<'a>( let mut castable = true; for (formal, actual) in sig.inputs_type.iter().zip_eq_fast(inputs) { let Some(actual) = actual else { continue }; - if formal == actual { + if formal.matches(actual) { n_exact += 1; - } else if !implicit_ok(*actual, *formal, false) { + } else if !implicit_ok(actual, formal, false) { castable = false; break; } - if is_preferred(*formal) { + if is_preferred(formal) { n_preferred += 1; } } @@ -816,9 +724,9 @@ fn top_matches<'a>( /// [rule 4e src]: https://github.com/postgres/postgres/blob/86a4dc1e6f29d1992a2afa3fac1a0b0a6e84568c/src/backend/parser/parse_func.c#L1164-L1298 fn narrow_category<'a>( candidates: Vec<&'a FuncSign>, - inputs: &[Option], + inputs: &[Option], ) -> Vec<&'a FuncSign> { - const BIASED_TYPE: DataTypeName = DataTypeName::Varchar; + const BIASED_TYPE: SigDataType = SigDataType::Exact(DataType::Varchar); let Ok(categories) = inputs .iter() .enumerate() @@ -835,21 +743,21 @@ fn narrow_category<'a>( if actual.is_some() { return Ok(None); } - let mut category = Ok(candidates[0].inputs_type[i]); + let mut category = Ok(&candidates[0].inputs_type[i]); for sig in &candidates[1..] { - let formal = sig.inputs_type[i]; - if formal == BIASED_TYPE || category == Ok(BIASED_TYPE) { - category = Ok(BIASED_TYPE); + let formal = &sig.inputs_type[i]; + if formal == &BIASED_TYPE || category == Ok(&BIASED_TYPE) { + category = Ok(&BIASED_TYPE); break; } // formal != BIASED_TYPE && category.is_err(): // - Category conflict err can only be solved by a later varchar. Skip this // candidate. - let Ok(selected) = category else { continue }; + let Ok(selected) = &category else { continue }; // least_restrictive or mark temporary conflict err - if implicit_ok(formal, selected, true) { + if formal.is_exact() && implicit_ok(formal.as_exact(), selected, true) { // noop - } else if implicit_ok(selected, formal, false) { + } else if selected.is_exact() && implicit_ok(selected.as_exact(), formal, false) { category = Ok(formal); } else { category = Err(()); @@ -874,8 +782,10 @@ fn narrow_category<'a>( let Some(selected) = category else { return true; }; - *formal == *selected - || !is_preferred(*selected) && implicit_ok(*formal, *selected, false) + formal == *selected + || !is_preferred(selected) + && formal.is_exact() + && implicit_ok(formal.as_exact(), selected, false) }) }) .copied() @@ -907,12 +817,12 @@ fn narrow_category<'a>( /// [Rule 2]: https://www.postgresql.org/docs/current/typeconv-oper.html#:~:text=then%20assume%20it%20is%20the%20same%20type%20as%20the%20other%20argument%20for%20this%20check fn narrow_same_type<'a>( candidates: Vec<&'a FuncSign>, - inputs: &[Option], + inputs: &[Option], ) -> Vec<&'a FuncSign> { let Ok(Some(same_type)) = inputs.iter().try_fold(None, |acc, cur| match (acc, cur) { - (None, t) => Ok(*t), + (None, t) => Ok(t.as_ref()), (t, None) => Ok(t), - (Some(l), Some(r)) if l == *r => Ok(Some(l)), + (Some(l), Some(r)) if l == r => Ok(Some(l)), _ => Err(()), }) else { return candidates; @@ -922,7 +832,7 @@ fn narrow_same_type<'a>( .filter(|sig| { sig.inputs_type .iter() - .all(|formal| implicit_ok(same_type, *formal, true)) + .all(|formal| implicit_ok(same_type, formal, true)) }) .copied() .collect_vec(); @@ -932,7 +842,7 @@ fn narrow_same_type<'a>( } } -struct TypeDebug<'a>(&'a Option); +struct TypeDebug<'a>(&'a Option); impl<'a> std::fmt::Debug for TypeDebug<'a> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self.0 { @@ -967,7 +877,7 @@ mod tests { ) .into() }) - .collect(); + .collect_vec(); infer_type(func_type, &mut inputs) } @@ -1108,10 +1018,7 @@ mod tests { #[test] fn test_match_implicit() { - use DataTypeName as T; - // func_name and ret_type does not affect the overload resolution logic - const DUMMY_FUNC: ExprType = ExprType::Add; - const DUMMY_RET: T = T::Int32; + use DataType as T; let testcases = [ ( "Binary special rule prefers arguments of same type.", @@ -1213,21 +1120,28 @@ mod tests { ), ]; for (desc, candidates, inputs, expected) in testcases { - let mut sig_map = FuncSigMap::default(); + let mut sig_map = FunctionRegistry::default(); for formals in candidates { sig_map.insert(FuncSign { - func: DUMMY_FUNC, - inputs_type: formals, + // func_name does not affect the overload resolution logic + name: ExprType::Add.into(), + inputs_type: formals.iter().map(|t| t.clone().into()).collect(), variadic: false, - ret_type: DUMMY_RET, - build: |_, _| unreachable!(), + // ret_type does not affect the overload resolution logic + ret_type: T::Int32.into(), + build: FuncBuilder::Scalar(|_, _| unreachable!()), + type_infer: |_| unreachable!(), deprecated: false, + state_type: None, + append_only: false, }); } - let result = infer_type_name(&sig_map, DUMMY_FUNC, inputs); + let result = infer_type_name(&sig_map, ExprType::Add, inputs); match (expected, result) { (Ok(expected), Ok(found)) => { - assert_eq!(expected, found.inputs_type, "case `{}`", desc) + if !found.match_args(expected) { + panic!("case `{}` expect {:?} != found {:?}", desc, expected, found) + } } (Ok(_), Err(err)) => panic!("case `{}` unexpected error: {:?}", desc, err), (Err(_), Ok(f)) => panic!( diff --git a/src/frontend/src/expr/type_inference/mod.rs b/src/frontend/src/expr/type_inference/mod.rs index 8135787ea7cdd..08007b14a2751 100644 --- a/src/frontend/src/expr/type_inference/mod.rs +++ b/src/frontend/src/expr/type_inference/mod.rs @@ -21,5 +21,4 @@ pub use cast::{ align_types, cast_map_array, cast_ok, cast_ok_base, cast_sigs, least_restrictive, CastContext, CastSig, }; -pub use func::{func_sigs, infer_some_all, infer_type, FuncSign}; -pub use risingwave_expr::sig::agg::{agg_func_sigs, AggFuncSig}; +pub use func::{infer_some_all, infer_type, FuncSign}; diff --git a/src/frontend/src/optimizer/plan_node/generic/agg.rs b/src/frontend/src/optimizer/plan_node/generic/agg.rs index e02c99858b7ce..97bd1435fb093 100644 --- a/src/frontend/src/optimizer/plan_node/generic/agg.rs +++ b/src/frontend/src/optimizer/plan_node/generic/agg.rs @@ -24,7 +24,7 @@ use risingwave_common::util::iter_util::ZipEqFast; use risingwave_common::util::sort_util::{ColumnOrder, ColumnOrderDisplay, OrderType}; use risingwave_common::util::value_encoding::DatumToProtoExt; use risingwave_expr::aggregate::{agg_kinds, AggKind}; -use risingwave_expr::sig::agg::AGG_FUNC_SIG_MAP; +use risingwave_expr::sig::FUNCTION_REGISTRY; use risingwave_pb::expr::{PbAggCall, PbConstant}; use risingwave_pb::stream_plan::{agg_call_state, AggCallState as AggCallStatePb}; @@ -488,15 +488,15 @@ impl Agg { .iter() .zip_eq_fast(&mut out_fields[self.group_key.len()..]) { - let sig = AGG_FUNC_SIG_MAP - .get( + let sig = FUNCTION_REGISTRY + .get_aggregate( agg_call.agg_kind, &agg_call .inputs .iter() - .map(|input| (&input.data_type).into()) + .map(|input| input.data_type.clone()) .collect_vec(), - (&agg_call.return_type).into(), + &agg_call.return_type, in_append_only, ) .expect("agg not found"); @@ -505,7 +505,10 @@ impl Agg { // for backward compatibility, the state type is same as the return type. // its values in the intermediate state table are always null. } else { - field.data_type = sig.state_type.into(); + field.data_type = sig + .state_type + .clone() + .unwrap_or(sig.ret_type.as_exact().clone()); } } let in_dist_key = self.input.distribution().dist_column_indices().to_vec(); diff --git a/src/tests/sqlsmith/src/sql_gen/agg.rs b/src/tests/sqlsmith/src/sql_gen/agg.rs index f26c37b4bd619..c42eb6c7b0ffc 100644 --- a/src/tests/sqlsmith/src/sql_gen/agg.rs +++ b/src/tests/sqlsmith/src/sql_gen/agg.rs @@ -16,6 +16,7 @@ use rand::seq::SliceRandom; use rand::Rng; use risingwave_common::types::DataType; use risingwave_expr::aggregate::AggKind; +use risingwave_expr::sig::SigDataType; use risingwave_sqlparser::ast::{ Expr, Function, FunctionArg, FunctionArgExpr, Ident, ObjectName, OrderByExpr, }; @@ -30,13 +31,12 @@ impl<'a, R: Rng> SqlGenerator<'a, R> { Some(funcs) => funcs, }; let func = funcs.choose(&mut self.rng).unwrap(); - if matches!( - (func.func, func.inputs_type.as_slice()), - ( - AggKind::Min | AggKind::Max, - [DataType::Boolean | DataType::Jsonb] + if matches!(func.name.as_aggregate(), AggKind::Min | AggKind::Max) + && matches!( + func.ret_type, + SigDataType::Exact(DataType::Boolean | DataType::Jsonb) ) - ) { + { return self.gen_simple_scalar(ret); } @@ -45,13 +45,13 @@ impl<'a, R: Rng> SqlGenerator<'a, R> { let exprs: Vec = func .inputs_type .iter() - .map(|t| self.gen_expr(t, context)) + .map(|t| self.gen_expr(t.as_exact(), context)) .collect(); // DISTINCT now only works with agg kinds except `ApproxCountDistinct`, and with at least // one argument and only the first being non-constant. See `Binder::bind_normal_agg` // for more details. - let distinct_allowed = func.func != AggKind::ApproxCountDistinct + let distinct_allowed = func.name.as_aggregate() != AggKind::ApproxCountDistinct && !exprs.is_empty() && exprs.iter().skip(1).all(|e| matches!(e, Expr::Value(_))); let distinct = distinct_allowed && self.flip_coin(); @@ -79,7 +79,7 @@ impl<'a, R: Rng> SqlGenerator<'a, R> { } else { vec![] }; - self.make_agg_expr(func.func, &exprs, distinct, filter, order_by) + self.make_agg_expr(func.name.as_aggregate(), &exprs, distinct, filter, order_by) .unwrap_or_else(|| self.gen_simple_scalar(ret)) } diff --git a/src/tests/sqlsmith/src/sql_gen/expr.rs b/src/tests/sqlsmith/src/sql_gen/expr.rs index f9772c97d4b5c..9999dcd9ea641 100644 --- a/src/tests/sqlsmith/src/sql_gen/expr.rs +++ b/src/tests/sqlsmith/src/sql_gen/expr.rs @@ -16,7 +16,8 @@ use itertools::Itertools; use rand::seq::SliceRandom; use rand::Rng; use risingwave_common::types::{DataType, DataTypeName, StructType}; -use risingwave_frontend::expr::{agg_func_sigs, cast_sigs, func_sigs}; +use risingwave_expr::sig::cast::cast_sigs; +use risingwave_expr::sig::FUNCTION_REGISTRY; use risingwave_sqlparser::ast::{Expr, Ident, OrderByExpr, Value}; use crate::sql_gen::types::data_type_to_ast_data_type; @@ -302,29 +303,25 @@ pub(crate) fn sql_null() -> Expr { // Add variadic function signatures. Can add these functions // to a FUNC_TABLE too. pub fn print_function_table() -> String { - let func_str = func_sigs() + let func_str = FUNCTION_REGISTRY + .iter_scalars() .map(|sign| { format!( - "{:?}({}) -> {:?}", - sign.func, - sign.inputs_type - .iter() - .map(|arg| format!("{:?}", arg)) - .join(", "), + "{}({}) -> {}", + sign.name, + sign.inputs_type.iter().format(", "), sign.ret_type, ) }) .join("\n"); - let agg_func_str = agg_func_sigs() + let agg_func_str = FUNCTION_REGISTRY + .iter_aggregates() .map(|sign| { format!( - "{:?}({}) -> {:?}", - sign.func, - sign.inputs_type - .iter() - .map(|arg| format!("{:?}", arg)) - .join(", "), + "{}({}) -> {}", + sign.name, + sign.inputs_type.iter().format(", "), sign.ret_type, ) }) diff --git a/src/tests/sqlsmith/src/sql_gen/functions.rs b/src/tests/sqlsmith/src/sql_gen/functions.rs index 6af491bd8a64d..3583b820f1204 100644 --- a/src/tests/sqlsmith/src/sql_gen/functions.rs +++ b/src/tests/sqlsmith/src/sql_gen/functions.rs @@ -127,30 +127,30 @@ impl<'a, R: Rng> SqlGenerator<'a, R> { Some(funcs) => funcs, }; let func = funcs.choose(&mut self.rng).unwrap(); - let can_implicit_cast = INVARIANT_FUNC_SET.contains(&func.func); + let can_implicit_cast = INVARIANT_FUNC_SET.contains(&func.name.as_scalar()); let exprs: Vec = func .inputs_type .iter() .map(|t| { - if let Some(from_tys) = IMPLICIT_CAST_TABLE.get(t) + if let Some(from_tys) = IMPLICIT_CAST_TABLE.get(t.as_exact()) && can_implicit_cast && self.flip_coin() { let from_ty = &from_tys.choose(&mut self.rng).unwrap().from_type; self.gen_implicit_cast(from_ty, context) } else { - self.gen_expr(t, context) + self.gen_expr(t.as_exact(), context) } }) .collect(); let expr = if exprs.len() == 1 { - make_unary_op(func.func, &exprs[0]) + make_unary_op(func.name.as_scalar(), &exprs[0]) } else if exprs.len() == 2 { - make_bin_op(func.func, &exprs) + make_bin_op(func.name.as_scalar(), &exprs) } else { None }; - expr.or_else(|| make_general_expr(func.func, exprs)) + expr.or_else(|| make_general_expr(func.name.as_scalar(), exprs)) .unwrap_or_else(|| self.gen_simple_scalar(ret)) } } diff --git a/src/tests/sqlsmith/src/sql_gen/types.rs b/src/tests/sqlsmith/src/sql_gen/types.rs index 57e6a64f0e93e..ea3c00e45e1da 100644 --- a/src/tests/sqlsmith/src/sql_gen/types.rs +++ b/src/tests/sqlsmith/src/sql_gen/types.rs @@ -20,9 +20,8 @@ use std::sync::LazyLock; use itertools::Itertools; use risingwave_common::types::{DataType, DataTypeName}; use risingwave_expr::aggregate::AggKind; -use risingwave_expr::sig::agg::{agg_func_sigs, AggFuncSig as RwAggFuncSig}; use risingwave_expr::sig::cast::{cast_sigs, CastContext, CastSig as RwCastSig}; -use risingwave_expr::sig::func::{func_sigs, FuncSign as RwFuncSig}; +use risingwave_expr::sig::{FuncSign, FUNCTION_REGISTRY}; use risingwave_frontend::expr::ExprType; use risingwave_sqlparser::ast::{BinaryOperator, DataType as AstDataType, StructField}; @@ -104,66 +103,6 @@ impl TryFrom for CastSig { } } -/// Provide internal `FuncSig` which can be used for `struct` and `list`. -#[derive(Clone)] -pub struct FuncSig { - pub func: ExprType, - pub inputs_type: Vec, - pub ret_type: DataType, -} - -impl TryFrom<&RwFuncSig> for FuncSig { - type Error = String; - - fn try_from(value: &RwFuncSig) -> Result { - if let Some(inputs_type) = value - .inputs_type - .iter() - .map(data_type_name_to_ast_data_type) - .collect() - && let Some(ret_type) = data_type_name_to_ast_data_type(&value.ret_type) - { - Ok(FuncSig { - inputs_type, - ret_type, - func: value.func, - }) - } else { - Err(format!("unsupported func sig: {:?}", value)) - } - } -} - -/// Provide internal `AggFuncSig` which can be used for `struct` and `list`. -#[derive(Clone)] -pub struct AggFuncSig { - pub func: AggKind, - pub inputs_type: Vec, - pub ret_type: DataType, -} - -impl TryFrom<&RwAggFuncSig> for AggFuncSig { - type Error = String; - - fn try_from(value: &RwAggFuncSig) -> Result { - if let Some(inputs_type) = value - .inputs_type - .iter() - .map(data_type_name_to_ast_data_type) - .collect() - && let Some(ret_type) = data_type_name_to_ast_data_type(&value.ret_type) - { - Ok(AggFuncSig { - inputs_type, - ret_type, - func: value.func, - }) - } else { - Err(format!("unsupported agg_func sig: {:?}", value)) - } - } -} - /// Function ban list. /// These functions should be generated eventually, by adding expression constraints. /// If we naively generate arguments for these functions, it will affect sqlsmith @@ -178,26 +117,35 @@ static FUNC_BAN_LIST: LazyLock> = LazyLock::new(|| { /// Table which maps functions' return types to possible function signatures. // ENABLE: https://github.com/risingwavelabs/risingwave/issues/5826 -pub(crate) static FUNC_TABLE: LazyLock>> = LazyLock::new(|| { - let mut funcs = HashMap::>::new(); - func_sigs() - .filter(|func| { - func.inputs_type - .iter() - .all(|t| *t != DataTypeName::Timestamptz) - && !FUNC_BAN_LIST.contains(&func.func) - && !func.deprecated // deprecated functions are not accepted by frontend - }) - .filter_map(|func| func.try_into().ok()) - .for_each(|func: FuncSig| funcs.entry(func.ret_type.clone()).or_default().push(func)); - funcs -}); +pub(crate) static FUNC_TABLE: LazyLock>> = + LazyLock::new(|| { + let mut funcs = HashMap::>::new(); + FUNCTION_REGISTRY + .iter_scalars() + .filter(|func| { + func.inputs_type.iter().all(|t| { + t.is_exact() + && t.as_exact() != &DataType::Timestamptz + && t.as_exact() != &DataType::Serial + }) && func.ret_type.is_exact() + && !FUNC_BAN_LIST.contains(&func.name.as_scalar()) + && !func.deprecated // deprecated functions are not accepted by frontend + }) + .for_each(|func| { + funcs + .entry(func.ret_type.as_exact().clone()) + .or_default() + .push(func) + }); + funcs + }); /// Set of invariant functions // ENABLE: https://github.com/risingwavelabs/risingwave/issues/5826 pub(crate) static INVARIANT_FUNC_SET: LazyLock> = LazyLock::new(|| { - func_sigs() - .map(|sig| sig.func) + FUNCTION_REGISTRY + .iter_scalars() + .map(|sig| sig.name.as_scalar()) .counts() .into_iter() .filter(|(_key, count)| *count == 1) @@ -207,14 +155,16 @@ pub(crate) static INVARIANT_FUNC_SET: LazyLock> = LazyLock::ne /// Table which maps aggregate functions' return types to possible function signatures. // ENABLE: https://github.com/risingwavelabs/risingwave/issues/5826 -pub(crate) static AGG_FUNC_TABLE: LazyLock>> = LazyLock::new( - || { - let mut funcs = HashMap::>::new(); - agg_func_sigs() +pub(crate) static AGG_FUNC_TABLE: LazyLock>> = + LazyLock::new(|| { + let mut funcs = HashMap::>::new(); + FUNCTION_REGISTRY + .iter_aggregates() .filter(|func| { func.inputs_type .iter() - .all(|t| *t != DataTypeName::Timestamptz) + .all(|t| t.is_exact() && t.as_exact() != &DataType::Timestamptz && t.as_exact() != &DataType::Serial) + && func.ret_type.is_exact() // Ignored functions && ![ AggKind::Sum0, // Used internally @@ -226,25 +176,23 @@ pub(crate) static AGG_FUNC_TABLE: LazyLock>> = AggKind::PercentileDisc, AggKind::Mode, ] - .contains(&func.func) + .contains(&func.name.as_aggregate()) // Exclude 2 phase agg global sum. // Sum(Int64) -> Int64. // Otherwise it conflicts with normal aggregation: // Sum(Int64) -> Decimal. // And sqlsmith will generate expressions with wrong types. - && if func.func == AggKind::Sum { - !(func.inputs_type[0] == DataTypeName::Int64 && func.ret_type == DataTypeName::Int64) + && if func.name.as_aggregate() == AggKind::Sum { + !(func.inputs_type[0].as_exact() == &DataType::Int64 && func.ret_type.as_exact() == &DataType::Int64) } else { true } }) - .filter_map(|func| func.try_into().ok()) - .for_each(|func: AggFuncSig| { - funcs.entry(func.ret_type.clone()).or_default().push(func) + .for_each(|func| { + funcs.entry(func.ret_type.as_exact().clone()).or_default().push(func) }); funcs - }, -); + }); /// Build a cast map from return types to viable cast-signatures. /// NOTE: We avoid cast from varchar to other datatypes apart from itself. @@ -299,28 +247,24 @@ pub(crate) static BINARY_INEQUALITY_OP_TABLE: LazyLock< HashMap<(DataType, DataType), Vec>, > = LazyLock::new(|| { let mut funcs = HashMap::<(DataType, DataType), Vec>::new(); - func_sigs() + FUNCTION_REGISTRY + .iter_scalars() .filter(|func| { - !FUNC_BAN_LIST.contains(&func.func) - && func.ret_type == DataTypeName::Boolean + !FUNC_BAN_LIST.contains(&func.name.as_scalar()) + && func.ret_type == DataType::Boolean.into() && func.inputs_type.len() == 2 && func .inputs_type .iter() - .all(|t| *t != DataTypeName::Timestamptz) + .all(|t| t.is_exact() && t.as_exact() != &DataType::Timestamptz) }) .filter_map(|func| { - let Some(lhs) = data_type_name_to_ast_data_type(&func.inputs_type[0]) else { - return None; - }; - let Some(rhs) = data_type_name_to_ast_data_type(&func.inputs_type[1]) else { - return None; - }; - let args = (lhs, rhs); - let Some(op) = expr_type_to_inequality_op(func.func) else { + let lhs = func.inputs_type[0].as_exact().clone(); + let rhs = func.inputs_type[1].as_exact().clone(); + let Some(op) = expr_type_to_inequality_op(func.name.as_scalar()) else { return None; }; - Some((args, op)) + Some(((lhs, rhs), op)) }) .for_each(|(args, op)| funcs.entry(args).or_default().push(op)); funcs