diff --git a/e2e_test/udf/sql_udf.slt b/e2e_test/udf/sql_udf.slt index e1100834c9bbd..a75c283934b41 100644 --- a/e2e_test/udf/sql_udf.slt +++ b/e2e_test/udf/sql_udf.slt @@ -18,6 +18,9 @@ create function add_sub_binding() returns int language sql as 'select add(1, 1) statement ok create function call_regexp_replace() returns varchar language sql as $$select regexp_replace('💩💩💩💩💩foo🤔️bar亲爱的😭baz这不是爱情❤️‍🔥', 'baz(...)', '这是🥵', 'ic')$$; +statement ok +create function regexp_replace_wrapper(varchar) returns varchar language sql as $$select regexp_replace($1, 'baz(...)', '这是🥵', 'ic')$$; + statement error Expected end of statement, found: 💩 create function call_regexp_replace() returns varchar language sql as 'select regexp_replace('💩💩💩💩💩foo🤔️bar亲爱的😭baz这不是爱情❤️‍🔥', 'baz(...)', '这是🥵', 'ic')'; @@ -36,6 +39,16 @@ create function recursive(INT, INT) returns int language sql as 'select recursiv statement ok create function add_sub_wrapper(INT, INT) returns int language sql as 'select add($1, $2) + sub($1, $2) + 114512'; +statement ok +create function print(INT) returns int language sql as 'select $1'; + +# Adjust the input value of the calling function (i.e., `print` here) with the actual input parameter +statement ok +create function print_add_one(INT) returns int language sql as 'select print($1 + 1)'; + +statement ok +create function print_add_two(INT) returns int language sql as 'select print($1 + $1)'; + # Call the defined sql udf query I select add(1, -1); @@ -72,11 +85,21 @@ select call_regexp_replace(); ---- 💩💩💩💩💩foo🤔️bar亲爱的😭这是🥵爱情❤️‍🔥 +query T +select regexp_replace_wrapper('💩💩💩💩💩foo🤔️bar亲爱的😭baz这不是爱情❤️‍🔥'); +---- +💩💩💩💩💩foo🤔️bar亲爱的😭这是🥵爱情❤️‍🔥 + query I select add_sub_wrapper(1, 1); ---- 114514 +query III +select print_add_one(1), print_add_one(114513), print_add_two(2); +---- +2 114514 4 + # Create a mock table statement ok create table t1 (c1 INT, c2 INT); @@ -187,6 +210,18 @@ drop function call_regexp_replace; statement ok drop function add_sub_wrapper; +statement ok +drop function print; + +statement ok +drop function print_add_one; + +statement ok +drop function print_add_two; + +statement ok +drop function regexp_replace_wrapper; + statement ok drop function add_sub_types; diff --git a/src/frontend/src/binder/expr/column.rs b/src/frontend/src/binder/expr/column.rs index 2f2a8d9335256..5a55952ad270f 100644 --- a/src/frontend/src/binder/expr/column.rs +++ b/src/frontend/src/binder/expr/column.rs @@ -46,7 +46,7 @@ impl Binder { // If so, we will treat this bind as an special bind, the actual expression // stored in `udf_context` will then be bound instead of binding the non-existing column. if let Some(expr) = self.udf_context.get(&column_name) { - return self.bind_expr(expr.clone()); + return Ok(expr.clone()); } match self diff --git a/src/frontend/src/binder/expr/function.rs b/src/frontend/src/binder/expr/function.rs index a0545b81b17d6..06505aa4e5854 100644 --- a/src/frontend/src/binder/expr/function.rs +++ b/src/frontend/src/binder/expr/function.rs @@ -245,36 +245,53 @@ impl Binder { // Here we just return the original parse error message return Err(ErrorCode::InvalidInputSyntax(err).into()); } + debug_assert!(parse_result.is_ok()); // We can safely unwrap here let ast = parse_result.unwrap(); - let mut clean_flag = true; + let mut clean_flag = false; - // We need to check if the `udf_context` is empty first, consider the following example: - // - create function add(INT, INT) returns int language sql as 'select $1 + $2'; - // - create function add_wrapper(INT, INT) returns int language sql as 'select add($1, $2)'; - // - select add_wrapper(1, 1); - // When binding `add($1, $2)` in `add_wrapper`, the input args are [$1, $2] instead of - // the original [1, 1], thus we need to check `udf_context` to see if the input - // args already exist in the context. If so, we do NOT need to create the context again. - // Otherwise the current `udf_context` will be corrupted. + // We need to check if the `udf_context` is empty first, + // If so, we will clear the `udf_context` after binding. + // Since this is the root (top-most) binding for sql udf. + // Otherwise we need to restore the context later, or the + // original `udf_context` will be corrupted. if self.udf_context.is_empty() { - // The actual inline logic for sql udf - if let Ok(context) = create_udf_context(&args, &Arc::clone(func)) { - self.udf_context = context; - } else { - return Err(ErrorCode::InvalidInputSyntax( - "failed to create the `udf_context`, please recheck your function definition and syntax".to_string() - ) - .into()); + clean_flag = true; + } + + // Stash the current `udf_context` + let prev_context = self.udf_context.clone(); + + // The actual inline logic for sql udf + // Note that we will always create new udf context for each sql udf + if let Ok(context) = create_udf_context(&args, &Arc::clone(func)) { + let mut udf_context = HashMap::new(); + for (c, e) in context { + // Note that we need to bind the args before actual delve in the function body + // This will update the context in the subsequent inner calling function + // e.g., + // - create function print(INT) returns int language sql as 'select $1'; + // - create function print_add_one(INT) returns int language sql as 'select print($1 + 1)'; + // - select print_add_one(1); # The result should be 2 instead of 1. + // Without the pre-binding here, the ($1 + 1) will not be correctly populated, + // causing the result to always be 1. + let Ok(e) = self.bind_expr(e) else { + return Err(ErrorCode::BindError(format!( + "failed to bind the argument, please recheck your syntax" + )) + .into()); + }; + udf_context.insert(c, e); } + self.udf_context = udf_context; } else { - // If the `udf_context` is not empty, this means the current binding - // function is not the root binding sql udf, thus we should NOT - // clean the context after binding. - clean_flag = false; + return Err(ErrorCode::InvalidInputSyntax( + "failed to create the `udf_context`, please recheck your function definition and syntax".to_string() + ) + .into()); } if let Ok(expr) = extract_udf_expression(ast) { @@ -283,6 +300,9 @@ impl Binder { // which makes sure the subsequent binding will not be affected if clean_flag { self.udf_context.clear(); + } else { + // Restore context information for subsequent binding + self.udf_context = prev_context; } return bind_result; } else { diff --git a/src/frontend/src/binder/expr/mod.rs b/src/frontend/src/binder/expr/mod.rs index cacd2d80dcfe4..89e038e628097 100644 --- a/src/frontend/src/binder/expr/mod.rs +++ b/src/frontend/src/binder/expr/mod.rs @@ -383,7 +383,7 @@ impl Binder { // parameters will be parsed and treated as `Parameter`. // For detailed explanation, consider checking `bind_column`. if let Some(expr) = self.udf_context.get(&format!("${index}")) { - return self.bind_expr(expr.clone()); + return Ok(expr.clone()); } Ok(Parameter::new(index, self.param_types.clone()).into()) diff --git a/src/frontend/src/binder/mod.rs b/src/frontend/src/binder/mod.rs index 6ba891aa6b513..75e3a6001b87f 100644 --- a/src/frontend/src/binder/mod.rs +++ b/src/frontend/src/binder/mod.rs @@ -21,7 +21,7 @@ use risingwave_common::error::Result; use risingwave_common::session_config::{ConfigMap, SearchPath}; use risingwave_common::types::DataType; use risingwave_common::util::iter_util::ZipEqDebug; -use risingwave_sqlparser::ast::{Expr as AstExpr, Statement}; +use risingwave_sqlparser::ast::Statement; mod bind_context; mod bind_param; @@ -59,6 +59,7 @@ pub use values::BoundValues; use crate::catalog::catalog_service::CatalogReadGuard; use crate::catalog::schema_catalog::SchemaCatalog; use crate::catalog::{CatalogResult, TableId, ViewId}; +use crate::expr::ExprImpl; use crate::session::{AuthContext, SessionImpl}; pub type ShareId = usize; @@ -116,9 +117,9 @@ pub struct Binder { param_types: ParameterTypes, - /// The mapping from sql udf parameters to ast expressions + /// The mapping from `sql udf parameters` to `ExprImpl` generated from ast expressions /// Note: The expressions are constructed during runtime, correspond to the actual users' input - udf_context: HashMap, + udf_context: HashMap, } /// `ParameterTypes` is used to record the types of the parameters during binding. It works