Skip to content

Commit

Permalink
fix(sql-udf): correctly binding arguments for inner calling sql udfs
Browse files Browse the repository at this point in the history
  • Loading branch information
xzhseh committed Jan 12, 2024
1 parent 1afb0ec commit fbe0a41
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 30 deletions.
35 changes: 35 additions & 0 deletions e2e_test/udf/sql_udf.slt
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,19 @@ select fib(100);
statement error function fib calling stack depth limit exceeded
create materialized view foo_mv as select fib(100);

statement ok
create function regexp_replace_wrapper(varchar) returns varchar language sql as $$select regexp_replace($1, 'baz(...)', '这是🥵', 'ic')$$;

statement ok
create function print(INT) returns int language sql as 'select $1';

# Adjust the input value of the calling function (i.e., `print` here) with the actual input parameter
statement ok
create function print_add_one(INT) returns int language sql as 'select print($1 + 1)';

statement ok
create function print_add_two(INT) returns int language sql as 'select print($1 + $1)';

# Call the defined sql udf
query I
select add(1, -1);
Expand Down Expand Up @@ -100,6 +113,11 @@ select call_regexp_replace();
----
💩💩💩💩💩foo🤔️bar亲爱的😭这是🥵爱情❤️‍🔥

query T
select regexp_replace_wrapper('💩💩💩💩💩foo🤔️bar亲爱的😭baz这不是爱情❤️‍🔥');
----
💩💩💩💩💩foo🤔️bar亲爱的😭这是🥵爱情❤️‍🔥

query T
select foo(114514);
----
Expand All @@ -118,6 +136,11 @@ select add_sub_wrapper(1, 1);
----
114514

query III
select print_add_one(1), print_add_one(114513), print_add_two(2);
----
2 114514 4

# Create a mock table
statement ok
create table t1 (c1 INT, c2 INT);
Expand Down Expand Up @@ -251,6 +274,18 @@ drop function add_sub_types;
statement ok
drop function fib;

statement ok
drop function print;

statement ok
drop function print_add_one;

statement ok
drop function print_add_two;

statement ok
drop function regexp_replace_wrapper;

