Skip to content

Commit

Permalink
rename function registry
Browse files Browse the repository at this point in the history
Signed-off-by: Runji Wang <[email protected]>
  • Loading branch information
wangrunji0408 committed Sep 11, 2023
1 parent 5f7b17e commit 31fccb7
Show file tree
Hide file tree
Showing 12 changed files with 71 additions and 68 deletions.
10 changes: 3 additions & 7 deletions src/expr/benches/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use risingwave_common::types::test_utils::IntervalTestExt;
use risingwave_common::types::*;
use risingwave_expr::agg::{build as build_agg, AggArgs, AggCall, AggKind};
use risingwave_expr::expr::*;
use risingwave_expr::sig::func_sigs;
use risingwave_expr::sig::{aggregate_functions, scalar_functions};
use risingwave_expr::ExprError;
use risingwave_pb::expr::expr_node::PbType;

Expand Down Expand Up @@ -262,9 +262,7 @@ fn bench_expr(c: &mut Criterion) {
.iter(|| extract.eval(&input))
});

let sigs = func_sigs()
.filter(|s| s.is_scalar())
.sorted_by_cached_key(|sig| format!("{sig:?}"));
let sigs = scalar_functions().sorted_by_cached_key(|sig| format!("{sig:?}"));
'sig: for sig in sigs {
if (sig.inputs_type.iter())
.chain([&sig.ret_type])
Expand Down Expand Up @@ -340,9 +338,7 @@ fn bench_expr(c: &mut Criterion) {
});
}

