Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(sql-udf): correctly handle udf_binding_flag & udf_global_count #14906

Merged
merged 14 commits into from
Feb 4, 2024
12 changes: 10 additions & 2 deletions e2e_test/udf/sql_udf.slt
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ statement ok
create function add_sub_mix(INT, a INT, INT) returns int language sql as 'select $1 - a + $3';

# Mixed parameter with calling inner sql udfs
# statement ok
# create function add_sub_mix_wrapper(INT, a INT, INT) returns int language sql as 'select add($1, a) + a + sub(a, $3)';
statement ok
create function add_sub_mix_wrapper(INT, a INT, INT) returns int language sql as 'select add($1, a) + a + sub(a, $3)';

# Named sql udf with corner case
statement ok
Expand Down Expand Up @@ -150,6 +150,11 @@ select add_sub_mix(1, 2, 3);
----
2

query I
select add_sub_mix_wrapper(1, 2, 3);
----
4

query T
select corner_case(1, 2, 3);
----
Expand Down Expand Up @@ -392,6 +397,9 @@ drop function add_named_wrapper;
statement ok
drop function type_match;

statement ok
drop function add_sub_mix_wrapper;

# Drop the mock table
statement ok
drop table t1;
Expand Down
16 changes: 15 additions & 1 deletion src/frontend/src/binder/expr/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,10 +241,24 @@ impl Binder {
if let Ok(expr) = UdfContext::extract_udf_expression(ast) {
self.set_udf_binding_flag();
let bind_result = self.bind_expr(expr);
self.unset_udf_binding_flag();

// We should properly decrement global count after a successful binding
// Since the subsequent `unset flag` operation relies on global counting
self.udf_context.decr_global_count();

// Only the top-most sql udf binding should unset the flag
// Otherwise the subsequent binding may not be able to
// find the corresponding context, consider the following example:
// e.g., `select add_wrapper(a INT, b INT) returns int language sql as 'select add(a, b) + a';`
// The inner `add` should not unset the flag, otherwise the `a` will be treated as
// a normal column, which is then invalid in this context.
if self.udf_context.global_count() == 0 {
self.unset_udf_binding_flag();
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like udf_binding_flag can be inferred by udf_global_counter != 0. Can we remove udf_binding_flag now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so, let me do a tiny refactor.


// Restore context information for subsequent binding
self.udf_context.update_context(stashed_udf_context);

return bind_result;
} else {
return Err(ErrorCode::InvalidInputSyntax(
Expand Down
20 changes: 16 additions & 4 deletions src/frontend/src/binder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,10 @@ impl UdfContext {
self.udf_global_counter += 1;
}

pub fn decr_global_count(&mut self) {
self.udf_global_counter -= 1;
}

pub fn _is_empty(&self) -> bool {
self.udf_param_context.is_empty()
}
Expand Down Expand Up @@ -219,6 +223,8 @@ impl UdfContext {
Ok(expr)
}

/// Create the sql udf context
/// used per `bind_function` for sql udf & semantic check at definition time
pub fn create_udf_context(
args: &[FunctionArg],
catalog: &Arc<FunctionCatalog>,
Expand All @@ -228,9 +234,10 @@ impl UdfContext {
match current_arg {
FunctionArg::Unnamed(arg) => {
let FunctionArgExpr::Expr(e) = arg else {
return Err(
ErrorCode::InvalidInputSyntax("invalid syntax".to_string()).into()
);
return Err(ErrorCode::InvalidInputSyntax(
"expect `FunctionArgExpr` for unnamed argument".to_string(),
)
.into());
};
if catalog.arg_names[i].is_empty() {
ret.insert(format!("${}", i + 1), e.clone());
Expand All @@ -240,7 +247,12 @@ impl UdfContext {
ret.insert(catalog.arg_names[i].clone(), e.clone());
}
}
_ => return Err(ErrorCode::InvalidInputSyntax("invalid syntax".to_string()).into()),
_ => {
return Err(ErrorCode::InvalidInputSyntax(
"expect unnamed argument when creating sql udf context".to_string(),
)
.into())
}
}
}
Ok(ret)
Expand Down
6 changes: 5 additions & 1 deletion src/frontend/src/handler/create_sql_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,10 @@ pub async fn handle_create_sql_function(

binder.set_udf_binding_flag();

// Need to set the initial global count to 1
// otherwise the flag will be unset during the semantic check
binder.udf_context_mut().incr_global_count();

if let Ok(expr) = UdfContext::extract_udf_expression(ast) {
match binder.bind_expr(expr) {
Ok(expr) => {
Expand All @@ -204,7 +208,7 @@ pub async fn handle_create_sql_function(
}
}
Err(e) => return Err(ErrorCode::InvalidInputSyntax(format!(
"failed to conduct semantic check, please see if you are calling non-existence functions: {}",
"failed to conduct semantic check, please see if you are calling non-existence functions or parameters\ndetailed error message: {}",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just a current work around, I prefer a better hint display in the future, related: #14853.

e.as_report()
))
.into()),
Expand Down
Loading