Skip to content

Commit

Permalink
refactor(expr): support variadic function in #[function] macro (#12178
Browse files Browse the repository at this point in the history
)

Signed-off-by: Runji Wang <[email protected]>
  • Loading branch information
wangrunji0408 authored Sep 15, 2023
1 parent 467ba4b commit 0032145
Show file tree
Hide file tree
Showing 20 changed files with 363 additions and 570 deletions.
23 changes: 23 additions & 0 deletions e2e_test/batch/functions/format.slt.part
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,26 @@ query T
SELECT format('Testing %s, %s, %s, %%', 'one', 'two', 'three');
----
Testing one, two, three, %

query T
SELECT format('%s %s', a, b) from (values
('Hello', 'World'),
('Rising', 'Wave')
) as t(a, b);
----
Hello World
Rising Wave

query T
SELECT format(f, a, b) from (values
('%s %s', 'Hello', 'World'),
('%s%s', 'Hello', null),
(null, 'Hello', 'World')
) as t(f, a, b);
----
Hello World
Hello
NULL

query error too few arguments for format()
SELECT format('%s %s', 'Hello');
1 change: 1 addition & 0 deletions src/common/src/array/data_chunk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -779,6 +779,7 @@ impl DataChunkTestExt for DataChunk {
"." => None,
"t" => Some(true.into()),
"f" => Some(false.into()),
"(empty)" => Some("".into()),
_ => Some(ScalarImpl::from_text(val_str.as_bytes(), ty).unwrap()),
};
builder.append(datum);
Expand Down
1 change: 0 additions & 1 deletion src/common/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ pub mod system_param;
pub mod telemetry;
pub mod transaction;

