Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(udf): allow aggregate: prefixed sql udf #18203

Merged
merged 3 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions e2e_test/udf/bug_fixes/18202_sql_udf_aggregate_prefix.slt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# https://github.com/risingwavelabs/risingwave/issues/18202

statement ok
create function as2 ( int[] ) returns bigint language sql as 'select array_sum($1)';

query I
select aggregate:as2(a) from (values (1), (2)) t(a);
----
3
46 changes: 18 additions & 28 deletions src/frontend/src/binder/expr/function/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ impl Binder {
InputRef::new(i, DataType::List(Box::new(expr.return_type()))).into()
})
.collect_vec();
let scalar: ExprImpl = if let Ok(schema) = self.first_valid_schema()
let scalar_func_expr = if let Ok(schema) = self.first_valid_schema()
&& let Some(func) = schema.get_function_by_name_inputs(&func_name, &mut array_args)
{
if !func.kind.is_scalar() {
Expand All @@ -172,13 +172,17 @@ impl Binder {
)
.into());
}
UserDefinedFunction::new(func.clone(), array_args).into()
if func.language == "sql" {
self.bind_sql_udf(func.clone(), array_args)?
} else {
UserDefinedFunction::new(func.clone(), array_args).into()
}
} else {
self.bind_builtin_scalar_function(&func_name, array_args, arg_list.variadic)?
};

// now this is either an aggregate/window function call
Some(AggKind::WrapScalar(scalar.to_expr_proto()))
Some(AggKind::WrapScalar(scalar_func_expr.to_expr_proto()))
} else {
None
};
Expand Down Expand Up @@ -214,7 +218,7 @@ impl Binder {
over.is_some(),
format!("`OVER` is not allowed in {} call", name)
);
return self.bind_sql_udf(func, arg_list.args);
return self.bind_sql_udf(func, args);
}

// now `func` is a non-SQL user-defined scalar/aggregate/table function
Expand All @@ -228,6 +232,7 @@ impl Binder {
} else if let Some(ref udf) = udf
&& udf.kind.is_aggregate()
{
assert_ne!(udf.language, "sql", "SQL UDAF is not supported yet");
Some(AggKind::UserDefined(udf.as_ref().into()))
} else if let Ok(kind) = AggKind::from_str(&func_name) {
Some(kind)
Expand Down Expand Up @@ -451,7 +456,7 @@ impl Binder {
fn bind_sql_udf(
&mut self,
func: Arc<FunctionCatalog>,
args: Vec<FunctionArg>,
args: Vec<ExprImpl>,
) -> Result<ExprImpl> {
if func.body.is_none() {
return Err(
Expand Down Expand Up @@ -483,30 +488,15 @@ impl Binder {

// The actual inline logic for sql udf
// Note that we will always create new udf context for each sql udf
let Ok(context) = UdfContext::create_udf_context(&args, &func) else {
return Err(ErrorCode::InvalidInputSyntax(
"failed to create the `udf_context`, please recheck your function definition and syntax".to_string()
)
.into());
};

let mut udf_context = HashMap::new();
for (c, e) in context {
// Note that we need to bind the args before actual delve in the function body
// This will update the context in the subsequent inner calling function
// e.g.,
// - create function print(INT) returns int language sql as 'select $1';
// - create function print_add_one(INT) returns int language sql as 'select print($1 + 1)';
// - select print_add_one(1); # The result should be 2 instead of 1.
// Without the pre-binding here, the ($1 + 1) will not be correctly populated,
// causing the result to always be 1.
let Ok(e) = self.bind_expr(e) else {
return Err(ErrorCode::BindError(
"failed to bind the argument, please recheck the syntax".to_string(),
)
.into());
};
udf_context.insert(c, e);
for (i, arg) in args.into_iter().enumerate() {
if func.arg_names[i].is_empty() {
// unnamed argument, use `$1`, `$2` as the name
udf_context.insert(format!("${}", i + 1), arg);
} else {
// named argument
udf_context.insert(func.arg_names[i].clone(), arg);
}
}
self.udf_context.update_context(udf_context);

Expand Down
40 changes: 1 addition & 39 deletions src/frontend/src/binder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@ use parking_lot::RwLock;
use risingwave_common::session_config::{SearchPath, SessionConfig};
use risingwave_common::types::DataType;
use risingwave_common::util::iter_util::ZipEqDebug;
use risingwave_sqlparser::ast::{
Expr as AstExpr, FunctionArg, FunctionArgExpr, SelectItem, SetExpr, Statement,
};
use risingwave_sqlparser::ast::{Expr as AstExpr, SelectItem, SetExpr, Statement};

use crate::error::Result;

Expand Down Expand Up @@ -63,7 +61,6 @@ pub use update::BoundUpdate;
pub use values::BoundValues;

use crate::catalog::catalog_service::CatalogReadGuard;
use crate::catalog::function_catalog::FunctionCatalog;
use crate::catalog::schema_catalog::SchemaCatalog;
use crate::catalog::{CatalogResult, TableId, ViewId};
use crate::error::ErrorCode;
Expand Down Expand Up @@ -223,41 +220,6 @@ impl UdfContext {

Ok(expr)
}

/// Create the sql udf context
/// used per `bind_function` for sql udf & semantic check at definition time
pub fn create_udf_context(
args: &[FunctionArg],
catalog: &Arc<FunctionCatalog>,
) -> Result<HashMap<String, AstExpr>> {
let mut ret: HashMap<String, AstExpr> = HashMap::new();
for (i, current_arg) in args.iter().enumerate() {
match current_arg {
FunctionArg::Unnamed(arg) => {
let FunctionArgExpr::Expr(e) = arg else {
return Err(ErrorCode::InvalidInputSyntax(
"expect `FunctionArgExpr` for unnamed argument".to_string(),
)
.into());
};
if catalog.arg_names[i].is_empty() {
ret.insert(format!("${}", i + 1), e.clone());
} else {
// The index mapping here is accurate
// So that we could directly use the index
ret.insert(catalog.arg_names[i].clone(), e.clone());
}
}
_ => {
return Err(ErrorCode::InvalidInputSyntax(
"expect unnamed argument when creating sql udf context".to_string(),
)
.into())
}
}
}
Ok(ret)
}
}

/// `ParameterTypes` is used to record the types of the parameters during binding prepared stataments.
Expand Down
Loading