Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(expr): support variadic function in #[function] macro #12178

Merged
merged 15 commits into from
Sep 15, 2023
Merged
23 changes: 23 additions & 0 deletions e2e_test/batch/functions/format.slt.part
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,26 @@ query T
SELECT format('Testing %s, %s, %s, %%', 'one', 'two', 'three');
----
Testing one, two, three, %

query T
SELECT format('%s %s', a, b) from (values
('Hello', 'World'),
('Rising', 'Wave')
) as t(a, b);
----
Hello World
Rising Wave

query T
SELECT format(f, a, b) from (values
('%s %s', 'Hello', 'World'),
('%s%s', 'Hello', null),
(null, 'Hello', 'World')
) as t(f, a, b);
----
Hello World
Hello
NULL

query error too few arguments for format()
SELECT format('%s %s', 'Hello');
1 change: 1 addition & 0 deletions src/common/src/array/data_chunk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -779,6 +779,7 @@ impl DataChunkTestExt for DataChunk {
"." => None,
"t" => Some(true.into()),
"f" => Some(false.into()),
"(empty)" => Some("".into()),
_ => Some(ScalarImpl::from_text(val_str.as_bytes(), ty).unwrap()),
};
builder.append(datum);
Expand Down
1 change: 0 additions & 1 deletion src/common/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ pub mod system_param;
pub mod telemetry;
pub mod transaction;

