Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(sql-udf): correctly handle udf_binding_flag & udf_global_count #14906

Merged
merged 14 commits into from
Feb 4, 2024
12 changes: 10 additions & 2 deletions e2e_test/udf/sql_udf.slt
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,13 @@ select add_named(a, b) from t3 order by a asc;
################################

# Mixed parameter with calling inner sql udfs
# statement ok
# create function add_sub_mix_wrapper(INT, a INT, INT) returns int language sql as 'select add($1, a) + a + sub(a, $3)';
statement ok
create function add_sub_mix_wrapper(INT, a INT, INT) returns int language sql as 'select add($1, a) + a + sub(a, $3)';

query I
select add_sub_mix_wrapper(1, 2, 3);
----
4

# Named sql udf with corner case
statement ok
Expand Down Expand Up @@ -404,6 +409,9 @@ drop function add_named_wrapper;
statement ok
drop function type_match;

statement ok
drop function add_sub_mix_wrapper;

statement ok
drop table t1;

Expand Down
2 changes: 1 addition & 1 deletion src/frontend/src/binder/expr/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ impl Binder {
// to the name of the defined sql udf parameters stored in `udf_context`.
// If so, we will treat this bind as an special bind, the actual expression
// stored in `udf_context` will then be bound instead of binding the non-existing column.
if self.udf_binding_flag {
if self.udf_context.global_count() != 0 {
if let Some(expr) = self.udf_context.get_expr(&column_name) {
return Ok(expr.clone());
} else {
Expand Down
8 changes: 6 additions & 2 deletions src/frontend/src/binder/expr/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -240,12 +240,16 @@ impl Binder {
}

if let Ok(expr) = UdfContext::extract_udf_expression(ast) {
self.set_udf_binding_flag();
let bind_result = self.bind_expr(expr);
self.unset_udf_binding_flag();

// We should properly decrement global count after a successful binding
// Since the subsequent probe operation in `bind_column` or
// `bind_parameter` relies on global counting
self.udf_context.decr_global_count();

// Restore context information for subsequent binding
self.udf_context.update_context(stashed_udf_context);

return bind_result;
} else {
return Err(ErrorCode::InvalidInputSyntax(
Expand Down
8 changes: 5 additions & 3 deletions src/frontend/src/binder/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -386,11 +386,13 @@ impl Binder {

fn bind_parameter(&mut self, index: u64) -> Result<ExprImpl> {
// Special check for sql udf
// Note: This is specific to anonymous sql udf, since the
// Note: This is specific to sql udf with unnamed parameters, since the
// parameters will be parsed and treated as `Parameter`.
// For detailed explanation, consider checking `bind_column`.
if let Some(expr) = self.udf_context.get_expr(&format!("${index}")) {
return Ok(expr.clone());
if self.udf_context.global_count() != 0 {
if let Some(expr) = self.udf_context.get_expr(&format!("${index}")) {
return Ok(expr.clone());
}
}

Ok(Parameter::new(index, self.param_types.clone()).into())
Expand Down
33 changes: 16 additions & 17 deletions src/frontend/src/binder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,6 @@ pub struct Binder {

/// The sql udf context that will be used during binding phase
udf_context: UdfContext,

/// Udf binding flag, used to distinguish between
/// columns and named parameters during sql udf binding
udf_binding_flag: bool,
}

#[derive(Clone, Debug, Default)]
Expand Down Expand Up @@ -155,6 +151,10 @@ impl UdfContext {
self.udf_global_counter += 1;
}

pub fn decr_global_count(&mut self) {
self.udf_global_counter -= 1;
}

pub fn _is_empty(&self) -> bool {
self.udf_param_context.is_empty()
}
Expand Down Expand Up @@ -219,6 +219,8 @@ impl UdfContext {
Ok(expr)
}

/// Create the sql udf context
/// used per `bind_function` for sql udf & semantic check at definition time
pub fn create_udf_context(
args: &[FunctionArg],
catalog: &Arc<FunctionCatalog>,
Expand All @@ -228,9 +230,10 @@ impl UdfContext {
match current_arg {
FunctionArg::Unnamed(arg) => {
let FunctionArgExpr::Expr(e) = arg else {
return Err(
ErrorCode::InvalidInputSyntax("invalid syntax".to_string()).into()
);
return Err(ErrorCode::InvalidInputSyntax(
"expect `FunctionArgExpr` for unnamed argument".to_string(),
)
.into());
};
if catalog.arg_names[i].is_empty() {
ret.insert(format!("${}", i + 1), e.clone());
Expand All @@ -240,7 +243,12 @@ impl UdfContext {
ret.insert(catalog.arg_names[i].clone(), e.clone());
}
}
_ => return Err(ErrorCode::InvalidInputSyntax("invalid syntax".to_string()).into()),
_ => {
return Err(ErrorCode::InvalidInputSyntax(
"expect unnamed argument when creating sql udf context".to_string(),
)
.into())
}
}
}
Ok(ret)
Expand Down Expand Up @@ -347,7 +355,6 @@ impl Binder {
included_relations: HashSet::new(),
param_types: ParameterTypes::new(param_types),
udf_context: UdfContext::new(),
udf_binding_flag: false,
}
}

Expand Down Expand Up @@ -497,14 +504,6 @@ impl Binder {
pub fn udf_context_mut(&mut self) -> &mut UdfContext {
&mut self.udf_context
}

pub fn set_udf_binding_flag(&mut self) {
self.udf_binding_flag = true;
}

pub fn unset_udf_binding_flag(&mut self) {
self.udf_binding_flag = false;
}
}

/// The column name stored in [`BindContext`] for a column without an alias.
Expand Down
8 changes: 4 additions & 4 deletions src/frontend/src/handler/create_sql_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,9 @@ pub async fn handle_create_sql_function(
arg_names.clone(),
));

binder.set_udf_binding_flag();
// Need to set the initial global count to 1
// otherwise the context will not be probed during the semantic check
binder.udf_context_mut().incr_global_count();

if let Ok(expr) = UdfContext::extract_udf_expression(ast) {
match binder.bind_expr(expr) {
Expand All @@ -204,7 +206,7 @@ pub async fn handle_create_sql_function(
}
}
Err(e) => return Err(ErrorCode::InvalidInputSyntax(format!(
"failed to conduct semantic check, please see if you are calling non-existent functions: {}",
"failed to conduct semantic check, please see if you are calling non-existence functions or parameters\ndetailed error message: {}",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just a current work around, I prefer a better hint display in the future, related: #14853.

e.as_report()
))
.into()),
Expand All @@ -217,8 +219,6 @@ pub async fn handle_create_sql_function(
)
.into());
}

binder.unset_udf_binding_flag();
}

// Create the actual function, will be stored in function catalog
Expand Down
Loading