let sigs = func_sigs()
.filter(|s| s.is_aggregate())
.sorted_by_cached_key(|sig| format!("{sig:?}"));
let sigs = aggregate_functions().sorted_by_cached_key(|sig| format!("{sig:?}"));
for sig in sigs {
if matches!(
sig.name.as_aggregate(),
Expand Down
2 changes: 1 addition & 1 deletion src/expr/src/agg/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ pub type BoxedAggregateFunction = Box<dyn AggregateFunction>;
/// NOTE: This function ignores argument indices, `column_orders`, `filter` and `distinct` in
/// `AggCall`. Such operations should be done in batch or streaming executors.
pub fn build(agg: &AggCall) -> Result<BoxedAggregateFunction> {
let desc = crate::sig::FUNC_SIG_MAP
let desc = crate::sig::FUNCTION_REGISTRY
.get(agg.kind, agg.args.arg_types(), &agg.return_type)
.ok_or_else(|| {
ExprError::UnsupportedFunction(format!(
Expand Down
20 changes: 11 additions & 9 deletions src/expr/src/expr/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use super::expr_vnode::VnodeExpression;
use crate::expr::{
BoxedExpression, Expression, InputRefExpression, LiteralExpression, TryFromExprNodeBoxed,
};
use crate::sig::FUNC_SIG_MAP;
use crate::sig::FUNCTION_REGISTRY;
use crate::{bail, ExprError, Result};

/// Build an expression from protobuf.
Expand Down Expand Up @@ -82,14 +82,16 @@ pub fn build_func(
}

let args = children.iter().map(|c| c.return_type()).collect_vec();
let desc = FUNC_SIG_MAP.get(func, &args, &ret_type).ok_or_else(|| {
ExprError::UnsupportedFunction(format!(
"{}({}) -> {}",
func.as_str_name().to_ascii_lowercase(),
args.iter().format(", "),
ret_type,
))
})?;
let desc = FUNCTION_REGISTRY
.get(func, &args, &ret_type)
.ok_or_else(|| {
ExprError::UnsupportedFunction(format!(
"{}({}) -> {}",
func.as_str_name().to_ascii_lowercase(),
args.iter().format(", "),
ret_type,
))
})?;
desc.build_scalar(ret_type, children)
}

Expand Down
45 changes: 29 additions & 16 deletions src/expr/src/sig/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,27 +31,40 @@ use crate::ExprError;

pub mod cast;

pub static FUNC_SIG_MAP: LazyLock<FuncSigMap> = LazyLock::new(|| unsafe {
let mut map = FuncSigMap::default();
tracing::info!("{} function signatures loaded.", FUNC_SIG_MAP_INIT.len());
for desc in FUNC_SIG_MAP_INIT.drain(..) {
map.insert(desc);
/// The global registry of all function signatures.
pub static FUNCTION_REGISTRY: LazyLock<FunctionRegistry> = LazyLock::new(|| unsafe {
// SAFETY: this function is called after all `#[ctor]` functions are called.
let mut map = FunctionRegistry::default();
tracing::info!("found {} functions", FUNCTION_REGISTRY_INIT.len());
for sig in FUNCTION_REGISTRY_INIT.drain(..) {
map.insert(sig);
}
map
});

/// The table of function signatures.
pub fn func_sigs() -> impl Iterator<Item = &'static FuncSign> {
FUNC_SIG_MAP.0.values().flatten()
/// Returns an iterator of all function signatures.
pub fn all_functions() -> impl Iterator<Item = &'static FuncSign> {
FUNCTION_REGISTRY.0.values().flatten()
}

/// Returns an iterator of all scalar functions.
pub fn scalar_functions() -> impl Iterator<Item = &'static FuncSign> {
all_functions().filter(|d| d.is_scalar())
}

/// Returns an iterator of all aggregate functions.
pub fn aggregate_functions() -> impl Iterator<Item = &'static FuncSign> {
all_functions().filter(|d| d.is_aggregate())
}

/// A set of function signatures.
#[derive(Default, Clone, Debug)]
pub struct FuncSigMap(HashMap<FuncName, Vec<FuncSign>>);
pub struct FunctionRegistry(HashMap<FuncName, Vec<FuncSign>>);

impl FuncSigMap {
impl FunctionRegistry {
/// Inserts a function signature.
pub fn insert(&mut self, desc: FuncSign) {
self.0.entry(desc.name).or_default().push(desc)
pub fn insert(&mut self, sig: FuncSign) {
self.0.entry(sig.name).or_default().push(sig)
}

/// Returns a function signature with the same type, argument types and return type.
Expand Down Expand Up @@ -359,16 +372,16 @@ pub enum FuncBuilder {
/// It is designed to be used by `#[function]` macro.
/// Users SHOULD NOT call this function.
#[doc(hidden)]
pub unsafe fn _register(desc: FuncSign) {
FUNC_SIG_MAP_INIT.push(desc)
pub unsafe fn _register(sig: FuncSign) {
FUNCTION_REGISTRY_INIT.push(sig)
}

/// The global registry of function signatures on initialization.
///
/// `#[function]` macro will generate a `#[ctor]` function to register the signature into this
/// vector. The calls are guaranteed to be sequential. The vector will be drained and moved into
/// `FUNC_SIG_MAP` on the first access of `FUNC_SIG_MAP`.
static mut FUNC_SIG_MAP_INIT: Vec<FuncSign> = Vec::new();
static mut FUNCTION_REGISTRY_INIT: Vec<FuncSign> = Vec::new();

#[cfg(test)]
mod tests {
Expand All @@ -383,7 +396,7 @@ mod tests {
// convert FUNC_SIG_MAP to a more convenient map for testing
let mut new_map: HashMap<FuncName, HashMap<Vec<SigDataType>, Vec<FuncSign>>> =
HashMap::new();
for (func, sigs) in &FUNC_SIG_MAP.0 {
for (func, sigs) in &FUNCTION_REGISTRY.0 {
for sig in sigs {
// validate the FUNC_SIG_MAP is consistent
assert_eq!(func, &sig.name);
Expand Down
2 changes: 1 addition & 1 deletion src/expr/src/table_function/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ pub fn build(
children: Vec<BoxedExpression>,
) -> Result<BoxedTableFunction> {
let args = children.iter().map(|t| t.return_type()).collect_vec();
let desc = crate::sig::FUNC_SIG_MAP
let desc = crate::sig::FUNCTION_REGISTRY
.get(func, &args, &return_type)
.ok_or_else(|| {
ExprError::UnsupportedFunction(format!(
Expand Down
4 changes: 2 additions & 2 deletions src/frontend/src/expr/agg_call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use itertools::Itertools;
use risingwave_common::error::{ErrorCode, Result, RwError};
use risingwave_common::types::DataType;
use risingwave_expr::agg::AggKind;
use risingwave_expr::sig::FUNC_SIG_MAP;
use risingwave_expr::sig::FUNCTION_REGISTRY;

use super::{Expr, ExprImpl, Literal, OrderBy};
use crate::utils::Condition;
Expand Down Expand Up @@ -88,7 +88,7 @@ impl AggCall {
// Ordered-Set Aggregation
(AggKind::Grouping, _) => Int32,
// other functions are handled by signature map
_ => FUNC_SIG_MAP.get_return_type(agg_kind, &args)?,
_ => FUNCTION_REGISTRY.get_return_type(agg_kind, &args)?,
})
}

Expand Down
6 changes: 2 additions & 4 deletions src/frontend/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::iter::once;

use enum_as_inner::EnumAsInner;
use fixedbitset::FixedBitSet;
use futures::FutureExt;
Expand Down Expand Up @@ -67,7 +65,7 @@ pub use session_timezone::SessionTimezone;
pub use subquery::{Subquery, SubqueryKind};
pub use table_function::{TableFunction, TableFunctionType};
pub use type_inference::{
align_types, cast_map_array, cast_ok, cast_sigs, func_sigs, infer_some_all, infer_type,
align_types, all_functions, cast_map_array, cast_ok, cast_sigs, infer_some_all, infer_type,
least_restrictive, CastContext, CastSig, FuncSign,
};
pub use user_defined_function::UserDefinedFunction;
Expand Down Expand Up @@ -201,7 +199,7 @@ impl ExprImpl {
/// # Panics
/// Panics if `input_ref >= input_col_num`.
pub fn collect_input_refs(&self, input_col_num: usize) -> FixedBitSet {
collect_input_refs(input_col_num, once(self))
collect_input_refs(input_col_num, [self])
}

/// Check if the expression has no side effects and output is deterministic
Expand Down
4 changes: 2 additions & 2 deletions src/frontend/src/expr/table_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use std::sync::Arc;

use itertools::Itertools;
use risingwave_common::types::DataType;
use risingwave_expr::sig::FUNC_SIG_MAP;
use risingwave_expr::sig::FUNCTION_REGISTRY;
pub use risingwave_pb::expr::table_function::PbType as TableFunctionType;
use risingwave_pb::expr::{
TableFunction as TableFunctionPb, UserDefinedTableFunction as UserDefinedTableFunctionPb,
Expand All @@ -43,7 +43,7 @@ impl TableFunction {
/// Create a `TableFunction` expr with the return type inferred from `func_type` and types of
/// `inputs`.
pub fn new(func_type: TableFunctionType, args: Vec<ExprImpl>) -> RwResult<Self> {
let return_type = FUNC_SIG_MAP.get_return_type(
let return_type = FUNCTION_REGISTRY.get_return_type(
func_type,
&args.iter().map(|c| c.return_type()).collect_vec(),
)?;
Expand Down
10 changes: 5 additions & 5 deletions src/frontend/src/expr/type_inference/func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ pub fn infer_type(func_type: ExprType, inputs: &mut [ExprImpl]) -> Result<DataTy
false => Some(e.return_type()),
})
.collect_vec();
let sig = infer_type_name(&FUNC_SIG_MAP, func_type, &actuals)?;
let sig = infer_type_name(&FUNCTION_REGISTRY, func_type, &actuals)?;

// add implicit casts to inputs
for (expr, t) in inputs.iter_mut().zip_eq_fast(&sig.inputs_type) {
Expand Down Expand Up @@ -80,7 +80,7 @@ pub fn infer_some_all(
(!inputs[0].is_untyped()).then_some(inputs[0].return_type()),
element_type.clone(),
];
let sig = infer_type_name(&FUNC_SIG_MAP, final_type, &actuals)?;
let sig = infer_type_name(&FUNCTION_REGISTRY, final_type, &actuals)?;
if sig.ret_type != DataType::Boolean.into() {
return Err(ErrorCode::BindError(format!(
"op SOME/ANY/ALL (array) requires operator to yield boolean, but got {}",
Expand Down Expand Up @@ -273,7 +273,7 @@ fn infer_struct_cast_target_type(
(NestedType::Infer(l), NestedType::Infer(r)) => {
// Both sides are *unknown*, using the sig_map to infer the return type.
let actuals = vec![None, None];
let sig = infer_type_name(&FUNC_SIG_MAP, func_type, &actuals)?;
let sig = infer_type_name(&FUNCTION_REGISTRY, func_type, &actuals)?;
Ok((
sig.ret_type != l.into(),
sig.ret_type != r.into(),
Expand Down Expand Up @@ -562,7 +562,7 @@ fn infer_type_for_special(
/// 5. Attempt to narrow down candidates by assuming all arguments are same type. This covers Rule
/// 4f in `PostgreSQL`. See [`narrow_same_type`] for details.
fn infer_type_name<'a>(
sig_map: &'a FuncSigMap,
sig_map: &'a FunctionRegistry,
func_type: ExprType,
inputs: &[Option<DataType>],
) -> Result<&'a FuncSign> {
Expand Down Expand Up @@ -1120,7 +1120,7 @@ mod tests {
),
];
for (desc, candidates, inputs, expected) in testcases {
let mut sig_map = FuncSigMap::default();
let mut sig_map = FunctionRegistry::default();
for formals in candidates {
sig_map.insert(FuncSign {
// func_name does not affect the overload resolution logic
Expand Down
2 changes: 1 addition & 1 deletion src/frontend/src/expr/type_inference/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@ pub use cast::{
align_types, cast_map_array, cast_ok, cast_ok_base, cast_sigs, least_restrictive, CastContext,
CastSig,
};
pub use func::{func_sigs, infer_some_all, infer_type, FuncSign};
pub use func::{all_functions, infer_some_all, infer_type, FuncSign};
9 changes: 4 additions & 5 deletions src/tests/sqlsmith/src/sql_gen/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ use itertools::Itertools;
use rand::seq::SliceRandom;
use rand::Rng;
use risingwave_common::types::{DataType, DataTypeName, StructType};
use risingwave_frontend::expr::{cast_sigs, func_sigs};
use risingwave_expr::sig::cast::cast_sigs;
use risingwave_expr::sig::{aggregate_functions, scalar_functions};
use risingwave_sqlparser::ast::{Expr, Ident, OrderByExpr, Value};

use crate::sql_gen::types::data_type_to_ast_data_type;
Expand Down Expand Up @@ -302,8 +303,7 @@ pub(crate) fn sql_null() -> Expr {
// Add variadic function signatures. Can add these functions
// to a FUNC_TABLE too.
pub fn print_function_table() -> String {
let func_str = func_sigs()
.filter(|sign| sign.is_scalar())
let func_str = scalar_functions()
.map(|sign| {
format!(
"{}({}) -> {}",
Expand All @@ -314,8 +314,7 @@ pub fn print_function_table() -> String {
})
.join("\n");

let agg_func_str = func_sigs()
.filter(|sign| sign.is_aggregate())
let agg_func_str = aggregate_functions()
.map(|sign| {
format!(
"{}({}) -> {}",
Expand Down
25 changes: 10 additions & 15 deletions src/tests/sqlsmith/src/sql_gen/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use itertools::Itertools;
use risingwave_common::types::{DataType, DataTypeName};
use risingwave_expr::agg::AggKind;
use risingwave_expr::sig::cast::{cast_sigs, CastContext, CastSig as RwCastSig};
use risingwave_expr::sig::{func_sigs, FuncSign};
use risingwave_expr::sig::{aggregate_functions, scalar_functions, FuncSign};
use risingwave_frontend::expr::ExprType;
use risingwave_sqlparser::ast::{BinaryOperator, DataType as AstDataType, StructField};

Expand Down Expand Up @@ -120,13 +120,11 @@ static FUNC_BAN_LIST: LazyLock<HashSet<ExprType>> = LazyLock::new(|| {
pub(crate) static FUNC_TABLE: LazyLock<HashMap<DataType, Vec<&'static FuncSign>>> =
LazyLock::new(|| {
let mut funcs = HashMap::<DataType, Vec<&'static FuncSign>>::new();
func_sigs()
scalar_functions()
.filter(|func| {
func.is_scalar()
&& func
.inputs_type
.iter()
.all(|t| t.is_exact() && t.as_exact() != &DataType::Timestamptz)
func.inputs_type
.iter()
.all(|t| t.is_exact() && t.as_exact() != &DataType::Timestamptz)
&& !FUNC_BAN_LIST.contains(&func.name.as_scalar())
&& !func.deprecated // deprecated functions are not accepted by frontend
})
Expand All @@ -142,8 +140,7 @@ pub(crate) static FUNC_TABLE: LazyLock<HashMap<DataType, Vec<&'static FuncSign>>
/// Set of invariant functions
// ENABLE: https://github.com/risingwavelabs/risingwave/issues/5826
pub(crate) static INVARIANT_FUNC_SET: LazyLock<HashSet<ExprType>> = LazyLock::new(|| {
func_sigs()
.filter(|sig| sig.is_scalar())
scalar_functions()
.map(|sig| sig.name.as_scalar())
.counts()
.into_iter()
Expand All @@ -157,10 +154,9 @@ pub(crate) static INVARIANT_FUNC_SET: LazyLock<HashSet<ExprType>> = LazyLock::ne
pub(crate) static AGG_FUNC_TABLE: LazyLock<HashMap<DataType, Vec<&'static FuncSign>>> =
LazyLock::new(|| {
let mut funcs = HashMap::<DataType, Vec<&'static FuncSign>>::new();
func_sigs()
aggregate_functions()
.filter(|func| {
func.is_aggregate()
&& func.inputs_type
func.inputs_type
.iter()
.all(|t| t != &DataType::Timestamptz.into())
// Ignored functions
Expand Down Expand Up @@ -245,10 +241,9 @@ pub(crate) static BINARY_INEQUALITY_OP_TABLE: LazyLock<
HashMap<(DataType, DataType), Vec<BinaryOperator>>,
> = LazyLock::new(|| {
let mut funcs = HashMap::<(DataType, DataType), Vec<BinaryOperator>>::new();
func_sigs()
scalar_functions()
.filter(|func| {
func.is_scalar()
&& !FUNC_BAN_LIST.contains(&func.name.as_scalar())
!FUNC_BAN_LIST.contains(&func.name.as_scalar())
&& func.ret_type == DataType::Boolean.into()
&& func.inputs_type.len() == 2
&& func
Expand Down

0 comments on commit 31fccb7

Please sign in to comment.