diff --git a/src/frontend/src/binder/expr/function/aggregate.rs b/src/frontend/src/binder/expr/function/aggregate.rs index 77538b799bad2..1e7b76bf7629e 100644 --- a/src/frontend/src/binder/expr/function/aggregate.rs +++ b/src/frontend/src/binder/expr/function/aggregate.rs @@ -53,7 +53,7 @@ impl Binder { ) -> Result { self.ensure_aggregate_allowed()?; - let distinct = f.distinct; + let distinct = f.arg_list.distinct; let filter_expr = f.filter.clone(); let (direct_args, args, order_by) = if matches!(kind, agg_kinds::ordered_set!()) { @@ -105,14 +105,14 @@ impl Binder { assert!(matches!(kind, agg_kinds::ordered_set!())); - if !f.order_by.is_empty() { + if !f.arg_list.order_by.is_empty() { return Err(ErrorCode::InvalidInputSyntax(format!( "ORDER BY is not allowed for ordered-set aggregation `{}`", kind )) .into()); } - if f.distinct { + if f.arg_list.distinct { return Err(ErrorCode::InvalidInputSyntax(format!( "DISTINCT is not allowed for ordered-set aggregation `{}`", kind @@ -128,6 +128,7 @@ impl Binder { })?; let mut direct_args: Vec<_> = f + .arg_list .args .into_iter() .map(|arg| self.bind_function_arg(arg)) @@ -207,19 +208,21 @@ impl Binder { } let args: Vec<_> = f + .arg_list .args .iter() .map(|arg| self.bind_function_arg(arg.clone())) .flatten_ok() .try_collect()?; let order_by = OrderBy::new( - f.order_by + f.arg_list + .order_by .into_iter() .map(|e| self.bind_order_by_expr(e)) .try_collect()?, ); - if f.distinct { + if f.arg_list.distinct { if matches!( kind, AggKind::Builtin(PbAggKind::ApproxCountDistinct) diff --git a/src/frontend/src/binder/expr/function/mod.rs b/src/frontend/src/binder/expr/function/mod.rs index 9c2b9e1c644e1..7e755ac2d5c2c 100644 --- a/src/frontend/src/binder/expr/function/mod.rs +++ b/src/frontend/src/binder/expr/function/mod.rs @@ -61,6 +61,10 @@ const SQL_UDF_MAX_CALLING_DEPTH: u32 = 16; impl Binder { pub(in crate::binder) fn bind_function(&mut self, f: Function) -> Result { + if f.arg_list.ignore_nulls { + bail_not_implemented!("IGNORE NULLS is not supported yet"); + } + let function_name = match f.name.0.as_slice() { [name] => name.real_value(), [schema, name] => { @@ -108,6 +112,7 @@ impl Binder { } let mut inputs: Vec<_> = f + .arg_list .args .iter() .map(|arg| self.bind_function_arg(arg.clone())) @@ -135,7 +140,11 @@ impl Binder { } UserDefinedFunction::new(func.clone(), scalar_inputs).into() } else { - self.bind_builtin_scalar_function(&function_name, scalar_inputs, f.variadic)? + self.bind_builtin_scalar_function( + &function_name, + scalar_inputs, + f.arg_list.variadic, + )? }; return self.bind_aggregate_function(f, AggKind::WrapScalar(scalar.to_expr_proto())); } @@ -180,7 +189,9 @@ impl Binder { // The actual inline logic for sql udf // Note that we will always create new udf context for each sql udf - let Ok(context) = UdfContext::create_udf_context(&f.args, &Arc::clone(func)) else { + let Ok(context) = + UdfContext::create_udf_context(&f.arg_list.args, &Arc::clone(func)) + else { return Err(ErrorCode::InvalidInputSyntax( "failed to create the `udf_context`, please recheck your function definition and syntax".to_string() ) @@ -265,7 +276,7 @@ impl Binder { return self.bind_aggregate_function(f, AggKind::Builtin(kind)); } - if f.distinct || !f.order_by.is_empty() || f.filter.is_some() { + if f.arg_list.distinct || !f.arg_list.order_by.is_empty() || f.filter.is_some() { return Err(ErrorCode::InvalidInputSyntax(format!( "DISTINCT, ORDER BY or FILTER is only allowed in aggregation functions, but `{}` is not an aggregation function", function_name ) @@ -303,17 +314,18 @@ impl Binder { return Ok(TableFunction::new(function_type, inputs)?.into()); } - self.bind_builtin_scalar_function(function_name.as_str(), inputs, f.variadic) + self.bind_builtin_scalar_function(function_name.as_str(), inputs, f.arg_list.variadic) } fn bind_array_transform(&mut self, f: Function) -> Result { - let [array, lambda] = <[FunctionArg; 2]>::try_from(f.args).map_err(|args| -> RwError { - ErrorCode::BindError(format!( - "`array_transform` expect two inputs `array` and `lambda`, but {} were given", - args.len() - )) - .into() - })?; + let [array, lambda] = + <[FunctionArg; 2]>::try_from(f.arg_list.args).map_err(|args| -> RwError { + ErrorCode::BindError(format!( + "`array_transform` expect two inputs `array` and `lambda`, but {} were given", + args.len() + )) + .into() + })?; let bound_array = self.bind_function_arg(array)?; let [bound_array] = <[ExprImpl; 1]>::try_from(bound_array).map_err(|bound_array| -> RwError { diff --git a/src/frontend/src/binder/mod.rs b/src/frontend/src/binder/mod.rs index 7cd9032890091..aefb66cdc94d1 100644 --- a/src/frontend/src/binder/mod.rs +++ b/src/frontend/src/binder/mod.rs @@ -790,30 +790,33 @@ mod tests { }, ], ), - args: [ - Unnamed( - Expr( - Value( - Number( - "0.5", + arg_list: FunctionArgList { + distinct: false, + args: [ + Unnamed( + Expr( + Value( + Number( + "0.5", + ), ), ), ), - ), - Unnamed( - Expr( - Value( - Number( - "0.01", + Unnamed( + Expr( + Value( + Number( + "0.01", + ), ), ), ), - ), - ], - variadic: false, + ], + variadic: false, + order_by: [], + ignore_nulls: false, + }, over: None, - distinct: false, - order_by: [], filter: None, within_group: Some( OrderByExpr { diff --git a/src/frontend/src/binder/relation/table_function.rs b/src/frontend/src/binder/relation/table_function.rs index 22b9c2a344c2c..cc672703cda35 100644 --- a/src/frontend/src/binder/relation/table_function.rs +++ b/src/frontend/src/binder/relation/table_function.rs @@ -18,7 +18,7 @@ use itertools::Itertools; use risingwave_common::bail_not_implemented; use risingwave_common::catalog::{Field, Schema, RW_INTERNAL_TABLE_FUNCTION_NAME}; use risingwave_common::types::DataType; -use risingwave_sqlparser::ast::{Function, FunctionArg, ObjectName, TableAlias}; +use risingwave_sqlparser::ast::{Function, FunctionArg, FunctionArgList, ObjectName, TableAlias}; use super::watermark::is_watermark_func; use super::{Binder, Relation, Result, WindowTableFunctionKind}; @@ -85,11 +85,8 @@ impl Binder { let func = self.bind_function(Function { scalar_as_agg: false, name, - args, - variadic: false, + arg_list: FunctionArgList::args_only(args), over: None, - distinct: false, - order_by: vec![], filter: None, within_group: None, }); diff --git a/src/meta/src/controller/rename.rs b/src/meta/src/controller/rename.rs index 220aac34b8a5d..86465e286d958 100644 --- a/src/meta/src/controller/rename.rs +++ b/src/meta/src/controller/rename.rs @@ -18,8 +18,8 @@ use risingwave_pb::expr::expr_node::RexNode; use risingwave_pb::expr::{ExprNode, FunctionCall, UserDefinedFunction}; use risingwave_sqlparser::ast::{ Array, CreateSink, CreateSinkStatement, CreateSourceStatement, CreateSubscriptionStatement, - Distinct, Expr, Function, FunctionArg, FunctionArgExpr, Ident, ObjectName, Query, SelectItem, - SetExpr, Statement, TableAlias, TableFactor, TableWithJoins, + Distinct, Expr, Function, FunctionArg, FunctionArgExpr, FunctionArgList, Ident, ObjectName, + Query, SelectItem, SetExpr, Statement, TableAlias, TableFactor, TableWithJoins, }; use risingwave_sqlparser::parser::Parser; @@ -264,14 +264,18 @@ impl QueryRewriter<'_> { } } - /// Visit function and update all references. - fn visit_function(&self, function: &mut Function) { - for arg in &mut function.args { + fn visit_function_arg_list(&self, arg_list: &mut FunctionArgList) { + for arg in &mut arg_list.args { self.visit_function_arg(arg); } - for expr in &mut function.order_by { + for expr in &mut arg_list.order_by { self.visit_expr(&mut expr.expr) } + } + + /// Visit function and update all references. + fn visit_function(&self, function: &mut Function) { + self.visit_function_arg_list(&mut function.arg_list); if let Some(over) = &mut function.over { for expr in &mut over.partition_by { self.visit_expr(expr); diff --git a/src/sqlparser/src/ast/mod.rs b/src/sqlparser/src/ast/mod.rs index d5cca61b6a186..e77113e28aefd 100644 --- a/src/sqlparser/src/ast/mod.rs +++ b/src/sqlparser/src/ast/mod.rs @@ -2477,6 +2477,79 @@ impl fmt::Display for FunctionArg { } } +/// A list of function arguments, including additional modifiers like `DISTINCT` or `ORDER BY`. +/// This basically holds all the information between the `(` and `)` in a function call. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct FunctionArgList { + /// Aggregate function calls may have a `DISTINCT`, e.g. `count(DISTINCT x)`. + pub distinct: bool, + pub args: Vec, + /// Whether the last argument is variadic, e.g. `foo(a, b, VARIADIC c)`. + pub variadic: bool, + /// Aggregate function calls may have an `ORDER BY`, e.g. `array_agg(x ORDER BY y)`. + pub order_by: Vec, + /// Window function calls may have an `IGNORE NULLS`, e.g. `first_value(x IGNORE NULLS)`. + pub ignore_nulls: bool, +} + +impl fmt::Display for FunctionArgList { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "(")?; + if self.distinct { + write!(f, "DISTINCT ")?; + } + if self.variadic { + for arg in &self.args[0..self.args.len() - 1] { + write!(f, "{}, ", arg)?; + } + write!(f, "VARIADIC {}", self.args.last().unwrap())?; + } else { + write!(f, "{}", display_comma_separated(&self.args))?; + } + if !self.order_by.is_empty() { + write!(f, " ORDER BY {}", display_comma_separated(&self.order_by))?; + } + if self.ignore_nulls { + write!(f, " IGNORE NULLS")?; + } + write!(f, ")")?; + Ok(()) + } +} + +impl FunctionArgList { + pub fn empty() -> Self { + Self { + distinct: false, + args: vec![], + variadic: false, + order_by: vec![], + ignore_nulls: false, + } + } + + pub fn args_only(args: Vec) -> Self { + Self { + distinct: false, + args, + variadic: false, + order_by: vec![], + ignore_nulls: false, + } + } + + pub fn for_agg(distinct: bool, args: Vec, order_by: Vec) -> Self { + Self { + distinct, + args, + variadic: false, + order_by: order_by, + ignore_nulls: false, + } + } +} + /// A function call #[derive(Debug, Clone, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] @@ -2484,14 +2557,8 @@ pub struct Function { /// Whether the function is prefixed with `aggregate:` pub scalar_as_agg: bool, pub name: ObjectName, - pub args: Vec, - /// whether the last argument is variadic, e.g. `foo(a, b, variadic c)` - pub variadic: bool, + pub arg_list: FunctionArgList, pub over: Option, - // aggregate functions may specify eg `COUNT(DISTINCT x)` - pub distinct: bool, - // aggregate functions may contain order_by_clause - pub order_by: Vec, pub filter: Option>, pub within_group: Option>, } @@ -2501,11 +2568,8 @@ impl Function { Self { scalar_as_agg: false, name, - args: vec![], - variadic: false, + arg_list: FunctionArgList::empty(), over: None, - distinct: false, - order_by: vec![], filter: None, within_group: None, } @@ -2515,26 +2579,9 @@ impl Function { impl fmt::Display for Function { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { if self.scalar_as_agg { - write!(f, "aggregate:")?; + write!(f, "AGGREGATE:")?; } - write!( - f, - "{}({}", - self.name, - if self.distinct { "DISTINCT " } else { "" }, - )?; - if self.variadic { - for arg in &self.args[0..self.args.len() - 1] { - write!(f, "{}, ", arg)?; - } - write!(f, "VARIADIC {}", self.args.last().unwrap())?; - } else { - write!(f, "{}", display_comma_separated(&self.args))?; - } - if !self.order_by.is_empty() { - write!(f, " ORDER BY {}", display_comma_separated(&self.order_by))?; - } - write!(f, ")")?; + write!(f, "{}{}", self.name, self.arg_list)?; if let Some(o) = &self.over { write!(f, " OVER ({})", o)?; } diff --git a/src/sqlparser/src/parser.rs b/src/sqlparser/src/parser.rs index 996fd9ebe8490..f4e234df0e38a 100644 --- a/src/sqlparser/src/parser.rs +++ b/src/sqlparser/src/parser.rs @@ -822,9 +822,7 @@ impl Parser<'_> { false }; let name = self.parse_object_name()?; - self.expect_token(&Token::LParen)?; - let distinct = self.parse_all_or_distinct()?; - let (args, order_by, variadic) = self.parse_optional_args()?; + let arg_list = self.parse_argument_list()?; let over = if self.parse_keyword(Keyword::OVER) { // TODO: support window names (`OVER mywin`) in place of inline specification self.expect_token(&Token::LParen)?; @@ -879,11 +877,8 @@ impl Parser<'_> { Ok(Expr::Function(Function { scalar_as_agg, name, - args, - variadic, + arg_list, over, - distinct, - order_by, filter, within_group, })) @@ -4664,17 +4659,24 @@ impl Parser<'_> { } } else { let name = self.parse_object_name()?; - // Postgres,table-valued functions: - if self.consume_token(&Token::LParen) { - // ignore VARIADIC here - let (args, order_by, _variadic) = self.parse_optional_args()?; - // Table-valued functions do not support ORDER BY, should return error if it appears - if !order_by.is_empty() { - parser_err!("Table-valued functions do not support ORDER BY clauses"); + if self.peek_token() == Token::LParen { + // table-valued function + + let arg_list = self.parse_argument_list()?; + if arg_list.distinct { + parser_err!("DISTINCT is not supported in table-valued function calls"); + } + if !arg_list.order_by.is_empty() { + parser_err!("ORDER BY is not supported in table-valued function calls"); + } + if arg_list.ignore_nulls { + parser_err!("IGNORE NULLS is not supported in table-valued function calls"); } - let with_ordinality = self.parse_keywords(&[Keyword::WITH, Keyword::ORDINALITY]); + let args = arg_list.args; + let with_ordinality = self.parse_keywords(&[Keyword::WITH, Keyword::ORDINALITY]); let alias = self.parse_optional_table_alias(keywords::RESERVED_FOR_TABLE_ALIAS)?; + Ok(TableFactor::TableFunction { name, alias, @@ -4957,17 +4959,19 @@ impl Parser<'_> { Ok((variadic, arg)) } - pub fn parse_optional_args(&mut self) -> PResult<(Vec, Vec, bool)> { + pub fn parse_argument_list(&mut self) -> PResult { + self.expect_token(&Token::LParen)?; if self.consume_token(&Token::RParen) { - Ok((vec![], vec![], false)) + Ok(FunctionArgList::empty()) } else { + let distinct = self.parse_all_or_distinct()?; let args = self.parse_comma_separated(Parser::parse_function_args)?; if args .iter() .take(args.len() - 1) .any(|(variadic, _)| *variadic) { - parser_err!("VARIADIC argument must be last"); + parser_err!("VARIADIC argument must be the last"); } let variadic = args.last().map(|(variadic, _)| *variadic).unwrap_or(false); let args = args.into_iter().map(|(_, arg)| arg).collect(); @@ -4977,8 +4981,23 @@ impl Parser<'_> { } else { vec![] }; + + let ignore_nulls = if self.parse_keywords(&[Keyword::IGNORE, Keyword::NULLS]) { + true + } else { + false + }; + + let arg_list = FunctionArgList { + distinct, + args, + order_by, + variadic, + ignore_nulls, + }; + self.expect_token(&Token::RParen)?; - Ok((args, order_by, variadic)) + Ok(arg_list) } } diff --git a/src/sqlparser/tests/sqlparser_common.rs b/src/sqlparser/tests/sqlparser_common.rs index 049cf79482032..46486419c629a 100644 --- a/src/sqlparser/tests/sqlparser_common.rs +++ b/src/sqlparser/tests/sqlparser_common.rs @@ -347,11 +347,10 @@ fn parse_select_count_wildcard() { &Expr::Function(Function { scalar_as_agg: false, name: ObjectName(vec![Ident::new_unchecked("COUNT")]), - args: vec![FunctionArg::Unnamed(FunctionArgExpr::Wildcard(None))], - variadic: false, + arg_list: FunctionArgList::args_only(vec![FunctionArg::Unnamed( + FunctionArgExpr::Wildcard(None) + )]), over: None, - distinct: false, - order_by: vec![], filter: None, within_group: None, }), @@ -367,14 +366,15 @@ fn parse_select_count_distinct() { &Expr::Function(Function { scalar_as_agg: false, name: ObjectName(vec![Ident::new_unchecked("COUNT")]), - args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::UnaryOp { - op: UnaryOperator::Plus, - expr: Box::new(Expr::Identifier(Ident::new_unchecked("x"))), - }))], - variadic: false, + arg_list: FunctionArgList::for_agg( + true, + vec![FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::UnaryOp { + op: UnaryOperator::Plus, + expr: Box::new(Expr::Identifier(Ident::new_unchecked("x"))), + }))], + vec![] + ), over: None, - distinct: true, - order_by: vec![], filter: None, within_group: None, }), @@ -1166,11 +1166,10 @@ fn parse_select_having() { left: Box::new(Expr::Function(Function { scalar_as_agg: false, name: ObjectName(vec![Ident::new_unchecked("COUNT")]), - args: vec![FunctionArg::Unnamed(FunctionArgExpr::Wildcard(None))], - variadic: false, + arg_list: FunctionArgList::args_only(vec![FunctionArg::Unnamed( + FunctionArgExpr::Wildcard(None) + )]), over: None, - distinct: false, - order_by: vec![], filter: None, within_group: None, })), @@ -1908,7 +1907,7 @@ fn parse_named_argument_function() { &Expr::Function(Function { scalar_as_agg: false, name: ObjectName(vec![Ident::new_unchecked("FUN")]), - args: vec![ + arg_list: FunctionArgList::args_only(vec![ FunctionArg::Named { name: Ident::new_unchecked("a"), arg: FunctionArgExpr::Expr(Expr::Value(Value::SingleQuotedString( @@ -1921,11 +1920,8 @@ fn parse_named_argument_function() { "2".to_owned() ))), }, - ], - variadic: false, + ]), over: None, - distinct: false, - order_by: vec![], filter: None, within_group: None, }), @@ -1951,8 +1947,7 @@ fn parse_window_functions() { &Expr::Function(Function { scalar_as_agg: false, name: ObjectName(vec![Ident::new_unchecked("row_number")]), - args: vec![], - variadic: false, + arg_list: FunctionArgList::empty(), over: Some(WindowSpec { partition_by: vec![], order_by: vec![OrderByExpr { @@ -1962,8 +1957,6 @@ fn parse_window_functions() { }], window_frame: None, }), - distinct: false, - order_by: vec![], filter: None, within_group: None, }), @@ -1986,29 +1979,30 @@ fn parse_aggregate_with_order_by() { &Expr::Function(Function { scalar_as_agg: false, name: ObjectName(vec![Ident::new_unchecked("STRING_AGG")]), - args: vec![ - FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Identifier( - Ident::new_unchecked("a") - ))), - FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Identifier( - Ident::new_unchecked("b") - ))), - ], - variadic: false, + arg_list: FunctionArgList::for_agg( + false, + vec![ + FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Identifier( + Ident::new_unchecked("a") + ))), + FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Identifier( + Ident::new_unchecked("b") + ))), + ], + vec![ + OrderByExpr { + expr: Expr::Identifier(Ident::new_unchecked("b")), + asc: Some(true), + nulls_first: None, + }, + OrderByExpr { + expr: Expr::Identifier(Ident::new_unchecked("a")), + asc: Some(false), + nulls_first: None, + } + ] + ), over: None, - distinct: false, - order_by: vec![ - OrderByExpr { - expr: Expr::Identifier(Ident::new_unchecked("b")), - asc: Some(true), - nulls_first: None, - }, - OrderByExpr { - expr: Expr::Identifier(Ident::new_unchecked("a")), - asc: Some(false), - nulls_first: None, - } - ], filter: None, within_group: None, }), @@ -2024,13 +2018,10 @@ fn parse_aggregate_with_filter() { &Expr::Function(Function { scalar_as_agg: false, name: ObjectName(vec![Ident::new_unchecked("sum")]), - args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr( - Expr::Identifier(Ident::new_unchecked("a")) - )),], - variadic: false, + arg_list: FunctionArgList::args_only(vec![FunctionArg::Unnamed( + FunctionArgExpr::Expr(Expr::Identifier(Ident::new_unchecked("a"))) + )]), over: None, - distinct: false, - order_by: vec![], filter: Some(Box::new(Expr::BinaryOp { left: Box::new(Expr::Nested(Box::new(Expr::BinaryOp { left: Box::new(Expr::Identifier(Ident::new_unchecked("a"))), @@ -2282,11 +2273,8 @@ fn parse_delimited_identifiers() { &Expr::Function(Function { scalar_as_agg: false, name: ObjectName(vec![Ident::with_quote_unchecked('"', "myfun")]), - args: vec![], - variadic: false, + arg_list: FunctionArgList::empty(), over: None, - distinct: false, - order_by: vec![], filter: None, within_group: None, }), diff --git a/src/sqlparser/tests/testdata/lambda.yaml b/src/sqlparser/tests/testdata/lambda.yaml index ae3f650d73d44..eceb53af280bd 100644 --- a/src/sqlparser/tests/testdata/lambda.yaml +++ b/src/sqlparser/tests/testdata/lambda.yaml @@ -1,10 +1,10 @@ # This file is automatically generated by `src/sqlparser/tests/parser_test.rs`. - input: select array_transform(array[1,2,3], |x| x * 2) formatted_sql: SELECT array_transform(ARRAY[1, 2, 3], |x| x * 2) - formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(Function(Function { scalar_as_agg: false, name: ObjectName([Ident { value: "array_transform", quote_style: None }]), args: [Unnamed(Expr(Array(Array { elem: [Value(Number("1")), Value(Number("2")), Value(Number("3"))], named: true }))), Unnamed(Expr(LambdaFunction { args: [Ident { value: "x", quote_style: None }], body: BinaryOp { left: Identifier(Ident { value: "x", quote_style: None }), op: Multiply, right: Value(Number("2")) } }))], variadic: false, over: None, distinct: false, order_by: [], filter: None, within_group: None }))], from: [], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' + formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(Function(Function { scalar_as_agg: false, name: ObjectName([Ident { value: "array_transform", quote_style: None }]), arg_list: FunctionArgList { distinct: false, args: [Unnamed(Expr(Array(Array { elem: [Value(Number("1")), Value(Number("2")), Value(Number("3"))], named: true }))), Unnamed(Expr(LambdaFunction { args: [Ident { value: "x", quote_style: None }], body: BinaryOp { left: Identifier(Ident { value: "x", quote_style: None }), op: Multiply, right: Value(Number("2")) } }))], variadic: false, order_by: [], ignore_nulls: false }, over: None, filter: None, within_group: None }))], from: [], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' - input: select array_transform(array[], |s| case when s ilike 'apple%' then 'apple' when s ilike 'google%' then 'google' else 'unknown' end) formatted_sql: SELECT array_transform(ARRAY[], |s| CASE WHEN s ILIKE 'apple%' THEN 'apple' WHEN s ILIKE 'google%' THEN 'google' ELSE 'unknown' END) - formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(Function(Function { scalar_as_agg: false, name: ObjectName([Ident { value: "array_transform", quote_style: None }]), args: [Unnamed(Expr(Array(Array { elem: [], named: true }))), Unnamed(Expr(LambdaFunction { args: [Ident { value: "s", quote_style: None }], body: Case { operand: None, conditions: [ILike { negated: false, expr: Identifier(Ident { value: "s", quote_style: None }), pattern: Value(SingleQuotedString("apple%")), escape_char: None }, ILike { negated: false, expr: Identifier(Ident { value: "s", quote_style: None }), pattern: Value(SingleQuotedString("google%")), escape_char: None }], results: [Value(SingleQuotedString("apple")), Value(SingleQuotedString("google"))], else_result: Some(Value(SingleQuotedString("unknown"))) } }))], variadic: false, over: None, distinct: false, order_by: [], filter: None, within_group: None }))], from: [], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' + formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(Function(Function { scalar_as_agg: false, name: ObjectName([Ident { value: "array_transform", quote_style: None }]), arg_list: FunctionArgList { distinct: false, args: [Unnamed(Expr(Array(Array { elem: [], named: true }))), Unnamed(Expr(LambdaFunction { args: [Ident { value: "s", quote_style: None }], body: Case { operand: None, conditions: [ILike { negated: false, expr: Identifier(Ident { value: "s", quote_style: None }), pattern: Value(SingleQuotedString("apple%")), escape_char: None }, ILike { negated: false, expr: Identifier(Ident { value: "s", quote_style: None }), pattern: Value(SingleQuotedString("google%")), escape_char: None }], results: [Value(SingleQuotedString("apple")), Value(SingleQuotedString("google"))], else_result: Some(Value(SingleQuotedString("unknown"))) } }))], variadic: false, order_by: [], ignore_nulls: false }, over: None, filter: None, within_group: None }))], from: [], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' - input: select array_transform(array[], |x, y| x + y * 2) formatted_sql: SELECT array_transform(ARRAY[], |x, y| x + y * 2) - formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(Function(Function { scalar_as_agg: false, name: ObjectName([Ident { value: "array_transform", quote_style: None }]), args: [Unnamed(Expr(Array(Array { elem: [], named: true }))), Unnamed(Expr(LambdaFunction { args: [Ident { value: "x", quote_style: None }, Ident { value: "y", quote_style: None }], body: BinaryOp { left: Identifier(Ident { value: "x", quote_style: None }), op: Plus, right: BinaryOp { left: Identifier(Ident { value: "y", quote_style: None }), op: Multiply, right: Value(Number("2")) } } }))], variadic: false, over: None, distinct: false, order_by: [], filter: None, within_group: None }))], from: [], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' + formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(Function(Function { scalar_as_agg: false, name: ObjectName([Ident { value: "array_transform", quote_style: None }]), arg_list: FunctionArgList { distinct: false, args: [Unnamed(Expr(Array(Array { elem: [], named: true }))), Unnamed(Expr(LambdaFunction { args: [Ident { value: "x", quote_style: None }, Ident { value: "y", quote_style: None }], body: BinaryOp { left: Identifier(Ident { value: "x", quote_style: None }), op: Plus, right: BinaryOp { left: Identifier(Ident { value: "y", quote_style: None }), op: Multiply, right: Value(Number("2")) } } }))], variadic: false, order_by: [], ignore_nulls: false }, over: None, filter: None, within_group: None }))], from: [], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' diff --git a/src/sqlparser/tests/testdata/qualified_operator.yaml b/src/sqlparser/tests/testdata/qualified_operator.yaml index 83f8e885989c7..ddab7deeb214f 100644 --- a/src/sqlparser/tests/testdata/qualified_operator.yaml +++ b/src/sqlparser/tests/testdata/qualified_operator.yaml @@ -19,10 +19,10 @@ formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(Identifier(Ident { value: "operator", quote_style: None }))], from: [], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' - input: select "operator"(foo.bar); formatted_sql: SELECT "operator"(foo.bar) - formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(Function(Function { scalar_as_agg: false, name: ObjectName([Ident { value: "operator", quote_style: Some(''"'') }]), args: [Unnamed(Expr(CompoundIdentifier([Ident { value: "foo", quote_style: None }, Ident { value: "bar", quote_style: None }])))], variadic: false, over: None, distinct: false, order_by: [], filter: None, within_group: None }))], from: [], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' + formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(Function(Function { scalar_as_agg: false, name: ObjectName([Ident { value: "operator", quote_style: Some(''"'') }]), arg_list: FunctionArgList { distinct: false, args: [Unnamed(Expr(CompoundIdentifier([Ident { value: "foo", quote_style: None }, Ident { value: "bar", quote_style: None }])))], variadic: false, order_by: [], ignore_nulls: false }, over: None, filter: None, within_group: None }))], from: [], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' - input: select operator operator(+) operator(+) "operator"(9) operator from operator; formatted_sql: SELECT operator OPERATOR(+) OPERATOR(+) "operator"(9) AS operator FROM operator - formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [ExprWithAlias { expr: BinaryOp { left: Identifier(Ident { value: "operator", quote_style: None }), op: PGQualified(QualifiedOperator { schema: None, name: "+" }), right: UnaryOp { op: PGQualified(QualifiedOperator { schema: None, name: "+" }), expr: Function(Function { scalar_as_agg: false, name: ObjectName([Ident { value: "operator", quote_style: Some(''"'') }]), args: [Unnamed(Expr(Value(Number("9"))))], variadic: false, over: None, distinct: false, order_by: [], filter: None, within_group: None }) } }, alias: Ident { value: "operator", quote_style: None } }], from: [TableWithJoins { relation: Table { name: ObjectName([Ident { value: "operator", quote_style: None }]), alias: None, as_of: None }, joins: [] }], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' + formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [ExprWithAlias { expr: BinaryOp { left: Identifier(Ident { value: "operator", quote_style: None }), op: PGQualified(QualifiedOperator { schema: None, name: "+" }), right: UnaryOp { op: PGQualified(QualifiedOperator { schema: None, name: "+" }), expr: Function(Function { scalar_as_agg: false, name: ObjectName([Ident { value: "operator", quote_style: Some(''"'') }]), arg_list: FunctionArgList { distinct: false, args: [Unnamed(Expr(Value(Number("9"))))], variadic: false, order_by: [], ignore_nulls: false }, over: None, filter: None, within_group: None }) } }, alias: Ident { value: "operator", quote_style: None } }], from: [TableWithJoins { relation: Table { name: ObjectName([Ident { value: "operator", quote_style: None }]), alias: None, as_of: None }, joins: [] }], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' - input: select 3 operator(-) 2 - 1; formatted_sql: SELECT 3 OPERATOR(-) 2 - 1 formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(BinaryOp { left: Value(Number("3")), op: PGQualified(QualifiedOperator { schema: None, name: "-" }), right: BinaryOp { left: Value(Number("2")), op: Minus, right: Value(Number("1")) } })], from: [], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' diff --git a/src/sqlparser/tests/testdata/select.yaml b/src/sqlparser/tests/testdata/select.yaml index 06c12a3e7d554..83c624c64309b 100644 --- a/src/sqlparser/tests/testdata/select.yaml +++ b/src/sqlparser/tests/testdata/select.yaml @@ -1,7 +1,7 @@ # This file is automatically generated by `src/sqlparser/tests/parser_test.rs`. - input: SELECT sqrt(id) FROM foo formatted_sql: SELECT sqrt(id) FROM foo - formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(Function(Function { scalar_as_agg: false, name: ObjectName([Ident { value: "sqrt", quote_style: None }]), args: [Unnamed(Expr(Identifier(Ident { value: "id", quote_style: None })))], variadic: false, over: None, distinct: false, order_by: [], filter: None, within_group: None }))], from: [TableWithJoins { relation: Table { name: ObjectName([Ident { value: "foo", quote_style: None }]), alias: None, as_of: None }, joins: [] }], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' + formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(Function(Function { scalar_as_agg: false, name: ObjectName([Ident { value: "sqrt", quote_style: None }]), arg_list: FunctionArgList { distinct: false, args: [Unnamed(Expr(Identifier(Ident { value: "id", quote_style: None })))], variadic: false, order_by: [], ignore_nulls: false }, over: None, filter: None, within_group: None }))], from: [TableWithJoins { relation: Table { name: ObjectName([Ident { value: "foo", quote_style: None }]), alias: None, as_of: None }, joins: [] }], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' - input: SELECT INT '1' formatted_sql: SELECT INT '1' - input: SELECT (foo).v1.v2 FROM foo @@ -99,7 +99,7 @@ formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(AtTimeZone { timestamp: TypedString { data_type: Timestamp(true), value: "2022-10-01 12:00:00Z" }, time_zone: Identifier(Ident { value: "zone", quote_style: None }) })], from: [], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' - input: SELECT now() + INTERVAL '14 days' AT TIME ZONE 'UTC'; -- https://www.postgresql.org/message-id/CADT4RqBPdbsZW7HS1jJP319TMRHs1hzUiP=iRJYR6UqgHCrgNQ@mail.gmail.com formatted_sql: SELECT now() + INTERVAL '14 days' AT TIME ZONE 'UTC' - formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(BinaryOp { left: Function(Function { scalar_as_agg: false, name: ObjectName([Ident { value: "now", quote_style: None }]), args: [], variadic: false, over: None, distinct: false, order_by: [], filter: None, within_group: None }), op: Plus, right: AtTimeZone { timestamp: Value(Interval { value: "14 days", leading_field: None, leading_precision: None, last_field: None, fractional_seconds_precision: None }), time_zone: Value(SingleQuotedString("UTC")) } })], from: [], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' + formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(BinaryOp { left: Function(Function { scalar_as_agg: false, name: ObjectName([Ident { value: "now", quote_style: None }]), arg_list: FunctionArgList { distinct: false, args: [], variadic: false, order_by: [], ignore_nulls: false }, over: None, filter: None, within_group: None }), op: Plus, right: AtTimeZone { timestamp: Value(Interval { value: "14 days", leading_field: None, leading_precision: None, last_field: None, fractional_seconds_precision: None }), time_zone: Value(SingleQuotedString("UTC")) } })], from: [], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' - input: SELECT c FROM t WHERE c >= '2019-03-27T22:00:00.000Z'::timestamp AT TIME ZONE 'Europe/Brussels'; -- https://github.com/sqlparser-rs/sqlparser-rs/issues/1266 formatted_sql: SELECT c FROM t WHERE c >= CAST('2019-03-27T22:00:00.000Z' AS TIMESTAMP) AT TIME ZONE 'Europe/Brussels' formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(Identifier(Ident { value: "c", quote_style: None }))], from: [TableWithJoins { relation: Table { name: ObjectName([Ident { value: "t", quote_style: None }]), alias: None, as_of: None }, joins: [] }], lateral_views: [], selection: Some(BinaryOp { left: Identifier(Ident { value: "c", quote_style: None }), op: GtEq, right: AtTimeZone { timestamp: Cast { expr: Value(SingleQuotedString("2019-03-27T22:00:00.000Z")), data_type: Timestamp(false) }, time_zone: Value(SingleQuotedString("Europe/Brussels")) } }), group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' @@ -173,7 +173,7 @@ formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(Identifier(Ident { value: "id1", quote_style: None })), UnnamedExpr(Identifier(Ident { value: "a1", quote_style: None })), UnnamedExpr(Identifier(Ident { value: "id2", quote_style: None })), UnnamedExpr(Identifier(Ident { value: "a2", quote_style: None }))], from: [TableWithJoins { relation: Table { name: ObjectName([Ident { value: "stream", quote_style: None }]), alias: Some(TableAlias { name: Ident { value: "S", quote_style: None }, columns: [] }), as_of: None }, joins: [Join { relation: Table { name: ObjectName([Ident { value: "version", quote_style: None }]), alias: Some(TableAlias { name: Ident { value: "V", quote_style: None }, columns: [] }), as_of: Some(ProcessTime) }, join_operator: Inner(On(BinaryOp { left: Identifier(Ident { value: "id1", quote_style: None }), op: Eq, right: Identifier(Ident { value: "id2", quote_style: None }) })) }] }], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' - input: select percentile_cont(0.3) within group (order by x desc) from unnest(array[1,2,4,5,10]) as x formatted_sql: SELECT percentile_cont(0.3) FROM unnest(ARRAY[1, 2, 4, 5, 10]) AS x - formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(Function(Function { scalar_as_agg: false, name: ObjectName([Ident { value: "percentile_cont", quote_style: None }]), args: [Unnamed(Expr(Value(Number("0.3"))))], variadic: false, over: None, distinct: false, order_by: [], filter: None, within_group: Some(OrderByExpr { expr: Identifier(Ident { value: "x", quote_style: None }), asc: Some(false), nulls_first: None }) }))], from: [TableWithJoins { relation: TableFunction { name: ObjectName([Ident { value: "unnest", quote_style: None }]), alias: Some(TableAlias { name: Ident { value: "x", quote_style: None }, columns: [] }), args: [Unnamed(Expr(Array(Array { elem: [Value(Number("1")), Value(Number("2")), Value(Number("4")), Value(Number("5")), Value(Number("10"))], named: true })))], with_ordinality: false }, joins: [] }], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' + formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(Function(Function { scalar_as_agg: false, name: ObjectName([Ident { value: "percentile_cont", quote_style: None }]), arg_list: FunctionArgList { distinct: false, args: [Unnamed(Expr(Value(Number("0.3"))))], variadic: false, order_by: [], ignore_nulls: false }, over: None, filter: None, within_group: Some(OrderByExpr { expr: Identifier(Ident { value: "x", quote_style: None }), asc: Some(false), nulls_first: None }) }))], from: [TableWithJoins { relation: TableFunction { name: ObjectName([Ident { value: "unnest", quote_style: None }]), alias: Some(TableAlias { name: Ident { value: "x", quote_style: None }, columns: [] }), args: [Unnamed(Expr(Array(Array { elem: [Value(Number("1")), Value(Number("2")), Value(Number("4")), Value(Number("5")), Value(Number("10"))], named: true })))], with_ordinality: false }, joins: [] }], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' - input: select percentile_cont(0.3) within group (order by x, y desc) from t error_msg: |- sql parser error: expected ), found: , diff --git a/src/tests/sqlsmith/src/sql_gen/agg.rs b/src/tests/sqlsmith/src/sql_gen/agg.rs index 4953235d4cba4..177603ddb333a 100644 --- a/src/tests/sqlsmith/src/sql_gen/agg.rs +++ b/src/tests/sqlsmith/src/sql_gen/agg.rs @@ -18,7 +18,7 @@ use risingwave_common::types::DataType; use risingwave_expr::aggregate::PbAggKind; use risingwave_expr::sig::SigDataType; use risingwave_sqlparser::ast::{ - Expr, Function, FunctionArg, FunctionArgExpr, Ident, ObjectName, OrderByExpr, + Expr, Function, FunctionArg, FunctionArgExpr, FunctionArgList, Ident, ObjectName, OrderByExpr, }; use crate::sql_gen::types::AGG_FUNC_TABLE; @@ -142,11 +142,8 @@ fn make_agg_func( Function { scalar_as_agg: false, name: ObjectName(vec![Ident::new_unchecked(func_name)]), - args, - variadic: false, + arg_list: FunctionArgList::for_agg(distinct, args, order_by), over: None, - distinct, - order_by, filter, within_group: None, } diff --git a/src/tests/sqlsmith/src/sql_gen/functions.rs b/src/tests/sqlsmith/src/sql_gen/functions.rs index cee18a18081ca..8cd1645ec1f5b 100644 --- a/src/tests/sqlsmith/src/sql_gen/functions.rs +++ b/src/tests/sqlsmith/src/sql_gen/functions.rs @@ -18,8 +18,8 @@ use rand::Rng; use risingwave_common::types::DataType; use risingwave_frontend::expr::ExprType; use risingwave_sqlparser::ast::{ - BinaryOperator, Expr, Function, FunctionArg, FunctionArgExpr, Ident, ObjectName, - TrimWhereField, UnaryOperator, Value, + BinaryOperator, Expr, Function, FunctionArg, FunctionArgExpr, FunctionArgList, Ident, + ObjectName, TrimWhereField, UnaryOperator, Value, }; use crate::sql_gen::types::{FUNC_TABLE, IMPLICIT_CAST_TABLE, INVARIANT_FUNC_SET}; @@ -258,11 +258,8 @@ pub fn make_simple_func(func_name: &str, exprs: &[Expr]) -> Function { Function { scalar_as_agg: false, name: ObjectName(vec![Ident::new_unchecked(func_name)]), - args, - variadic: false, + arg_list: FunctionArgList::args_only(args), over: None, - distinct: false, - order_by: vec![], filter: None, within_group: None, }