Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(sql-udf): correctly binding arguments for inner calling sql udfs #14422

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions e2e_test/udf/sql_udf.slt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ create function add_sub_binding() returns int language sql as 'select add(1, 1)
statement ok
create function call_regexp_replace() returns varchar language sql as $$select regexp_replace('💩💩💩💩💩foo🤔️bar亲爱的😭baz这不是爱情❤️‍🔥', 'baz(...)', '这是🥵', 'ic')$$;

statement ok
create function regexp_replace_wrapper(varchar) returns varchar language sql as $$select regexp_replace($1, 'baz(...)', '这是🥵', 'ic')$$;

statement error Expected end of statement, found: 💩
create function call_regexp_replace() returns varchar language sql as 'select regexp_replace('💩💩💩💩💩foo🤔️bar亲爱的😭baz这不是爱情❤️‍🔥', 'baz(...)', '这是🥵', 'ic')';

Expand All @@ -36,6 +39,16 @@ create function recursive(INT, INT) returns int language sql as 'select recursiv
statement ok
create function add_sub_wrapper(INT, INT) returns int language sql as 'select add($1, $2) + sub($1, $2) + 114512';

statement ok
create function print(INT) returns int language sql as 'select $1';

# Adjust the input value of the calling function (i.e., `print` here) with the actual input parameter
statement ok
create function print_add_one(INT) returns int language sql as 'select print($1 + 1)';

statement ok
create function print_add_two(INT) returns int language sql as 'select print($1 + $1)';

# Call the defined sql udf
query I
select add(1, -1);
Expand Down Expand Up @@ -72,11 +85,21 @@ select call_regexp_replace();
----
💩💩💩💩💩foo🤔️bar亲爱的😭这是🥵爱情❤️‍🔥

query T
select regexp_replace_wrapper('💩💩💩💩💩foo🤔️bar亲爱的😭baz这不是爱情❤️‍🔥');
----
💩💩💩💩💩foo🤔️bar亲爱的😭这是🥵爱情❤️‍🔥

query I
select add_sub_wrapper(1, 1);
----
114514

query III
select print_add_one(1), print_add_one(114513), print_add_two(2);
----
2 114514 4

# Create a mock table
statement ok
create table t1 (c1 INT, c2 INT);
Expand Down Expand Up @@ -187,6 +210,18 @@ drop function call_regexp_replace;
statement ok
drop function add_sub_wrapper;

statement ok
drop function print;

statement ok
drop function print_add_one;

statement ok
drop function print_add_two;

statement ok
drop function regexp_replace_wrapper;

statement ok
drop function add_sub_types;

