Skip to content

Commit

Permalink
feat(sql-udf): deep calling stack (recursion) prevention for sql udf (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
xzhseh authored Jan 12, 2024
1 parent 7aabd3b commit 1afb0ec
Show file tree
Hide file tree
Showing 6 changed files with 134 additions and 18 deletions.
65 changes: 63 additions & 2 deletions e2e_test/udf/sql_udf.slt
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,42 @@ create function add_return(INT, INT) returns int language sql return $1 + $2;
statement ok
create function add_return_binding() returns int language sql return add_return(1, 1) + add_return(1, 1);

# Recursive definition is forbidden
statement error recursive definition is forbidden, please recheck your function syntax
# Recursive definition can be accepted, but will be eventually rejected during runtime
statement ok
create function recursive(INT, INT) returns int language sql as 'select recursive($1, $2) + recursive($1, $2)';

# Complex but error-prone definition, recursive & normal sql udfs interleaving
statement ok
create function recursive_non_recursive(INT, INT) returns int language sql as 'select recursive($1, $2) + sub($1, $2)';

# Recursive corner case
statement ok
create function foo(INT) returns varchar language sql as $$select 'foo(INT)'$$;

# Create a wrapper function for `add` & `sub`
statement ok
create function add_sub_wrapper(INT, INT) returns int language sql as 'select add($1, $2) + sub($1, $2) + 114512';

# Create a valid recursive function
# Please note we do NOT support actual running the recursive sql udf at present
statement ok
create function fib(INT) returns int
language sql as 'select case
when $1 = 0 then 0
when $1 = 1 then 1
when $1 = 2 then 1
when $1 = 3 then 2
else fib($1 - 1) + fib($1 - 2)
end;';

# The execution will eventually exceed the pre-defined max stack depth
statement error function fib calling stack depth limit exceeded
select fib(100);

# Currently create a materialized view with a recursive sql udf will be rejected
statement error function fib calling stack depth limit exceeded
create materialized view foo_mv as select fib(100);

# Call the defined sql udf
query I
select add(1, -1);
Expand Down Expand Up @@ -72,6 +100,19 @@ select call_regexp_replace();
----
💩💩💩💩💩foo🤔️bar亲爱的😭这是🥵爱情❤️‍🔥

query T
select foo(114514);
----
foo(INT)

# Rejected deep calling stack
statement error function recursive calling stack depth limit exceeded
select recursive(1, 1);

# Same as above
statement error function recursive calling stack depth limit exceeded
select recursive_non_recursive(1, 1);

query I
select add_sub_wrapper(1, 1);
----
Expand Down Expand Up @@ -103,6 +144,14 @@ select c1, c2, add_return(c1, c2) from t1 order by c1 asc;
4 4 8
5 5 10

# Recursive sql udf with normal table
statement error function fib calling stack depth limit exceeded
select fib(c1) from t1;

# Recursive sql udf with materialized view
statement error function fib calling stack depth limit exceeded
create materialized view bar_mv as select fib(c1) from t1;

# Invalid function body syntax
statement error Expected an expression:, found: EOF at the end
create function add_error(INT, INT) returns int language sql as $$select $1 + $2 +$$;
Expand Down Expand Up @@ -187,9 +236,21 @@ drop function call_regexp_replace;
statement ok
drop function add_sub_wrapper;

statement ok
drop function recursive;

statement ok
drop function foo;

statement ok
drop function recursive_non_recursive;

statement ok
drop function add_sub_types;

statement ok
drop function fib;

# Drop the mock table
statement ok
drop table t1;
Expand Down
2 changes: 1 addition & 1 deletion src/frontend/src/binder/expr/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ impl Binder {
// to the name of the defined sql udf parameters stored in `udf_context`.
// If so, we will treat this bind as an special bind, the actual expression
// stored in `udf_context` will then be bound instead of binding the non-existing column.
if let Some(expr) = self.udf_context.get(&column_name) {
if let Some(expr) = self.udf_context.get_expr(&column_name) {
return self.bind_expr(expr.clone());
}

Expand Down
24 changes: 22 additions & 2 deletions src/frontend/src/binder/expr/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ pub const SYS_FUNCTION_WITHOUT_ARGS: &[&str] = &[
"current_timestamp",
];

/// The global max calling depth for the global counter in `udf_context`
/// To reduce the chance that the current running rw thread
/// be killed by os, the current allowance depth of calling
/// stack is set to `16`.
const SQL_UDF_MAX_CALLING_DEPTH: u32 = 16;

impl Binder {
pub(in crate::binder) fn bind_function(&mut self, f: Function) -> Result<ExprImpl> {
let function_name = match f.name.0.as_slice() {
Expand Down Expand Up @@ -235,6 +241,7 @@ impl Binder {
)
.into());
}

// This represents the current user defined function is `language sql`
let parse_result = risingwave_sqlparser::parser::Parser::parse_sql(
func.body.as_ref().unwrap().as_str(),
Expand All @@ -245,6 +252,7 @@ impl Binder {
// Here we just return the original parse error message
return Err(ErrorCode::InvalidInputSyntax(err).into());
}

debug_assert!(parse_result.is_ok());

// We can safely unwrap here
Expand All @@ -263,7 +271,7 @@ impl Binder {
if self.udf_context.is_empty() {
// The actual inline logic for sql udf
if let Ok(context) = create_udf_context(&args, &Arc::clone(func)) {
self.udf_context = context;
self.udf_context.update_context(context);
} else {
return Err(ErrorCode::InvalidInputSyntax(
"failed to create the `udf_context`, please recheck your function definition and syntax".to_string()
Expand All @@ -277,9 +285,21 @@ impl Binder {
clean_flag = false;
}

// Check for potential recursive calling
if self.udf_context.global_count() >= SQL_UDF_MAX_CALLING_DEPTH {
return Err(ErrorCode::BindError(format!(
"function {} calling stack depth limit exceeded",
&function_name
))
.into());
} else {
// Update the status for the global counter
self.udf_context.incr_global_count();
}

if let Ok(expr) = extract_udf_expression(ast) {
let bind_result = self.bind_expr(expr);
// Clean the `udf_context` after inlining,
// Clean the `udf_context` & `udf_recursive_context` after inlining,
// which makes sure the subsequent binding will not be affected
if clean_flag {
self.udf_context.clear();
Expand Down
2 changes: 1 addition & 1 deletion src/frontend/src/binder/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ impl Binder {
// Note: This is specific to anonymous sql udf, since the
// parameters will be parsed and treated as `Parameter`.
// For detailed explanation, consider checking `bind_column`.
if let Some(expr) = self.udf_context.get(&format!("${index}")) {
if let Some(expr) = self.udf_context.get_expr(&format!("${index}")) {
return self.bind_expr(expr.clone());
}

Expand Down
50 changes: 47 additions & 3 deletions src/frontend/src/binder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,53 @@ pub struct Binder {

param_types: ParameterTypes,

/// The mapping from sql udf parameters to ast expressions
/// The sql udf context that will be used during binding phase
udf_context: UdfContext,
}

#[derive(Clone, Debug, Default)]
pub struct UdfContext {
/// The mapping from `sql udf parameters` to `ast expressions`
/// Note: The expressions are constructed during runtime, correspond to the actual users' input
udf_context: HashMap<String, AstExpr>,
udf_param_context: HashMap<String, AstExpr>,

/// The global counter that records the calling stack depth
/// of the current binding sql udf chain
udf_global_counter: u32,
}

impl UdfContext {
pub fn new() -> Self {
Self {
udf_param_context: HashMap::new(),
udf_global_counter: 0,
}
}

pub fn global_count(&self) -> u32 {
self.udf_global_counter
}

pub fn incr_global_count(&mut self) {
self.udf_global_counter += 1;
}

pub fn is_empty(&self) -> bool {
self.udf_param_context.is_empty()
}

pub fn update_context(&mut self, context: HashMap<String, AstExpr>) {
self.udf_param_context = context;
}

pub fn clear(&mut self) {
self.udf_global_counter = 0;
self.udf_param_context.clear();
}

pub fn get_expr(&self, name: &str) -> Option<&AstExpr> {
self.udf_param_context.get(name)
}
}

/// `ParameterTypes` is used to record the types of the parameters during binding. It works
Expand Down Expand Up @@ -220,7 +264,7 @@ impl Binder {
shared_views: HashMap::new(),
included_relations: HashSet::new(),
param_types: ParameterTypes::new(param_types),
udf_context: HashMap::new(),
udf_context: UdfContext::new(),
}
}

Expand Down
9 changes: 0 additions & 9 deletions src/frontend/src/handler/create_sql_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,6 @@ pub async fn handle_create_sql_function(
}
};

// We do NOT allow recursive calling inside sql udf
// Since there does not exist the base case for this definition
if body.contains(format!("{}(", name.real_value()).as_str()) {
return Err(ErrorCode::InvalidInputSyntax(
"recursive definition is forbidden, please recheck your function syntax".to_string(),
)
.into());
}

// Sanity check for link, this must be none with sql udf function
if let Some(CreateFunctionUsing::Link(_)) = params.using {
return Err(ErrorCode::InvalidParameterValue(
Expand Down

0 comments on commit 1afb0ec

Please sign in to comment.