From 00321457dafc2c07590d7213ee076f6bf07ff073 Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Fri, 15 Sep 2023 16:57:25 +0800 Subject: [PATCH] refactor(expr): support variadic function in `#[function]` macro (#12178) Signed-off-by: Runji Wang --- e2e_test/batch/functions/format.slt.part | 23 ++ src/common/src/array/data_chunk.rs | 1 + src/common/src/lib.rs | 1 - src/expr/macro/src/gen.rs | 110 +++++--- src/expr/macro/src/lib.rs | 27 +- src/expr/macro/src/types.rs | 7 +- src/expr/src/error.rs | 3 + src/expr/src/expr/build.rs | 5 - src/expr/src/expr/expr_concat_ws.rs | 250 ------------------ src/expr/src/expr/expr_nested_construct.rs | 153 ----------- src/expr/src/expr/mod.rs | 2 - src/expr/src/sig/func.rs | 22 +- src/expr/src/vector_op/array.rs | 28 ++ src/expr/src/vector_op/concat_ws.rs | 70 +++++ .../src => expr/src/vector_op}/format.rs | 102 ++++++- src/expr/src/vector_op/mod.rs | 3 + .../tests/testdata/input/format.yaml | 8 +- .../tests/testdata/output/format.yaml | 30 +-- src/frontend/src/binder/expr/function.rs | 72 +---- src/frontend/src/expr/type_inference/func.rs | 16 ++ 20 files changed, 363 insertions(+), 570 deletions(-) delete mode 100644 src/expr/src/expr/expr_concat_ws.rs delete mode 100644 src/expr/src/expr/expr_nested_construct.rs create mode 100644 src/expr/src/vector_op/array.rs create mode 100644 src/expr/src/vector_op/concat_ws.rs rename src/{common/src => expr/src/vector_op}/format.rs (50%) diff --git a/e2e_test/batch/functions/format.slt.part b/e2e_test/batch/functions/format.slt.part index 92b4fc1553a65..ab6090737e304 100644 --- a/e2e_test/batch/functions/format.slt.part +++ b/e2e_test/batch/functions/format.slt.part @@ -7,3 +7,26 @@ query T SELECT format('Testing %s, %s, %s, %%', 'one', 'two', 'three'); ---- Testing one, two, three, % + +query T +SELECT format('%s %s', a, b) from (values + ('Hello', 'World'), + ('Rising', 'Wave') +) as t(a, b); +---- +Hello World +Rising Wave + +query T +SELECT format(f, a, b) from (values + ('%s %s', 'Hello', 'World'), + ('%s%s', 'Hello', null), + (null, 'Hello', 'World') +) as t(f, a, b); +---- +Hello World +Hello +NULL + +query error too few arguments for format() +SELECT format('%s %s', 'Hello'); diff --git a/src/common/src/array/data_chunk.rs b/src/common/src/array/data_chunk.rs index f335b56a60edb..cc4bef12cccff 100644 --- a/src/common/src/array/data_chunk.rs +++ b/src/common/src/array/data_chunk.rs @@ -779,6 +779,7 @@ impl DataChunkTestExt for DataChunk { "." => None, "t" => Some(true.into()), "f" => Some(false.into()), + "(empty)" => Some("".into()), _ => Some(ScalarImpl::from_text(val_str.as_bytes(), ty).unwrap()), }; builder.append(datum); diff --git a/src/common/src/lib.rs b/src/common/src/lib.rs index 5f7e9fd476da5..da58e53b8c52d 100644 --- a/src/common/src/lib.rs +++ b/src/common/src/lib.rs @@ -72,7 +72,6 @@ pub mod system_param; pub mod telemetry; pub mod transaction; -pub mod format; pub mod metrics; pub mod test_utils; pub mod types; diff --git a/src/expr/macro/src/gen.rs b/src/expr/macro/src/gen.rs index 9e8b4d9d3c1b2..56464f2bf2809 100644 --- a/src/expr/macro/src/gen.rs +++ b/src/expr/macro/src/gen.rs @@ -59,10 +59,14 @@ impl FunctionAttr { return self.generate_table_function_descriptor(user_fn, build_fn); } let name = self.name.clone(); - let mut args = Vec::with_capacity(self.args.len()); - for ty in &self.args { - args.push(data_type_name(ty)); + let variadic = matches!(self.args.last(), Some(t) if t == "..."); + let args = match variadic { + true => &self.args[..self.args.len() - 1], + false => &self.args[..], } + .iter() + .map(|ty| data_type_name(ty)) + .collect_vec(); let ret = data_type_name(&self.ret); let pb_type = format_ident!("{}", utils::to_camel_case(&name)); @@ -82,6 +86,7 @@ impl FunctionAttr { unsafe { crate::sig::func::_register(#descriptor_type { func: risingwave_pb::expr::expr_node::Type::#pb_type, inputs_type: &[#(#args),*], + variadic: #variadic, ret_type: #ret, build: #build_fn, deprecated: #deprecated, @@ -99,7 +104,8 @@ impl FunctionAttr { user_fn: &UserFunctionAttr, optimize_const: bool, ) -> Result { - let num_args = self.args.len(); + let variadic = matches!(self.args.last(), Some(t) if t == "..."); + let num_args = self.args.len() - if variadic { 1 } else { 0 }; let fn_name = format_ident!("{}", user_fn.name); let struct_name = match optimize_const { true => format_ident!("{}OptimizeConst", utils::to_camel_case(&self.ident_name())), @@ -148,8 +154,6 @@ impl FunctionAttr { let inputs = idents("i", &children_indices); let prebuilt_inputs = idents("i", &prebuilt_indices); let non_prebuilt_inputs = idents("i", &non_prebuilt_indices); - let all_child = idents("child", &(0..num_args).collect_vec()); - let child = idents("child", &children_indices); let array_refs = idents("array", &children_indices); let arrays = idents("a", &children_indices); let datums = idents("v", &children_indices); @@ -210,15 +214,41 @@ impl FunctionAttr { } else { quote! { () } }; - let generic = if self.ret == "boolean" && user_fn.generic == 3 { + + // ensure the number of children matches the number of arguments + let check_children = match variadic { + true => quote! { crate::ensure!(children.len() >= #num_args); }, + false => quote! { crate::ensure!(children.len() == #num_args); }, + }; + + // evaluate variadic arguments in `eval` + let eval_variadic = variadic.then(|| { + quote! { + let mut columns = Vec::with_capacity(self.children.len() - #num_args); + for child in &self.children[#num_args..] { + columns.push(child.eval_checked(input).await?); + } + let variadic_input = DataChunk::new(columns, input.vis().clone()); + } + }); + // evaluate variadic arguments in `eval_row` + let eval_row_variadic = variadic.then(|| { + quote! { + let mut row = Vec::with_capacity(self.children.len() - #num_args); + for child in &self.children[#num_args..] { + row.push(child.eval_row(input).await?); + } + let variadic_row = OwnedRow::new(row); + } + }); + + let generic = (self.ret == "boolean" && user_fn.generic == 3).then(|| { // XXX: for generic compare functions, we need to specify the compatible type let compatible_type = types::ref_type(types::min_compatible_type(&self.args)) .parse::() .unwrap(); quote! { ::<_, _, #compatible_type> } - } else { - quote! {} - }; + }); let prebuilt_arg = match (&self.prebuild, optimize_const) { // use the prebuilt argument (Some(_), true) => quote! { &self.prebuilt_arg, }, @@ -227,18 +257,21 @@ impl FunctionAttr { // no prebuilt argument (None, _) => quote! {}, }; - let context = match user_fn.context { - true => quote! { &self.context, }, - false => quote! {}, - }; - let writer = match user_fn.write { - true => quote! { &mut writer, }, - false => quote! {}, - }; + let variadic_args = variadic.then(|| quote! { variadic_row, }); + let context = user_fn.context.then(|| quote! { &self.context, }); + let writer = user_fn.write.then(|| quote! { &mut writer, }); let await_ = user_fn.async_.then(|| quote! { .await }); // call the user defined function // inputs: [ Option ] - let mut output = quote! { #fn_name #generic(#(#non_prebuilt_inputs,)* #prebuilt_arg #context #writer) #await_ }; + let mut output = quote! { #fn_name #generic( + #(#non_prebuilt_inputs,)* + #prebuilt_arg + #variadic_args + #context + #writer + ) #await_ }; + // handle error if the function returns `Result` + // wrap a `Some` if the function doesn't return `Option` output = match user_fn.return_type_kind { // XXX: we don't support void type yet. return null::int for now. _ if self.ret == "void" => quote! { { #output; Option::::None } }, @@ -257,7 +290,7 @@ impl FunctionAttr { } }; }; - // output: Option + // now the `output` is: Option let append_output = match user_fn.write { true => quote! {{ let mut writer = builder.writer().begin(); @@ -292,13 +325,20 @@ impl FunctionAttr { }; // the main body in `eval` let eval = if let Some(batch_fn) = &self.batch_fn { + assert!( + !variadic, + "customized batch function is not supported for variadic functions" + ); // user defined batch function let fn_name = format_ident!("{}", batch_fn); quote! { let c = #fn_name(#(#arrays),*); Ok(Arc::new(c.into())) } - } else if (types::is_primitive(&self.ret) || self.ret == "boolean") && user_fn.is_pure() { + } else if (types::is_primitive(&self.ret) || self.ret == "boolean") + && user_fn.is_pure() + && !variadic + { // SIMD optimization for primitive types match self.args.len() { 0 => quote! { @@ -330,10 +370,15 @@ impl FunctionAttr { } } else { // no optimization - let array_zip = match num_args { + let array_zip = match children_indices.len() { 0 => quote! { std::iter::repeat(()).take(input.capacity()) }, _ => quote! { multizip((#(#arrays.iter(),)*)) }, }; + let let_variadic = variadic.then(|| { + quote! { + let variadic_row = variadic_input.row_at_unchecked_vis(i); + } + }); quote! { let mut builder = #builder_type::with_type(input.capacity(), self.context.return_type.clone()); @@ -341,16 +386,18 @@ impl FunctionAttr { Vis::Bitmap(vis) => { // allow using `zip` for performance #[allow(clippy::disallowed_methods)] - for ((#(#inputs,)*), visible) in #array_zip.zip(vis.iter()) { + for (i, ((#(#inputs,)*), visible)) in #array_zip.zip(vis.iter()).enumerate() { if !visible { builder.append_null(); continue; } + #let_variadic #append_output } } Vis::Compact(_) => { - for (#(#inputs,)*) in #array_zip { + for (i, (#(#inputs,)*)) in #array_zip.enumerate() { + #let_variadic #append_output } } @@ -374,19 +421,17 @@ impl FunctionAttr { use crate::expr::{Context, BoxedExpression}; use crate::Result; - crate::ensure!(children.len() == #num_args); + #check_children let prebuilt_arg = #prebuild_const; let context = Context { return_type, arg_types: children.iter().map(|c| c.return_type()).collect(), }; - let mut iter = children.into_iter(); - #(let #all_child = iter.next().unwrap();)* #[derive(Debug)] struct #struct_name { context: Context, - #(#child: BoxedExpression,)* + children: Vec, prebuilt_arg: #prebuilt_arg_type, } #[async_trait::async_trait] @@ -395,25 +440,26 @@ impl FunctionAttr { self.context.return_type.clone() } async fn eval(&self, input: &DataChunk) -> Result { - // evaluate children and downcast arrays #( - let #array_refs = self.#child.eval_checked(input).await?; + let #array_refs = self.children[#children_indices].eval_checked(input).await?; let #arrays: &#arg_arrays = #array_refs.as_ref().into(); )* + #eval_variadic #eval } async fn eval_row(&self, input: &OwnedRow) -> Result { #( - let #datums = self.#child.eval_row(input).await?; + let #datums = self.children[#children_indices].eval_row(input).await?; let #inputs: Option<#arg_types> = #datums.as_ref().map(|s| s.as_scalar_ref_impl().try_into().unwrap()); )* + #eval_row_variadic Ok(#row_output) } } Ok(Box::new(#struct_name { context, - #(#child,)* + children, prebuilt_arg, })) } diff --git a/src/expr/macro/src/lib.rs b/src/expr/macro/src/lib.rs index 4d8c48ca9ccac..c6ebffdff660f 100644 --- a/src/expr/macro/src/lib.rs +++ b/src/expr/macro/src/lib.rs @@ -38,6 +38,7 @@ mod utils; /// - [Rust Function Signature](#rust-function-signature) /// - [Nullable Arguments](#nullable-arguments) /// - [Return Value](#return-value) +/// - [Variadic Function](#variadic-function) /// - [Optimization](#optimization) /// - [Functions Returning Strings](#functions-returning-strings) /// - [Preprocessing Constant Arguments](#preprocessing-constant-arguments) @@ -62,13 +63,15 @@ mod utils; /// invocation. The signature follows this pattern: /// /// ```text -/// name ( [arg_types],* ) [ -> [setof] return_type ] +/// name ( [arg_types],* [...] ) [ -> [setof] return_type ] /// ``` /// -/// Where `name` is the function name, which must match the function name defined in `prost`. +/// Where `name` is the function name in `snake_case`, which must match the function name defined +/// in `prost`. /// -/// The allowed data types are listed in the `name` column of the appendix's [type matrix]. -/// Wildcards or `auto` can also be used, as explained below. +/// `arg_types` is a comma-separated list of argument types. The allowed data types are listed in +/// in the `name` column of the appendix's [type matrix]. Wildcards or `auto` can also be used, as +/// explained below. If the function is variadic, the last argument can be denoted as `...`. /// /// When `setof` appears before the return type, this indicates that the function is a set-returning /// function (table function), meaning it can return multiple values instead of just one. For more @@ -203,6 +206,21 @@ mod utils; /// /// Therefore, try to avoid returning `Option` and `Result` whenever possible. /// +/// ## Variadic Function +/// +/// Variadic functions accept a `impl Row` input to represent tailing arguments. +/// For example: +/// +/// ```ignore +/// #[function("concat_ws(varchar, ...) -> varchar")] +/// fn concat_ws(sep: &str, vals: impl Row) -> Option> { +/// let mut string_iter = vals.iter().flatten(); +/// // ... +/// } +/// ``` +/// +/// See `risingwave_common::row::Row` for more details. +/// /// ## Functions Returning Strings /// /// For functions that return varchar types, you can also use the writer style function signature to @@ -569,6 +587,7 @@ impl FunctionAttr { fn ident_name(&self) -> String { format!("{}_{}_{}", self.name, self.args.join("_"), self.ret) .replace("[]", "list") + .replace("...", "variadic") .replace(['<', '>', ' ', ','], "_") .replace("__", "_") } diff --git a/src/expr/macro/src/types.rs b/src/expr/macro/src/types.rs index 53f224b79a773..f0868697757f7 100644 --- a/src/expr/macro/src/types.rs +++ b/src/expr/macro/src/types.rs @@ -40,9 +40,10 @@ const TYPE_MATRIX: &str = " /// Maps a data type to its corresponding data type name. pub fn data_type(ty: &str) -> &str { // XXX: - // For functions that contain `any` type, there are special handlings in the frontend, - // and the signature won't be accessed. So we simply return a placeholder here. - if ty == "any" { + // For functions that contain `any` type, or `...` variable arguments, + // there are special handlings in the frontend, and the signature won't be accessed. + // So we simply return a placeholder here. + if ty == "any" || ty == "..." { return "Int32"; } lookup_matrix(ty, 1) diff --git a/src/expr/src/error.rs b/src/expr/src/error.rs index d1ae1eb35ad0c..1128c05e76a77 100644 --- a/src/expr/src/error.rs +++ b/src/expr/src/error.rs @@ -77,6 +77,9 @@ pub enum ExprError { #[error("field name must not be null")] FieldNameNull, + #[error("too few arguments for format()")] + TooFewArguments, + #[error("invalid state: {0}")] InvalidState(String), } diff --git a/src/expr/src/expr/build.rs b/src/expr/src/expr/build.rs index 1f34adead8855..b2b6db6eafc3c 100644 --- a/src/expr/src/expr/build.rs +++ b/src/expr/src/expr/build.rs @@ -22,10 +22,8 @@ use risingwave_pb::expr::ExprNode; use super::expr_array_transform::ArrayTransformExpression; use super::expr_case::CaseExpression; use super::expr_coalesce::CoalesceExpression; -use super::expr_concat_ws::ConcatWsExpression; use super::expr_field::FieldExpression; use super::expr_in::InExpression; -use super::expr_nested_construct::NestedConstructExpression; use super::expr_some_all::SomeAllExpression; use super::expr_udf::UdfExpression; use super::expr_vnode::VnodeExpression; @@ -56,10 +54,7 @@ pub fn build_from_prost(prost: &ExprNode) -> Result { E::In => InExpression::try_from_boxed(prost), E::Case => CaseExpression::try_from_boxed(prost), E::Coalesce => CoalesceExpression::try_from_boxed(prost), - E::ConcatWs => ConcatWsExpression::try_from_boxed(prost), E::Field => FieldExpression::try_from_boxed(prost), - E::Array => NestedConstructExpression::try_from_boxed(prost), - E::Row => NestedConstructExpression::try_from_boxed(prost), E::Vnode => VnodeExpression::try_from_boxed(prost), _ => { diff --git a/src/expr/src/expr/expr_concat_ws.rs b/src/expr/src/expr/expr_concat_ws.rs deleted file mode 100644 index 5bca7d0aea75c..0000000000000 --- a/src/expr/src/expr/expr_concat_ws.rs +++ /dev/null @@ -1,250 +0,0 @@ -// Copyright 2023 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::convert::TryFrom; -use std::fmt::Write; -use std::sync::Arc; - -use risingwave_common::array::{ - Array, ArrayBuilder, ArrayImpl, ArrayRef, DataChunk, Utf8ArrayBuilder, -}; -use risingwave_common::row::OwnedRow; -use risingwave_common::types::{DataType, Datum}; -use risingwave_pb::expr::expr_node::{RexNode, Type}; -use risingwave_pb::expr::ExprNode; - -use crate::expr::{build_from_prost as expr_build_from_prost, BoxedExpression, Expression}; -use crate::{bail, ensure, ExprError, Result}; - -#[derive(Debug)] -pub struct ConcatWsExpression { - return_type: DataType, - sep_expr: BoxedExpression, - string_exprs: Vec, -} - -#[async_trait::async_trait] -impl Expression for ConcatWsExpression { - fn return_type(&self) -> DataType { - self.return_type.clone() - } - - async fn eval(&self, input: &DataChunk) -> Result { - let sep_column = self.sep_expr.eval_checked(input).await?; - let sep_column = sep_column.as_utf8(); - - let mut string_columns = Vec::with_capacity(self.string_exprs.len()); - for expr in &self.string_exprs { - string_columns.push(expr.eval_checked(input).await?); - } - let string_columns_ref = string_columns - .iter() - .map(|c| c.as_utf8()) - .collect::>(); - - let row_len = input.capacity(); - let vis = input.vis(); - let mut builder = Utf8ArrayBuilder::new(row_len); - - for row_idx in 0..row_len { - if !vis.is_set(row_idx) { - builder.append(None); - continue; - } - let sep = match sep_column.value_at(row_idx) { - Some(sep) => sep, - None => { - builder.append(None); - continue; - } - }; - - let mut writer = builder.writer().begin(); - - let mut string_columns = string_columns_ref.iter(); - for string_column in string_columns.by_ref() { - if let Some(string) = string_column.value_at(row_idx) { - writer.write_str(string).unwrap(); - break; - } - } - - for string_column in string_columns { - if let Some(string) = string_column.value_at(row_idx) { - writer.write_str(sep).unwrap(); - writer.write_str(string).unwrap(); - } - } - - writer.finish(); - } - Ok(Arc::new(ArrayImpl::from(builder.finish()))) - } - - async fn eval_row(&self, input: &OwnedRow) -> Result { - let sep = self.sep_expr.eval_row(input).await?; - let sep = match sep { - Some(sep) => sep, - None => return Ok(None), - }; - - let mut strings = Vec::with_capacity(self.string_exprs.len()); - for expr in &self.string_exprs { - strings.push(expr.eval_row(input).await?); - } - let mut final_string = String::new(); - - let mut strings_iter = strings.iter(); - if let Some(string) = strings_iter.by_ref().flatten().next() { - final_string.push_str(string.as_utf8()) - } - - for string in strings_iter.flatten() { - final_string.push_str(sep.as_utf8()); - final_string.push_str(string.as_utf8()); - } - - Ok(Some(final_string.into())) - } -} - -impl ConcatWsExpression { - pub fn new( - return_type: DataType, - sep_expr: BoxedExpression, - string_exprs: Vec, - ) -> Self { - ConcatWsExpression { - return_type, - sep_expr, - string_exprs, - } - } -} - -impl<'a> TryFrom<&'a ExprNode> for ConcatWsExpression { - type Error = ExprError; - - fn try_from(prost: &'a ExprNode) -> Result { - ensure!(prost.get_function_type().unwrap() == Type::ConcatWs); - - let ret_type = DataType::from(prost.get_return_type().unwrap()); - let RexNode::FuncCall(func_call_node) = prost.get_rex_node().unwrap() else { - bail!("Expected RexNode::FuncCall"); - }; - - let children = &func_call_node.children; - let sep_expr = expr_build_from_prost(&children[0])?; - - let string_exprs = children[1..] - .iter() - .map(expr_build_from_prost) - .collect::>>()?; - Ok(ConcatWsExpression::new(ret_type, sep_expr, string_exprs)) - } -} - -#[cfg(test)] -mod tests { - use itertools::Itertools; - use risingwave_common::array::{DataChunk, DataChunkTestExt}; - use risingwave_common::row::OwnedRow; - use risingwave_common::types::Datum; - use risingwave_pb::data::data_type::TypeName; - use risingwave_pb::data::PbDataType; - use risingwave_pb::expr::expr_node::RexNode; - use risingwave_pb::expr::expr_node::Type::ConcatWs; - use risingwave_pb::expr::{ExprNode, FunctionCall}; - - use crate::expr::expr_concat_ws::ConcatWsExpression; - use crate::expr::test_utils::make_input_ref; - use crate::expr::Expression; - - pub fn make_concat_ws_function(children: Vec, ret: TypeName) -> ExprNode { - ExprNode { - function_type: ConcatWs as i32, - return_type: Some(PbDataType { - type_name: ret as i32, - ..Default::default() - }), - rex_node: Some(RexNode::FuncCall(FunctionCall { children })), - } - } - - #[tokio::test] - async fn test_eval_concat_ws_expr() { - let input_node1 = make_input_ref(0, TypeName::Varchar); - let input_node2 = make_input_ref(1, TypeName::Varchar); - let input_node3 = make_input_ref(2, TypeName::Varchar); - let input_node4 = make_input_ref(3, TypeName::Varchar); - let concat_ws_expr = ConcatWsExpression::try_from(&make_concat_ws_function( - vec![input_node1, input_node2, input_node3, input_node4], - TypeName::Varchar, - )) - .unwrap(); - - let chunk = DataChunk::from_pretty( - " - T T T T - , a b c - . a b c - , . b c - , . . . - . . . .", - ); - - let actual = concat_ws_expr.eval(&chunk).await.unwrap(); - let actual = actual - .iter() - .map(|r| r.map(|s| s.into_utf8())) - .collect_vec(); - - let expected = vec![Some("a,b,c"), None, Some("b,c"), Some(""), None]; - - assert_eq!(actual, expected); - } - - #[tokio::test] - async fn test_eval_row_concat_ws_expr() { - let input_node1 = make_input_ref(0, TypeName::Varchar); - let input_node2 = make_input_ref(1, TypeName::Varchar); - let input_node3 = make_input_ref(2, TypeName::Varchar); - let input_node4 = make_input_ref(3, TypeName::Varchar); - let concat_ws_expr = ConcatWsExpression::try_from(&make_concat_ws_function( - vec![input_node1, input_node2, input_node3, input_node4], - TypeName::Varchar, - )) - .unwrap(); - - let row_inputs = vec![ - vec![Some(","), Some("a"), Some("b"), Some("c")], - vec![None, Some("a"), Some("b"), Some("c")], - vec![Some(","), None, Some("b"), Some("c")], - vec![Some(","), None, None, None], - vec![None, None, None, None], - ]; - - let expected = [Some("a,b,c"), None, Some("b,c"), Some(""), None]; - - for (i, row_input) in row_inputs.iter().enumerate() { - let datum_vec: Vec = row_input.iter().map(|e| e.map(|s| s.into())).collect(); - let row = OwnedRow::new(datum_vec); - - let result = concat_ws_expr.eval_row(&row).await.unwrap(); - let expected = expected[i].map(|s| s.into()); - - assert_eq!(result, expected); - } - } -} diff --git a/src/expr/src/expr/expr_nested_construct.rs b/src/expr/src/expr/expr_nested_construct.rs deleted file mode 100644 index ece26ed138258..0000000000000 --- a/src/expr/src/expr/expr_nested_construct.rs +++ /dev/null @@ -1,153 +0,0 @@ -// Copyright 2023 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::convert::TryFrom; -use std::sync::Arc; - -use risingwave_common::array::{ - ArrayBuilder, ArrayImpl, ArrayRef, DataChunk, ListArrayBuilder, ListValue, StructArray, - StructValue, -}; -use risingwave_common::row::OwnedRow; -use risingwave_common::types::{DataType, Datum, Scalar}; -use risingwave_pb::expr::expr_node::{RexNode, Type}; -use risingwave_pb::expr::ExprNode; - -use crate::expr::{build_from_prost as expr_build_from_prost, BoxedExpression, Expression}; -use crate::{bail, ensure, ExprError, Result}; - -#[derive(Debug)] -pub struct NestedConstructExpression { - data_type: DataType, - elements: Vec, -} - -#[async_trait::async_trait] -impl Expression for NestedConstructExpression { - fn return_type(&self) -> DataType { - self.data_type.clone() - } - - async fn eval(&self, input: &DataChunk) -> Result { - let mut columns = Vec::with_capacity(self.elements.len()); - for e in &self.elements { - columns.push(e.eval_checked(input).await?); - } - - if let DataType::Struct(ty) = &self.data_type { - let array = StructArray::new(ty.clone(), columns, input.vis().to_bitmap()); - Ok(Arc::new(ArrayImpl::Struct(array))) - } else if let DataType::List { .. } = &self.data_type { - let chunk = DataChunk::new(columns, input.vis().clone()); - let mut builder = ListArrayBuilder::with_type(input.capacity(), self.data_type.clone()); - for row in chunk.rows_with_holes() { - if let Some(row) = row { - builder.append_row_ref(row); - } else { - builder.append_null(); - } - } - Ok(Arc::new(ArrayImpl::List(builder.finish()))) - } else { - Err(ExprError::UnsupportedFunction( - "expects struct or list type".to_string(), - )) - } - } - - async fn eval_row(&self, input: &OwnedRow) -> Result { - let mut datums = Vec::with_capacity(self.elements.len()); - for e in &self.elements { - datums.push(e.eval_row(input).await?); - } - if let DataType::Struct { .. } = &self.data_type { - Ok(Some(StructValue::new(datums).to_scalar_value())) - } else if let DataType::List(_) = &self.data_type { - Ok(Some(ListValue::new(datums).to_scalar_value())) - } else { - Err(ExprError::UnsupportedFunction( - "expects struct or list type".to_string(), - )) - } - } -} - -impl NestedConstructExpression { - pub fn new(data_type: DataType, elements: Vec) -> Self { - NestedConstructExpression { - data_type, - elements, - } - } -} - -impl<'a> TryFrom<&'a ExprNode> for NestedConstructExpression { - type Error = ExprError; - - fn try_from(prost: &'a ExprNode) -> Result { - ensure!([Type::Array, Type::Row].contains(&prost.get_function_type().unwrap())); - - let ret_type = DataType::from(prost.get_return_type().unwrap()); - let RexNode::FuncCall(func_call_node) = prost.get_rex_node().unwrap() else { - bail!("Expected RexNode::FuncCall"); - }; - let elements = func_call_node - .children - .iter() - .map(expr_build_from_prost) - .collect::>>()?; - Ok(NestedConstructExpression::new(ret_type, elements)) - } -} - -#[cfg(test)] -mod tests { - use risingwave_common::array::{DataChunk, ListValue}; - use risingwave_common::row::OwnedRow; - use risingwave_common::types::{DataType, Scalar, ScalarImpl}; - - use super::NestedConstructExpression; - use crate::expr::{BoxedExpression, Expression, LiteralExpression}; - - #[tokio::test] - async fn test_eval_array_expr() { - let expr = NestedConstructExpression { - data_type: DataType::List(DataType::Int32.into()), - elements: vec![i32_expr(1.into()), i32_expr(2.into())], - }; - - let arr = expr.eval(&DataChunk::new_dummy(2)).await.unwrap(); - assert_eq!(arr.len(), 2); - } - - #[tokio::test] - async fn test_eval_row_array_expr() { - let expr = NestedConstructExpression { - data_type: DataType::List(DataType::Int32.into()), - elements: vec![i32_expr(1.into()), i32_expr(2.into())], - }; - - let scalar_impl = expr - .eval_row(&OwnedRow::new(vec![])) - .await - .unwrap() - .unwrap(); - let expected = ListValue::new(vec![Some(1.into()), Some(2.into())]).to_scalar_value(); - assert_eq!(expected, scalar_impl); - } - - fn i32_expr(v: ScalarImpl) -> BoxedExpression { - Box::new(LiteralExpression::new(DataType::Int32, Some(v))) - } -} diff --git a/src/expr/src/expr/mod.rs b/src/expr/src/expr/mod.rs index 8a424f4908448..33509506753fa 100644 --- a/src/expr/src/expr/mod.rs +++ b/src/expr/src/expr/mod.rs @@ -37,12 +37,10 @@ mod expr_binary_nonnull; mod expr_binary_nullable; mod expr_case; mod expr_coalesce; -mod expr_concat_ws; mod expr_field; mod expr_in; mod expr_input_ref; mod expr_literal; -mod expr_nested_construct; mod expr_some_all; pub(crate) mod expr_udf; mod expr_unary; diff --git a/src/expr/src/sig/func.rs b/src/expr/src/sig/func.rs index 6e665ead3fc40..e8e0d19ec11d7 100644 --- a/src/expr/src/sig/func.rs +++ b/src/expr/src/sig/func.rs @@ -40,30 +40,30 @@ pub fn func_sigs() -> impl Iterator { } #[derive(Default, Clone, Debug)] -pub struct FuncSigMap(HashMap<(PbType, usize), Vec>); +pub struct FuncSigMap(HashMap>); impl FuncSigMap { /// Inserts a function signature. pub fn insert(&mut self, desc: FuncSign) { - self.0 - .entry((desc.func, desc.inputs_type.len())) - .or_default() - .push(desc) + self.0.entry(desc.func).or_default().push(desc) } /// Returns a function signature with the same type, argument types and return type. /// Deprecated functions are included. pub fn get(&self, ty: PbType, args: &[DataTypeName], ret: DataTypeName) -> Option<&FuncSign> { - let v = self.0.get(&(ty, args.len()))?; + let v = self.0.get(&ty)?; v.iter() - .find(|d| d.inputs_type == args && d.ret_type == ret) + .find(|d| (d.variadic || d.inputs_type == args) && d.ret_type == ret) } /// Returns all function signatures with the same type and number of arguments. /// Deprecated functions are excluded. pub fn get_with_arg_nums(&self, ty: PbType, nargs: usize) -> Vec<&FuncSign> { - match self.0.get(&(ty, nargs)) { - Some(v) => v.iter().filter(|d| !d.deprecated).collect(), + match self.0.get(&ty) { + Some(v) => v + .iter() + .filter(|d| (d.variadic || d.inputs_type.len() == nargs) && !d.deprecated) + .collect(), None => vec![], } } @@ -74,6 +74,7 @@ impl FuncSigMap { pub struct FuncSign { pub func: PbType, pub inputs_type: &'static [DataTypeName], + pub variadic: bool, pub ret_type: DataTypeName, pub build: fn(return_type: DataType, children: Vec) -> Result, /// Whether the function is deprecated and should not be used in the frontend. @@ -128,11 +129,10 @@ mod tests { // convert FUNC_SIG_MAP to a more convenient map for testing let mut new_map: BTreeMap, Vec>> = BTreeMap::new(); - for ((func, num_args), sigs) in &FUNC_SIG_MAP.0 { + for (func, sigs) in &FUNC_SIG_MAP.0 { for sig in sigs { // validate the FUNC_SIG_MAP is consistent assert_eq!(func, &sig.func); - assert_eq!(num_args, &sig.inputs_type.len()); // exclude deprecated functions if sig.deprecated { continue; diff --git a/src/expr/src/vector_op/array.rs b/src/expr/src/vector_op/array.rs new file mode 100644 index 0000000000000..26f1fa2492064 --- /dev/null +++ b/src/expr/src/vector_op/array.rs @@ -0,0 +1,28 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use risingwave_common::array::{ListValue, StructValue}; +use risingwave_common::row::Row; +use risingwave_common::types::ToOwnedDatum; +use risingwave_expr_macro::function; + +#[function("array(...) -> list")] +fn array(row: impl Row) -> ListValue { + ListValue::new(row.iter().map(|d| d.to_owned_datum()).collect()) +} + +#[function("row(...) -> struct")] +fn row_(row: impl Row) -> StructValue { + StructValue::new(row.iter().map(|d| d.to_owned_datum()).collect()) +} diff --git a/src/expr/src/vector_op/concat_ws.rs b/src/expr/src/vector_op/concat_ws.rs new file mode 100644 index 0000000000000..293adb3079c0d --- /dev/null +++ b/src/expr/src/vector_op/concat_ws.rs @@ -0,0 +1,70 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::fmt::Write; + +use risingwave_common::row::Row; +use risingwave_common::types::ToText; +use risingwave_expr_macro::function; + +/// Concatenates all but the first argument, with separators. The first argument is used as the +/// separator string, and should not be NULL. Other NULL arguments are ignored. +#[function("concat_ws(varchar, ...) -> varchar")] +fn concat_ws(sep: &str, vals: impl Row, writer: &mut impl Write) -> Option<()> { + let mut string_iter = vals.iter().flatten(); + if let Some(string) = string_iter.next() { + string.write(writer).unwrap(); + } + for string in string_iter { + write!(writer, "{}", sep).unwrap(); + string.write(writer).unwrap(); + } + Some(()) +} + +#[cfg(test)] +mod tests { + use risingwave_common::array::DataChunk; + use risingwave_common::row::Row; + use risingwave_common::test_prelude::DataChunkTestExt; + use risingwave_common::types::ToOwnedDatum; + use risingwave_common::util::iter_util::ZipEqDebug; + + use crate::expr::build_from_pretty; + + #[tokio::test] + async fn test_concat_ws() { + let concat_ws = + build_from_pretty("(concat_ws:varchar $0:varchar $1:varchar $2:varchar $3:varchar)"); + let (input, expected) = DataChunk::from_pretty( + "T T T T T + , a b c a,b,c + , . b c b,c + . a b c . + , . . . (empty) + . . . . .", + ) + .split_column_at(4); + + // test eval + let output = concat_ws.eval(&input).await.unwrap(); + assert_eq!(&output, expected.column_at(0)); + + // test eval_row + for (row, expected) in input.rows().zip_eq_debug(expected.rows()) { + let result = concat_ws.eval_row(&row.to_owned_row()).await.unwrap(); + assert_eq!(result, expected.datum_at(0).to_owned_datum()); + } + } +} diff --git a/src/common/src/format.rs b/src/expr/src/vector_op/format.rs similarity index 50% rename from src/common/src/format.rs rename to src/expr/src/vector_op/format.rs index 4bd5e8c905a4d..081a1e2ef1fb6 100644 --- a/src/common/src/format.rs +++ b/src/expr/src/vector_op/format.rs @@ -12,7 +12,53 @@ // See the License for the specific language governing permissions and // limitations under the License. -use thiserror::Error; +use std::fmt::Write; +use std::str::FromStr; + +use risingwave_common::row::Row; +use risingwave_common::types::{ScalarRefImpl, ToText}; +use risingwave_expr_macro::function; + +use super::string::quote_ident; +use crate::{ExprError, Result}; + +/// Formats arguments according to a format string. +#[function( + "format(varchar, ...) -> varchar", + prebuild = "Formatter::from_str($0).map_err(|e| ExprError::Parse(e.to_string().into()))?" +)] +fn format(formatter: &Formatter, row: impl Row, writer: &mut impl Write) -> Result<()> { + let mut args = row.iter(); + for node in &formatter.nodes { + match node { + FormatterNode::Literal(literal) => writer.write_str(literal).unwrap(), + FormatterNode::Specifier(sp) => { + let arg = args.next().ok_or(ExprError::TooFewArguments)?; + match sp.ty { + SpecifierType::SimpleString => { + if let Some(scalar) = arg { + scalar.write(writer).unwrap(); + } + } + SpecifierType::SqlIdentifier => match arg { + Some(ScalarRefImpl::Utf8(arg)) => quote_ident(arg, writer), + _ => { + return Err(ExprError::UnsupportedFunction( + "unsupported data for specifier type 'I'".to_string(), + )) + } + }, + SpecifierType::SqlLiteral => { + return Err(ExprError::UnsupportedFunction( + "unsupported specifier type 'L'".to_string(), + )) + } + } + } + } + } + Ok(()) +} /// The type of format conversion to use to produce the format specifier's output. #[derive(Copy, Clone, Debug, PartialEq, Eq)] @@ -31,7 +77,7 @@ pub enum SpecifierType { impl TryFrom for SpecifierType { type Error = (); - fn try_from(c: char) -> Result { + fn try_from(c: char) -> std::result::Result { match c { 's' => Ok(SpecifierType::SimpleString), 'I' => Ok(SpecifierType::SqlIdentifier), @@ -42,34 +88,36 @@ impl TryFrom for SpecifierType { } #[derive(Debug)] -pub struct Specifier { +struct Specifier { // TODO: support position, flags and width. - pub ty: SpecifierType, + ty: SpecifierType, } #[derive(Debug)] -pub enum FormatterNode { +enum FormatterNode { Specifier(Specifier), Literal(String), } #[derive(Debug)] -pub struct Formatter { +struct Formatter { nodes: Vec, } -#[derive(Debug, Error)] -pub enum ParseFormatError { +#[derive(Debug, thiserror::Error)] +enum ParseFormatError { #[error("unrecognized format() type specifier \"{0}\"")] UnrecognizedSpecifierType(char), #[error("unterminated format() type specifier")] UnterminatedSpecifier, } -impl Formatter { +impl FromStr for Formatter { + type Err = ParseFormatError; + /// Parse the format string into a high-efficient representation. /// - pub fn parse(format: &str) -> Result { + fn from_str(format: &str) -> std::result::Result { // 8 is a good magic number here, it can cover an input like 'Testing %s, %s, %s, %%'. let mut nodes = Vec::with_capacity(8); let mut after_percent = false; @@ -106,8 +154,38 @@ impl Formatter { Ok(Formatter { nodes }) } +} + +#[cfg(test)] +mod tests { + use risingwave_common::array::DataChunk; + use risingwave_common::row::Row; + use risingwave_common::test_prelude::DataChunkTestExt; + use risingwave_common::types::ToOwnedDatum; + use risingwave_common::util::iter_util::ZipEqDebug; - pub fn nodes(&self) -> &[FormatterNode] { - &self.nodes + use crate::expr::build_from_pretty; + + #[tokio::test] + async fn test_format() { + let format = build_from_pretty("(format:varchar $0:varchar $1:varchar $2:varchar)"); + let (input, expected) = DataChunk::from_pretty( + "T T T T + Hello%s World . HelloWorld + %s%s Hello World HelloWorld + %I && . \"&&\" + . a b .", + ) + .split_column_at(3); + + // test eval + let output = format.eval(&input).await.unwrap(); + assert_eq!(&output, expected.column_at(0)); + + // test eval_row + for (row, expected) in input.rows().zip_eq_debug(expected.rows()) { + let result = format.eval_row(&row.to_owned_row()).await.unwrap(); + assert_eq!(result, expected.datum_at(0).to_owned_datum()); + } } } diff --git a/src/expr/src/vector_op/mod.rs b/src/expr/src/vector_op/mod.rs index 7125ca33ec3de..0027abfa65542 100644 --- a/src/expr/src/vector_op/mod.rs +++ b/src/expr/src/vector_op/mod.rs @@ -13,6 +13,7 @@ // limitations under the License. pub mod arithmetic_op; +pub mod array; pub mod array_access; pub mod array_concat; pub mod array_distinct; @@ -31,12 +32,14 @@ pub mod cardinality; pub mod cast; pub mod cmp; pub mod concat_op; +pub mod concat_ws; pub mod conjunction; pub mod date_trunc; pub mod delay; pub mod encdec; pub mod exp; pub mod extract; +pub mod format; pub mod format_type; pub mod int256; pub mod jsonb_access; diff --git a/src/frontend/planner_test/tests/testdata/input/format.yaml b/src/frontend/planner_test/tests/testdata/input/format.yaml index 010d188766555..dd0df4dcd02df 100644 --- a/src/frontend/planner_test/tests/testdata/input/format.yaml +++ b/src/frontend/planner_test/tests/testdata/input/format.yaml @@ -11,19 +11,19 @@ CREATE TABLE t1(v1 varchar, v2 int, v3 int); SELECT format('Testing %s, %I, %s, %%', v1, v2, v3) FROM t1; expected_outputs: - - binder_error + - batch_plan - sql: | SELECT format('Testing %s, %s, %s, %%', 'one', 'two'); expected_outputs: - - binder_error + - batch_error - sql: | SELECT format('Testing %s, %s, %s, %', 'one', 'two', 'three'); expected_outputs: - - binder_error + - batch_error - sql: | SELECT format('Testing %s, %f, %d, %', 'one', 'two', 'three'); expected_outputs: - - binder_error + - batch_error - sql: | SELECT format(); expected_outputs: diff --git a/src/frontend/planner_test/tests/testdata/output/format.yaml b/src/frontend/planner_test/tests/testdata/output/format.yaml index 06ace2257308c..1c42a0f9252ac 100644 --- a/src/frontend/planner_test/tests/testdata/output/format.yaml +++ b/src/frontend/planner_test/tests/testdata/output/format.yaml @@ -7,38 +7,24 @@ SELECT format('Testing %s, %I, %s, %%', v1, v2, v3) FROM t1; batch_plan: |- BatchExchange { order: [], dist: Single } - └─BatchProject { exprs: [ConcatWs('':Varchar, 'Testing ':Varchar, t1.v1, ', ':Varchar, QuoteIdent(t1.v2), ', ':Varchar, t1.v3::Varchar, ', %':Varchar) as $expr1] } + └─BatchProject { exprs: [Format('Testing %s, %I, %s, %%':Varchar, t1.v1, t1.v2, t1.v3::Varchar) as $expr1] } └─BatchScan { table: t1, columns: [t1.v1, t1.v2, t1.v3], distribution: SomeShard } - sql: | CREATE TABLE t1(v1 varchar, v2 int, v3 int); SELECT format('Testing %s, %I, %s, %%', v1, v2, v3) FROM t1; - binder_error: |- - Bind error: failed to bind expression: format('Testing %s, %I, %s, %%', v1, v2, v3) - - Caused by: - Feature is not yet implemented: QuoteIdent[Int32] - Tracking issue: https://github.com/risingwavelabs/risingwave/issues/112 + batch_plan: |- + BatchExchange { order: [], dist: Single } + └─BatchProject { exprs: [Format('Testing %s, %I, %s, %%':Varchar, t1.v1, t1.v2::Varchar, t1.v3::Varchar) as $expr1] } + └─BatchScan { table: t1, columns: [t1.v1, t1.v2, t1.v3], distribution: SomeShard } - sql: | SELECT format('Testing %s, %s, %s, %%', 'one', 'two'); - binder_error: |- - Bind error: failed to bind expression: format('Testing %s, %s, %s, %%', 'one', 'two') - - Caused by: - Bind error: Function `format` required 3 arguments based on the `formatstr`, but 2 found. + batch_error: 'Expr error: too few arguments for format()' - sql: | SELECT format('Testing %s, %s, %s, %', 'one', 'two', 'three'); - binder_error: |- - Bind error: failed to bind expression: format('Testing %s, %s, %s, %', 'one', 'two', 'three') - - Caused by: - Bind error: unterminated format() type specifier + batch_error: 'Expr error: Parse error: unterminated format() type specifier' - sql: | SELECT format('Testing %s, %f, %d, %', 'one', 'two', 'three'); - binder_error: |- - Bind error: failed to bind expression: format('Testing %s, %f, %d, %', 'one', 'two', 'three') - - Caused by: - Bind error: unrecognized format() type specifier "f" + batch_error: 'Expr error: Parse error: unrecognized format() type specifier "f"' - sql: | SELECT format(); binder_error: |- diff --git a/src/frontend/src/binder/expr/function.rs b/src/frontend/src/binder/expr/function.rs index c3bdf1febcccb..44c79a1620c79 100644 --- a/src/frontend/src/binder/expr/function.rs +++ b/src/frontend/src/binder/expr/function.rs @@ -22,7 +22,6 @@ use itertools::Itertools; use risingwave_common::array::ListValue; use risingwave_common::catalog::PG_CATALOG_SCHEMA_NAME; use risingwave_common::error::{ErrorCode, Result, RwError}; -use risingwave_common::format::{Formatter, FormatterNode, SpecifierType}; use risingwave_common::session_config::USER_NAME_WILD_CARD; use risingwave_common::types::{DataType, ScalarImpl, Timestamptz}; use risingwave_common::{GIT_SHA, RW_VERSION}; @@ -755,7 +754,7 @@ impl Binder { rewrite(ExprType::ConcatWs, Binder::rewrite_concat_to_concat_ws), ), ("concat_ws", raw_call(ExprType::ConcatWs)), - ("format", rewrite(ExprType::ConcatWs, Binder::rewrite_format_to_concat_ws)), + ("format", raw_call(ExprType::Format)), ("translate", raw_call(ExprType::Translate)), ("split_part", raw_call(ExprType::SplitPart)), ("char_length", raw_call(ExprType::CharLength)), @@ -1204,75 +1203,6 @@ impl Binder { } } - fn rewrite_format_to_concat_ws(inputs: Vec) -> Result> { - let Some((format_expr, args)) = inputs.split_first() else { - return Err(ErrorCode::BindError( - "Function `format` takes at least 1 arguments (0 given)".to_string(), - ) - .into()); - }; - let ExprImpl::Literal(expr_literal) = format_expr else { - return Err(ErrorCode::BindError( - "Function `format` takes a literal string as the first argument".to_string(), - ) - .into()); - }; - let Some(ScalarImpl::Utf8(format_str)) = expr_literal.get_data() else { - return Err(ErrorCode::BindError( - "Function `format` takes a literal string as the first argument".to_string(), - ) - .into()); - }; - let formatter = Formatter::parse(format_str) - .map_err(|err| -> RwError { ErrorCode::BindError(err.to_string()).into() })?; - - let specifier_count = formatter - .nodes() - .iter() - .filter(|f_node| matches!(f_node, FormatterNode::Specifier(_))) - .count(); - if specifier_count != args.len() { - return Err(ErrorCode::BindError(format!( - "Function `format` required {} arguments based on the `formatstr`, but {} found.", - specifier_count, - args.len() - )) - .into()); - } - - // iter the args. - let mut j = 0; - let new_args = [Ok(ExprImpl::literal_varchar("".to_string()))] - .into_iter() - .chain(formatter.nodes().iter().map(move |f_node| -> Result<_> { - let new_arg = match f_node { - FormatterNode::Specifier(sp) => { - // We've checked the count. - let arg = &args[j]; - j += 1; - match sp.ty { - SpecifierType::SimpleString => arg.clone(), - SpecifierType::SqlIdentifier => { - FunctionCall::new(ExprType::QuoteIdent, vec![arg.clone()])?.into() - } - SpecifierType::SqlLiteral => { - return Err::<_, RwError>( - ErrorCode::BindError( - "unsupported specifier type 'L'".to_string(), - ) - .into(), - ) - } - } - } - FormatterNode::Literal(literal) => ExprImpl::literal_varchar(literal.clone()), - }; - Ok(new_arg) - })) - .try_collect()?; - Ok(new_args) - } - /// Make sure inputs only have 2 value and rewrite the arguments. /// Nullif(expr1,expr2) -> Case(Equal(expr1 = expr2),null,expr1). fn rewrite_nullif_to_case_when(inputs: Vec) -> Result> { diff --git a/src/frontend/src/expr/type_inference/func.rs b/src/frontend/src/expr/type_inference/func.rs index f359b3bcea642..69be8376b46ca 100644 --- a/src/frontend/src/expr/type_inference/func.rs +++ b/src/frontend/src/expr/type_inference/func.rs @@ -342,6 +342,21 @@ fn infer_type_for_special( .try_collect()?; Ok(Some(DataType::Varchar)) } + ExprType::Format => { + ensure_arity!("format", 1 <= | inputs |); + let inputs_owned = std::mem::take(inputs); + *inputs = inputs_owned + .into_iter() + .enumerate() + .map(|(i, input)| match i { + // 0-th arg must be string + 0 => input.cast_implicit(DataType::Varchar).map_err(Into::into), + // subsequent can be any type, using the output format + _ => input.cast_output(), + }) + .try_collect()?; + Ok(Some(DataType::Varchar)) + } ExprType::IsNotNull => { ensure_arity!("is_not_null", | inputs | == 1); match inputs[0].return_type() { @@ -1203,6 +1218,7 @@ mod tests { sig_map.insert(FuncSign { func: DUMMY_FUNC, inputs_type: formals, + variadic: false, ret_type: DUMMY_RET, build: |_, _| unreachable!(), deprecated: false,