diff --git a/e2e_test/udf/sql_udf.slt b/e2e_test/udf/sql_udf.slt index 717a1cf23a829..4b4d1f1c39f7d 100644 --- a/e2e_test/udf/sql_udf.slt +++ b/e2e_test/udf/sql_udf.slt @@ -28,12 +28,12 @@ create function add_return(INT, INT) returns int language sql return $1 + $2; statement ok create function add_return_binding() returns int language sql return add_return(1, 1) + add_return(1, 1); -# Recursive definition can be accepted, but will be eventually rejected during runtime -statement ok +# Recursive definition can NOT be accepted at present due to semantic check +statement error failed to conduct semantic check, please see if you are calling non-existence functions create function recursive(INT, INT) returns int language sql as 'select recursive($1, $2) + recursive($1, $2)'; # Complex but error-prone definition, recursive & normal sql udfs interleaving -statement ok +statement error failed to conduct semantic check, please see if you are calling non-existence functions create function recursive_non_recursive(INT, INT) returns int language sql as 'select recursive($1, $2) + sub($1, $2)'; # Recursive corner case @@ -46,7 +46,7 @@ create function add_sub_wrapper(INT, INT) returns int language sql as 'select ad # Create a valid recursive function # Please note we do NOT support actual running the recursive sql udf at present -statement ok +statement error failed to conduct semantic check, please see if you are calling non-existence functions create function fib(INT) returns int language sql as 'select case when $1 = 0 then 0 @@ -57,12 +57,12 @@ create function fib(INT) returns int end;'; # The execution will eventually exceed the pre-defined max stack depth -statement error function fib calling stack depth limit exceeded -select fib(100); +# statement error function fib calling stack depth limit exceeded +# select fib(100); # Currently create a materialized view with a recursive sql udf will be rejected -statement error function fib calling stack depth limit exceeded -create materialized view foo_mv as select fib(100); +# statement error function fib calling stack depth limit exceeded +# create materialized view foo_mv as select fib(100); statement ok create function regexp_replace_wrapper(varchar) returns varchar language sql as $$select regexp_replace($1, 'baz(...)', '这是🥵', 'ic')$$; @@ -77,6 +77,10 @@ create function print_add_one(INT) returns int language sql as 'select print($1 statement ok create function print_add_two(INT) returns int language sql as 'select print($1 + $1)'; +# Calling a non-existence function +statement error failed to conduct semantic check, please see if you are calling non-existence functions +create function non_exist(INT) returns int language sql as 'select yo(114514)'; + # Call the defined sql udf query I select add(1, -1); @@ -124,12 +128,12 @@ select foo(114514); foo(INT) # Rejected deep calling stack -statement error function recursive calling stack depth limit exceeded -select recursive(1, 1); +# statement error function recursive calling stack depth limit exceeded +# select recursive(1, 1); # Same as above -statement error function recursive calling stack depth limit exceeded -select recursive_non_recursive(1, 1); +# statement error function recursive calling stack depth limit exceeded +# select recursive_non_recursive(1, 1); query I select add_sub_wrapper(1, 1); @@ -168,12 +172,12 @@ select c1, c2, add_return(c1, c2) from t1 order by c1 asc; 5 5 10 # Recursive sql udf with normal table -statement error function fib calling stack depth limit exceeded -select fib(c1) from t1; +# statement error function fib calling stack depth limit exceeded +# select fib(c1) from t1; # Recursive sql udf with materialized view -statement error function fib calling stack depth limit exceeded -create materialized view bar_mv as select fib(c1) from t1; +# statement error function fib calling stack depth limit exceeded +# create materialized view bar_mv as select fib(c1) from t1; # Invalid function body syntax statement error Expected an expression:, found: EOF at the end @@ -259,20 +263,20 @@ drop function call_regexp_replace; statement ok drop function add_sub_wrapper; -statement ok -drop function recursive; +# statement ok +# drop function recursive; statement ok drop function foo; -statement ok -drop function recursive_non_recursive; +# statement ok +# drop function recursive_non_recursive; statement ok drop function add_sub_types; -statement ok -drop function fib; +# statement ok +# drop function fib; statement ok drop function print; diff --git a/src/frontend/src/binder/mod.rs b/src/frontend/src/binder/mod.rs index 772b1f7d0bd1d..bf090a87a7514 100644 --- a/src/frontend/src/binder/mod.rs +++ b/src/frontend/src/binder/mod.rs @@ -415,6 +415,10 @@ impl Binder { pub fn set_clause(&mut self, clause: Option) { self.context.clause = clause; } + + pub fn udf_context_mut(&mut self) -> &mut UdfContext { + &mut self.udf_context + } } /// The column name stored in [`BindContext`] for a column without an alias. diff --git a/src/frontend/src/handler/create_sql_function.rs b/src/frontend/src/handler/create_sql_function.rs index c0f80844351a9..2af3f5d9291b6 100644 --- a/src/frontend/src/handler/create_sql_function.rs +++ b/src/frontend/src/handler/create_sql_function.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::HashMap; + use itertools::Itertools; use pgwire::pg_response::StatementType; use risingwave_common::catalog::FunctionId; @@ -25,8 +27,60 @@ use risingwave_sqlparser::parser::{Parser, ParserError}; use super::*; use crate::catalog::CatalogError; +use crate::expr::{ExprImpl, Literal}; use crate::{bind_data_type, Binder}; +/// Create a mock `udf_context`, which is used for semantic check +fn create_mock_udf_context(arg_types: Vec) -> HashMap { + (1..=arg_types.len()) + .map(|i| { + let mock_expr = + ExprImpl::Literal(Box::new(Literal::new(None, arg_types[i - 1].clone()))); + (format!("${i}"), mock_expr.clone()) + }) + .collect() +} + +fn extract_udf_expression(ast: Vec) -> Result { + if ast.len() != 1 { + return Err(ErrorCode::InvalidInputSyntax( + "the query for sql udf should contain only one statement".to_string(), + ) + .into()); + } + + // Extract the expression out + let Statement::Query(query) = ast[0].clone() else { + return Err(ErrorCode::InvalidInputSyntax( + "invalid function definition, please recheck the syntax".to_string(), + ) + .into()); + }; + + let SetExpr::Select(select) = query.body else { + return Err(ErrorCode::InvalidInputSyntax( + "missing `select` body for sql udf expression, please recheck the syntax".to_string(), + ) + .into()); + }; + + if select.projection.len() != 1 { + return Err(ErrorCode::InvalidInputSyntax( + "`projection` should contain only one `SelectItem`".to_string(), + ) + .into()); + } + + let SelectItem::UnnamedExpr(expr) = select.projection[0].clone() else { + return Err(ErrorCode::InvalidInputSyntax( + "expect `UnnamedExpr` for `projection`".to_string(), + ) + .into()); + }; + + Ok(expr) +} + pub async fn handle_create_sql_function( handler_args: HandlerArgs, or_replace: bool, @@ -45,7 +99,8 @@ pub async fn handle_create_sql_function( } let language = "sql".to_string(); - // Just a basic sanity check for language + + // Just a basic sanity check for `language` if !matches!(params.language, Some(lang) if lang.real_value().to_lowercase() == "sql") { return Err(ErrorCode::InvalidParameterValue( "`language` for sql udf must be `sql`".to_string(), @@ -149,6 +204,30 @@ pub async fn handle_create_sql_function( return Err(ErrorCode::InvalidInputSyntax(err).into()); } else { debug_assert!(parse_result.is_ok()); + + // Conduct semantic check (e.g., see if the inner calling functions exist, etc.) + let ast = parse_result.unwrap(); + let mut binder = Binder::new_for_system(session); + + binder + .udf_context_mut() + .update_context(create_mock_udf_context(arg_types.clone())); + + if let Ok(expr) = extract_udf_expression(ast) { + if let Err(e) = binder.bind_expr(expr) { + return Err(ErrorCode::InvalidInputSyntax( + format!("failed to conduct semantic check, please see if you are calling non-existence functions.\nDetailed error: {e}") + ) + .into()); + } + } else { + return Err(ErrorCode::InvalidInputSyntax( + "failed to parse the input query and extract the udf expression, + please recheck the syntax" + .to_string(), + ) + .into()); + } } // Create the actual function, will be stored in function catalog