Skip to content

Commit

Permalink
feat(sql-udf): support named sql udf (#14806)
Browse files Browse the repository at this point in the history
  • Loading branch information
xzhseh authored Jan 30, 2024
1 parent 92bfe50 commit 66adcdd
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 21 deletions.
93 changes: 91 additions & 2 deletions e2e_test/udf/sql_udf.slt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,35 @@ create function sub(INT, INT) returns int language sql as 'select $1 - $2';
statement ok
create function add_sub_binding() returns int language sql as 'select add(1, 1) + sub(2, 2)';

# Create a named sql udf
statement ok
create function add_named(a INT, b INT) returns int language sql as 'select a + b';

# Create another named sql udf
statement ok
create function sub_named(a INT, b INT) returns int language sql as 'select a - b';

# Mixed parameter with named / anonymous parameters
statement ok
create function add_sub_mix(INT, a INT, INT) returns int language sql as 'select $1 - a + $3';

# Mixed parameter with calling inner sql udfs
# statement ok
# create function add_sub_mix_wrapper(INT, a INT, INT) returns int language sql as 'select add($1, a) + a + sub(a, $3)';

# Named sql udf with corner case
statement ok
create function corner_case(INT, a INT, INT) returns varchar language sql as $$select '$1 + a + $3'$$;

# Named sql udf with invalid parameter in body definition
# Will be rejected at creation time
statement error failed to find named parameter aa
create function unknown_parameter(a INT) returns int language sql as 'select a + aa + a';

# Call anonymous sql udf inside named sql udf
statement ok
create function add_named_wrapper(a INT, b INT) returns int language sql as 'select add(a, b)';

# Create an anonymous function that calls built-in functions
# Note that double dollar signs should be used otherwise the parsing will fail, as illutrates below
statement ok
Expand Down Expand Up @@ -94,7 +123,7 @@ select type_match(114514);
----
$1 + 114514 + $1

# Call the defined sql udf
# Call the defined anonymous sql udfs
query I
select add(1, -1);
----
Expand All @@ -105,6 +134,32 @@ select sub(1, 1);
----
0

# Call the defined named sql udfs
query I
select add_named(1, -1);
----
0

query I
select sub_named(1, 1);
----
0

query I
select add_sub_mix(1, 2, 3);
----
2

query T
select corner_case(1, 2, 3);
----
$1 + a + $3

query I
select add_named_wrapper(1, -1);
----
0

query I
select add_sub_binding();
----
Expand Down Expand Up @@ -158,14 +213,30 @@ select print_add_one(1), print_add_one(114513), print_add_two(2);
----
2 114514 4

# Create a mock table
# Create a mock table for anonymous sql udf
statement ok
create table t1 (c1 INT, c2 INT);

# Create a mock table for named sql udf
statement ok
create table t3 (a INT, b INT);

# Insert some data into the mock table
statement ok
insert into t1 values (1, 1), (2, 2), (3, 3), (4, 4), (5, 5);

statement ok
insert into t3 values (1, 1), (2, 2), (3, 3), (4, 4), (5, 5);

query I
select add_named(a, b) from t3 order by a asc;
----
2
4
6
8
10

query III
select sub(c1, c2), c1, c2, add(c1, c2) from t1 order by c1 asc;
----
Expand Down Expand Up @@ -303,6 +374,21 @@ drop function print_add_two;
statement ok
drop function regexp_replace_wrapper;

statement ok
drop function corner_case;

statement ok
drop function add_named;

statement ok
drop function sub_named;

statement ok
drop function add_sub_mix;

statement ok
drop function add_named_wrapper;

statement ok
drop function type_match;

Expand All @@ -312,3 +398,6 @@ drop table t1;

statement ok
drop table t2;

statement ok
drop table t3;
2 changes: 2 additions & 0 deletions src/frontend/src/binder/expr/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ impl Binder {
// The reason that we directly return error here,
// is because during a valid sql udf binding,
// there will not exist any column identifiers
// And invalid cases should already be caught
// during semantic check phase
return Err(ErrorCode::BindError(format!(
"failed to find named parameter {column_name}"
))
Expand Down
31 changes: 17 additions & 14 deletions src/frontend/src/binder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,26 +219,29 @@ impl UdfContext {
Ok(expr)
}

/// TODO: add name related logic
/// NOTE: need to think of a way to prevent naming conflict
/// e.g., when existing column names conflict with parameter names in sql udf
pub fn create_udf_context(
args: &[FunctionArg],
_catalog: &Arc<FunctionCatalog>,
catalog: &Arc<FunctionCatalog>,
) -> Result<HashMap<String, AstExpr>> {
let mut ret: HashMap<String, AstExpr> = HashMap::new();
for (i, current_arg) in args.iter().enumerate() {
if let FunctionArg::Unnamed(arg) = current_arg {
let FunctionArgExpr::Expr(e) = arg else {
return Err(ErrorCode::InvalidInputSyntax("invalid syntax".to_string()).into());
};
// if catalog.arg_names.is_some() {
// todo!()
// }
ret.insert(format!("${}", i + 1), e.clone());
continue;
match current_arg {
FunctionArg::Unnamed(arg) => {
let FunctionArgExpr::Expr(e) = arg else {
return Err(
ErrorCode::InvalidInputSyntax("invalid syntax".to_string()).into()
);
};
if catalog.arg_names[i].is_empty() {
ret.insert(format!("${}", i + 1), e.clone());
} else {
// The index mapping here is accurate
// So that we could directly use the index
ret.insert(catalog.arg_names[i].clone(), e.clone());
}
}
_ => return Err(ErrorCode::InvalidInputSyntax("invalid syntax".to_string()).into()),
}
return Err(ErrorCode::InvalidInputSyntax("invalid syntax".to_string()).into());
}
Ok(ret)
}
Expand Down
27 changes: 22 additions & 5 deletions src/frontend/src/handler/create_sql_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,24 @@ use crate::expr::{Expr, ExprImpl, Literal};
use crate::{bind_data_type, Binder};

/// Create a mock `udf_context`, which is used for semantic check
fn create_mock_udf_context(arg_types: Vec<DataType>) -> HashMap<String, ExprImpl> {
(1..=arg_types.len())
fn create_mock_udf_context(
arg_types: Vec<DataType>,
arg_names: Vec<String>,
) -> HashMap<String, ExprImpl> {
let mut ret: HashMap<String, ExprImpl> = (1..=arg_types.len())
.map(|i| {
let mock_expr =
ExprImpl::Literal(Box::new(Literal::new(None, arg_types[i - 1].clone())));
(format!("${i}"), mock_expr.clone())
(format!("${i}"), mock_expr)
})
.collect()
.collect();

for (i, arg_name) in arg_names.into_iter().enumerate() {
let mock_expr = ExprImpl::Literal(Box::new(Literal::new(None, arg_types[i].clone())));
ret.insert(arg_name, mock_expr);
}

ret
}

pub async fn handle_create_sql_function(
Expand Down Expand Up @@ -173,7 +183,12 @@ pub async fn handle_create_sql_function(

binder
.udf_context_mut()
.update_context(create_mock_udf_context(arg_types.clone()));
.update_context(create_mock_udf_context(
arg_types.clone(),
arg_names.clone(),
));

binder.set_udf_binding_flag();

if let Ok(expr) = UdfContext::extract_udf_expression(ast) {
match binder.bind_expr(expr) {
Expand Down Expand Up @@ -202,6 +217,8 @@ pub async fn handle_create_sql_function(
)
.into());
}

binder.unset_udf_binding_flag();
}

// Create the actual function, will be stored in function catalog
Expand Down

0 comments on commit 66adcdd

Please sign in to comment.