diff --git a/src/expr/core/src/aggregate/def.rs b/src/expr/core/src/aggregate/def.rs index b050f8039e1c6..59e02af81fa9e 100644 --- a/src/expr/core/src/aggregate/def.rs +++ b/src/expr/core/src/aggregate/def.rs @@ -239,6 +239,7 @@ impl Display for AggKind { } } +/// `FromStr` for builtin aggregate functions. impl FromStr for AggKind { type Err = (); diff --git a/src/expr/core/src/window_function/kind.rs b/src/expr/core/src/window_function/kind.rs index 04b320f8ce9f2..3042facb5cffc 100644 --- a/src/expr/core/src/window_function/kind.rs +++ b/src/expr/core/src/window_function/kind.rs @@ -19,7 +19,7 @@ use crate::aggregate::AggKind; use crate::Result; /// Kind of window functions. -#[derive(Debug, Display, FromStr, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Display, FromStr /* for builtin */, Clone, PartialEq, Eq, Hash)] #[display(style = "snake_case")] pub enum WindowFuncKind { // General-purpose window functions. diff --git a/src/frontend/planner_test/tests/testdata/output/agg.yaml b/src/frontend/planner_test/tests/testdata/output/agg.yaml index 6a8a00eaa970b..3c6f0d6133616 100644 --- a/src/frontend/planner_test/tests/testdata/output/agg.yaml +++ b/src/frontend/planner_test/tests/testdata/output/agg.yaml @@ -801,7 +801,7 @@ Failed to bind expression: abs(a) FILTER (WHERE a > 0) Caused by: - Invalid input syntax: DISTINCT, ORDER BY or FILTER is only allowed in aggregation functions, but `abs` is not an aggregation function + Invalid input syntax: `FILTER` is not allowed in scalar/table function call - name: prune column before filter sql: | create table t(v1 int, v2 int); diff --git a/src/frontend/planner_test/tests/testdata/output/over_window_function.yaml b/src/frontend/planner_test/tests/testdata/output/over_window_function.yaml index fe98be90cc664..6ab86ec9e20fe 100644 --- a/src/frontend/planner_test/tests/testdata/output/over_window_function.yaml +++ b/src/frontend/planner_test/tests/testdata/output/over_window_function.yaml @@ -62,7 +62,7 @@ Failed to bind expression: lag(x) Caused by: - Invalid input syntax: Window function `lag` must have OVER clause + function lag(integer) does not exist, do you mean log - id: lag with empty over clause sql: | create table t(x int); diff --git a/src/frontend/src/binder/expr/function/aggregate.rs b/src/frontend/src/binder/expr/function/aggregate.rs index d6410616c1d9d..a9067205f77b0 100644 --- a/src/frontend/src/binder/expr/function/aggregate.rs +++ b/src/frontend/src/binder/expr/function/aggregate.rs @@ -16,7 +16,7 @@ use itertools::Itertools; use risingwave_common::bail_not_implemented; use risingwave_common::types::{DataType, ScalarImpl}; use risingwave_expr::aggregate::{agg_kinds, AggKind, PbAggKind}; -use risingwave_sqlparser::ast::{Function, FunctionArgExpr}; +use risingwave_sqlparser::ast::{self, FunctionArgExpr}; use crate::binder::Clause; use crate::error::{ErrorCode, Result}; @@ -48,21 +48,22 @@ impl Binder { pub(super) fn bind_aggregate_function( &mut self, - f: Function, kind: AggKind, + distinct: bool, + args: Vec, + order_by: Vec, + within_group: Option>, + filter: Option>, ) -> Result { self.ensure_aggregate_allowed()?; - let distinct = f.arg_list.distinct; - let filter_expr = f.filter.clone(); - let (direct_args, args, order_by) = if matches!(kind, agg_kinds::ordered_set!()) { - self.bind_ordered_set_agg(f, kind.clone())? + self.bind_ordered_set_agg(&kind, distinct, args, order_by, within_group)? } else { - self.bind_normal_agg(f, kind.clone())? + self.bind_normal_agg(&kind, distinct, args, order_by, within_group)? }; - let filter = match filter_expr { + let filter = match filter { Some(filter) => { let mut clause = Some(Clause::Filter); std::mem::swap(&mut self.context.clause, &mut clause); @@ -96,8 +97,11 @@ impl Binder { fn bind_ordered_set_agg( &mut self, - f: Function, - kind: AggKind, + kind: &AggKind, + distinct: bool, + args: Vec, + order_by: Vec, + within_group: Option>, ) -> Result<(Vec, Vec, OrderBy)> { // Syntax: // aggregate_name ( [ expression [ , ... ] ] ) WITHIN GROUP ( order_by_clause ) [ FILTER @@ -105,44 +109,38 @@ impl Binder { assert!(matches!(kind, agg_kinds::ordered_set!())); - if !f.arg_list.order_by.is_empty() { + if !order_by.is_empty() { return Err(ErrorCode::InvalidInputSyntax(format!( - "ORDER BY is not allowed for ordered-set aggregation `{}`", + "`ORDER BY` is not allowed for ordered-set aggregation `{}`", kind )) .into()); } - if f.arg_list.distinct { + if distinct { return Err(ErrorCode::InvalidInputSyntax(format!( - "DISTINCT is not allowed for ordered-set aggregation `{}`", + "`DISTINCT` is not allowed for ordered-set aggregation `{}`", kind )) .into()); } - let within_group = *f.within_group.ok_or_else(|| { + let within_group = *within_group.ok_or_else(|| { ErrorCode::InvalidInputSyntax(format!( - "WITHIN GROUP is expected for ordered-set aggregation `{}`", + "`WITHIN GROUP` is expected for ordered-set aggregation `{}`", kind )) })?; - let mut direct_args: Vec<_> = f - .arg_list - .args - .into_iter() - .map(|arg| self.bind_function_arg(arg)) - .flatten_ok() - .try_collect()?; + let mut direct_args = args; let mut args = self.bind_function_expr_arg(FunctionArgExpr::Expr(within_group.expr.clone()))?; let order_by = OrderBy::new(vec![self.bind_order_by_expr(within_group)?]); // check signature and do implicit cast - match (&kind, direct_args.len(), args.as_mut_slice()) { + match (kind, direct_args.len(), args.as_mut_slice()) { (AggKind::Builtin(PbAggKind::PercentileCont | PbAggKind::PercentileDisc), 1, [arg]) => { let fraction = &mut direct_args[0]; - decimal_to_float64(fraction, &kind)?; + decimal_to_float64(fraction, kind)?; if matches!(&kind, AggKind::Builtin(PbAggKind::PercentileCont)) { arg.cast_implicit_mut(DataType::Float64).map_err(|_| { ErrorCode::InvalidInputSyntax(format!( @@ -155,11 +153,11 @@ impl Binder { (AggKind::Builtin(PbAggKind::Mode), 0, [_arg]) => {} (AggKind::Builtin(PbAggKind::ApproxPercentile), 1..=2, [_percentile_col]) => { let percentile = &mut direct_args[0]; - decimal_to_float64(percentile, &kind)?; + decimal_to_float64(percentile, kind)?; match direct_args.len() { 2 => { let relative_error = &mut direct_args[1]; - decimal_to_float64(relative_error, &kind)?; + decimal_to_float64(relative_error, kind)?; } 1 => { let relative_error: ExprImpl = Literal::new( @@ -198,8 +196,11 @@ impl Binder { fn bind_normal_agg( &mut self, - f: Function, - kind: AggKind, + kind: &AggKind, + distinct: bool, + args: Vec, + order_by: Vec, + within_group: Option>, ) -> Result<(Vec, Vec, OrderBy)> { // Syntax: // aggregate_name (expression [ , ... ] [ order_by_clause ] ) [ FILTER ( WHERE @@ -212,30 +213,22 @@ impl Binder { assert!(!matches!(kind, agg_kinds::ordered_set!())); - if f.within_group.is_some() { + if within_group.is_some() { return Err(ErrorCode::InvalidInputSyntax(format!( - "WITHIN GROUP is not allowed for non-ordered-set aggregation `{}`", + "`WITHIN GROUP` is not allowed for non-ordered-set aggregation `{}`", kind )) .into()); } - let args: Vec<_> = f - .arg_list - .args - .iter() - .map(|arg| self.bind_function_arg(arg.clone())) - .flatten_ok() - .try_collect()?; let order_by = OrderBy::new( - f.arg_list - .order_by + order_by .into_iter() .map(|e| self.bind_order_by_expr(e)) .try_collect()?, ); - if f.arg_list.distinct { + if distinct { if matches!( kind, AggKind::Builtin(PbAggKind::ApproxCountDistinct) diff --git a/src/frontend/src/binder/expr/function/mod.rs b/src/frontend/src/binder/expr/function/mod.rs index 898a98c57dea6..2793ee54b85d6 100644 --- a/src/frontend/src/binder/expr/function/mod.rs +++ b/src/frontend/src/binder/expr/function/mod.rs @@ -27,6 +27,7 @@ use risingwave_sqlparser::parser::ParserError; use crate::binder::bind_context::Clause; use crate::binder::{Binder, UdfContext}; +use crate::catalog::function_catalog::FunctionCatalog; use crate::error::{ErrorCode, Result, RwError}; use crate::expr::{ Expr, ExprImpl, ExprType, FunctionCallWithLambda, InputRef, TableFunction, TableFunctionType, @@ -60,13 +61,27 @@ pub(super) fn is_sys_function_without_args(ident: &Ident) -> bool { /// stack is set to `16`. const SQL_UDF_MAX_CALLING_DEPTH: u32 = 16; -impl Binder { - pub(in crate::binder) fn bind_function(&mut self, f: Function) -> Result { - if f.arg_list.ignore_nulls { - bail_not_implemented!("IGNORE NULLS is not supported yet"); +macro_rules! reject_syntax { + ($pred:expr, $msg:expr) => { + if $pred { + return Err(ErrorCode::InvalidInputSyntax($msg.to_string()).into()); } + }; +} - let function_name = match f.name.0.as_slice() { +impl Binder { + pub(in crate::binder) fn bind_function( + &mut self, + Function { + scalar_as_agg, + name, + arg_list, + within_group, + filter, + over, + }: Function, + ) -> Result { + let func_name = match name.0.as_slice() { [name] => name.real_value(), [schema, name] => { let schema_name = schema.real_value(); @@ -95,7 +110,7 @@ impl Binder { ); } } - _ => bail_not_implemented!(issue = 112, "qualified function {}", f.name), + _ => bail_not_implemented!(issue = 112, "qualified function {}", name), }; // FIXME: This is a hack to support [Bytebase queries](https://github.com/TennyZhuang/bytebase/blob/4a26f7c62b80e86e58ad2f77063138dc2f420623/backend/plugin/db/pg/sync.go#L549). @@ -104,25 +119,44 @@ impl Binder { // retrieve object comment, however we don't support casting a non-literal expression to // regclass. We just hack the `obj_description` and `col_description` here, to disable it to // bind its arguments. - if function_name == "obj_description" || function_name == "col_description" { + if func_name == "obj_description" || func_name == "col_description" { return Ok(ExprImpl::literal_varchar("".to_string())); } - if function_name == "array_transform" { + + // special binding logic for `array_transform` + if func_name == "array_transform" { // For type inference, we need to bind the array type first. - return self.bind_array_transform(f); + reject_syntax!( + scalar_as_agg, + "`AGGREGATE:` prefix is not allowed for `array_transform`" + ); + reject_syntax!(!arg_list.is_args_only(), "keywords like `DISTINCT`, `ORDER BY` are not allowed in `array_transform` argument list"); + reject_syntax!( + within_group.is_some(), + "`WITHIN GROUP` is not allowed in `array_transform` call" + ); + reject_syntax!( + filter.is_some(), + "`FILTER` is not allowed in `array_transform` call" + ); + reject_syntax!( + over.is_some(), + "`OVER` is not allowed in `array_transform` call" + ); + return self.bind_array_transform(arg_list.args); } - let mut inputs: Vec<_> = f - .arg_list + let mut args: Vec<_> = arg_list .args .iter() .map(|arg| self.bind_function_arg(arg.clone())) .flatten_ok() .try_collect()?; - // `aggregate:` on a scalar function - if f.scalar_as_agg { - let mut scalar_inputs = inputs + let wrapped_agg_kind = if scalar_as_agg { + // Let's firstly try to apply the `AGGREGATE:` prefix. + // We will reject functions that are not able to be wrapped as aggregate function. + let mut array_args = args .iter() .enumerate() .map(|(i, expr)| { @@ -130,203 +164,201 @@ impl Binder { }) .collect_vec(); let scalar: ExprImpl = if let Ok(schema) = self.first_valid_schema() - && let Some(func) = - schema.get_function_by_name_inputs(&function_name, &mut scalar_inputs) + && let Some(func) = schema.get_function_by_name_inputs(&func_name, &mut array_args) { if !func.kind.is_scalar() { return Err(ErrorCode::InvalidInputSyntax( - "expect a scalar function after `aggregate:`".to_string(), + "expect a scalar function after `AGGREGATE:`".to_string(), ) .into()); } - UserDefinedFunction::new(func.clone(), scalar_inputs).into() + UserDefinedFunction::new(func.clone(), array_args).into() } else { - self.bind_builtin_scalar_function( - &function_name, - scalar_inputs, - f.arg_list.variadic, - )? + self.bind_builtin_scalar_function(&func_name, array_args, arg_list.variadic)? }; - return self.bind_aggregate_function(f, AggKind::WrapScalar(scalar.to_expr_proto())); - } - // user defined function - // TODO: resolve schema name https://github.com/risingwavelabs/risingwave/issues/12422 - if let Ok(schema) = self.first_valid_schema() - && let Some(func) = schema.get_function_by_name_inputs(&function_name, &mut inputs) - { - use crate::catalog::function_catalog::FunctionKind::*; + // now this is either an aggregate/window function call + Some(AggKind::WrapScalar(scalar.to_expr_proto())) + } else { + None + }; + let udf = if wrapped_agg_kind.is_none() + && let Ok(schema) = self.first_valid_schema() + && let Some(func) = schema + .get_function_by_name_inputs(&func_name, &mut args) + .cloned() + { if func.language == "sql" { - if func.body.is_none() { - return Err(ErrorCode::InvalidInputSyntax( - "`body` must exist for sql udf".to_string(), - ) - .into()); - } - - // This represents the current user defined function is `language sql` - let parse_result = risingwave_sqlparser::parser::Parser::parse_sql( - func.body.as_ref().unwrap().as_str(), + let name = format!("SQL user-defined function `{}`", func.name); + reject_syntax!( + scalar_as_agg, + format!("`AGGREGATE:` prefix is not allowed for {}", name) ); - if let Err(ParserError::ParserError(err)) | Err(ParserError::TokenizerError(err)) = - parse_result - { - // Here we just return the original parse error message - return Err(ErrorCode::InvalidInputSyntax(err).into()); - } - - debug_assert!(parse_result.is_ok()); - - // We can safely unwrap here - let ast = parse_result.unwrap(); - - // Stash the current `udf_context` - // Note that the `udf_context` may be empty, - // if the current binding is the root (top-most) sql udf. - // In this case the empty context will be stashed - // and restored later, no need to maintain other flags. - let stashed_udf_context = self.udf_context.get_context(); - - // The actual inline logic for sql udf - // Note that we will always create new udf context for each sql udf - let Ok(context) = - UdfContext::create_udf_context(&f.arg_list.args, &Arc::clone(func)) - else { - return Err(ErrorCode::InvalidInputSyntax( - "failed to create the `udf_context`, please recheck your function definition and syntax".to_string() + reject_syntax!( + !arg_list.is_args_only(), + format!( + "keywords like `DISTINCT`, `ORDER BY` are not allowed in {} argument list", + name ) - .into()); - }; - - let mut udf_context = HashMap::new(); - for (c, e) in context { - // Note that we need to bind the args before actual delve in the function body - // This will update the context in the subsequent inner calling function - // e.g., - // - create function print(INT) returns int language sql as 'select $1'; - // - create function print_add_one(INT) returns int language sql as 'select print($1 + 1)'; - // - select print_add_one(1); # The result should be 2 instead of 1. - // Without the pre-binding here, the ($1 + 1) will not be correctly populated, - // causing the result to always be 1. - let Ok(e) = self.bind_expr(e) else { - return Err(ErrorCode::BindError( - "failed to bind the argument, please recheck the syntax".to_string(), - ) - .into()); - }; - udf_context.insert(c, e); - } - self.udf_context.update_context(udf_context); - - // Check for potential recursive calling - if self.udf_context.global_count() >= SQL_UDF_MAX_CALLING_DEPTH { - return Err(ErrorCode::BindError(format!( - "function {} calling stack depth limit exceeded", - &function_name - )) - .into()); - } else { - // Update the status for the global counter - self.udf_context.incr_global_count(); - } + ); + reject_syntax!( + within_group.is_some(), + format!("`WITHIN GROUP` is not allowed in {} call", name) + ); + reject_syntax!( + filter.is_some(), + format!("`FILTER` is not allowed in {} call", name) + ); + reject_syntax!( + over.is_some(), + format!("`OVER` is not allowed in {} call", name) + ); + return self.bind_sql_udf(func, arg_list.args); + } - if let Ok(expr) = UdfContext::extract_udf_expression(ast) { - let bind_result = self.bind_expr(expr); + // now `func` is a non-SQL user-defined scalar/aggregate/table function + Some(func) + } else { + None + }; - // We should properly decrement global count after a successful binding - // Since the subsequent probe operation in `bind_column` or - // `bind_parameter` relies on global counting - self.udf_context.decr_global_count(); + let agg_kind = if let Some(wrapped_agg_kind) = wrapped_agg_kind { + Some(wrapped_agg_kind) + } else if let Some(ref udf) = udf + && udf.kind.is_aggregate() + { + Some(AggKind::UserDefined(udf.as_ref().into())) + } else if let Ok(kind) = AggKind::from_str(&func_name) { + Some(kind) + } else { + None + }; - // Restore context information for subsequent binding - self.udf_context.update_context(stashed_udf_context); + // try to bind it as a window function call + if let Some(over) = over { + reject_syntax!( + arg_list.distinct, + "`DISTINCT` is not allowed in window function call" + ); + reject_syntax!( + arg_list.variadic, + "`VARIADIC` is not allowed in window function call" + ); + reject_syntax!( + !arg_list.order_by.is_empty(), + "`ORDER BY` is not allowed in window function call argument list" + ); + reject_syntax!( + within_group.is_some(), + "`WITHIN GROUP` is not allowed in window function call" + ); - return bind_result; - } else { - return Err(ErrorCode::InvalidInputSyntax( - "failed to parse the input query and extract the udf expression, - please recheck the syntax" - .to_string(), - ) - .into()); - } + let kind = if let Some(agg_kind) = agg_kind { + // aggregate as window function + WindowFuncKind::Aggregate(agg_kind) + } else if let Ok(kind) = WindowFuncKind::from_str(&func_name) { + kind } else { - match &func.kind { - Scalar { .. } => { - return Ok(UserDefinedFunction::new(func.clone(), inputs).into()) - } - Table { .. } => { - self.ensure_table_function_allowed()?; - return Ok(TableFunction::new_user_defined(func.clone(), inputs).into()); - } - Aggregate => { - return self.bind_aggregate_function( - f, - AggKind::UserDefined(func.as_ref().into()), - ); - } - } - } - } - - // agg calls - if f.over.is_none() - && let Ok(kind) = function_name.parse() - { - return self.bind_aggregate_function(f, AggKind::Builtin(kind)); + bail_not_implemented!(issue = 8961, "Unrecognized window function: {}", func_name); + }; + return self.bind_window_function(kind, args, arg_list.ignore_nulls, filter, over); } - if f.arg_list.distinct || !f.arg_list.order_by.is_empty() || f.filter.is_some() { - return Err(ErrorCode::InvalidInputSyntax(format!( - "DISTINCT, ORDER BY or FILTER is only allowed in aggregation functions, but `{}` is not an aggregation function", function_name - ) - ) - .into()); + // now it's a aggregate/scalar/table function call + reject_syntax!( + arg_list.ignore_nulls, + "`IGNORE NULLS` is not allowed in aggregate/scalar/table function call" + ); + + // try to bind it as an aggregate function call + if let Some(agg_kind) = agg_kind { + reject_syntax!( + arg_list.variadic, + "`VARIADIC` is not allowed in aggregate function call" + ); + return self.bind_aggregate_function( + agg_kind, + arg_list.distinct, + args, + arg_list.order_by, + within_group, + filter, + ); } - // window function - let window_func_kind = WindowFuncKind::from_str(function_name.as_str()); - if let Ok(kind) = window_func_kind { - if let Some(window_spec) = f.over { - return self.bind_window_function(kind, inputs, window_spec); + // now it's a scalar/table function call + reject_syntax!( + arg_list.distinct, + "`DISTINCT` is not allowed in scalar/table function call" + ); + reject_syntax!( + !arg_list.order_by.is_empty(), + "`ORDER BY` is not allowed in scalar/table function call" + ); + reject_syntax!( + within_group.is_some(), + "`WITHIN GROUP` is not allowed in scalar/table function call" + ); + reject_syntax!( + filter.is_some(), + "`FILTER` is not allowed in scalar/table function call" + ); + + // try to bind it as a table function call + { + // `file_scan` table function + if func_name.eq_ignore_ascii_case("file_scan") { + reject_syntax!( + arg_list.variadic, + "`VARIADIC` is not allowed in table function call" + ); + self.ensure_table_function_allowed()?; + return Ok(TableFunction::new_file_scan(args)?.into()); + } + // UDTF + if let Some(ref udf) = udf + && udf.kind.is_table() + { + reject_syntax!( + arg_list.variadic, + "`VARIADIC` is not allowed in table function call" + ); + self.ensure_table_function_allowed()?; + return Ok(TableFunction::new_user_defined(udf.clone(), args).into()); + } + // builtin table function + if let Ok(function_type) = TableFunctionType::from_str(&func_name) { + reject_syntax!( + arg_list.variadic, + "`VARIADIC` is not allowed in table function call" + ); + self.ensure_table_function_allowed()?; + return Ok(TableFunction::new(function_type, args)?.into()); } - return Err(ErrorCode::InvalidInputSyntax(format!( - "Window function `{}` must have OVER clause", - function_name - )) - .into()); - } else if f.over.is_some() { - bail_not_implemented!( - issue = 8961, - "Unrecognized window function: {}", - function_name - ); } - // file_scan table function - if function_name.eq_ignore_ascii_case("file_scan") { - self.ensure_table_function_allowed()?; - return Ok(TableFunction::new_file_scan(inputs)?.into()); - } - // table function - if let Ok(function_type) = TableFunctionType::from_str(function_name.as_str()) { - self.ensure_table_function_allowed()?; - return Ok(TableFunction::new(function_type, inputs)?.into()); + // try to bind it as a scalar function call + if let Some(ref udf) = udf { + assert!(udf.kind.is_scalar()); + reject_syntax!( + arg_list.variadic, + "`VARIADIC` is not allowed in user-defined function call" + ); + return Ok(UserDefinedFunction::new(udf.clone(), args).into()); } - self.bind_builtin_scalar_function(function_name.as_str(), inputs, f.arg_list.variadic) + self.bind_builtin_scalar_function(&func_name, args, arg_list.variadic) } - fn bind_array_transform(&mut self, f: Function) -> Result { - let [array, lambda] = - <[FunctionArg; 2]>::try_from(f.arg_list.args).map_err(|args| -> RwError { - ErrorCode::BindError(format!( - "`array_transform` expect two inputs `array` and `lambda`, but {} were given", - args.len() - )) - .into() - })?; + fn bind_array_transform(&mut self, args: Vec) -> Result { + let [array, lambda] = <[FunctionArg; 2]>::try_from(args).map_err(|args| -> RwError { + ErrorCode::BindError(format!( + "`array_transform` expect two inputs `array` and `lambda`, but {} were given", + args.len() + )) + .into() + })?; let bound_array = self.bind_function_arg(array)?; let [bound_array] = <[ExprImpl; 1]>::try_from(bound_array).map_err(|bound_array| -> RwError { @@ -416,6 +448,102 @@ impl Binder { Ok(()) } + fn bind_sql_udf( + &mut self, + func: Arc, + args: Vec, + ) -> Result { + if func.body.is_none() { + return Err( + ErrorCode::InvalidInputSyntax("`body` must exist for sql udf".to_string()).into(), + ); + } + + // This represents the current user defined function is `language sql` + let parse_result = + risingwave_sqlparser::parser::Parser::parse_sql(func.body.as_ref().unwrap().as_str()); + if let Err(ParserError::ParserError(err)) | Err(ParserError::TokenizerError(err)) = + parse_result + { + // Here we just return the original parse error message + return Err(ErrorCode::InvalidInputSyntax(err).into()); + } + + debug_assert!(parse_result.is_ok()); + + // We can safely unwrap here + let ast = parse_result.unwrap(); + + // Stash the current `udf_context` + // Note that the `udf_context` may be empty, + // if the current binding is the root (top-most) sql udf. + // In this case the empty context will be stashed + // and restored later, no need to maintain other flags. + let stashed_udf_context = self.udf_context.get_context(); + + // The actual inline logic for sql udf + // Note that we will always create new udf context for each sql udf + let Ok(context) = UdfContext::create_udf_context(&args, &func) else { + return Err(ErrorCode::InvalidInputSyntax( + "failed to create the `udf_context`, please recheck your function definition and syntax".to_string() + ) + .into()); + }; + + let mut udf_context = HashMap::new(); + for (c, e) in context { + // Note that we need to bind the args before actual delve in the function body + // This will update the context in the subsequent inner calling function + // e.g., + // - create function print(INT) returns int language sql as 'select $1'; + // - create function print_add_one(INT) returns int language sql as 'select print($1 + 1)'; + // - select print_add_one(1); # The result should be 2 instead of 1. + // Without the pre-binding here, the ($1 + 1) will not be correctly populated, + // causing the result to always be 1. + let Ok(e) = self.bind_expr(e) else { + return Err(ErrorCode::BindError( + "failed to bind the argument, please recheck the syntax".to_string(), + ) + .into()); + }; + udf_context.insert(c, e); + } + self.udf_context.update_context(udf_context); + + // Check for potential recursive calling + if self.udf_context.global_count() >= SQL_UDF_MAX_CALLING_DEPTH { + return Err(ErrorCode::BindError(format!( + "function {} calling stack depth limit exceeded", + func.name + )) + .into()); + } else { + // Update the status for the global counter + self.udf_context.incr_global_count(); + } + + if let Ok(expr) = UdfContext::extract_udf_expression(ast) { + let bind_result = self.bind_expr(expr); + + // We should properly decrement global count after a successful binding + // Since the subsequent probe operation in `bind_column` or + // `bind_parameter` relies on global counting + self.udf_context.decr_global_count(); + + // Restore context information for subsequent binding + self.udf_context.update_context(stashed_udf_context); + + return bind_result; + } + + Err(ErrorCode::InvalidInputSyntax( + "failed to parse the input query and extract the udf expression, + please recheck the syntax" + .to_string(), + ) + .into()) + } + pub(in crate::binder) fn bind_function_expr_arg( &mut self, arg_expr: FunctionArgExpr, diff --git a/src/frontend/src/binder/expr/function/window.rs b/src/frontend/src/binder/expr/function/window.rs index 3124d4717dd82..03288cbf9e240 100644 --- a/src/frontend/src/binder/expr/function/window.rs +++ b/src/frontend/src/binder/expr/function/window.rs @@ -57,7 +57,9 @@ impl Binder { pub(super) fn bind_window_function( &mut self, kind: WindowFuncKind, - inputs: Vec, + args: Vec, + ignore_nulls: bool, + filter: Option>, WindowSpec { partition_by, order_by, @@ -65,6 +67,15 @@ impl Binder { }: WindowSpec, ) -> Result { self.ensure_window_function_allowed()?; + + if ignore_nulls { + bail_not_implemented!("`IGNORE NULLS` is not supported yet"); + } + + if filter.is_some() { + bail_not_implemented!("`FILTER` is not supported yet"); + } + let partition_by = partition_by .into_iter() .map(|arg| self.bind_expr_inner(arg)) @@ -181,7 +192,7 @@ impl Binder { } else { None }; - Ok(WindowFunction::new(kind, partition_by, order_by, inputs, frame)?.into()) + Ok(WindowFunction::new(kind, partition_by, order_by, args, frame)?.into()) } fn bind_window_frame_usize_bounds( diff --git a/src/sqlparser/src/ast/mod.rs b/src/sqlparser/src/ast/mod.rs index 3df7b753147ca..80410f767fb62 100644 --- a/src/sqlparser/src/ast/mod.rs +++ b/src/sqlparser/src/ast/mod.rs @@ -2552,6 +2552,10 @@ impl FunctionArgList { } } + pub fn is_args_only(&self) -> bool { + !self.distinct && !self.variadic && self.order_by.is_empty() && !self.ignore_nulls + } + pub fn for_agg(distinct: bool, args: Vec, order_by: Vec) -> Self { Self { distinct,