Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(sql-udf): deep calling stack (recursion) prevention for sql udf #14392

Merged
merged 12 commits into from
Jan 12, 2024
65 changes: 63 additions & 2 deletions e2e_test/udf/sql_udf.slt
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,42 @@ create function add_return(INT, INT) returns int language sql return $1 + $2;
statement ok
create function add_return_binding() returns int language sql return add_return(1, 1) + add_return(1, 1);

# Recursive definition is forbidden
statement error recursive definition is forbidden, please recheck your function syntax
# Recursive definition can be accepted, but will be eventually rejected during runtime
statement ok
create function recursive(INT, INT) returns int language sql as 'select recursive($1, $2) + recursive($1, $2)';

# Complex but error-prone definition, recursive & normal sql udfs interleaving
statement ok
create function recursive_non_recursive(INT, INT) returns int language sql as 'select recursive($1, $2) + sub($1, $2)';

# Recursive corner case
statement ok
create function foo(INT) returns varchar language sql as $$select 'foo(INT)'$$;

# Create a wrapper function for `add` & `sub`
statement ok
create function add_sub_wrapper(INT, INT) returns int language sql as 'select add($1, $2) + sub($1, $2) + 114512';

# Create a valid recursive function
# Please note we do NOT support actual running the recursive sql udf at present
statement ok
create function fib(INT) returns int
language sql as 'select case
when $1 = 0 then 0
when $1 = 1 then 1
when $1 = 2 then 1
when $1 = 3 then 2
else fib($1 - 1) + fib($1 - 2)
end;';

# The execution will eventually exceed the pre-defined max stack depth
statement error function fib calling stack depth limit exceeded
select fib(100);

# Currently create a materialized view with a recursive sql udf will be rejected
statement error function fib calling stack depth limit exceeded
create materialized view foo_mv as select fib(100);

# Call the defined sql udf
query I
select add(1, -1);
Expand Down Expand Up @@ -72,6 +100,19 @@ select call_regexp_replace();
----
💩💩💩💩💩foo🤔️bar亲爱的😭这是🥵爱情❤️‍🔥

query T
select foo(114514);
----
foo(INT)

# Rejected deep calling stack
statement error function recursive calling stack depth limit exceeded
select recursive(1, 1);

# Same as above
statement error function recursive calling stack depth limit exceeded
select recursive_non_recursive(1, 1);
xzhseh marked this conversation as resolved.
Show resolved Hide resolved

query I
select add_sub_wrapper(1, 1);
----
Expand Down Expand Up @@ -103,6 +144,14 @@ select c1, c2, add_return(c1, c2) from t1 order by c1 asc;
4 4 8
5 5 10

# Recursive sql udf with normal table
statement error function fib calling stack depth limit exceeded
select fib(c1) from t1;

# Recursive sql udf with materialized view
statement error function fib calling stack depth limit exceeded
create materialized view bar_mv as select fib(c1) from t1;

# Invalid function body syntax
statement error Expected an expression:, found: EOF at the end
create function add_error(INT, INT) returns int language sql as $$select $1 + $2 +$$;
Expand Down Expand Up @@ -187,9 +236,21 @@ drop function call_regexp_replace;
statement ok
drop function add_sub_wrapper;

statement ok
drop function recursive;

statement ok
drop function foo;

statement ok
drop function recursive_non_recursive;

statement ok
drop function add_sub_types;

statement ok
drop function fib;

