diff --git a/e2e_test/udf/sql_udf.slt b/e2e_test/udf/sql_udf.slt index f2e3b2bfd5288..bfae0f7045888 100644 --- a/e2e_test/udf/sql_udf.slt +++ b/e2e_test/udf/sql_udf.slt @@ -209,8 +209,13 @@ select add_named(a, b) from t3 order by a asc; ################################ # Mixed parameter with calling inner sql udfs -# statement ok -# create function add_sub_mix_wrapper(INT, a INT, INT) returns int language sql as 'select add($1, a) + a + sub(a, $3)'; +statement ok +create function add_sub_mix_wrapper(INT, a INT, INT) returns int language sql as 'select add($1, a) + a + sub(a, $3)'; + +query I +select add_sub_mix_wrapper(1, 2, 3); +---- +4 # Named sql udf with corner case statement ok @@ -404,6 +409,9 @@ drop function add_named_wrapper; statement ok drop function type_match; +statement ok +drop function add_sub_mix_wrapper; + statement ok drop table t1; diff --git a/src/frontend/src/binder/expr/column.rs b/src/frontend/src/binder/expr/column.rs index 9fb17c6e43520..c2fd536c1ccc4 100644 --- a/src/frontend/src/binder/expr/column.rs +++ b/src/frontend/src/binder/expr/column.rs @@ -45,7 +45,7 @@ impl Binder { // to the name of the defined sql udf parameters stored in `udf_context`. // If so, we will treat this bind as an special bind, the actual expression // stored in `udf_context` will then be bound instead of binding the non-existing column. - if self.udf_binding_flag { + if self.udf_context.global_count() != 0 { if let Some(expr) = self.udf_context.get_expr(&column_name) { return Ok(expr.clone()); } else { diff --git a/src/frontend/src/binder/expr/function.rs b/src/frontend/src/binder/expr/function.rs index 3d2aba906a7db..bb6e4ee14c335 100644 --- a/src/frontend/src/binder/expr/function.rs +++ b/src/frontend/src/binder/expr/function.rs @@ -240,12 +240,16 @@ impl Binder { } if let Ok(expr) = UdfContext::extract_udf_expression(ast) { - self.set_udf_binding_flag(); let bind_result = self.bind_expr(expr); - self.unset_udf_binding_flag(); + + // We should properly decrement global count after a successful binding + // Since the subsequent probe operation in `bind_column` or + // `bind_parameter` relies on global counting + self.udf_context.decr_global_count(); // Restore context information for subsequent binding self.udf_context.update_context(stashed_udf_context); + return bind_result; } else { return Err(ErrorCode::InvalidInputSyntax( diff --git a/src/frontend/src/binder/expr/mod.rs b/src/frontend/src/binder/expr/mod.rs index 7da1518ceb552..a50eec922143b 100644 --- a/src/frontend/src/binder/expr/mod.rs +++ b/src/frontend/src/binder/expr/mod.rs @@ -386,11 +386,13 @@ impl Binder { fn bind_parameter(&mut self, index: u64) -> Result { // Special check for sql udf - // Note: This is specific to anonymous sql udf, since the + // Note: This is specific to sql udf with unnamed parameters, since the // parameters will be parsed and treated as `Parameter`. // For detailed explanation, consider checking `bind_column`. - if let Some(expr) = self.udf_context.get_expr(&format!("${index}")) { - return Ok(expr.clone()); + if self.udf_context.global_count() != 0 { + if let Some(expr) = self.udf_context.get_expr(&format!("${index}")) { + return Ok(expr.clone()); + } } Ok(Parameter::new(index, self.param_types.clone()).into()) diff --git a/src/frontend/src/binder/mod.rs b/src/frontend/src/binder/mod.rs index e726d7dc01d4c..f1c7d97c57fa2 100644 --- a/src/frontend/src/binder/mod.rs +++ b/src/frontend/src/binder/mod.rs @@ -122,10 +122,6 @@ pub struct Binder { /// The sql udf context that will be used during binding phase udf_context: UdfContext, - - /// Udf binding flag, used to distinguish between - /// columns and named parameters during sql udf binding - udf_binding_flag: bool, } #[derive(Clone, Debug, Default)] @@ -155,6 +151,10 @@ impl UdfContext { self.udf_global_counter += 1; } + pub fn decr_global_count(&mut self) { + self.udf_global_counter -= 1; + } + pub fn _is_empty(&self) -> bool { self.udf_param_context.is_empty() } @@ -219,6 +219,8 @@ impl UdfContext { Ok(expr) } + /// Create the sql udf context + /// used per `bind_function` for sql udf & semantic check at definition time pub fn create_udf_context( args: &[FunctionArg], catalog: &Arc, @@ -228,9 +230,10 @@ impl UdfContext { match current_arg { FunctionArg::Unnamed(arg) => { let FunctionArgExpr::Expr(e) = arg else { - return Err( - ErrorCode::InvalidInputSyntax("invalid syntax".to_string()).into() - ); + return Err(ErrorCode::InvalidInputSyntax( + "expect `FunctionArgExpr` for unnamed argument".to_string(), + ) + .into()); }; if catalog.arg_names[i].is_empty() { ret.insert(format!("${}", i + 1), e.clone()); @@ -240,7 +243,12 @@ impl UdfContext { ret.insert(catalog.arg_names[i].clone(), e.clone()); } } - _ => return Err(ErrorCode::InvalidInputSyntax("invalid syntax".to_string()).into()), + _ => { + return Err(ErrorCode::InvalidInputSyntax( + "expect unnamed argument when creating sql udf context".to_string(), + ) + .into()) + } } } Ok(ret) @@ -347,7 +355,6 @@ impl Binder { included_relations: HashSet::new(), param_types: ParameterTypes::new(param_types), udf_context: UdfContext::new(), - udf_binding_flag: false, } } @@ -497,14 +504,6 @@ impl Binder { pub fn udf_context_mut(&mut self) -> &mut UdfContext { &mut self.udf_context } - - pub fn set_udf_binding_flag(&mut self) { - self.udf_binding_flag = true; - } - - pub fn unset_udf_binding_flag(&mut self) { - self.udf_binding_flag = false; - } } /// The column name stored in [`BindContext`] for a column without an alias. diff --git a/src/frontend/src/handler/create_sql_function.rs b/src/frontend/src/handler/create_sql_function.rs index de24027723bbf..311664735603f 100644 --- a/src/frontend/src/handler/create_sql_function.rs +++ b/src/frontend/src/handler/create_sql_function.rs @@ -188,7 +188,9 @@ pub async fn handle_create_sql_function( arg_names.clone(), )); - binder.set_udf_binding_flag(); + // Need to set the initial global count to 1 + // otherwise the context will not be probed during the semantic check + binder.udf_context_mut().incr_global_count(); if let Ok(expr) = UdfContext::extract_udf_expression(ast) { match binder.bind_expr(expr) { @@ -204,7 +206,7 @@ pub async fn handle_create_sql_function( } } Err(e) => return Err(ErrorCode::InvalidInputSyntax(format!( - "failed to conduct semantic check, please see if you are calling non-existent functions: {}", + "failed to conduct semantic check, please see if you are calling non-existence functions or parameters\ndetailed error message: {}", e.as_report() )) .into()), @@ -217,8 +219,6 @@ pub async fn handle_create_sql_function( ) .into()); } - - binder.unset_udf_binding_flag(); } // Create the actual function, will be stored in function catalog