# Drop the mock table
statement ok
drop table t1;
Expand Down
2 changes: 1 addition & 1 deletion src/frontend/src/binder/expr/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ impl Binder {
// If so, we will treat this bind as an special bind, the actual expression
// stored in `udf_context` will then be bound instead of binding the non-existing column.
if let Some(expr) = self.udf_context.get_expr(&column_name) {
return self.bind_expr(expr.clone());
return Ok(expr.clone());
}

match self
Expand Down
66 changes: 43 additions & 23 deletions src/frontend/src/binder/expr/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -258,31 +258,47 @@ impl Binder {
// We can safely unwrap here
let ast = parse_result.unwrap();

let mut clean_flag = true;

// We need to check if the `udf_context` is empty first, consider the following example:
// - create function add(INT, INT) returns int language sql as 'select $1 + $2';
// - create function add_wrapper(INT, INT) returns int language sql as 'select add($1, $2)';
// - select add_wrapper(1, 1);
// When binding `add($1, $2)` in `add_wrapper`, the input args are [$1, $2] instead of
// the original [1, 1], thus we need to check `udf_context` to see if the input
// args already exist in the context. If so, we do NOT need to create the context again.
// Otherwise the current `udf_context` will be corrupted.
let mut clean_flag = false;

// We need to check if the `udf_context` is empty first,
// If so, we will clear the `udf_context` after binding.
// Since this is the root (top-most) binding for sql udf.
// Otherwise we need to restore the context later, or the
// original `udf_context` will be corrupted.
if self.udf_context.is_empty() {
// The actual inline logic for sql udf
if let Ok(context) = create_udf_context(&args, &Arc::clone(func)) {
self.udf_context.update_context(context);
} else {
return Err(ErrorCode::InvalidInputSyntax(
"failed to create the `udf_context`, please recheck your function definition and syntax".to_string()
)
.into());
clean_flag = true;
}

// Stash the current `udf_context`
let stashed_udf_context = self.udf_context.get_context();

// The actual inline logic for sql udf
// Note that we will always create new udf context for each sql udf
if let Ok(context) = create_udf_context(&args, &Arc::clone(func)) {
let mut udf_context = HashMap::new();
for (c, e) in context {
// Note that we need to bind the args before actual delve in the function body
// This will update the context in the subsequent inner calling function
// e.g.,
// - create function print(INT) returns int language sql as 'select $1';
// - create function print_add_one(INT) returns int language sql as 'select print($1 + 1)';
// - select print_add_one(1); # The result should be 2 instead of 1.
// Without the pre-binding here, the ($1 + 1) will not be correctly populated,
// causing the result to always be 1.
let Ok(e) = self.bind_expr(e) else {
return Err(ErrorCode::BindError(format!(
"failed to bind the argument, please recheck the syntax"
))
.into());
};
udf_context.insert(c, e);
}
self.udf_context.update_context(udf_context);
} else {
// If the `udf_context` is not empty, this means the current binding
// function is not the root binding sql udf, thus we should NOT
// clean the context after binding.
clean_flag = false;
return Err(ErrorCode::InvalidInputSyntax(
"failed to create the `udf_context`, please recheck your function definition and syntax".to_string()
)
.into());
}

// Check for potential recursive calling
Expand All @@ -299,10 +315,14 @@ impl Binder {

if let Ok(expr) = extract_udf_expression(ast) {
let bind_result = self.bind_expr(expr);
// Clean the `udf_context` & `udf_recursive_context` after inlining,
// Clean the `udf_context` after inlining,
// which makes sure the subsequent binding will not be affected
if clean_flag {
self.udf_context.clear();
} else {
// Restore context information for subsequent binding
// Since this is not the root binding sql udf
self.udf_context.update_context(stashed_udf_context);
}
return bind_result;
} else {
Expand Down
2 changes: 1 addition & 1 deletion src/frontend/src/binder/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ impl Binder {
// parameters will be parsed and treated as `Parameter`.
// For detailed explanation, consider checking `bind_column`.
if let Some(expr) = self.udf_context.get_expr(&format!("${index}")) {
return self.bind_expr(expr.clone());
return Ok(expr.clone());
}

Ok(Parameter::new(index, self.param_types.clone()).into())
Expand Down
15 changes: 10 additions & 5 deletions src/frontend/src/binder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use risingwave_common::error::Result;
use risingwave_common::session_config::{ConfigMap, SearchPath};
use risingwave_common::types::DataType;
use risingwave_common::util::iter_util::ZipEqDebug;
use risingwave_sqlparser::ast::{Expr as AstExpr, Statement};
use risingwave_sqlparser::ast::Statement;

mod bind_context;
mod bind_param;
Expand Down Expand Up @@ -59,6 +59,7 @@ pub use values::BoundValues;
use crate::catalog::catalog_service::CatalogReadGuard;
use crate::catalog::schema_catalog::SchemaCatalog;
use crate::catalog::{CatalogResult, TableId, ViewId};
use crate::expr::ExprImpl;
use crate::session::{AuthContext, SessionImpl};

pub type ShareId = usize;
Expand Down Expand Up @@ -122,9 +123,9 @@ pub struct Binder {

#[derive(Clone, Debug, Default)]
pub struct UdfContext {
/// The mapping from `sql udf parameters` to `ast expressions`
/// The mapping from `sql udf parameters` to a bound `ExprImpl` generated from `ast expressions`
/// Note: The expressions are constructed during runtime, correspond to the actual users' input
udf_param_context: HashMap<String, AstExpr>,
udf_param_context: HashMap<String, ExprImpl>,

/// The global counter that records the calling stack depth
/// of the current binding sql udf chain
Expand All @@ -151,7 +152,7 @@ impl UdfContext {
self.udf_param_context.is_empty()
}

pub fn update_context(&mut self, context: HashMap<String, AstExpr>) {
pub fn update_context(&mut self, context: HashMap<String, ExprImpl>) {
self.udf_param_context = context;
}

Expand All @@ -160,9 +161,13 @@ impl UdfContext {
self.udf_param_context.clear();
}

pub fn get_expr(&self, name: &str) -> Option<&AstExpr> {
pub fn get_expr(&self, name: &str) -> Option<&ExprImpl> {
self.udf_param_context.get(name)
}

pub fn get_context(&self) -> HashMap<String, ExprImpl> {
self.udf_param_context.clone()
}
}

/// `ParameterTypes` is used to record the types of the parameters during binding. It works
Expand Down

0 comments on commit fbe0a41

Please sign in to comment.