Expand Down
2 changes: 1 addition & 1 deletion src/frontend/src/binder/expr/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ impl Binder {
// If so, we will treat this bind as an special bind, the actual expression
// stored in `udf_context` will then be bound instead of binding the non-existing column.
if let Some(expr) = self.udf_context.get(&column_name) {
return self.bind_expr(expr.clone());
return Ok(expr.clone());
}

match self
Expand Down
62 changes: 41 additions & 21 deletions src/frontend/src/binder/expr/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -245,36 +245,53 @@ impl Binder {
// Here we just return the original parse error message
return Err(ErrorCode::InvalidInputSyntax(err).into());
}

debug_assert!(parse_result.is_ok());

// We can safely unwrap here
let ast = parse_result.unwrap();

let mut clean_flag = true;
let mut clean_flag = false;

// We need to check if the `udf_context` is empty first, consider the following example:
// - create function add(INT, INT) returns int language sql as 'select $1 + $2';
// - create function add_wrapper(INT, INT) returns int language sql as 'select add($1, $2)';
// - select add_wrapper(1, 1);
// When binding `add($1, $2)` in `add_wrapper`, the input args are [$1, $2] instead of
// the original [1, 1], thus we need to check `udf_context` to see if the input
// args already exist in the context. If so, we do NOT need to create the context again.
// Otherwise the current `udf_context` will be corrupted.
// We need to check if the `udf_context` is empty first,
// If so, we will clear the `udf_context` after binding.
// Since this is the root (top-most) binding for sql udf.
// Otherwise we need to restore the context later, or the
// original `udf_context` will be corrupted.
if self.udf_context.is_empty() {
// The actual inline logic for sql udf
if let Ok(context) = create_udf_context(&args, &Arc::clone(func)) {
self.udf_context = context;
} else {
return Err(ErrorCode::InvalidInputSyntax(
"failed to create the `udf_context`, please recheck your function definition and syntax".to_string()
)
.into());
clean_flag = true;
}

// Stash the current `udf_context`
let prev_context = self.udf_context.clone();

// The actual inline logic for sql udf
// Note that we will always create new udf context for each sql udf
if let Ok(context) = create_udf_context(&args, &Arc::clone(func)) {
let mut udf_context = HashMap::new();
for (c, e) in context {
// Note that we need to bind the args before actual delve in the function body
// This will update the context in the subsequent inner calling function
// e.g.,
// - create function print(INT) returns int language sql as 'select $1';
// - create function print_add_one(INT) returns int language sql as 'select print($1 + 1)';
// - select print_add_one(1); # The result should be 2 instead of 1.
// Without the pre-binding here, the ($1 + 1) will not be correctly populated,
// causing the result to always be 1.
let Ok(e) = self.bind_expr(e) else {
return Err(ErrorCode::BindError(format!(
"failed to bind the argument, please recheck your syntax"
))
.into());
};
udf_context.insert(c, e);
}
self.udf_context = udf_context;
} else {
// If the `udf_context` is not empty, this means the current binding
// function is not the root binding sql udf, thus we should NOT
// clean the context after binding.
clean_flag = false;
return Err(ErrorCode::InvalidInputSyntax(
"failed to create the `udf_context`, please recheck your function definition and syntax".to_string()
)
.into());
}

if let Ok(expr) = extract_udf_expression(ast) {
Expand All @@ -283,6 +300,9 @@ impl Binder {
// which makes sure the subsequent binding will not be affected
if clean_flag {
self.udf_context.clear();
} else {
// Restore context information for subsequent binding
self.udf_context = prev_context;
}
return bind_result;
} else {
Expand Down
2 changes: 1 addition & 1 deletion src/frontend/src/binder/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ impl Binder {
// parameters will be parsed and treated as `Parameter`.
// For detailed explanation, consider checking `bind_column`.
if let Some(expr) = self.udf_context.get(&format!("${index}")) {
return self.bind_expr(expr.clone());
return Ok(expr.clone());
}

Ok(Parameter::new(index, self.param_types.clone()).into())
Expand Down
7 changes: 4 additions & 3 deletions src/frontend/src/binder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use risingwave_common::error::Result;
use risingwave_common::session_config::{ConfigMap, SearchPath};
use risingwave_common::types::DataType;
use risingwave_common::util::iter_util::ZipEqDebug;
use risingwave_sqlparser::ast::{Expr as AstExpr, Statement};
use risingwave_sqlparser::ast::Statement;

mod bind_context;
mod bind_param;
Expand Down Expand Up @@ -59,6 +59,7 @@ pub use values::BoundValues;
use crate::catalog::catalog_service::CatalogReadGuard;
use crate::catalog::schema_catalog::SchemaCatalog;
use crate::catalog::{CatalogResult, TableId, ViewId};
use crate::expr::ExprImpl;
use crate::session::{AuthContext, SessionImpl};

pub type ShareId = usize;
Expand Down Expand Up @@ -116,9 +117,9 @@ pub struct Binder {

param_types: ParameterTypes,

/// The mapping from sql udf parameters to ast expressions
/// The mapping from `sql udf parameters` to `ExprImpl` generated from ast expressions
/// Note: The expressions are constructed during runtime, correspond to the actual users' input
udf_context: HashMap<String, AstExpr>,
udf_context: HashMap<String, ExprImpl>,
}

/// `ParameterTypes` is used to record the types of the parameters during binding. It works
Expand Down
Loading