pub mod format;
pub mod metrics;
pub mod test_utils;
pub mod types;
Expand Down
110 changes: 78 additions & 32 deletions src/expr/macro/src/gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,14 @@ impl FunctionAttr {
return self.generate_table_function_descriptor(user_fn, build_fn);
}
let name = self.name.clone();
let mut args = Vec::with_capacity(self.args.len());
for ty in &self.args {
args.push(data_type_name(ty));
let variadic = matches!(self.args.last(), Some(t) if t == "...");
let args = match variadic {
true => &self.args[..self.args.len() - 1],
false => &self.args[..],
}
.iter()
.map(|ty| data_type_name(ty))
.collect_vec();
let ret = data_type_name(&self.ret);

let pb_type = format_ident!("{}", utils::to_camel_case(&name));
Expand All @@ -82,6 +86,7 @@ impl FunctionAttr {
unsafe { crate::sig::func::_register(#descriptor_type {
func: risingwave_pb::expr::expr_node::Type::#pb_type,
inputs_type: &[#(#args),*],
variadic: #variadic,
ret_type: #ret,
build: #build_fn,
deprecated: #deprecated,
Expand All @@ -99,7 +104,8 @@ impl FunctionAttr {
user_fn: &UserFunctionAttr,
optimize_const: bool,
) -> Result<TokenStream2> {
let num_args = self.args.len();
let variadic = matches!(self.args.last(), Some(t) if t == "...");
let num_args = self.args.len() - if variadic { 1 } else { 0 };
let fn_name = format_ident!("{}", user_fn.name);
let struct_name = match optimize_const {
true => format_ident!("{}OptimizeConst", utils::to_camel_case(&self.ident_name())),
Expand Down Expand Up @@ -148,8 +154,6 @@ impl FunctionAttr {
let inputs = idents("i", &children_indices);
let prebuilt_inputs = idents("i", &prebuilt_indices);
let non_prebuilt_inputs = idents("i", &non_prebuilt_indices);
let all_child = idents("child", &(0..num_args).collect_vec());
let child = idents("child", &children_indices);
let array_refs = idents("array", &children_indices);
let arrays = idents("a", &children_indices);
let datums = idents("v", &children_indices);
Expand Down Expand Up @@ -210,15 +214,41 @@ impl FunctionAttr {
} else {
quote! { () }
};
let generic = if self.ret == "boolean" && user_fn.generic == 3 {

// ensure the number of children matches the number of arguments
let check_children = match variadic {
true => quote! { crate::ensure!(children.len() >= #num_args); },
false => quote! { crate::ensure!(children.len() == #num_args); },
};

// evaluate variadic arguments in `eval`
let eval_variadic = variadic.then(|| {
quote! {
let mut columns = Vec::with_capacity(self.children.len() - #num_args);
for child in &self.children[#num_args..] {
columns.push(child.eval_checked(input).await?);
}
let variadic_input = DataChunk::new(columns, input.vis().clone());
}
});
// evaluate variadic arguments in `eval_row`
let eval_row_variadic = variadic.then(|| {
quote! {
let mut row = Vec::with_capacity(self.children.len() - #num_args);
for child in &self.children[#num_args..] {
row.push(child.eval_row(input).await?);
}
let variadic_row = OwnedRow::new(row);
}
});

let generic = (self.ret == "boolean" && user_fn.generic == 3).then(|| {
// XXX: for generic compare functions, we need to specify the compatible type
let compatible_type = types::ref_type(types::min_compatible_type(&self.args))
.parse::<TokenStream2>()
.unwrap();
quote! { ::<_, _, #compatible_type> }
} else {
quote! {}
};
});
let prebuilt_arg = match (&self.prebuild, optimize_const) {
// use the prebuilt argument
(Some(_), true) => quote! { &self.prebuilt_arg, },
Expand All @@ -227,18 +257,21 @@ impl FunctionAttr {
// no prebuilt argument
(None, _) => quote! {},
};
let context = match user_fn.context {
true => quote! { &self.context, },
false => quote! {},
};
let writer = match user_fn.write {
true => quote! { &mut writer, },
false => quote! {},
};
let variadic_args = variadic.then(|| quote! { variadic_row, });
let context = user_fn.context.then(|| quote! { &self.context, });
let writer = user_fn.write.then(|| quote! { &mut writer, });
let await_ = user_fn.async_.then(|| quote! { .await });
// call the user defined function
// inputs: [ Option<impl ScalarRef> ]
let mut output = quote! { #fn_name #generic(#(#non_prebuilt_inputs,)* #prebuilt_arg #context #writer) #await_ };
let mut output = quote! { #fn_name #generic(
#(#non_prebuilt_inputs,)*
#prebuilt_arg
#variadic_args
#context
#writer
) #await_ };
// handle error if the function returns `Result`
// wrap a `Some` if the function doesn't return `Option`
output = match user_fn.return_type_kind {
// XXX: we don't support void type yet. return null::int for now.
_ if self.ret == "void" => quote! { { #output; Option::<i32>::None } },
Expand All @@ -257,7 +290,7 @@ impl FunctionAttr {
}
};
};
// output: Option<impl ScalarRef or Scalar>
// now the `output` is: Option<impl ScalarRef or Scalar>
let append_output = match user_fn.write {
true => quote! {{
let mut writer = builder.writer().begin();
Expand Down Expand Up @@ -292,13 +325,20 @@ impl FunctionAttr {
};
// the main body in `eval`
let eval = if let Some(batch_fn) = &self.batch_fn {
assert!(
!variadic,
"customized batch function is not supported for variadic functions"
);
// user defined batch function
let fn_name = format_ident!("{}", batch_fn);
quote! {
let c = #fn_name(#(#arrays),*);
Ok(Arc::new(c.into()))
}
} else if (types::is_primitive(&self.ret) || self.ret == "boolean") && user_fn.is_pure() {
} else if (types::is_primitive(&self.ret) || self.ret == "boolean")
&& user_fn.is_pure()
&& !variadic
{
// SIMD optimization for primitive types
match self.args.len() {
0 => quote! {
Expand Down Expand Up @@ -330,27 +370,34 @@ impl FunctionAttr {
}
} else {
// no optimization
let array_zip = match num_args {
let array_zip = match children_indices.len() {
0 => quote! { std::iter::repeat(()).take(input.capacity()) },
_ => quote! { multizip((#(#arrays.iter(),)*)) },
};
let let_variadic = variadic.then(|| {
quote! {
let variadic_row = variadic_input.row_at_unchecked_vis(i);
}
});
quote! {
let mut builder = #builder_type::with_type(input.capacity(), self.context.return_type.clone());

match input.vis() {
Vis::Bitmap(vis) => {
// allow using `zip` for performance
#[allow(clippy::disallowed_methods)]
for ((#(#inputs,)*), visible) in #array_zip.zip(vis.iter()) {
for (i, ((#(#inputs,)*), visible)) in #array_zip.zip(vis.iter()).enumerate() {
if !visible {
builder.append_null();
continue;
}
#let_variadic
#append_output
}
}
Vis::Compact(_) => {
for (#(#inputs,)*) in #array_zip {
for (i, (#(#inputs,)*)) in #array_zip.enumerate() {
#let_variadic
#append_output
}
}
Expand All @@ -374,19 +421,17 @@ impl FunctionAttr {
use crate::expr::{Context, BoxedExpression};
use crate::Result;

crate::ensure!(children.len() == #num_args);
#check_children
let prebuilt_arg = #prebuild_const;
let context = Context {
return_type,
arg_types: children.iter().map(|c| c.return_type()).collect(),
};
let mut iter = children.into_iter();
#(let #all_child = iter.next().unwrap();)*

#[derive(Debug)]
struct #struct_name {
context: Context,
#(#child: BoxedExpression,)*
children: Vec<BoxedExpression>,
prebuilt_arg: #prebuilt_arg_type,
}
#[async_trait::async_trait]
Expand All @@ -395,25 +440,26 @@ impl FunctionAttr {
self.context.return_type.clone()
}
async fn eval(&self, input: &DataChunk) -> Result<ArrayRef> {
// evaluate children and downcast arrays
#(
let #array_refs = self.#child.eval_checked(input).await?;
let #array_refs = self.children[#children_indices].eval_checked(input).await?;
let #arrays: &#arg_arrays = #array_refs.as_ref().into();
)*
#eval_variadic
#eval
}
async fn eval_row(&self, input: &OwnedRow) -> Result<Datum> {
#(
let #datums = self.#child.eval_row(input).await?;
let #datums = self.children[#children_indices].eval_row(input).await?;
let #inputs: Option<#arg_types> = #datums.as_ref().map(|s| s.as_scalar_ref_impl().try_into().unwrap());
)*
#eval_row_variadic
Ok(#row_output)
}
}

Ok(Box::new(#struct_name {
context,
#(#child,)*
children,
prebuilt_arg,
}))
}
Expand Down
27 changes: 23 additions & 4 deletions src/expr/macro/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ mod utils;
/// - [Rust Function Signature](#rust-function-signature)
/// - [Nullable Arguments](#nullable-arguments)
/// - [Return Value](#return-value)
/// - [Variadic Function](#variadic-function)
/// - [Optimization](#optimization)
/// - [Functions Returning Strings](#functions-returning-strings)
/// - [Preprocessing Constant Arguments](#preprocessing-constant-arguments)
Expand All @@ -62,13 +63,15 @@ mod utils;
/// invocation. The signature follows this pattern:
///
/// ```text
/// name ( [arg_types],* ) [ -> [setof] return_type ]
/// name ( [arg_types],* [...] ) [ -> [setof] return_type ]
/// ```
///
/// Where `name` is the function name, which must match the function name defined in `prost`.
/// Where `name` is the function name in `snake_case`, which must match the function name defined
/// in `prost`.
///
/// The allowed data types are listed in the `name` column of the appendix's [type matrix].
/// Wildcards or `auto` can also be used, as explained below.
/// `arg_types` is a comma-separated list of argument types. The allowed data types are listed in
/// in the `name` column of the appendix's [type matrix]. Wildcards or `auto` can also be used, as
/// explained below. If the function is variadic, the last argument can be denoted as `...`.
///
/// When `setof` appears before the return type, this indicates that the function is a set-returning
/// function (table function), meaning it can return multiple values instead of just one. For more
Expand Down Expand Up @@ -203,6 +206,21 @@ mod utils;
///
/// Therefore, try to avoid returning `Option` and `Result` whenever possible.
///
/// ## Variadic Function
///
/// Variadic functions accept a `impl Row` input to represent tailing arguments.
/// For example:
///
/// ```ignore
/// #[function("concat_ws(varchar, ...) -> varchar")]
/// fn concat_ws(sep: &str, vals: impl Row) -> Option<Box<str>> {
/// let mut string_iter = vals.iter().flatten();
/// // ...
/// }
/// ```
///
/// See `risingwave_common::row::Row` for more details.
///
/// ## Functions Returning Strings
///
/// For functions that return varchar types, you can also use the writer style function signature to
Expand Down Expand Up @@ -569,6 +587,7 @@ impl FunctionAttr {
fn ident_name(&self) -> String {
format!("{}_{}_{}", self.name, self.args.join("_"), self.ret)
.replace("[]", "list")
.replace("...", "variadic")
.replace(['<', '>', ' ', ','], "_")
.replace("__", "_")
}
Expand Down
7 changes: 4 additions & 3 deletions src/expr/macro/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,10 @@ const TYPE_MATRIX: &str = "
/// Maps a data type to its corresponding data type name.
pub fn data_type(ty: &str) -> &str {
// XXX:
// For functions that contain `any` type, there are special handlings in the frontend,
// and the signature won't be accessed. So we simply return a placeholder here.
if ty == "any" {
// For functions that contain `any` type, or `...` variable arguments,
// there are special handlings in the frontend, and the signature won't be accessed.
// So we simply return a placeholder here.
if ty == "any" || ty == "..." {
return "Int32";
}
lookup_matrix(ty, 1)
Expand Down
3 changes: 3 additions & 0 deletions src/expr/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ pub enum ExprError {
#[error("field name must not be null")]
FieldNameNull,

#[error("too few arguments for format()")]
TooFewArguments,

#[error("invalid state: {0}")]
InvalidState(String),
}
Expand Down
5 changes: 0 additions & 5 deletions src/expr/src/expr/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,8 @@ use risingwave_pb::expr::ExprNode;
use super::expr_array_transform::ArrayTransformExpression;
use super::expr_case::CaseExpression;
use super::expr_coalesce::CoalesceExpression;
use super::expr_concat_ws::ConcatWsExpression;
use super::expr_field::FieldExpression;
use super::expr_in::InExpression;
use super::expr_nested_construct::NestedConstructExpression;
use super::expr_some_all::SomeAllExpression;
use super::expr_udf::UdfExpression;
use super::expr_vnode::VnodeExpression;
Expand Down Expand Up @@ -56,10 +54,7 @@ pub fn build_from_prost(prost: &ExprNode) -> Result<BoxedExpression> {
E::In => InExpression::try_from_boxed(prost),
E::Case => CaseExpression::try_from_boxed(prost),
E::Coalesce => CoalesceExpression::try_from_boxed(prost),
E::ConcatWs => ConcatWsExpression::try_from_boxed(prost),
E::Field => FieldExpression::try_from_boxed(prost),
E::Array => NestedConstructExpression::try_from_boxed(prost),
E::Row => NestedConstructExpression::try_from_boxed(prost),
E::Vnode => VnodeExpression::try_from_boxed(prost),

_ => {
Expand Down
Loading