diff --git a/e2e_test/udf/sql_udf.slt b/e2e_test/udf/sql_udf.slt index e1100834c9bb..02fc23b2d7f0 100644 --- a/e2e_test/udf/sql_udf.slt +++ b/e2e_test/udf/sql_udf.slt @@ -28,14 +28,42 @@ create function add_return(INT, INT) returns int language sql return $1 + $2; statement ok create function add_return_binding() returns int language sql return add_return(1, 1) + add_return(1, 1); -# Recursive definition is forbidden -statement error recursive definition is forbidden, please recheck your function syntax +# Recursive definition can be accepted, but will be eventually rejected during runtime +statement ok create function recursive(INT, INT) returns int language sql as 'select recursive($1, $2) + recursive($1, $2)'; +# Complex but error-prone definition, recursive & normal sql udfs interleaving +statement ok +create function recursive_non_recursive(INT, INT) returns int language sql as 'select recursive($1, $2) + sub($1, $2)'; + +# Recursive corner case +statement ok +create function foo(INT) returns varchar language sql as $$select 'foo(INT)'$$; + # Create a wrapper function for `add` & `sub` statement ok create function add_sub_wrapper(INT, INT) returns int language sql as 'select add($1, $2) + sub($1, $2) + 114512'; +# Create a valid recursive function +# Please note we do NOT support actual running the recursive sql udf at present +statement ok +create function fib(INT) returns int + language sql as 'select case + when $1 = 0 then 0 + when $1 = 1 then 1 + when $1 = 2 then 1 + when $1 = 3 then 2 + else fib($1 - 1) + fib($1 - 2) + end;'; + +# The execution will eventually exceed the pre-defined max stack depth +statement error function fib calling stack depth limit exceeded +select fib(100); + +# Currently create a materialized view with a recursive sql udf will be rejected +statement error function fib calling stack depth limit exceeded +create materialized view foo_mv as select fib(100); + # Call the defined sql udf query I select add(1, -1); @@ -72,6 +100,19 @@ select call_regexp_replace(); ---- 💩💩💩💩💩foo🤔️bar亲爱的😭这是🥵爱情❤️‍🔥 +query T +select foo(114514); +---- +foo(INT) + +# Rejected deep calling stack +statement error function recursive calling stack depth limit exceeded +select recursive(1, 1); + +# Same as above +statement error function recursive calling stack depth limit exceeded +select recursive_non_recursive(1, 1); + query I select add_sub_wrapper(1, 1); ---- @@ -103,6 +144,14 @@ select c1, c2, add_return(c1, c2) from t1 order by c1 asc; 4 4 8 5 5 10 +# Recursive sql udf with normal table +statement error function fib calling stack depth limit exceeded +select fib(c1) from t1; + +# Recursive sql udf with materialized view +statement error function fib calling stack depth limit exceeded +create materialized view bar_mv as select fib(c1) from t1; + # Invalid function body syntax statement error Expected an expression:, found: EOF at the end create function add_error(INT, INT) returns int language sql as $$select $1 + $2 +$$; @@ -187,9 +236,21 @@ drop function call_regexp_replace; statement ok drop function add_sub_wrapper; +statement ok +drop function recursive; + +statement ok +drop function foo; + +statement ok +drop function recursive_non_recursive; + statement ok drop function add_sub_types; +statement ok +drop function fib; + # Drop the mock table statement ok drop table t1; diff --git a/src/frontend/src/binder/expr/column.rs b/src/frontend/src/binder/expr/column.rs index 2f2a8d933525..cac4f7eccd62 100644 --- a/src/frontend/src/binder/expr/column.rs +++ b/src/frontend/src/binder/expr/column.rs @@ -45,7 +45,7 @@ impl Binder { // to the name of the defined sql udf parameters stored in `udf_context`. // If so, we will treat this bind as an special bind, the actual expression // stored in `udf_context` will then be bound instead of binding the non-existing column. - if let Some(expr) = self.udf_context.get(&column_name) { + if let Some(expr) = self.udf_context.get_expr(&column_name) { return self.bind_expr(expr.clone()); } diff --git a/src/frontend/src/binder/expr/function.rs b/src/frontend/src/binder/expr/function.rs index b92b7e832f81..de4f0f4aa633 100644 --- a/src/frontend/src/binder/expr/function.rs +++ b/src/frontend/src/binder/expr/function.rs @@ -55,6 +55,12 @@ pub const SYS_FUNCTION_WITHOUT_ARGS: &[&str] = &[ "current_timestamp", ]; +/// The global max calling depth for the global counter in `udf_context` +/// To reduce the chance that the current running rw thread +/// be killed by os, the current allowance depth of calling +/// stack is set to `16`. +const SQL_UDF_MAX_CALLING_DEPTH: u32 = 16; + impl Binder { pub(in crate::binder) fn bind_function(&mut self, f: Function) -> Result { let function_name = match f.name.0.as_slice() { @@ -235,6 +241,7 @@ impl Binder { ) .into()); } + // This represents the current user defined function is `language sql` let parse_result = risingwave_sqlparser::parser::Parser::parse_sql( func.body.as_ref().unwrap().as_str(), @@ -245,6 +252,7 @@ impl Binder { // Here we just return the original parse error message return Err(ErrorCode::InvalidInputSyntax(err).into()); } + debug_assert!(parse_result.is_ok()); // We can safely unwrap here @@ -263,7 +271,7 @@ impl Binder { if self.udf_context.is_empty() { // The actual inline logic for sql udf if let Ok(context) = create_udf_context(&args, &Arc::clone(func)) { - self.udf_context = context; + self.udf_context.update_context(context); } else { return Err(ErrorCode::InvalidInputSyntax( "failed to create the `udf_context`, please recheck your function definition and syntax".to_string() @@ -277,9 +285,21 @@ impl Binder { clean_flag = false; } + // Check for potential recursive calling + if self.udf_context.global_count() >= SQL_UDF_MAX_CALLING_DEPTH { + return Err(ErrorCode::BindError(format!( + "function {} calling stack depth limit exceeded", + &function_name + )) + .into()); + } else { + // Update the status for the global counter + self.udf_context.incr_global_count(); + } + if let Ok(expr) = extract_udf_expression(ast) { let bind_result = self.bind_expr(expr); - // Clean the `udf_context` after inlining, + // Clean the `udf_context` & `udf_recursive_context` after inlining, // which makes sure the subsequent binding will not be affected if clean_flag { self.udf_context.clear(); diff --git a/src/frontend/src/binder/expr/mod.rs b/src/frontend/src/binder/expr/mod.rs index cacd2d80dcfe..1b3dcd5dd051 100644 --- a/src/frontend/src/binder/expr/mod.rs +++ b/src/frontend/src/binder/expr/mod.rs @@ -382,7 +382,7 @@ impl Binder { // Note: This is specific to anonymous sql udf, since the // parameters will be parsed and treated as `Parameter`. // For detailed explanation, consider checking `bind_column`. - if let Some(expr) = self.udf_context.get(&format!("${index}")) { + if let Some(expr) = self.udf_context.get_expr(&format!("${index}")) { return self.bind_expr(expr.clone()); } diff --git a/src/frontend/src/binder/mod.rs b/src/frontend/src/binder/mod.rs index 6ba891aa6b51..51b53d23a2e3 100644 --- a/src/frontend/src/binder/mod.rs +++ b/src/frontend/src/binder/mod.rs @@ -116,9 +116,53 @@ pub struct Binder { param_types: ParameterTypes, - /// The mapping from sql udf parameters to ast expressions + /// The sql udf context that will be used during binding phase + udf_context: UdfContext, +} + +#[derive(Clone, Debug, Default)] +pub struct UdfContext { + /// The mapping from `sql udf parameters` to `ast expressions` /// Note: The expressions are constructed during runtime, correspond to the actual users' input - udf_context: HashMap, + udf_param_context: HashMap, + + /// The global counter that records the calling stack depth + /// of the current binding sql udf chain + udf_global_counter: u32, +} + +impl UdfContext { + pub fn new() -> Self { + Self { + udf_param_context: HashMap::new(), + udf_global_counter: 0, + } + } + + pub fn global_count(&self) -> u32 { + self.udf_global_counter + } + + pub fn incr_global_count(&mut self) { + self.udf_global_counter += 1; + } + + pub fn is_empty(&self) -> bool { + self.udf_param_context.is_empty() + } + + pub fn update_context(&mut self, context: HashMap) { + self.udf_param_context = context; + } + + pub fn clear(&mut self) { + self.udf_global_counter = 0; + self.udf_param_context.clear(); + } + + pub fn get_expr(&self, name: &str) -> Option<&AstExpr> { + self.udf_param_context.get(name) + } } /// `ParameterTypes` is used to record the types of the parameters during binding. It works @@ -220,7 +264,7 @@ impl Binder { shared_views: HashMap::new(), included_relations: HashSet::new(), param_types: ParameterTypes::new(param_types), - udf_context: HashMap::new(), + udf_context: UdfContext::new(), } } diff --git a/src/frontend/src/handler/create_sql_function.rs b/src/frontend/src/handler/create_sql_function.rs index 834e0bec3135..bbe504d779bf 100644 --- a/src/frontend/src/handler/create_sql_function.rs +++ b/src/frontend/src/handler/create_sql_function.rs @@ -72,15 +72,6 @@ pub async fn handle_create_sql_function( } }; - // We do NOT allow recursive calling inside sql udf - // Since there does not exist the base case for this definition - if body.contains(format!("{}(", name.real_value()).as_str()) { - return Err(ErrorCode::InvalidInputSyntax( - "recursive definition is forbidden, please recheck your function syntax".to_string(), - ) - .into()); - } - // Sanity check for link, this must be none with sql udf function if let Some(CreateFunctionUsing::Link(_)) = params.using { return Err(ErrorCode::InvalidParameterValue(