From 31fccb7aaae1750a70e2ee790c932073d4221d81 Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Tue, 12 Sep 2023 01:31:20 +0800 Subject: [PATCH] rename function registry Signed-off-by: Runji Wang --- src/expr/benches/expr.rs | 10 ++--- src/expr/src/agg/mod.rs | 2 +- src/expr/src/expr/build.rs | 20 +++++---- src/expr/src/sig/mod.rs | 45 +++++++++++++------- src/expr/src/table_function/mod.rs | 2 +- src/frontend/src/expr/agg_call.rs | 4 +- src/frontend/src/expr/mod.rs | 6 +-- src/frontend/src/expr/table_function.rs | 4 +- src/frontend/src/expr/type_inference/func.rs | 10 ++--- src/frontend/src/expr/type_inference/mod.rs | 2 +- src/tests/sqlsmith/src/sql_gen/expr.rs | 9 ++-- src/tests/sqlsmith/src/sql_gen/types.rs | 25 +++++------ 12 files changed, 71 insertions(+), 68 deletions(-) diff --git a/src/expr/benches/expr.rs b/src/expr/benches/expr.rs index d83a511e5f23b..a159025cb87ac 100644 --- a/src/expr/benches/expr.rs +++ b/src/expr/benches/expr.rs @@ -26,7 +26,7 @@ use risingwave_common::types::test_utils::IntervalTestExt; use risingwave_common::types::*; use risingwave_expr::agg::{build as build_agg, AggArgs, AggCall, AggKind}; use risingwave_expr::expr::*; -use risingwave_expr::sig::func_sigs; +use risingwave_expr::sig::{aggregate_functions, scalar_functions}; use risingwave_expr::ExprError; use risingwave_pb::expr::expr_node::PbType; @@ -262,9 +262,7 @@ fn bench_expr(c: &mut Criterion) { .iter(|| extract.eval(&input)) }); - let sigs = func_sigs() - .filter(|s| s.is_scalar()) - .sorted_by_cached_key(|sig| format!("{sig:?}")); + let sigs = scalar_functions().sorted_by_cached_key(|sig| format!("{sig:?}")); 'sig: for sig in sigs { if (sig.inputs_type.iter()) .chain([&sig.ret_type]) @@ -340,9 +338,7 @@ fn bench_expr(c: &mut Criterion) { }); } - let sigs = func_sigs() - .filter(|s| s.is_aggregate()) - .sorted_by_cached_key(|sig| format!("{sig:?}")); + let sigs = aggregate_functions().sorted_by_cached_key(|sig| format!("{sig:?}")); for sig in sigs { if matches!( sig.name.as_aggregate(), diff --git a/src/expr/src/agg/mod.rs b/src/expr/src/agg/mod.rs index e8892f9c4c174..09a5b903fb8e0 100644 --- a/src/expr/src/agg/mod.rs +++ b/src/expr/src/agg/mod.rs @@ -123,7 +123,7 @@ pub type BoxedAggregateFunction = Box; /// NOTE: This function ignores argument indices, `column_orders`, `filter` and `distinct` in /// `AggCall`. Such operations should be done in batch or streaming executors. pub fn build(agg: &AggCall) -> Result { - let desc = crate::sig::FUNC_SIG_MAP + let desc = crate::sig::FUNCTION_REGISTRY .get(agg.kind, agg.args.arg_types(), &agg.return_type) .ok_or_else(|| { ExprError::UnsupportedFunction(format!( diff --git a/src/expr/src/expr/build.rs b/src/expr/src/expr/build.rs index 5c40b237266e1..c83839cb99cfb 100644 --- a/src/expr/src/expr/build.rs +++ b/src/expr/src/expr/build.rs @@ -30,7 +30,7 @@ use super::expr_vnode::VnodeExpression; use crate::expr::{ BoxedExpression, Expression, InputRefExpression, LiteralExpression, TryFromExprNodeBoxed, }; -use crate::sig::FUNC_SIG_MAP; +use crate::sig::FUNCTION_REGISTRY; use crate::{bail, ExprError, Result}; /// Build an expression from protobuf. @@ -82,14 +82,16 @@ pub fn build_func( } let args = children.iter().map(|c| c.return_type()).collect_vec(); - let desc = FUNC_SIG_MAP.get(func, &args, &ret_type).ok_or_else(|| { - ExprError::UnsupportedFunction(format!( - "{}({}) -> {}", - func.as_str_name().to_ascii_lowercase(), - args.iter().format(", "), - ret_type, - )) - })?; + let desc = FUNCTION_REGISTRY + .get(func, &args, &ret_type) + .ok_or_else(|| { + ExprError::UnsupportedFunction(format!( + "{}({}) -> {}", + func.as_str_name().to_ascii_lowercase(), + args.iter().format(", "), + ret_type, + )) + })?; desc.build_scalar(ret_type, children) } diff --git a/src/expr/src/sig/mod.rs b/src/expr/src/sig/mod.rs index 40f935a87b3da..6d7fa0d0514f3 100644 --- a/src/expr/src/sig/mod.rs +++ b/src/expr/src/sig/mod.rs @@ -31,27 +31,40 @@ use crate::ExprError; pub mod cast; -pub static FUNC_SIG_MAP: LazyLock = LazyLock::new(|| unsafe { - let mut map = FuncSigMap::default(); - tracing::info!("{} function signatures loaded.", FUNC_SIG_MAP_INIT.len()); - for desc in FUNC_SIG_MAP_INIT.drain(..) { - map.insert(desc); +/// The global registry of all function signatures. +pub static FUNCTION_REGISTRY: LazyLock = LazyLock::new(|| unsafe { + // SAFETY: this function is called after all `#[ctor]` functions are called. + let mut map = FunctionRegistry::default(); + tracing::info!("found {} functions", FUNCTION_REGISTRY_INIT.len()); + for sig in FUNCTION_REGISTRY_INIT.drain(..) { + map.insert(sig); } map }); -/// The table of function signatures. -pub fn func_sigs() -> impl Iterator { - FUNC_SIG_MAP.0.values().flatten() +/// Returns an iterator of all function signatures. +pub fn all_functions() -> impl Iterator { + FUNCTION_REGISTRY.0.values().flatten() } +/// Returns an iterator of all scalar functions. +pub fn scalar_functions() -> impl Iterator { + all_functions().filter(|d| d.is_scalar()) +} + +/// Returns an iterator of all aggregate functions. +pub fn aggregate_functions() -> impl Iterator { + all_functions().filter(|d| d.is_aggregate()) +} + +/// A set of function signatures. #[derive(Default, Clone, Debug)] -pub struct FuncSigMap(HashMap>); +pub struct FunctionRegistry(HashMap>); -impl FuncSigMap { +impl FunctionRegistry { /// Inserts a function signature. - pub fn insert(&mut self, desc: FuncSign) { - self.0.entry(desc.name).or_default().push(desc) + pub fn insert(&mut self, sig: FuncSign) { + self.0.entry(sig.name).or_default().push(sig) } /// Returns a function signature with the same type, argument types and return type. @@ -359,8 +372,8 @@ pub enum FuncBuilder { /// It is designed to be used by `#[function]` macro. /// Users SHOULD NOT call this function. #[doc(hidden)] -pub unsafe fn _register(desc: FuncSign) { - FUNC_SIG_MAP_INIT.push(desc) +pub unsafe fn _register(sig: FuncSign) { + FUNCTION_REGISTRY_INIT.push(sig) } /// The global registry of function signatures on initialization. @@ -368,7 +381,7 @@ pub unsafe fn _register(desc: FuncSign) { /// `#[function]` macro will generate a `#[ctor]` function to register the signature into this /// vector. The calls are guaranteed to be sequential. The vector will be drained and moved into /// `FUNC_SIG_MAP` on the first access of `FUNC_SIG_MAP`. -static mut FUNC_SIG_MAP_INIT: Vec = Vec::new(); +static mut FUNCTION_REGISTRY_INIT: Vec = Vec::new(); #[cfg(test)] mod tests { @@ -383,7 +396,7 @@ mod tests { // convert FUNC_SIG_MAP to a more convenient map for testing let mut new_map: HashMap, Vec>> = HashMap::new(); - for (func, sigs) in &FUNC_SIG_MAP.0 { + for (func, sigs) in &FUNCTION_REGISTRY.0 { for sig in sigs { // validate the FUNC_SIG_MAP is consistent assert_eq!(func, &sig.name); diff --git a/src/expr/src/table_function/mod.rs b/src/expr/src/table_function/mod.rs index 12a5c8cfa8829..8fbfdb9327911 100644 --- a/src/expr/src/table_function/mod.rs +++ b/src/expr/src/table_function/mod.rs @@ -135,7 +135,7 @@ pub fn build( children: Vec, ) -> Result { let args = children.iter().map(|t| t.return_type()).collect_vec(); - let desc = crate::sig::FUNC_SIG_MAP + let desc = crate::sig::FUNCTION_REGISTRY .get(func, &args, &return_type) .ok_or_else(|| { ExprError::UnsupportedFunction(format!( diff --git a/src/frontend/src/expr/agg_call.rs b/src/frontend/src/expr/agg_call.rs index 255199c07b430..3291e7ff7878f 100644 --- a/src/frontend/src/expr/agg_call.rs +++ b/src/frontend/src/expr/agg_call.rs @@ -16,7 +16,7 @@ use itertools::Itertools; use risingwave_common::error::{ErrorCode, Result, RwError}; use risingwave_common::types::DataType; use risingwave_expr::agg::AggKind; -use risingwave_expr::sig::FUNC_SIG_MAP; +use risingwave_expr::sig::FUNCTION_REGISTRY; use super::{Expr, ExprImpl, Literal, OrderBy}; use crate::utils::Condition; @@ -88,7 +88,7 @@ impl AggCall { // Ordered-Set Aggregation (AggKind::Grouping, _) => Int32, // other functions are handled by signature map - _ => FUNC_SIG_MAP.get_return_type(agg_kind, &args)?, + _ => FUNCTION_REGISTRY.get_return_type(agg_kind, &args)?, }) } diff --git a/src/frontend/src/expr/mod.rs b/src/frontend/src/expr/mod.rs index d4b30e7056211..a1bf248ce40fd 100644 --- a/src/frontend/src/expr/mod.rs +++ b/src/frontend/src/expr/mod.rs @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::iter::once; - use enum_as_inner::EnumAsInner; use fixedbitset::FixedBitSet; use futures::FutureExt; @@ -67,7 +65,7 @@ pub use session_timezone::SessionTimezone; pub use subquery::{Subquery, SubqueryKind}; pub use table_function::{TableFunction, TableFunctionType}; pub use type_inference::{ - align_types, cast_map_array, cast_ok, cast_sigs, func_sigs, infer_some_all, infer_type, + align_types, all_functions, cast_map_array, cast_ok, cast_sigs, infer_some_all, infer_type, least_restrictive, CastContext, CastSig, FuncSign, }; pub use user_defined_function::UserDefinedFunction; @@ -201,7 +199,7 @@ impl ExprImpl { /// # Panics /// Panics if `input_ref >= input_col_num`. pub fn collect_input_refs(&self, input_col_num: usize) -> FixedBitSet { - collect_input_refs(input_col_num, once(self)) + collect_input_refs(input_col_num, [self]) } /// Check if the expression has no side effects and output is deterministic diff --git a/src/frontend/src/expr/table_function.rs b/src/frontend/src/expr/table_function.rs index fd3ff2aec7439..dfb028d605705 100644 --- a/src/frontend/src/expr/table_function.rs +++ b/src/frontend/src/expr/table_function.rs @@ -16,7 +16,7 @@ use std::sync::Arc; use itertools::Itertools; use risingwave_common::types::DataType; -use risingwave_expr::sig::FUNC_SIG_MAP; +use risingwave_expr::sig::FUNCTION_REGISTRY; pub use risingwave_pb::expr::table_function::PbType as TableFunctionType; use risingwave_pb::expr::{ TableFunction as TableFunctionPb, UserDefinedTableFunction as UserDefinedTableFunctionPb, @@ -43,7 +43,7 @@ impl TableFunction { /// Create a `TableFunction` expr with the return type inferred from `func_type` and types of /// `inputs`. pub fn new(func_type: TableFunctionType, args: Vec) -> RwResult { - let return_type = FUNC_SIG_MAP.get_return_type( + let return_type = FUNCTION_REGISTRY.get_return_type( func_type, &args.iter().map(|c| c.return_type()).collect_vec(), )?; diff --git a/src/frontend/src/expr/type_inference/func.rs b/src/frontend/src/expr/type_inference/func.rs index 637d34e0f1568..4339f56beac23 100644 --- a/src/frontend/src/expr/type_inference/func.rs +++ b/src/frontend/src/expr/type_inference/func.rs @@ -39,7 +39,7 @@ pub fn infer_type(func_type: ExprType, inputs: &mut [ExprImpl]) -> Result Some(e.return_type()), }) .collect_vec(); - let sig = infer_type_name(&FUNC_SIG_MAP, func_type, &actuals)?; + let sig = infer_type_name(&FUNCTION_REGISTRY, func_type, &actuals)?; // add implicit casts to inputs for (expr, t) in inputs.iter_mut().zip_eq_fast(&sig.inputs_type) { @@ -80,7 +80,7 @@ pub fn infer_some_all( (!inputs[0].is_untyped()).then_some(inputs[0].return_type()), element_type.clone(), ]; - let sig = infer_type_name(&FUNC_SIG_MAP, final_type, &actuals)?; + let sig = infer_type_name(&FUNCTION_REGISTRY, final_type, &actuals)?; if sig.ret_type != DataType::Boolean.into() { return Err(ErrorCode::BindError(format!( "op SOME/ANY/ALL (array) requires operator to yield boolean, but got {}", @@ -273,7 +273,7 @@ fn infer_struct_cast_target_type( (NestedType::Infer(l), NestedType::Infer(r)) => { // Both sides are *unknown*, using the sig_map to infer the return type. let actuals = vec![None, None]; - let sig = infer_type_name(&FUNC_SIG_MAP, func_type, &actuals)?; + let sig = infer_type_name(&FUNCTION_REGISTRY, func_type, &actuals)?; Ok(( sig.ret_type != l.into(), sig.ret_type != r.into(), @@ -562,7 +562,7 @@ fn infer_type_for_special( /// 5. Attempt to narrow down candidates by assuming all arguments are same type. This covers Rule /// 4f in `PostgreSQL`. See [`narrow_same_type`] for details. fn infer_type_name<'a>( - sig_map: &'a FuncSigMap, + sig_map: &'a FunctionRegistry, func_type: ExprType, inputs: &[Option], ) -> Result<&'a FuncSign> { @@ -1120,7 +1120,7 @@ mod tests { ), ]; for (desc, candidates, inputs, expected) in testcases { - let mut sig_map = FuncSigMap::default(); + let mut sig_map = FunctionRegistry::default(); for formals in candidates { sig_map.insert(FuncSign { // func_name does not affect the overload resolution logic diff --git a/src/frontend/src/expr/type_inference/mod.rs b/src/frontend/src/expr/type_inference/mod.rs index 1eab51341af18..67d5022bb8911 100644 --- a/src/frontend/src/expr/type_inference/mod.rs +++ b/src/frontend/src/expr/type_inference/mod.rs @@ -21,4 +21,4 @@ pub use cast::{ align_types, cast_map_array, cast_ok, cast_ok_base, cast_sigs, least_restrictive, CastContext, CastSig, }; -pub use func::{func_sigs, infer_some_all, infer_type, FuncSign}; +pub use func::{all_functions, infer_some_all, infer_type, FuncSign}; diff --git a/src/tests/sqlsmith/src/sql_gen/expr.rs b/src/tests/sqlsmith/src/sql_gen/expr.rs index 68a18e8f8a30f..c91b25910f9f0 100644 --- a/src/tests/sqlsmith/src/sql_gen/expr.rs +++ b/src/tests/sqlsmith/src/sql_gen/expr.rs @@ -16,7 +16,8 @@ use itertools::Itertools; use rand::seq::SliceRandom; use rand::Rng; use risingwave_common::types::{DataType, DataTypeName, StructType}; -use risingwave_frontend::expr::{cast_sigs, func_sigs}; +use risingwave_expr::sig::cast::cast_sigs; +use risingwave_expr::sig::{aggregate_functions, scalar_functions}; use risingwave_sqlparser::ast::{Expr, Ident, OrderByExpr, Value}; use crate::sql_gen::types::data_type_to_ast_data_type; @@ -302,8 +303,7 @@ pub(crate) fn sql_null() -> Expr { // Add variadic function signatures. Can add these functions // to a FUNC_TABLE too. pub fn print_function_table() -> String { - let func_str = func_sigs() - .filter(|sign| sign.is_scalar()) + let func_str = scalar_functions() .map(|sign| { format!( "{}({}) -> {}", @@ -314,8 +314,7 @@ pub fn print_function_table() -> String { }) .join("\n"); - let agg_func_str = func_sigs() - .filter(|sign| sign.is_aggregate()) + let agg_func_str = aggregate_functions() .map(|sign| { format!( "{}({}) -> {}", diff --git a/src/tests/sqlsmith/src/sql_gen/types.rs b/src/tests/sqlsmith/src/sql_gen/types.rs index ff7072e85b788..151e527b5ea58 100644 --- a/src/tests/sqlsmith/src/sql_gen/types.rs +++ b/src/tests/sqlsmith/src/sql_gen/types.rs @@ -21,7 +21,7 @@ use itertools::Itertools; use risingwave_common::types::{DataType, DataTypeName}; use risingwave_expr::agg::AggKind; use risingwave_expr::sig::cast::{cast_sigs, CastContext, CastSig as RwCastSig}; -use risingwave_expr::sig::{func_sigs, FuncSign}; +use risingwave_expr::sig::{aggregate_functions, scalar_functions, FuncSign}; use risingwave_frontend::expr::ExprType; use risingwave_sqlparser::ast::{BinaryOperator, DataType as AstDataType, StructField}; @@ -120,13 +120,11 @@ static FUNC_BAN_LIST: LazyLock> = LazyLock::new(|| { pub(crate) static FUNC_TABLE: LazyLock>> = LazyLock::new(|| { let mut funcs = HashMap::>::new(); - func_sigs() + scalar_functions() .filter(|func| { - func.is_scalar() - && func - .inputs_type - .iter() - .all(|t| t.is_exact() && t.as_exact() != &DataType::Timestamptz) + func.inputs_type + .iter() + .all(|t| t.is_exact() && t.as_exact() != &DataType::Timestamptz) && !FUNC_BAN_LIST.contains(&func.name.as_scalar()) && !func.deprecated // deprecated functions are not accepted by frontend }) @@ -142,8 +140,7 @@ pub(crate) static FUNC_TABLE: LazyLock> /// Set of invariant functions // ENABLE: https://github.com/risingwavelabs/risingwave/issues/5826 pub(crate) static INVARIANT_FUNC_SET: LazyLock> = LazyLock::new(|| { - func_sigs() - .filter(|sig| sig.is_scalar()) + scalar_functions() .map(|sig| sig.name.as_scalar()) .counts() .into_iter() @@ -157,10 +154,9 @@ pub(crate) static INVARIANT_FUNC_SET: LazyLock> = LazyLock::ne pub(crate) static AGG_FUNC_TABLE: LazyLock>> = LazyLock::new(|| { let mut funcs = HashMap::>::new(); - func_sigs() + aggregate_functions() .filter(|func| { - func.is_aggregate() - && func.inputs_type + func.inputs_type .iter() .all(|t| t != &DataType::Timestamptz.into()) // Ignored functions @@ -245,10 +241,9 @@ pub(crate) static BINARY_INEQUALITY_OP_TABLE: LazyLock< HashMap<(DataType, DataType), Vec>, > = LazyLock::new(|| { let mut funcs = HashMap::<(DataType, DataType), Vec>::new(); - func_sigs() + scalar_functions() .filter(|func| { - func.is_scalar() - && !FUNC_BAN_LIST.contains(&func.name.as_scalar()) + !FUNC_BAN_LIST.contains(&func.name.as_scalar()) && func.ret_type == DataType::Boolean.into() && func.inputs_type.len() == 2 && func