Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(sql-udf): support basic anonymous sql udf #14139

Merged
merged 28 commits into from
Jan 5, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
6e3f5d4
feat(sql-udf): support basic sql udf function
xzhseh Dec 21, 2023
bf57014
support anonymous sql udf
xzhseh Dec 26, 2023
16780b4
remove redundant code
xzhseh Dec 26, 2023
fa8fc7a
add basic test for anonymous sql udf
xzhseh Dec 26, 2023
5b5c82d
add more test cases for parser & sql_udf
xzhseh Dec 26, 2023
456d6fd
Merge branch 'main' into xzhseh/sql-udf
xzhseh Dec 27, 2023
94ebbda
feat(sql-udf): support basic anonymous sql udf return expression (#14…
xzhseh Dec 28, 2023
22ca504
Merge branch 'main' into xzhseh/sql-udf
xzhseh Dec 28, 2023
7084a6c
adjust test cases for some syntax when creating sql udfs
xzhseh Dec 28, 2023
fa0c44b
remove panic code && refactor some parts && add more comments
xzhseh Dec 29, 2023
2dde82e
add more test cases
xzhseh Dec 29, 2023
831a7b2
fix format && remove unnecessary file
xzhseh Dec 29, 2023
d5f6b7f
ignore language specification for general udf
xzhseh Dec 29, 2023
7e821f2
remove unused type checker
xzhseh Dec 29, 2023
8074636
forbid recursive definition
xzhseh Dec 29, 2023
44f8d46
refactor some parts
xzhseh Dec 29, 2023
a356566
fix format
xzhseh Dec 29, 2023
1c9b8ab
support nested sql udf calling
xzhseh Dec 29, 2023
2618a0b
fix check
xzhseh Dec 29, 2023
922c1e7
fix check
xzhseh Jan 1, 2024
cfb0a62
Merge branch 'main' into xzhseh/sql-udf
xzhseh Jan 1, 2024
9ccaa07
Merge branch 'main' into xzhseh/sql-udf
xzhseh Jan 2, 2024
c6d677b
update copyright
xzhseh Jan 2, 2024
88b00f1
update copyright
xzhseh Jan 2, 2024
53e53ff
remove unnecessary check & refactor
xzhseh Jan 2, 2024
d6c3d0d
feat(udf): add `body` field for udf body definition (#14300)
xzhseh Jan 2, 2024
199f04f
Merge branch 'main' into xzhseh/sql-udf
xzhseh Jan 2, 2024
d601ac8
Merge branch 'main' into xzhseh/sql-udf
xzhseh Jan 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions e2e_test/udf/sql_udf.slt
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: Can prepared statements work correctly? And are there any tests?

Copy link
Member

@BugenZhao BugenZhao Jan 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be surprising but I couldn't find any e2e tests for the prepared statement itself. 😕

#14141

Copy link
Contributor Author

@xzhseh xzhseh Jan 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can prepared statements work correctly?

I think so, under current implementation the check of udf_context will be immediately cleared after bind_function, which ensure it will not interleave with the parameter bindings of prepare-related statements. (e.g., triggered by third party drivers, etc)

Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
statement ok
SET RW_IMPLICIT_FLUSH TO true;

# Create an anonymous function with double dollar as clause
statement ok
create function add(INT, INT) returns int as $$select $1 + $2$$ language sql;

# Create an anonymous function with single quote as clause
statement ok
create function sub(INT, INT) returns int as 'select $1 - $2' language sql;

# Currently we can only support constant calling convention
xzhseh marked this conversation as resolved.
Show resolved Hide resolved
statement ok
create function add_sub_binding() returns int as 'select add(1, 1) + sub(2, 2)' language sql;

# Call the defined sql udf
query I
select add(1, -1);
----
0

query I
select sub(1, 1);
----
0

query I
select add_sub_binding();
----
2

query III
select add(1, -1), sub(1, 1), add_sub_binding();
----
0 0 2

# Create a mock table
statement ok
create table t1 (c1 INT, c2 INT);

# Insert some data into the mock table
statement ok
insert into t1 values (1, 1), (2, 2), (3, 3), (4, 4), (5, 5);

query III
select sub(c1, c2), c1, c2, add(c1, c2) from t1 order by c1 asc;
----
0 1 1 2
0 2 2 4
0 3 3 6
0 4 4 8
0 5 5 10

# Invalid function body syntax
statement error
create function add_error(INT, INT) returns int as $$select $1 + $2 +$$ language sql;
xzhseh marked this conversation as resolved.
Show resolved Hide resolved

# Multiple type interleaving sql udf
statement ok
create function add_sub(INT, FLOAT, INT) returns FLOAT as $$select -$1 + $2 - $3$$ language sql;

# Note: need EXPLICIT type cast in order to call the multiple types interleaving sql udf
xzhseh marked this conversation as resolved.
Show resolved Hide resolved
query I
select add_sub(1::INT, 5.1415926::FLOAT, 1::INT);
----
3.1415926

query III
select add(1, -1), sub(1, 1), add_sub(1::INT, 5.1415926::FLOAT, 1::INT);
----
0 0 3.1415926

# Create another mock table
statement ok
create table t2 (c1 INT, c2 FLOAT, c3 INT);

statement ok
insert into t2 values (1, 3.14, 2), (2, 4.44, 5), (20, 10.30, 02);

query IIIII
select c1, c2, c3, add(c1, c3), sub(c1, c3), add_sub(c1::INT, c2::FLOAT, c3::INT) from t2 order by c1 asc;
----
1 3.14 2 3 -1 0.14000000000000012
2 4.44 5 7 -3 -2.5599999999999996
20 10.3 2 22 18 -11.7

# Drop the functions
statement ok
drop function add;

statement ok
drop function sub;

statement ok
drop function add_sub_binding;

statement ok
drop function add_sub;

# Drop the mock table
statement ok
drop table t1;

statement ok
drop table t2;
5 changes: 5 additions & 0 deletions src/frontend/src/binder/expr/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ impl Binder {
}
};

// Special check for sql udf
xzhseh marked this conversation as resolved.
Show resolved Hide resolved
if self.udf_context.contains_key(&column_name) {
return self.bind_expr(self.udf_context.get(&column_name).unwrap().clone());
}
xzhseh marked this conversation as resolved.
Show resolved Hide resolved

match self
.context
.get_column_binding_indices(&table_name, &column_name)
Expand Down
80 changes: 70 additions & 10 deletions src/frontend/src/binder/expr/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
use std::collections::HashMap;
use std::iter::once;
use std::str::FromStr;
use std::sync::LazyLock;
use std::sync::{Arc, LazyLock};

use bk_tree::{metrics, BKTree};
use itertools::Itertools;
Expand All @@ -30,13 +30,14 @@ use risingwave_expr::window_function::{
Frame, FrameBound, FrameBounds, FrameExclusion, WindowFuncKind,
};
use risingwave_sqlparser::ast::{
self, Function, FunctionArg, FunctionArgExpr, Ident, WindowFrameBound, WindowFrameExclusion,
WindowFrameUnits, WindowSpec,
self, Expr as AstExpr, Function, FunctionArg, FunctionArgExpr, Ident, SelectItem, SetExpr,
Statement, WindowFrameBound, WindowFrameExclusion, WindowFrameUnits, WindowSpec,
};
use thiserror_ext::AsReport;

use crate::binder::bind_context::Clause;
use crate::binder::{Binder, BoundQuery, BoundSetExpr};
use crate::catalog::function_catalog::FunctionCatalog;
use crate::expr::{
AggCall, Expr, ExprImpl, ExprType, FunctionCall, FunctionCallWithLambda, Literal, Now, OrderBy,
Subquery, SubqueryKind, TableFunction, TableFunctionType, UserDefinedFunction, WindowFunction,
Expand Down Expand Up @@ -117,6 +118,9 @@ impl Binder {
return self.bind_array_transform(f);
}

// Used later in sql udf expression evaluation
let args = f.args.clone();

let inputs = f
.args
.into_iter()
Expand Down Expand Up @@ -149,6 +153,53 @@ impl Binder {
return Ok(TableFunction::new(function_type, inputs)?.into());
}

/// TODO: add name related logic
fn create_udf_context(
binder: &mut Binder,
args: &[FunctionArg],
_catalog: &Arc<FunctionCatalog>,
) {
binder.udf_context = args
xzhseh marked this conversation as resolved.
Show resolved Hide resolved
.iter()
.enumerate()
.map(|(i, current_arg)| {
if let FunctionArg::Unnamed(arg) = current_arg {
let FunctionArgExpr::Expr(e) = arg else {
panic!("invalid syntax");
xzhseh marked this conversation as resolved.
Show resolved Hide resolved
};
// if catalog.arg_names.is_some() {
// panic!("invalid syntax");
// }
return ("$".to_string() + &(i + 1).to_string(), e.clone());
xzhseh marked this conversation as resolved.
Show resolved Hide resolved
}
panic!("invalid syntax");
})
.collect()
}

fn extract_udf_expression(ast: Vec<Statement>) -> AstExpr {
// Extract the expression out
let Statement::Query(query) = ast
.into_iter()
.exactly_one()
.expect("sql udf should contain only one statement")
else {
unreachable!("sql udf should contain a query statement");
xzhseh marked this conversation as resolved.
Show resolved Hide resolved
};
let SetExpr::Select(select) = query.body else {
panic!("Invalid syntax");
};
let projection = select.projection;
let SelectItem::UnnamedExpr(expr) = projection
.into_iter()
.exactly_one()
.expect("`projection` should contain only one `SelectItem`")
else {
unreachable!("`projection` should contain only one `SelectItem`");
};
expr
}

// user defined function
// TODO: resolve schema name https://github.com/risingwavelabs/risingwave/issues/12422
if let Ok(schema) = self.first_valid_schema()
Expand All @@ -158,13 +209,22 @@ impl Binder {
)
{
use crate::catalog::function_catalog::FunctionKind::*;
match &func.kind {
Scalar { .. } => return Ok(UserDefinedFunction::new(func.clone(), inputs).into()),
Table { .. } => {
self.ensure_table_function_allowed()?;
return Ok(TableFunction::new_user_defined(func.clone(), inputs).into());
if func.language == "sql" {
// This represents the current user defined function is `language sql`
let ast = risingwave_sqlparser::parser::Parser::parse_sql(func.identifier.as_str()).unwrap();
xzhseh marked this conversation as resolved.
Show resolved Hide resolved
// The actual inline logic
create_udf_context(self, &args, &Arc::clone(func));
return self.bind_expr(extract_udf_expression(ast));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if we should create a new binder to do the inline. Or at least we should clean the udf_context after inlining, otherwise the following binding will be affected.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer the latter solution, creating a new binder to do the inline may add extra cost and is hard to integrate with the on-going binding process in the main routine 🤔️?

} else {
debug_assert!(func.language == "python" || func.language == "java", "only `python` and `java` are currently supported for general udf");
match &func.kind {
Scalar { .. } => return Ok(UserDefinedFunction::new(func.clone(), inputs).into()),
Table { .. } => {
self.ensure_table_function_allowed()?;
return Ok(TableFunction::new_user_defined(func.clone(), inputs).into());
}
Aggregate => todo!("support UDAF"),
}
Aggregate => todo!("support UDAF"),
}
}

Expand Down Expand Up @@ -1216,7 +1276,7 @@ impl Binder {
static FUNCTIONS_BKTREE: LazyLock<BKTree<&str>> = LazyLock::new(|| {
let mut tree = BKTree::new(metrics::Levenshtein);

// TODO: Also hint other functinos, e,g, Agg or UDF.
// TODO: Also hint other functinos, e.g., Agg or UDF.
for k in HANDLES.keys() {
tree.add(*k);
}
Expand Down
6 changes: 6 additions & 0 deletions src/frontend/src/binder/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,12 @@ impl Binder {
}

fn bind_parameter(&mut self, index: u64) -> Result<ExprImpl> {
// Special check for sql udf
let name = "$".to_string() + &index.to_string();
if self.udf_context.contains_key(&name) {
return self.bind_expr(self.udf_context.get(&name).unwrap().clone());
}
xzhseh marked this conversation as resolved.
Show resolved Hide resolved

Ok(Parameter::new(index, self.param_types.clone()).into())
}

Expand Down
5 changes: 4 additions & 1 deletion src/frontend/src/binder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use risingwave_common::error::Result;
use risingwave_common::session_config::{ConfigMap, SearchPath};
use risingwave_common::types::DataType;
use risingwave_common::util::iter_util::ZipEqDebug;
use risingwave_sqlparser::ast::Statement;
use risingwave_sqlparser::ast::{Expr as AstExpr, Statement};

mod bind_context;
mod bind_param;
Expand Down Expand Up @@ -115,6 +115,8 @@ pub struct Binder {
included_relations: HashSet<TableId>,

param_types: ParameterTypes,

udf_context: HashMap<String, AstExpr>,
xzhseh marked this conversation as resolved.
Show resolved Hide resolved
}

/// `ParameterTypes` is used to record the types of the parameters during binding. It works
Expand Down Expand Up @@ -216,6 +218,7 @@ impl Binder {
shared_views: HashMap::new(),
included_relations: HashSet::new(),
param_types: ParameterTypes::new(param_types),
udf_context: HashMap::new(),
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/frontend/src/catalog/root_catalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ impl<'a> SchemaPath<'a> {
/// - catalog (root catalog)
/// - database catalog
/// - schema catalog
/// - function catalog
/// - function catalog (e.g., user defined function)
xzhseh marked this conversation as resolved.
Show resolved Hide resolved
/// - table/sink/source/index/view catalog
/// - column catalog
pub struct Catalog {
Expand Down
1 change: 1 addition & 0 deletions src/frontend/src/catalog/schema_catalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,7 @@ impl SchemaCatalog {
name: &str,
args: &[DataType],
) -> Option<&Arc<FunctionCatalog>> {
println!("Current args: {:#?}", args);
xzhseh marked this conversation as resolved.
Show resolved Hide resolved
self.function_by_name.get(name)?.get(args)
}

Expand Down
1 change: 1 addition & 0 deletions src/frontend/src/handler/create_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ pub async fn handle_create_function(
if temporary {
bail_not_implemented!("CREATE TEMPORARY FUNCTION");
}
// e.g., `language [ python / java / ...etc]`
let language = match params.language {
Some(lang) => {
let lang = lang.real_value().to_lowercase();
Expand Down
Loading
Loading