Skip to content

Commit

Permalink
feat(sql-udf): support basic anonymous sql udf (#14139)
Browse files Browse the repository at this point in the history
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
xzhseh and github-actions[bot] authored Jan 5, 2024
1 parent c749f7b commit 75735e2
Show file tree
Hide file tree
Showing 17 changed files with 653 additions and 24 deletions.
192 changes: 192 additions & 0 deletions e2e_test/udf/sql_udf.slt
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
statement ok
SET RW_IMPLICIT_FLUSH TO true;

# Create an anonymous function with double dollar as clause
statement ok
create function add(INT, INT) returns int language sql as $$select $1 + $2$$;

# Create an anonymous function with single quote as clause
statement ok
create function sub(INT, INT) returns int language sql as 'select $1 - $2';

# Create an anonymous function that calls other pre-defined sql udfs
statement ok
create function add_sub_binding() returns int language sql as 'select add(1, 1) + sub(2, 2)';

# Create an anonymous function that calls built-in functions
# Note that double dollar signs should be used otherwise the parsing will fail, as illutrates below
statement ok
create function call_regexp_replace() returns varchar language sql as $$select regexp_replace('💩💩💩💩💩foo🤔️bar亲爱的😭baz这不是爱情❤️‍🔥', 'baz(...)', '这是🥵', 'ic')$$;

statement error Expected end of statement, found: 💩
create function call_regexp_replace() returns varchar language sql as 'select regexp_replace('💩💩💩💩💩foo🤔️bar亲爱的😭baz这不是爱情❤️‍🔥', 'baz(...)', '这是🥵', 'ic')';

# Create an anonymous function with return expression
statement ok
create function add_return(INT, INT) returns int language sql return $1 + $2;

statement ok
create function add_return_binding() returns int language sql return add_return(1, 1) + add_return(1, 1);

# Recursive definition is forbidden
statement error recursive definition is forbidden, please recheck your function syntax
create function recursive(INT, INT) returns int language sql as 'select recursive($1, $2) + recursive($1, $2)';

# Create a wrapper function for `add` & `sub`
statement ok
create function add_sub_wrapper(INT, INT) returns int language sql as 'select add($1, $2) + sub($1, $2) + 114512';

# Call the defined sql udf
query I
select add(1, -1);
----
0

query I
select sub(1, 1);
----
0

query I
select add_sub_binding();
----
2

query III
select add(1, -1), sub(1, 1), add_sub_binding();
----
0 0 2

query I
select add_return(1, 1);
----
2

query I
select add_return_binding();
----
4

query T
select call_regexp_replace();
----
💩💩💩💩💩foo🤔️bar亲爱的😭这是🥵爱情❤️‍🔥

query I
select add_sub_wrapper(1, 1);
----
114514

# Create a mock table
statement ok
create table t1 (c1 INT, c2 INT);

# Insert some data into the mock table
statement ok
insert into t1 values (1, 1), (2, 2), (3, 3), (4, 4), (5, 5);

query III
select sub(c1, c2), c1, c2, add(c1, c2) from t1 order by c1 asc;
----
0 1 1 2
0 2 2 4
0 3 3 6
0 4 4 8
0 5 5 10

query I
select c1, c2, add_return(c1, c2) from t1 order by c1 asc;
----
1 1 2
2 2 4
3 3 6
4 4 8
5 5 10

# Invalid function body syntax
statement error Expected an expression:, found: EOF at the end
create function add_error(INT, INT) returns int language sql as $$select $1 + $2 +$$;

# Multiple type interleaving sql udf
statement ok
create function add_sub(INT, FLOAT, INT) returns float language sql as $$select -$1 + $2 - $3$$;

statement ok
create function add_sub_return(INT, FLOAT, INT) returns float language sql return -$1 + $2 - $3;

# Note: need EXPLICIT type cast in order to call the multiple types interleaving sql udf
query I
select add_sub(1::INT, 5.1415926::FLOAT, 1::INT);
----
3.1415926

# Without EXPLICIT type cast
statement error unsupported function: "add_sub"
select add_sub(1, 3.14, 2);

# Same as above, need EXPLICIT type cast to make the binding works
query I
select add_sub_return(1::INT, 5.1415926::FLOAT, 1::INT);
----
3.1415926

query III
select add(1, -1), sub(1, 1), add_sub(1::INT, 5.1415926::FLOAT, 1::INT);
----
0 0 3.1415926

# Create another mock table
statement ok
create table t2 (c1 INT, c2 FLOAT, c3 INT);

statement ok
insert into t2 values (1, 3.14, 2), (2, 4.44, 5), (20, 10.30, 02);

query IIIIII
select c1, c2, c3, add(c1, c3), sub(c1, c3), add_sub(c1::INT, c2::FLOAT, c3::INT) from t2 order by c1 asc;
----
1 3.14 2 3 -1 0.14000000000000012
2 4.44 5 7 -3 -2.5599999999999996
20 10.3 2 22 18 -11.7

query IIIIII
select c1, c2, c3, add(c1, c3), sub(c1, c3), add_sub_return(c1::INT, c2::FLOAT, c3::INT) from t2 order by c1 asc;
----
1 3.14 2 3 -1 0.14000000000000012
2 4.44 5 7 -3 -2.5599999999999996
20 10.3 2 22 18 -11.7

# Drop the functions
statement ok
drop function add;

statement ok
drop function sub;

statement ok
drop function add_sub_binding;

statement ok
drop function add_sub;

statement ok
drop function add_sub_return;

statement ok
drop function add_return;

statement ok
drop function add_return_binding;

statement ok
drop function call_regexp_replace;

statement ok
drop function add_sub_wrapper;

# Drop the mock table
statement ok
drop table t1;

statement ok
drop table t2;
1 change: 1 addition & 0 deletions proto/catalog.proto
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ message Function {
string language = 7;
string link = 8;
string identifier = 10;
optional string body = 14;

oneof kind {
ScalarFunction scalar = 11;
Expand Down
12 changes: 12 additions & 0 deletions src/frontend/src/binder/expr/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,18 @@ impl Binder {
}
};

// Special check for sql udf
// Note: The check in `bind_column` is to inline the identifiers,
// which, in the context of sql udf, will NOT be perceived as normal
// columns, but the actual named input parameters.
// Thus, we need to figure out if the current "column name" corresponds
// to the name of the defined sql udf parameters stored in `udf_context`.
// If so, we will treat this bind as an special bind, the actual expression
// stored in `udf_context` will then be bound instead of binding the non-existing column.
if let Some(expr) = self.udf_context.get(&column_name) {
return self.bind_expr(expr.clone());
}

match self
.context
.get_column_binding_indices(&table_name, &column_name)
Expand Down
Loading

0 comments on commit 75735e2

Please sign in to comment.