# Drop the mock table
statement ok
drop table t1;
Expand Down
2 changes: 1 addition & 1 deletion src/frontend/src/binder/expr/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ impl Binder {
// to the name of the defined sql udf parameters stored in `udf_context`.
// If so, we will treat this bind as an special bind, the actual expression
// stored in `udf_context` will then be bound instead of binding the non-existing column.
if let Some(expr) = self.udf_context.get(&column_name) {
if let Some(expr) = self.udf_context.get_expr(&column_name) {
return self.bind_expr(expr.clone());
}

Expand Down
24 changes: 22 additions & 2 deletions src/frontend/src/binder/expr/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ pub const SYS_FUNCTION_WITHOUT_ARGS: &[&str] = &[
"current_timestamp",
];

/// The global max calling depth for the global counter in `udf_context`
/// To reduce the chance that the current running rw thread
/// be killed by os, the current allowance depth of calling
/// stack is set to `16`.
const SQL_UDF_MAX_CALLING_DEPTH: u32 = 16;

impl Binder {
pub(in crate::binder) fn bind_function(&mut self, f: Function) -> Result<ExprImpl> {
let function_name = match f.name.0.as_slice() {
Expand Down Expand Up @@ -235,6 +241,7 @@ impl Binder {
)
.into());
}

// This represents the current user defined function is `language sql`
let parse_result = risingwave_sqlparser::parser::Parser::parse_sql(
func.body.as_ref().unwrap().as_str(),
Expand All @@ -245,6 +252,7 @@ impl Binder {
// Here we just return the original parse error message
return Err(ErrorCode::InvalidInputSyntax(err).into());
}

debug_assert!(parse_result.is_ok());

// We can safely unwrap here
Expand All @@ -263,7 +271,7 @@ impl Binder {
if self.udf_context.is_empty() {
// The actual inline logic for sql udf
if let Ok(context) = create_udf_context(&args, &Arc::clone(func)) {
self.udf_context = context;
self.udf_context.update_context(context);
} else {
return Err(ErrorCode::InvalidInputSyntax(
"failed to create the `udf_context`, please recheck your function definition and syntax".to_string()
Expand All @@ -277,9 +285,21 @@ impl Binder {
clean_flag = false;
}

// Check for potential recursive calling
if self.udf_context.global_count() >= SQL_UDF_MAX_CALLING_DEPTH {
return Err(ErrorCode::BindError(format!(
"function {} calling stack depth limit exceeded",
&function_name
))
.into());
} else {
// Update the status for the global counter
self.udf_context.incr_global_count();
Comment on lines +296 to +297
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems not right that this counter never decreases?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The counter will be reset to 0 after a success binding to a specific sql udf. (i.e., when cleaning the udf_context)

// Clean the `udf_context` & `udf_recursive_context` after inlining,
// which makes sure the subsequent binding will not be affected
if clean_flag {
    self.udf_context.clear();
}

}

if let Ok(expr) = extract_udf_expression(ast) {
let bind_result = self.bind_expr(expr);
// Clean the `udf_context` after inlining,
// Clean the `udf_context` & `udf_recursive_context` after inlining,
// which makes sure the subsequent binding will not be affected
if clean_flag {
self.udf_context.clear();
Expand Down
2 changes: 1 addition & 1 deletion src/frontend/src/binder/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ impl Binder {
// Note: This is specific to anonymous sql udf, since the
// parameters will be parsed and treated as `Parameter`.
// For detailed explanation, consider checking `bind_column`.
if let Some(expr) = self.udf_context.get(&format!("${index}")) {
if let Some(expr) = self.udf_context.get_expr(&format!("${index}")) {
return self.bind_expr(expr.clone());
}

Expand Down
50 changes: 47 additions & 3 deletions src/frontend/src/binder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,53 @@ pub struct Binder {

param_types: ParameterTypes,

/// The mapping from sql udf parameters to ast expressions
/// The sql udf context that will be used during binding phase
udf_context: UdfContext,
}

#[derive(Clone, Debug)]
pub struct UdfContext {
/// The mapping from `sql udf parameters` to `ast expressions`
/// Note: The expressions are constructed during runtime, correspond to the actual users' input
udf_context: HashMap<String, AstExpr>,
udf_param_context: HashMap<String, AstExpr>,

/// The global counter that records the calling stack depth
/// of the current binding sql udf chain
udf_global_counter: u32,
}

impl UdfContext {
pub fn new() -> Self {
xzhseh marked this conversation as resolved.
Show resolved Hide resolved
Self {
udf_param_context: HashMap::new(),
udf_global_counter: 0,
}
}

pub fn global_count(&self) -> u32 {
self.udf_global_counter
}

pub fn incr_global_count(&mut self) {
self.udf_global_counter += 1;
}

pub fn is_empty(&self) -> bool {
self.udf_param_context.is_empty()
}

pub fn update_context(&mut self, context: HashMap<String, AstExpr>) {
self.udf_param_context = context;
}

pub fn clear(&mut self) {
self.udf_global_counter = 0;
self.udf_param_context.clear();
}

pub fn get_expr(&self, name: &str) -> Option<&AstExpr> {
self.udf_param_context.get(name)
}
}

/// `ParameterTypes` is used to record the types of the parameters during binding. It works
Expand Down Expand Up @@ -220,7 +264,7 @@ impl Binder {
shared_views: HashMap::new(),
included_relations: HashSet::new(),
param_types: ParameterTypes::new(param_types),
udf_context: HashMap::new(),
udf_context: UdfContext::new(),
}
}

Expand Down
9 changes: 0 additions & 9 deletions src/frontend/src/handler/create_sql_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,6 @@ pub async fn handle_create_sql_function(
}
};

// We do NOT allow recursive calling inside sql udf
// Since there does not exist the base case for this definition
if body.contains(format!("{}(", name.real_value()).as_str()) {
return Err(ErrorCode::InvalidInputSyntax(
"recursive definition is forbidden, please recheck your function syntax".to_string(),
)
.into());
}

// Sanity check for link, this must be none with sql udf function
if let Some(CreateFunctionUsing::Link(_)) = params.using {
return Err(ErrorCode::InvalidParameterValue(
Expand Down
Loading