pub mod format;
pub mod metrics;
pub mod test_utils;
pub mod types;
Expand Down
110 changes: 78 additions & 32 deletions src/expr/macro/src/gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,14 @@ impl FunctionAttr {
return self.generate_table_function_descriptor(user_fn, build_fn);
}
let name = self.name.clone();
let mut args = Vec::with_capacity(self.args.len());
for ty in &self.args {
args.push(data_type_name(ty));
let variadic = matches!(self.args.last(), Some(t) if t == "...");
let args = match variadic {
true => &self.args[..self.args.len() - 1],
false => &self.args[..],
}
.iter()
.map(|ty| data_type_name(ty))
.collect_vec();
let ret = data_type_name(&self.ret);

let pb_type = format_ident!("{}", utils::to_camel_case(&name));
Expand All @@ -82,6 +86,7 @@ impl FunctionAttr {
unsafe { crate::sig::func::_register(#descriptor_type {
func: risingwave_pb::expr::expr_node::Type::#pb_type,
inputs_type: &[#(#args),*],
variadic: #variadic,
ret_type: #ret,
build: #build_fn,
deprecated: #deprecated,
Expand All @@ -99,7 +104,8 @@ impl FunctionAttr {
user_fn: &UserFunctionAttr,
optimize_const: bool,
) -> Result<TokenStream2> {
let num_args = self.args.len();
let variadic = matches!(self.args.last(), Some(t) if t == "...");
let num_args = self.args.len() - if variadic { 1 } else { 0 };
let fn_name = format_ident!("{}", user_fn.name);
let struct_name = match optimize_const {
true => format_ident!("{}OptimizeConst", utils::to_camel_case(&self.ident_name())),
Expand Down Expand Up @@ -148,8 +154,6 @@ impl FunctionAttr {
let inputs = idents("i", &children_indices);
let prebuilt_inputs = idents("i", &prebuilt_indices);
let non_prebuilt_inputs = idents("i", &non_prebuilt_indices);
let all_child = idents("child", &(0..num_args).collect_vec());
let child = idents("child", &children_indices);
let array_refs = idents("array", &children_indices);
let arrays = idents("a", &children_indices);
let datums = idents("v", &children_indices);
Expand Down Expand Up @@ -210,15 +214,41 @@ impl FunctionAttr {
} else {
quote! { () }
};
let generic = if self.ret == "boolean" && user_fn.generic == 3 {

// ensure the number of children matches the number of arguments
let check_children = match variadic {
true => quote! { crate::ensure!(children.len() >= #num_args); },
false => quote! { crate::ensure!(children.len() == #num_args); },
};

// evaluate variadic arguments in `eval`
let eval_variadic = variadic.then(|| {
quote! {
let mut columns = Vec::with_capacity(self.children.len() - #num_args);
for child in &self.children[#num_args..] {
columns.push(child.eval_checked(input).await?);
}
let variadic_input = DataChunk::new(columns, input.vis().clone());
}
});
// evaluate variadic arguments in `eval_row`
let eval_row_variadic = variadic.then(|| {
quote! {
let mut row = Vec::with_capacity(self.children.len() - #num_args);
for child in &self.children[#num_args..] {
row.push(child.eval_row(input).await?);
}
let variadic_row = OwnedRow::new(row);
}
});

let generic = (self.ret == "boolean" && user_fn.generic == 3).then(|| {
// XXX: for generic compare functions, we need to specify the compatible type
let compatible_type = types::ref_type(types::min_compatible_type(&self.args))
.parse::<TokenStream2>()
.unwrap();
quote! { ::<_, _, #compatible_type> }
} else {
quote! {}
};
});
let prebuilt_arg = match (&self.prebuild, optimize_const) {
// use the prebuilt argument
(Some(_), true) => quote! { &self.prebuilt_arg, },
Expand All @@ -227,18 +257,21 @@ impl FunctionAttr {
// no prebuilt argument
(None, _) => quote! {},
};
let context = match user_fn.context {
true => quote! { &self.context, },
false => quote! {},
};
let writer = match user_fn.write {
true => quote! { &mut writer, },
false => quote! {},
};
let variadic_args = variadic.then(|| quote! { variadic_row, });
let context = user_fn.context.then(|| quote! { &self.context, });
let writer = user_fn.write.then(|| quote! { &mut writer, });
let await_ = user_fn.async_.then(|| quote! { .await });
// call the user defined function
// inputs: [ Option<impl ScalarRef> ]
let mut output = quote! { #fn_name #generic(#(#non_prebuilt_inputs,)* #prebuilt_arg #context #writer) #await_ };
let mut output = quote! { #fn_name #generic(
#(#non_prebuilt_inputs,)*
#prebuilt_arg
#variadic_args
#context
#writer
) #await_ };
// handle error if the function returns `Result`
// wrap a `Some` if the function doesn't return `Option`
output = match user_fn.return_type_kind {
// XXX: we don't support void type yet. return null::int for now.
_ if self.ret == "void" => quote! { { #output; Option::<i32>::None } },
Expand All @@ -257,7 +290,7 @@ impl FunctionAttr {
}
};
};
// output: Option<impl ScalarRef or Scalar>
// now the `output` is: Option<impl ScalarRef or Scalar>
let append_output = match user_fn.write {
true => quote! {{
let mut writer = builder.writer().begin();
Expand Down Expand Up @@ -292,13 +325,20 @@ impl FunctionAttr {
};
// the main body in `eval`
let eval = if let Some(batch_fn) = &self.batch_fn {
assert!(
!variadic,
"customized batch function is not supported for variadic functions"
);
// user defined batch function
let fn_name = format_ident!("{}", batch_fn);
quote! {
let c = #fn_name(#(#arrays),*);
Ok(Arc::new(c.into()))
}
} else if (types::is_primitive(&self.ret) || self.ret == "boolean") && user_fn.is_pure() {
} else if (types::is_primitive(&self.ret) || self.ret == "boolean")
&& user_fn.is_pure()
&& !variadic
{
// SIMD optimization for primitive types
match self.args.len() {
0 => quote! {
Expand Down Expand Up @@ -330,27 +370,34 @@ impl FunctionAttr {
}
} else {
// no optimization
let array_zip = match num_args {
let array_zip = match children_indices.len() {
0 => quote! { std::iter::repeat(()).take(input.capacity()) },
_ => quote! { multizip((#(#arrays.iter(),)*)) },
};
let let_variadic = variadic.then(|| {
quote! {
let variadic_row = variadic_input.row_at_unchecked_vis(i);
}
});
quote! {
let mut builder = #builder_type::with_type(input.capacity(), self.context.return_type.clone());

match input.vis() {
Vis::Bitmap(vis) => {
// allow using `zip` for performance
#[allow(clippy::disallowed_methods)]
for ((#(#inputs,)*), visible) in #array_zip.zip(vis.iter()) {
for (i, ((#(#inputs,)*), visible)) in #array_zip.zip(vis.iter()).enumerate() {
if !visible {
builder.append_null();
continue;
}
#let_variadic
#append_output
}
}
Vis::Compact(_) => {
for (#(#inputs,)*) in #array_zip {
for (i, (#(#inputs,)*)) in #array_zip.enumerate() {
#let_variadic
#append_output
}
}
Expand All @@ -374,19 +421,17 @@ impl FunctionAttr {
use crate::expr::{Context, BoxedExpression};
use crate::Result;

crate::ensure!(children.len() == #num_args);
#check_children
let prebuilt_arg = #prebuild_const;
let context = Context {
return_type,
arg_types: children.iter().map(|c| c.return_type()).collect(),
};
let mut iter = children.into_iter();
#(let #all_child = iter.next().unwrap();)*

#[derive(Debug)]
struct #struct_name {
context: Context,
#(#child: BoxedExpression,)*
children: Vec<BoxedExpression>,
prebuilt_arg: #prebuilt_arg_type,
}
#[async_trait::async_trait]
Expand All @@ -395,25 +440,26 @@ impl FunctionAttr {
self.context.return_type.clone()
}
async fn eval(&self, input: &DataChunk) -> Result<ArrayRef> {
// evaluate children and downcast arrays
#(
let #array_refs = self.#child.eval_checked(input).await?;
let #array_refs = self.children[#children_indices].eval_checked(input).await?;
let #arrays: &#arg_arrays = #array_refs.as_ref().into();
)*
#eval_variadic
#eval
}
async fn eval_row(&self, input: &OwnedRow) -> Result<Datum> {
#(
let #datums = self.#child.eval_row(input).await?;
let #datums = self.children[#children_indices].eval_row(input).await?;
let #inputs: Option<#arg_types> = #datums.as_ref().map(|s| s.as_scalar_ref_impl().try_into().unwrap());
)*
#eval_row_variadic
Ok(#row_output)
}
}

Ok(Box::new(#struct_name {
context,
#(#child,)*
children,
prebuilt_arg,
}))
}
Expand Down
27 changes: 23 additions & 4 deletions src/expr/macro/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ mod utils;
/// - [Rust Function Signature](#rust-function-signature)
/// - [Nullable Arguments](#nullable-arguments)
/// - [Return Value](#return-value)
/// - [Variadic Function](#variadic-function)
/// - [Optimization](#optimization)
/// - [Functions Returning Strings](#functions-returning-strings)
/// - [Preprocessing Constant Arguments](#preprocessing-constant-arguments)
Expand All @@ -62,13 +63,15 @@ mod utils;
/// invocation. The signature follows this pattern:
///
/// ```text
/// name ( [arg_types],* ) [ -> [setof] return_type ]
/// name ( [arg_types],* [...] ) [ -> [setof] return_type ]
/// ```
///
/// Where `name` is the function name, which must match the function name defined in `prost`.
/// Where `name` is the function name in `snake_case`, which must match the function name defined
/// in `prost`.
///
/// The allowed data types are listed in the `name` column of the appendix's [type matrix].
/// Wildcards or `auto` can also be used, as explained below.
/// `arg_types` is a comma-separated list of argument types. The allowed data types are listed in
/// in the `name` column of the appendix's [type matrix]. Wildcards or `auto` can also be used, as
/// explained below. If the function is variadic, the last argument can be denoted as `...`.
///
/// When `setof` appears before the return type, this indicates that the function is a set-returning
/// function (table function), meaning it can return multiple values instead of just one. For more
Expand Down Expand Up @@ -203,6 +206,21 @@ mod utils;
///
/// Therefore, try to avoid returning `Option` and `Result` whenever possible.
///
/// ## Variadic Function
///
/// Variadic functions accept a `impl Row` input to represent tailing arguments.
/// For example:
///
/// ```ignore
/// #[function("concat_ws(varchar, ...) -> varchar")]
/// fn concat_ws(sep: &str, vals: impl Row) -> Option<Box<str>> {
/// let mut string_iter = vals.iter().flatten();
/// // ...
/// }
/// ```
///
/// See `risingwave_common::row::Row` for more details.
///
/// ## Functions Returning Strings
///
/// For functions that return varchar types, you can also use the writer style function signature to
Expand Down Expand Up @@ -569,6 +587,7 @@ impl FunctionAttr {
fn ident_name(&self) -> String {
format!("{}_{}_{}", self.name, self.args.join("_"), self.ret)
.replace("[]", "list")
.replace("...", "variadic")
.replace(['<', '>', ' ', ','], "_")
.replace("__", "_")
}
Expand Down
7 changes: 4 additions & 3 deletions src/expr/macro/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,10 @@ const TYPE_MATRIX: &str = "
/// Maps a data type to its corresponding data type name.
pub fn data_type(ty: &str) -> &str {
// XXX:
// For functions that contain `any` type, there are special handlings in the frontend,
// and the signature won't be accessed. So we simply return a placeholder here.
if ty == "any" {
// For functions that contain `any` type, or `...` variable arguments,
// there are special handlings in the frontend, and the signature won't be accessed.
// So we simply return a placeholder here.
if ty == "any" || ty == "..." {
return "Int32";
}
lookup_matrix(ty, 1)
Expand Down
3 changes: 3 additions & 0 deletions src/expr/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ pub enum ExprError {
#[error("field name must not be null")]
FieldNameNull,

#[error("too few arguments for format()")]
TooFewArguments,

#[error("invalid state: {0}")]
InvalidState(String),
}
Expand Down
5 changes: 0 additions & 5 deletions src/expr/src/expr/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,8 @@ use risingwave_pb::expr::ExprNode;
use super::expr_array_transform::ArrayTransformExpression;
use super::expr_case::CaseExpression;
use super::expr_coalesce::CoalesceExpression;
use super::expr_concat_ws::ConcatWsExpression;
use super::expr_field::FieldExpression;
use super::expr_in::InExpression;
use super::expr_nested_construct::NestedConstructExpression;
use super::expr_some_all::SomeAllExpression;
use super::expr_udf::UdfExpression;
use super::expr_vnode::VnodeExpression;
Expand Down Expand Up @@ -56,10 +54,7 @@ pub fn build_from_prost(prost: &ExprNode) -> Result<BoxedExpression> {
E::In => InExpression::try_from_boxed(prost),
E::Case => CaseExpression::try_from_boxed(prost),
E::Coalesce => CoalesceExpression::try_from_boxed(prost),
E::ConcatWs => ConcatWsExpression::try_from_boxed(prost),
E::Field => FieldExpression::try_from_boxed(prost),
E::Array => NestedConstructExpression::try_from_boxed(prost),
E::Row => NestedConstructExpression::try_from_boxed(prost),
E::Vnode => VnodeExpression::try_from_boxed(prost),

_ => {
Expand Down
Loading

0 comments on commit 0032145

Please sign in to comment.