From 66adcdd15d199f1a4cd69afcd0e0d3de8f785b88 Mon Sep 17 00:00:00 2001 From: Zihao Xu Date: Tue, 30 Jan 2024 15:17:01 -0500 Subject: [PATCH] feat(sql-udf): support named sql udf (#14806) --- e2e_test/udf/sql_udf.slt | 93 ++++++++++++++++++- src/frontend/src/binder/expr/column.rs | 2 + src/frontend/src/binder/mod.rs | 31 ++++--- .../src/handler/create_sql_function.rs | 27 +++++- 4 files changed, 132 insertions(+), 21 deletions(-) diff --git a/e2e_test/udf/sql_udf.slt b/e2e_test/udf/sql_udf.slt index 79cb40f2e54cc..758ec43ca53fc 100644 --- a/e2e_test/udf/sql_udf.slt +++ b/e2e_test/udf/sql_udf.slt @@ -13,6 +13,35 @@ create function sub(INT, INT) returns int language sql as 'select $1 - $2'; statement ok create function add_sub_binding() returns int language sql as 'select add(1, 1) + sub(2, 2)'; +# Create a named sql udf +statement ok +create function add_named(a INT, b INT) returns int language sql as 'select a + b'; + +# Create another named sql udf +statement ok +create function sub_named(a INT, b INT) returns int language sql as 'select a - b'; + +# Mixed parameter with named / anonymous parameters +statement ok +create function add_sub_mix(INT, a INT, INT) returns int language sql as 'select $1 - a + $3'; + +# Mixed parameter with calling inner sql udfs +# statement ok +# create function add_sub_mix_wrapper(INT, a INT, INT) returns int language sql as 'select add($1, a) + a + sub(a, $3)'; + +# Named sql udf with corner case +statement ok +create function corner_case(INT, a INT, INT) returns varchar language sql as $$select '$1 + a + $3'$$; + +# Named sql udf with invalid parameter in body definition +# Will be rejected at creation time +statement error failed to find named parameter aa +create function unknown_parameter(a INT) returns int language sql as 'select a + aa + a'; + +# Call anonymous sql udf inside named sql udf +statement ok +create function add_named_wrapper(a INT, b INT) returns int language sql as 'select add(a, b)'; + # Create an anonymous function that calls built-in functions # Note that double dollar signs should be used otherwise the parsing will fail, as illutrates below statement ok @@ -94,7 +123,7 @@ select type_match(114514); ---- $1 + 114514 + $1 -# Call the defined sql udf +# Call the defined anonymous sql udfs query I select add(1, -1); ---- @@ -105,6 +134,32 @@ select sub(1, 1); ---- 0 +# Call the defined named sql udfs +query I +select add_named(1, -1); +---- +0 + +query I +select sub_named(1, 1); +---- +0 + +query I +select add_sub_mix(1, 2, 3); +---- +2 + +query T +select corner_case(1, 2, 3); +---- +$1 + a + $3 + +query I +select add_named_wrapper(1, -1); +---- +0 + query I select add_sub_binding(); ---- @@ -158,14 +213,30 @@ select print_add_one(1), print_add_one(114513), print_add_two(2); ---- 2 114514 4 -# Create a mock table +# Create a mock table for anonymous sql udf statement ok create table t1 (c1 INT, c2 INT); +# Create a mock table for named sql udf +statement ok +create table t3 (a INT, b INT); + # Insert some data into the mock table statement ok insert into t1 values (1, 1), (2, 2), (3, 3), (4, 4), (5, 5); +statement ok +insert into t3 values (1, 1), (2, 2), (3, 3), (4, 4), (5, 5); + +query I +select add_named(a, b) from t3 order by a asc; +---- +2 +4 +6 +8 +10 + query III select sub(c1, c2), c1, c2, add(c1, c2) from t1 order by c1 asc; ---- @@ -303,6 +374,21 @@ drop function print_add_two; statement ok drop function regexp_replace_wrapper; +statement ok +drop function corner_case; + +statement ok +drop function add_named; + +statement ok +drop function sub_named; + +statement ok +drop function add_sub_mix; + +statement ok +drop function add_named_wrapper; + statement ok drop function type_match; @@ -312,3 +398,6 @@ drop table t1; statement ok drop table t2; + +statement ok +drop table t3; \ No newline at end of file diff --git a/src/frontend/src/binder/expr/column.rs b/src/frontend/src/binder/expr/column.rs index dbee0b0708cb6..9fb17c6e43520 100644 --- a/src/frontend/src/binder/expr/column.rs +++ b/src/frontend/src/binder/expr/column.rs @@ -52,6 +52,8 @@ impl Binder { // The reason that we directly return error here, // is because during a valid sql udf binding, // there will not exist any column identifiers + // And invalid cases should already be caught + // during semantic check phase return Err(ErrorCode::BindError(format!( "failed to find named parameter {column_name}" )) diff --git a/src/frontend/src/binder/mod.rs b/src/frontend/src/binder/mod.rs index 978074e7455e8..e726d7dc01d4c 100644 --- a/src/frontend/src/binder/mod.rs +++ b/src/frontend/src/binder/mod.rs @@ -219,26 +219,29 @@ impl UdfContext { Ok(expr) } - /// TODO: add name related logic - /// NOTE: need to think of a way to prevent naming conflict - /// e.g., when existing column names conflict with parameter names in sql udf pub fn create_udf_context( args: &[FunctionArg], - _catalog: &Arc, + catalog: &Arc, ) -> Result> { let mut ret: HashMap = HashMap::new(); for (i, current_arg) in args.iter().enumerate() { - if let FunctionArg::Unnamed(arg) = current_arg { - let FunctionArgExpr::Expr(e) = arg else { - return Err(ErrorCode::InvalidInputSyntax("invalid syntax".to_string()).into()); - }; - // if catalog.arg_names.is_some() { - // todo!() - // } - ret.insert(format!("${}", i + 1), e.clone()); - continue; + match current_arg { + FunctionArg::Unnamed(arg) => { + let FunctionArgExpr::Expr(e) = arg else { + return Err( + ErrorCode::InvalidInputSyntax("invalid syntax".to_string()).into() + ); + }; + if catalog.arg_names[i].is_empty() { + ret.insert(format!("${}", i + 1), e.clone()); + } else { + // The index mapping here is accurate + // So that we could directly use the index + ret.insert(catalog.arg_names[i].clone(), e.clone()); + } + } + _ => return Err(ErrorCode::InvalidInputSyntax("invalid syntax".to_string()).into()), } - return Err(ErrorCode::InvalidInputSyntax("invalid syntax".to_string()).into()); } Ok(ret) } diff --git a/src/frontend/src/handler/create_sql_function.rs b/src/frontend/src/handler/create_sql_function.rs index ed77cd5e61189..4eaa78f82533e 100644 --- a/src/frontend/src/handler/create_sql_function.rs +++ b/src/frontend/src/handler/create_sql_function.rs @@ -33,14 +33,24 @@ use crate::expr::{Expr, ExprImpl, Literal}; use crate::{bind_data_type, Binder}; /// Create a mock `udf_context`, which is used for semantic check -fn create_mock_udf_context(arg_types: Vec) -> HashMap { - (1..=arg_types.len()) +fn create_mock_udf_context( + arg_types: Vec, + arg_names: Vec, +) -> HashMap { + let mut ret: HashMap = (1..=arg_types.len()) .map(|i| { let mock_expr = ExprImpl::Literal(Box::new(Literal::new(None, arg_types[i - 1].clone()))); - (format!("${i}"), mock_expr.clone()) + (format!("${i}"), mock_expr) }) - .collect() + .collect(); + + for (i, arg_name) in arg_names.into_iter().enumerate() { + let mock_expr = ExprImpl::Literal(Box::new(Literal::new(None, arg_types[i].clone()))); + ret.insert(arg_name, mock_expr); + } + + ret } pub async fn handle_create_sql_function( @@ -173,7 +183,12 @@ pub async fn handle_create_sql_function( binder .udf_context_mut() - .update_context(create_mock_udf_context(arg_types.clone())); + .update_context(create_mock_udf_context( + arg_types.clone(), + arg_names.clone(), + )); + + binder.set_udf_binding_flag(); if let Ok(expr) = UdfContext::extract_udf_expression(ast) { match binder.bind_expr(expr) { @@ -202,6 +217,8 @@ pub async fn handle_create_sql_function( ) .into()); } + + binder.unset_udf_binding_flag(); } // Create the actual function, will be stored in function catalog