diff --git a/e2e_test/udf/sql_udf.slt b/e2e_test/udf/sql_udf.slt new file mode 100644 index 0000000000000..8b89010f70a93 --- /dev/null +++ b/e2e_test/udf/sql_udf.slt @@ -0,0 +1,192 @@ +statement ok +SET RW_IMPLICIT_FLUSH TO true; + +# Create an anonymous function with double dollar as clause +statement ok +create function add(INT, INT) returns int language sql as $$select $1 + $2$$; + +# Create an anonymous function with single quote as clause +statement ok +create function sub(INT, INT) returns int language sql as 'select $1 - $2'; + +# Create an anonymous function that calls other pre-defined sql udfs +statement ok +create function add_sub_binding() returns int language sql as 'select add(1, 1) + sub(2, 2)'; + +# Create an anonymous function that calls built-in functions +# Note that double dollar signs should be used otherwise the parsing will fail, as illutrates below +statement ok +create function call_regexp_replace() returns varchar language sql as $$select regexp_replace('💩💩💩💩💩foo🤔️bar亲爱的😭baz这不是爱情❤️‍🔥', 'baz(...)', '这是🥵', 'ic')$$; + +statement error Expected end of statement, found: 💩 +create function call_regexp_replace() returns varchar language sql as 'select regexp_replace('💩💩💩💩💩foo🤔️bar亲爱的😭baz这不是爱情❤️‍🔥', 'baz(...)', '这是🥵', 'ic')'; + +# Create an anonymous function with return expression +statement ok +create function add_return(INT, INT) returns int language sql return $1 + $2; + +statement ok +create function add_return_binding() returns int language sql return add_return(1, 1) + add_return(1, 1); + +# Recursive definition is forbidden +statement error recursive definition is forbidden, please recheck your function syntax +create function recursive(INT, INT) returns int language sql as 'select recursive($1, $2) + recursive($1, $2)'; + +# Create a wrapper function for `add` & `sub` +statement ok +create function add_sub_wrapper(INT, INT) returns int language sql as 'select add($1, $2) + sub($1, $2) + 114512'; + +# Call the defined sql udf +query I +select add(1, -1); +---- +0 + +query I +select sub(1, 1); +---- +0 + +query I +select add_sub_binding(); +---- +2 + +query III +select add(1, -1), sub(1, 1), add_sub_binding(); +---- +0 0 2 + +query I +select add_return(1, 1); +---- +2 + +query I +select add_return_binding(); +---- +4 + +query T +select call_regexp_replace(); +---- +💩💩💩💩💩foo🤔️bar亲爱的😭这是🥵爱情❤️‍🔥 + +query I +select add_sub_wrapper(1, 1); +---- +114514 + +# Create a mock table +statement ok +create table t1 (c1 INT, c2 INT); + +# Insert some data into the mock table +statement ok +insert into t1 values (1, 1), (2, 2), (3, 3), (4, 4), (5, 5); + +query III +select sub(c1, c2), c1, c2, add(c1, c2) from t1 order by c1 asc; +---- +0 1 1 2 +0 2 2 4 +0 3 3 6 +0 4 4 8 +0 5 5 10 + +query I +select c1, c2, add_return(c1, c2) from t1 order by c1 asc; +---- +1 1 2 +2 2 4 +3 3 6 +4 4 8 +5 5 10 + +# Invalid function body syntax +statement error Expected an expression:, found: EOF at the end +create function add_error(INT, INT) returns int language sql as $$select $1 + $2 +$$; + +# Multiple type interleaving sql udf +statement ok +create function add_sub(INT, FLOAT, INT) returns float language sql as $$select -$1 + $2 - $3$$; + +statement ok +create function add_sub_return(INT, FLOAT, INT) returns float language sql return -$1 + $2 - $3; + +# Note: need EXPLICIT type cast in order to call the multiple types interleaving sql udf +query I +select add_sub(1::INT, 5.1415926::FLOAT, 1::INT); +---- +3.1415926 + +# Without EXPLICIT type cast +statement error unsupported function: "add_sub" +select add_sub(1, 3.14, 2); + +# Same as above, need EXPLICIT type cast to make the binding works +query I +select add_sub_return(1::INT, 5.1415926::FLOAT, 1::INT); +---- +3.1415926 + +query III +select add(1, -1), sub(1, 1), add_sub(1::INT, 5.1415926::FLOAT, 1::INT); +---- +0 0 3.1415926 + +# Create another mock table +statement ok +create table t2 (c1 INT, c2 FLOAT, c3 INT); + +statement ok +insert into t2 values (1, 3.14, 2), (2, 4.44, 5), (20, 10.30, 02); + +query IIIIII +select c1, c2, c3, add(c1, c3), sub(c1, c3), add_sub(c1::INT, c2::FLOAT, c3::INT) from t2 order by c1 asc; +---- +1 3.14 2 3 -1 0.14000000000000012 +2 4.44 5 7 -3 -2.5599999999999996 +20 10.3 2 22 18 -11.7 + +query IIIIII +select c1, c2, c3, add(c1, c3), sub(c1, c3), add_sub_return(c1::INT, c2::FLOAT, c3::INT) from t2 order by c1 asc; +---- +1 3.14 2 3 -1 0.14000000000000012 +2 4.44 5 7 -3 -2.5599999999999996 +20 10.3 2 22 18 -11.7 + +# Drop the functions +statement ok +drop function add; + +statement ok +drop function sub; + +statement ok +drop function add_sub_binding; + +statement ok +drop function add_sub; + +statement ok +drop function add_sub_return; + +statement ok +drop function add_return; + +statement ok +drop function add_return_binding; + +statement ok +drop function call_regexp_replace; + +statement ok +drop function add_sub_wrapper; + +# Drop the mock table +statement ok +drop table t1; + +statement ok +drop table t2; diff --git a/proto/catalog.proto b/proto/catalog.proto index 01a7893383232..ec7c68a3802ba 100644 --- a/proto/catalog.proto +++ b/proto/catalog.proto @@ -218,6 +218,7 @@ message Function { string language = 7; string link = 8; string identifier = 10; + optional string body = 14; oneof kind { ScalarFunction scalar = 11; diff --git a/src/frontend/src/binder/expr/column.rs b/src/frontend/src/binder/expr/column.rs index 16053208ec8d3..2f2a8d9335256 100644 --- a/src/frontend/src/binder/expr/column.rs +++ b/src/frontend/src/binder/expr/column.rs @@ -37,6 +37,18 @@ impl Binder { } }; + // Special check for sql udf + // Note: The check in `bind_column` is to inline the identifiers, + // which, in the context of sql udf, will NOT be perceived as normal + // columns, but the actual named input parameters. + // Thus, we need to figure out if the current "column name" corresponds + // to the name of the defined sql udf parameters stored in `udf_context`. + // If so, we will treat this bind as an special bind, the actual expression + // stored in `udf_context` will then be bound instead of binding the non-existing column. + if let Some(expr) = self.udf_context.get(&column_name) { + return self.bind_expr(expr.clone()); + } + match self .context .get_column_binding_indices(&table_name, &column_name) diff --git a/src/frontend/src/binder/expr/function.rs b/src/frontend/src/binder/expr/function.rs index c8558b0756b5d..7244a8527f857 100644 --- a/src/frontend/src/binder/expr/function.rs +++ b/src/frontend/src/binder/expr/function.rs @@ -15,7 +15,7 @@ use std::collections::HashMap; use std::iter::once; use std::str::FromStr; -use std::sync::LazyLock; +use std::sync::{Arc, LazyLock}; use bk_tree::{metrics, BKTree}; use itertools::Itertools; @@ -30,13 +30,15 @@ use risingwave_expr::window_function::{ Frame, FrameBound, FrameBounds, FrameExclusion, WindowFuncKind, }; use risingwave_sqlparser::ast::{ - self, Function, FunctionArg, FunctionArgExpr, Ident, WindowFrameBound, WindowFrameExclusion, - WindowFrameUnits, WindowSpec, + self, Expr as AstExpr, Function, FunctionArg, FunctionArgExpr, Ident, SelectItem, SetExpr, + Statement, WindowFrameBound, WindowFrameExclusion, WindowFrameUnits, WindowSpec, }; +use risingwave_sqlparser::parser::ParserError; use thiserror_ext::AsReport; use crate::binder::bind_context::Clause; use crate::binder::{Binder, BoundQuery, BoundSetExpr}; +use crate::catalog::function_catalog::FunctionCatalog; use crate::expr::{ AggCall, Expr, ExprImpl, ExprType, FunctionCall, FunctionCallWithLambda, Literal, Now, OrderBy, Subquery, SubqueryKind, TableFunction, TableFunctionType, UserDefinedFunction, WindowFunction, @@ -117,6 +119,9 @@ impl Binder { return self.bind_array_transform(f); } + // Used later in sql udf expression evaluation + let args = f.args.clone(); + let inputs = f .args .into_iter() @@ -149,6 +154,73 @@ impl Binder { return Ok(TableFunction::new(function_type, inputs)?.into()); } + /// TODO: add name related logic + /// NOTE: need to think of a way to prevent naming conflict + /// e.g., when existing column names conflict with parameter names in sql udf + fn create_udf_context( + args: &[FunctionArg], + _catalog: &Arc, + ) -> Result> { + let mut ret: HashMap = HashMap::new(); + for (i, current_arg) in args.iter().enumerate() { + if let FunctionArg::Unnamed(arg) = current_arg { + let FunctionArgExpr::Expr(e) = arg else { + return Err( + ErrorCode::InvalidInputSyntax("invalid syntax".to_string()).into() + ); + }; + // if catalog.arg_names.is_some() { + // todo!() + // } + ret.insert(format!("${}", i + 1), e.clone()); + continue; + } + return Err(ErrorCode::InvalidInputSyntax("invalid syntax".to_string()).into()); + } + Ok(ret) + } + + fn extract_udf_expression(ast: Vec) -> Result { + if ast.len() != 1 { + return Err(ErrorCode::InvalidInputSyntax( + "the query for sql udf should contain only one statement".to_string(), + ) + .into()); + } + + // Extract the expression out + let Statement::Query(query) = ast[0].clone() else { + return Err(ErrorCode::InvalidInputSyntax( + "invalid function definition, please recheck the syntax".to_string(), + ) + .into()); + }; + + let SetExpr::Select(select) = query.body else { + return Err(ErrorCode::InvalidInputSyntax( + "missing `select` body for sql udf expression, please recheck the syntax" + .to_string(), + ) + .into()); + }; + + if select.projection.len() != 1 { + return Err(ErrorCode::InvalidInputSyntax( + "`projection` should contain only one `SelectItem`".to_string(), + ) + .into()); + } + + let SelectItem::UnnamedExpr(expr) = select.projection[0].clone() else { + return Err(ErrorCode::InvalidInputSyntax( + "expect `UnnamedExpr` for `projection`".to_string(), + ) + .into()); + }; + + Ok(expr) + } + // user defined function // TODO: resolve schema name https://github.com/risingwavelabs/risingwave/issues/12422 if let Ok(schema) = self.first_valid_schema() @@ -158,13 +230,89 @@ impl Binder { ) { use crate::catalog::function_catalog::FunctionKind::*; - match &func.kind { - Scalar { .. } => return Ok(UserDefinedFunction::new(func.clone(), inputs).into()), - Table { .. } => { - self.ensure_table_function_allowed()?; - return Ok(TableFunction::new_user_defined(func.clone(), inputs).into()); + if func.language == "sql" { + if func.body.is_none() { + return Err(ErrorCode::InvalidInputSyntax( + "`body` must exist for sql udf".to_string(), + ) + .into()); + } + // This represents the current user defined function is `language sql` + let parse_result = risingwave_sqlparser::parser::Parser::parse_sql( + func.body.as_ref().unwrap().as_str(), + ); + if let Err(ParserError::ParserError(err)) | Err(ParserError::TokenizerError(err)) = + parse_result + { + // Here we just return the original parse error message + return Err(ErrorCode::InvalidInputSyntax(err).into()); + } + debug_assert!(parse_result.is_ok()); + + // We can safely unwrap here + let ast = parse_result.unwrap(); + + let mut clean_flag = true; + + // We need to check if the `udf_context` is empty first, consider the following example: + // - create function add(INT, INT) returns int language sql as 'select $1 + $2'; + // - create function add_wrapper(INT, INT) returns int language sql as 'select add($1, $2)'; + // - select add_wrapper(1, 1); + // When binding `add($1, $2)` in `add_wrapper`, the input args are [$1, $2] instead of + // the original [1, 1], thus we need to check `udf_context` to see if the input + // args already exist in the context. If so, we do NOT need to create the context again. + // Otherwise the current `udf_context` will be corrupted. + if self.udf_context.is_empty() { + // The actual inline logic for sql udf + if let Ok(context) = create_udf_context(&args, &Arc::clone(func)) { + self.udf_context = context; + } else { + return Err(ErrorCode::InvalidInputSyntax( + "failed to create the `udf_context`, please recheck your function definition and syntax".to_string() + ) + .into()); + } + } else { + // If the `udf_context` is not empty, this means the current binding + // function is not the root binding sql udf, thus we should NOT + // clean the context after binding. + clean_flag = false; + } + + if let Ok(expr) = extract_udf_expression(ast) { + let bind_result = self.bind_expr(expr); + // Clean the `udf_context` after inlining, + // which makes sure the subsequent binding will not be affected + if clean_flag { + self.udf_context.clear(); + } + return bind_result; + } else { + return Err(ErrorCode::InvalidInputSyntax( + "failed to parse the input query and extract the udf expression, + please recheck the syntax" + .to_string(), + ) + .into()); + } + } else { + // Note that `language` may be empty for external udf + if !func.language.is_empty() { + debug_assert!( + func.language == "python" || func.language == "java", + "only `python` and `java` are currently supported for general udf" + ); + } + match &func.kind { + Scalar { .. } => { + return Ok(UserDefinedFunction::new(func.clone(), inputs).into()) + } + Table { .. } => { + self.ensure_table_function_allowed()?; + return Ok(TableFunction::new_user_defined(func.clone(), inputs).into()); + } + Aggregate => todo!("support UDAF"), } - Aggregate => todo!("support UDAF"), } } @@ -1213,7 +1361,7 @@ impl Binder { static FUNCTIONS_BKTREE: LazyLock> = LazyLock::new(|| { let mut tree = BKTree::new(metrics::Levenshtein); - // TODO: Also hint other functinos, e,g, Agg or UDF. + // TODO: Also hint other functinos, e.g., Agg or UDF. for k in HANDLES.keys() { tree.add(*k); } diff --git a/src/frontend/src/binder/expr/mod.rs b/src/frontend/src/binder/expr/mod.rs index 179469e545c2b..cacd2d80dcfe4 100644 --- a/src/frontend/src/binder/expr/mod.rs +++ b/src/frontend/src/binder/expr/mod.rs @@ -378,6 +378,14 @@ impl Binder { } fn bind_parameter(&mut self, index: u64) -> Result { + // Special check for sql udf + // Note: This is specific to anonymous sql udf, since the + // parameters will be parsed and treated as `Parameter`. + // For detailed explanation, consider checking `bind_column`. + if let Some(expr) = self.udf_context.get(&format!("${index}")) { + return self.bind_expr(expr.clone()); + } + Ok(Parameter::new(index, self.param_types.clone()).into()) } diff --git a/src/frontend/src/binder/mod.rs b/src/frontend/src/binder/mod.rs index be0f441c0743a..6ba891aa6b513 100644 --- a/src/frontend/src/binder/mod.rs +++ b/src/frontend/src/binder/mod.rs @@ -21,7 +21,7 @@ use risingwave_common::error::Result; use risingwave_common::session_config::{ConfigMap, SearchPath}; use risingwave_common::types::DataType; use risingwave_common::util::iter_util::ZipEqDebug; -use risingwave_sqlparser::ast::Statement; +use risingwave_sqlparser::ast::{Expr as AstExpr, Statement}; mod bind_context; mod bind_param; @@ -115,6 +115,10 @@ pub struct Binder { included_relations: HashSet, param_types: ParameterTypes, + + /// The mapping from sql udf parameters to ast expressions + /// Note: The expressions are constructed during runtime, correspond to the actual users' input + udf_context: HashMap, } /// `ParameterTypes` is used to record the types of the parameters during binding. It works @@ -216,6 +220,7 @@ impl Binder { shared_views: HashMap::new(), included_relations: HashSet::new(), param_types: ParameterTypes::new(param_types), + udf_context: HashMap::new(), } } diff --git a/src/frontend/src/catalog/function_catalog.rs b/src/frontend/src/catalog/function_catalog.rs index 7197821b33ce6..d0f037bcb47b5 100644 --- a/src/frontend/src/catalog/function_catalog.rs +++ b/src/frontend/src/catalog/function_catalog.rs @@ -30,6 +30,7 @@ pub struct FunctionCatalog { pub return_type: DataType, pub language: String, pub identifier: String, + pub body: Option, pub link: String, } @@ -63,6 +64,7 @@ impl From<&PbFunction> for FunctionCatalog { return_type: prost.return_type.as_ref().expect("no return type").into(), language: prost.language.clone(), identifier: prost.identifier.clone(), + body: prost.body.clone(), link: prost.link.clone(), } } diff --git a/src/frontend/src/catalog/root_catalog.rs b/src/frontend/src/catalog/root_catalog.rs index bd411577476ba..9d2045f5dc61f 100644 --- a/src/frontend/src/catalog/root_catalog.rs +++ b/src/frontend/src/catalog/root_catalog.rs @@ -88,7 +88,7 @@ impl<'a> SchemaPath<'a> { /// - catalog (root catalog) /// - database catalog /// - schema catalog -/// - function catalog +/// - function catalog (i.e., user defined function) /// - table/sink/source/index/view catalog /// - column catalog pub struct Catalog { diff --git a/src/frontend/src/expr/user_defined_function.rs b/src/frontend/src/expr/user_defined_function.rs index abd39fdbbc0c4..165774d1acb4b 100644 --- a/src/frontend/src/expr/user_defined_function.rs +++ b/src/frontend/src/expr/user_defined_function.rs @@ -55,6 +55,8 @@ impl UserDefinedFunction { return_type, language: udf.get_language().clone(), identifier: udf.get_identifier().clone(), + // TODO: Ensure if we need `body` here + body: None, link: udf.get_link().clone(), }; diff --git a/src/frontend/src/handler/create_function.rs b/src/frontend/src/handler/create_function.rs index c623b67eecb60..4557b71223b98 100644 --- a/src/frontend/src/handler/create_function.rs +++ b/src/frontend/src/handler/create_function.rs @@ -44,6 +44,7 @@ pub async fn handle_create_function( if temporary { bail_not_implemented!("CREATE TEMPORARY FUNCTION"); } + // e.g., `language [ python / java / ...etc]` let language = match params.language { Some(lang) => { let lang = lang.real_value().to_lowercase(); @@ -161,6 +162,7 @@ pub async fn handle_create_function( return_type: Some(return_type.into()), language, identifier, + body: None, link, owner: session.user_id(), }; diff --git a/src/frontend/src/handler/create_sql_function.rs b/src/frontend/src/handler/create_sql_function.rs new file mode 100644 index 0000000000000..834e0bec3135d --- /dev/null +++ b/src/frontend/src/handler/create_sql_function.rs @@ -0,0 +1,181 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use itertools::Itertools; +use pgwire::pg_response::StatementType; +use risingwave_common::catalog::FunctionId; +use risingwave_common::types::DataType; +use risingwave_pb::catalog::function::{Kind, ScalarFunction, TableFunction}; +use risingwave_pb::catalog::Function; +use risingwave_sqlparser::ast::{ + CreateFunctionBody, FunctionDefinition, ObjectName, OperateFunctionArg, +}; +use risingwave_sqlparser::parser::{Parser, ParserError}; + +use super::*; +use crate::catalog::CatalogError; +use crate::{bind_data_type, Binder}; + +pub async fn handle_create_sql_function( + handler_args: HandlerArgs, + or_replace: bool, + temporary: bool, + name: ObjectName, + args: Option>, + returns: Option, + params: CreateFunctionBody, +) -> Result { + if or_replace { + bail_not_implemented!("CREATE OR REPLACE FUNCTION"); + } + + if temporary { + bail_not_implemented!("CREATE TEMPORARY FUNCTION"); + } + + let language = "sql".to_string(); + // Just a basic sanity check for language + if !matches!(params.language, Some(lang) if lang.real_value().to_lowercase() == "sql") { + return Err(ErrorCode::InvalidParameterValue( + "`language` for sql udf must be `sql`".to_string(), + ) + .into()); + } + + // SQL udf function supports both single quote (i.e., as 'select $1 + $2') + // and double dollar (i.e., as $$select $1 + $2$$) for as clause + let body = match ¶ms.as_ { + Some(FunctionDefinition::SingleQuotedDef(s)) => s.clone(), + Some(FunctionDefinition::DoubleDollarDef(s)) => s.clone(), + None => { + if params.return_.is_none() { + return Err(ErrorCode::InvalidParameterValue( + "AS or RETURN must be specified".to_string(), + ) + .into()); + } + // Otherwise this is a return expression + // Note: this is a current work around, and we are assuming return sql udf + // will NOT involve complex syntax, so just reuse the logic for select definition + format!("select {}", ¶ms.return_.unwrap().to_string()) + } + }; + + // We do NOT allow recursive calling inside sql udf + // Since there does not exist the base case for this definition + if body.contains(format!("{}(", name.real_value()).as_str()) { + return Err(ErrorCode::InvalidInputSyntax( + "recursive definition is forbidden, please recheck your function syntax".to_string(), + ) + .into()); + } + + // Sanity check for link, this must be none with sql udf function + if let Some(CreateFunctionUsing::Link(_)) = params.using { + return Err(ErrorCode::InvalidParameterValue( + "USING must NOT be specified with sql udf function".to_string(), + ) + .into()); + }; + + // Get return type for the current sql udf function + let return_type; + let kind = match returns { + Some(CreateFunctionReturns::Value(data_type)) => { + return_type = bind_data_type(&data_type)?; + Kind::Scalar(ScalarFunction {}) + } + Some(CreateFunctionReturns::Table(columns)) => { + if columns.len() == 1 { + // return type is the original type for single column + return_type = bind_data_type(&columns[0].data_type)?; + } else { + // return type is a struct for multiple columns + let datatypes = columns + .iter() + .map(|c| bind_data_type(&c.data_type)) + .collect::>>()?; + let names = columns + .iter() + .map(|c| c.name.real_value()) + .collect::>(); + return_type = DataType::new_struct(datatypes, names); + } + Kind::Table(TableFunction {}) + } + None => { + return Err(ErrorCode::InvalidParameterValue( + "return type must be specified".to_string(), + ) + .into()) + } + }; + + let mut arg_types = vec![]; + for arg in args.unwrap_or_default() { + arg_types.push(bind_data_type(&arg.data_type)?); + } + + // resolve database and schema id + let session = &handler_args.session; + let db_name = session.database(); + let (schema_name, function_name) = Binder::resolve_schema_qualified_name(db_name, name)?; + let (database_id, schema_id) = session.get_database_and_schema_id_for_create(schema_name)?; + + // check if function exists + if (session.env().catalog_reader().read_guard()) + .get_schema_by_id(&database_id, &schema_id)? + .get_function_by_name_args(&function_name, &arg_types) + .is_some() + { + let name = format!( + "{function_name}({})", + arg_types.iter().map(|t| t.to_string()).join(",") + ); + return Err(CatalogError::Duplicated("function", name).into()); + } + + // Parse function body here + // Note that the parsing here is just basic syntax / semantic check, the result will NOT be stored + // e.g., The provided function body contains invalid syntax, return type mismatch, ..., etc. + let parse_result = Parser::parse_sql(body.as_str()); + if let Err(ParserError::ParserError(err)) | Err(ParserError::TokenizerError(err)) = parse_result + { + // Here we just return the original parse error message + return Err(ErrorCode::InvalidInputSyntax(err).into()); + } else { + debug_assert!(parse_result.is_ok()); + } + + // Create the actual function, will be stored in function catalog + let function = Function { + id: FunctionId::placeholder().0, + schema_id, + database_id, + name: function_name, + kind: Some(kind), + arg_types: arg_types.into_iter().map(|t| t.into()).collect(), + return_type: Some(return_type.into()), + language, + identifier: "".to_string(), + body: Some(body), + link: "".to_string(), + owner: session.user_id(), + }; + + let catalog_writer = session.catalog_writer()?; + catalog_writer.create_function(function).await?; + + Ok(PgResponse::empty_result(StatementType::CREATE_FUNCTION)) +} diff --git a/src/frontend/src/handler/mod.rs b/src/frontend/src/handler/mod.rs index d2ec467df424f..495d9ebd3729a 100644 --- a/src/frontend/src/handler/mod.rs +++ b/src/frontend/src/handler/mod.rs @@ -53,6 +53,7 @@ pub mod create_mv; pub mod create_schema; pub mod create_sink; pub mod create_source; +pub mod create_sql_function; pub mod create_table; pub mod create_table_as; pub mod create_user; @@ -205,16 +206,39 @@ pub async fn handle( returns, params, } => { - create_function::handle_create_function( - handler_args, - or_replace, - temporary, - name, - args, - returns, - params, - ) - .await + // For general udf, `language` clause could be ignored + // refer: https://github.com/risingwavelabs/risingwave/pull/10608 + if params.language.is_none() + || !params + .language + .as_ref() + .unwrap() + .real_value() + .eq_ignore_ascii_case("sql") + { + // User defined function with external source (e.g., language [ python / java ]) + create_function::handle_create_function( + handler_args, + or_replace, + temporary, + name, + args, + returns, + params, + ) + .await + } else { + create_sql_function::handle_create_sql_function( + handler_args, + or_replace, + temporary, + name, + args, + returns, + params, + ) + .await + } } Statement::CreateTable { name, diff --git a/src/meta/model_v2/migration/src/m20230908_072257_init.rs b/src/meta/model_v2/migration/src/m20230908_072257_init.rs index 5b3d55ab83bfb..bc9ce2b08c32b 100644 --- a/src/meta/model_v2/migration/src/m20230908_072257_init.rs +++ b/src/meta/model_v2/migration/src/m20230908_072257_init.rs @@ -708,6 +708,7 @@ impl MigrationTrait for Migration { .col(ColumnDef::new(Function::Language).string().not_null()) .col(ColumnDef::new(Function::Link).string().not_null()) .col(ColumnDef::new(Function::Identifier).string().not_null()) + .col(ColumnDef::new(Function::Body).string()) .col(ColumnDef::new(Function::Kind).string().not_null()) .foreign_key( &mut ForeignKey::create() @@ -1099,6 +1100,7 @@ enum Function { Language, Link, Identifier, + Body, Kind, } diff --git a/src/meta/model_v2/src/function.rs b/src/meta/model_v2/src/function.rs index 71391a3cc27b0..5976685893afb 100644 --- a/src/meta/model_v2/src/function.rs +++ b/src/meta/model_v2/src/function.rs @@ -41,6 +41,7 @@ pub struct Model { pub language: String, pub link: String, pub identifier: String, + pub body: Option, pub kind: FunctionKind, } @@ -94,6 +95,7 @@ impl From for ActiveModel { language: Set(function.language), link: Set(function.link), identifier: Set(function.identifier), + body: Set(function.body), kind: Set(function.kind.unwrap().into()), } } diff --git a/src/meta/src/controller/mod.rs b/src/meta/src/controller/mod.rs index d6d6891a12e8d..037f9e3417163 100644 --- a/src/meta/src/controller/mod.rs +++ b/src/meta/src/controller/mod.rs @@ -277,6 +277,7 @@ impl From> for PbFunction { language: value.0.language, link: value.0.link, identifier: value.0.identifier, + body: value.0.body, kind: Some(value.0.kind.into()), } } diff --git a/src/sqlparser/README.md b/src/sqlparser/README.md index 20a5ac0c8e6ab..2a829102710ba 100644 --- a/src/sqlparser/README.md +++ b/src/sqlparser/README.md @@ -4,5 +4,5 @@ This parser is a fork of . ## Add a new test case -1. Copy an item in the yaml file and edit the `input` to the sql you want to test -2. Run `./risedev do-apply-parser-test` to regenerate the `formatted_sql` whicih is the expected output \ No newline at end of file +1. Copy an item in the yaml file and edit the `input` to the sql you want to test. +2. Run `./risedev do-apply-parser-test` to regenerate the `formatted_sql` which is the expected output. \ No newline at end of file diff --git a/src/sqlparser/tests/sqlparser_postgres.rs b/src/sqlparser/tests/sqlparser_postgres.rs index 2d97834ad23b5..99e2c185fdcff 100644 --- a/src/sqlparser/tests/sqlparser_postgres.rs +++ b/src/sqlparser/tests/sqlparser_postgres.rs @@ -768,6 +768,53 @@ fn parse_create_function() { } ); + let sql = "CREATE FUNCTION sub(INT, INT) RETURNS INT LANGUAGE SQL AS $$select $1 - $2;$$"; + assert_eq!( + verified_stmt(sql), + Statement::CreateFunction { + or_replace: false, + temporary: false, + name: ObjectName(vec![Ident::new_unchecked("sub")]), + args: Some(vec![ + OperateFunctionArg::unnamed(DataType::Int), + OperateFunctionArg::unnamed(DataType::Int), + ]), + returns: Some(CreateFunctionReturns::Value(DataType::Int)), + params: CreateFunctionBody { + language: Some("SQL".into()), + as_: Some(FunctionDefinition::DoubleDollarDef( + "select $1 - $2;".into() + )), + ..Default::default() + } + }, + ); + + // Anonymous return sql udf parsing test + let sql = "CREATE FUNCTION return_test(INT, INT) RETURNS INT LANGUAGE SQL RETURN $1 + $2"; + assert_eq!( + verified_stmt(sql), + Statement::CreateFunction { + or_replace: false, + temporary: false, + name: ObjectName(vec![Ident::new_unchecked("return_test")]), + args: Some(vec![ + OperateFunctionArg::unnamed(DataType::Int), + OperateFunctionArg::unnamed(DataType::Int), + ]), + returns: Some(CreateFunctionReturns::Value(DataType::Int)), + params: CreateFunctionBody { + language: Some("SQL".into()), + return_: Some(Expr::BinaryOp { + left: Box::new(Expr::Parameter { index: 1 }), + op: BinaryOperator::Plus, + right: Box::new(Expr::Parameter { index: 2 }), + }), + ..Default::default() + } + }, + ); + let sql = "CREATE OR REPLACE FUNCTION add(a INT, IN b INT = 1) RETURNS INT LANGUAGE SQL IMMUTABLE RETURN a + b"; assert_eq!( verified_stmt(sql),