Skip to content

Commit

Permalink
feat(sql-udf): add semantic check when creating sql udf (risingwavela…
Browse files Browse the repository at this point in the history
  • Loading branch information
xzhseh authored Jan 22, 2024
1 parent 6414832 commit 657b8ab
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 23 deletions.
48 changes: 26 additions & 22 deletions e2e_test/udf/sql_udf.slt
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@ create function add_return(INT, INT) returns int language sql return $1 + $2;
statement ok
create function add_return_binding() returns int language sql return add_return(1, 1) + add_return(1, 1);

# Recursive definition can be accepted, but will be eventually rejected during runtime
statement ok
# Recursive definition can NOT be accepted at present due to semantic check
statement error failed to conduct semantic check, please see if you are calling non-existence functions
create function recursive(INT, INT) returns int language sql as 'select recursive($1, $2) + recursive($1, $2)';

# Complex but error-prone definition, recursive & normal sql udfs interleaving
statement ok
statement error failed to conduct semantic check, please see if you are calling non-existence functions
create function recursive_non_recursive(INT, INT) returns int language sql as 'select recursive($1, $2) + sub($1, $2)';

# Recursive corner case
Expand All @@ -46,7 +46,7 @@ create function add_sub_wrapper(INT, INT) returns int language sql as 'select ad

# Create a valid recursive function
# Please note we do NOT support actual running the recursive sql udf at present
statement ok
statement error failed to conduct semantic check, please see if you are calling non-existence functions
create function fib(INT) returns int
language sql as 'select case
when $1 = 0 then 0
Expand All @@ -57,12 +57,12 @@ create function fib(INT) returns int
end;';

# The execution will eventually exceed the pre-defined max stack depth
statement error function fib calling stack depth limit exceeded
select fib(100);
# statement error function fib calling stack depth limit exceeded
# select fib(100);

# Currently create a materialized view with a recursive sql udf will be rejected
statement error function fib calling stack depth limit exceeded
create materialized view foo_mv as select fib(100);
# statement error function fib calling stack depth limit exceeded
# create materialized view foo_mv as select fib(100);

statement ok
create function regexp_replace_wrapper(varchar) returns varchar language sql as $$select regexp_replace($1, 'baz(...)', '这是🥵', 'ic')$$;
Expand All @@ -77,6 +77,10 @@ create function print_add_one(INT) returns int language sql as 'select print($1
statement ok
create function print_add_two(INT) returns int language sql as 'select print($1 + $1)';

# Calling a non-existence function
statement error failed to conduct semantic check, please see if you are calling non-existence functions
create function non_exist(INT) returns int language sql as 'select yo(114514)';

# Call the defined sql udf
query I
select add(1, -1);
Expand Down Expand Up @@ -124,12 +128,12 @@ select foo(114514);
foo(INT)

# Rejected deep calling stack
statement error function recursive calling stack depth limit exceeded
select recursive(1, 1);
# statement error function recursive calling stack depth limit exceeded
# select recursive(1, 1);

# Same as above
statement error function recursive calling stack depth limit exceeded
select recursive_non_recursive(1, 1);
# statement error function recursive calling stack depth limit exceeded
# select recursive_non_recursive(1, 1);

query I
select add_sub_wrapper(1, 1);
Expand Down Expand Up @@ -168,12 +172,12 @@ select c1, c2, add_return(c1, c2) from t1 order by c1 asc;
5 5 10

# Recursive sql udf with normal table
statement error function fib calling stack depth limit exceeded
select fib(c1) from t1;
# statement error function fib calling stack depth limit exceeded
# select fib(c1) from t1;

# Recursive sql udf with materialized view
statement error function fib calling stack depth limit exceeded
create materialized view bar_mv as select fib(c1) from t1;
# statement error function fib calling stack depth limit exceeded
# create materialized view bar_mv as select fib(c1) from t1;

# Invalid function body syntax
statement error Expected an expression:, found: EOF at the end
Expand Down Expand Up @@ -259,20 +263,20 @@ drop function call_regexp_replace;
statement ok
drop function add_sub_wrapper;

statement ok
drop function recursive;
# statement ok
# drop function recursive;

statement ok
drop function foo;

statement ok
drop function recursive_non_recursive;
# statement ok
# drop function recursive_non_recursive;

statement ok
drop function add_sub_types;

statement ok
drop function fib;
# statement ok
# drop function fib;

statement ok
drop function print;
Expand Down
4 changes: 4 additions & 0 deletions src/frontend/src/binder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,10 @@ impl Binder {
pub fn set_clause(&mut self, clause: Option<Clause>) {
self.context.clause = clause;
}

pub fn udf_context_mut(&mut self) -> &mut UdfContext {
&mut self.udf_context
}
}

/// The column name stored in [`BindContext`] for a column without an alias.
Expand Down
81 changes: 80 additions & 1 deletion src/frontend/src/handler/create_sql_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::collections::HashMap;

use itertools::Itertools;
use pgwire::pg_response::StatementType;
use risingwave_common::catalog::FunctionId;
Expand All @@ -25,8 +27,60 @@ use risingwave_sqlparser::parser::{Parser, ParserError};

use super::*;
use crate::catalog::CatalogError;
use crate::expr::{ExprImpl, Literal};
use crate::{bind_data_type, Binder};

/// Create a mock `udf_context`, which is used for semantic check
fn create_mock_udf_context(arg_types: Vec<DataType>) -> HashMap<String, ExprImpl> {
(1..=arg_types.len())
.map(|i| {
let mock_expr =
ExprImpl::Literal(Box::new(Literal::new(None, arg_types[i - 1].clone())));
(format!("${i}"), mock_expr.clone())
})
.collect()
}

fn extract_udf_expression(ast: Vec<Statement>) -> Result<Expr> {
if ast.len() != 1 {
return Err(ErrorCode::InvalidInputSyntax(
"the query for sql udf should contain only one statement".to_string(),
)
.into());
}

// Extract the expression out
let Statement::Query(query) = ast[0].clone() else {
return Err(ErrorCode::InvalidInputSyntax(
"invalid function definition, please recheck the syntax".to_string(),
)
.into());
};

let SetExpr::Select(select) = query.body else {
return Err(ErrorCode::InvalidInputSyntax(
"missing `select` body for sql udf expression, please recheck the syntax".to_string(),
)
.into());
};

if select.projection.len() != 1 {
return Err(ErrorCode::InvalidInputSyntax(
"`projection` should contain only one `SelectItem`".to_string(),
)
.into());
}

let SelectItem::UnnamedExpr(expr) = select.projection[0].clone() else {
return Err(ErrorCode::InvalidInputSyntax(
"expect `UnnamedExpr` for `projection`".to_string(),
)
.into());
};

Ok(expr)
}

pub async fn handle_create_sql_function(
handler_args: HandlerArgs,
or_replace: bool,
Expand All @@ -45,7 +99,8 @@ pub async fn handle_create_sql_function(
}

let language = "sql".to_string();
// Just a basic sanity check for language

// Just a basic sanity check for `language`
if !matches!(params.language, Some(lang) if lang.real_value().to_lowercase() == "sql") {
return Err(ErrorCode::InvalidParameterValue(
"`language` for sql udf must be `sql`".to_string(),
Expand Down Expand Up @@ -149,6 +204,30 @@ pub async fn handle_create_sql_function(
return Err(ErrorCode::InvalidInputSyntax(err).into());
} else {
debug_assert!(parse_result.is_ok());

// Conduct semantic check (e.g., see if the inner calling functions exist, etc.)
let ast = parse_result.unwrap();
let mut binder = Binder::new_for_system(session);

binder
.udf_context_mut()
.update_context(create_mock_udf_context(arg_types.clone()));

if let Ok(expr) = extract_udf_expression(ast) {
if let Err(e) = binder.bind_expr(expr) {
return Err(ErrorCode::InvalidInputSyntax(
format!("failed to conduct semantic check, please see if you are calling non-existence functions.\nDetailed error: {e}")
)
.into());
}
} else {
return Err(ErrorCode::InvalidInputSyntax(
"failed to parse the input query and extract the udf expression,
please recheck the syntax"
.to_string(),
)
.into());
}
}

// Create the actual function, will be stored in function catalog
Expand Down

0 comments on commit 657b8ab

Please sign in to comment.