Skip to content

Commit

Permalink
feat(expr): Implement lambda function and array_transform (#11937)
Browse files Browse the repository at this point in the history
Signed-off-by: TennyZhuang <[email protected]>
Co-authored-by: stonepage <[email protected]>
  • Loading branch information
2 people authored and Li0k committed Sep 15, 2023
1 parent 63262e4 commit c8fdb9d
Show file tree
Hide file tree
Showing 21 changed files with 442 additions and 29 deletions.
56 changes: 56 additions & 0 deletions e2e_test/batch/functions/array_transform.slt.part
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
statement ok
SET RW_IMPLICIT_FLUSH TO true;

query T
select array_transform('{1,2,3}'::int[], |x| x * 2);
----
{2,4,6}

query T
select array_transform('{1,2,3}'::int[], |x| (x::double precision+0.5));
----
{1.5,2.5,3.5}

query T
select array_transform('{1,2,3}'::int[], |x| (x::double precision+0.5));
----
{1.5,2.5,3.5}

query T
select array_transform(
array_transform(
array_transform('{1,2,3}'::int[], |x| x * 2),
|x| x + 0.5
),
|x| concat(x::varchar, '!')
)
----
{2.5!,4.5!,6.5!}

query T
select array_transform(
ARRAY['Apple', 'Airbnb', 'Amazon', 'Facebook', 'Google', 'Microsoft', 'Netflix', 'Uber'],
|x| case when x ilike 'A%' then 'A' else 'Other' end
)
----
{A,A,A,Other,Other,Other,Other,Other}

statement ok
create table t(v int, arr int[]);

statement ok
insert into t values (4, '{1,2,3}'), (5, '{4,5,6,8}');

# this makes sure `x + 1` is not extracted as common sub-expression by accident. See #11766
query TT
select array_transform(arr, |x| x + 1), array_transform(arr, |x| x + 1 + 2) from t;
----
{2,3,4} {4,5,6}
{5,6,7,9} {7,8,9,11}

# this clarifies that we do not support referencing columns.
statement error
select array_transform(arr, |x| x + v) from t;

statement ok
drop table t;
1 change: 1 addition & 0 deletions proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ message ExprNode {
ARRAY_POSITION = 542;
ARRAY_REPLACE = 543;
ARRAY_DIMS = 544;
ARRAY_TRANSFORM = 545;

// Int256 functions
HEX_TO_INT256 = 560;
Expand Down
26 changes: 26 additions & 0 deletions src/common/src/array/list_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
use core::fmt;
use std::cmp::Ordering;
use std::fmt::Debug;
use std::future::Future;
use std::hash::Hash;
use std::mem::size_of;

Expand Down Expand Up @@ -270,6 +271,31 @@ impl ListArray {
Ok(arr.into())
}

/// Apply the function on the underlying elements.
/// e.g. `map_inner([[1,2,3],NULL,[4,5]], DOUBLE) = [[2,4,6],NULL,[8,10]]`
pub async fn map_inner<E, Fut, F>(self, f: F) -> std::result::Result<ListArray, E>
where
F: FnOnce(ArrayImpl) -> Fut,
Fut: Future<Output = std::result::Result<ArrayImpl, E>>,
{
let Self {
bitmap,
offsets,
value,
..
} = self;

let new_value = (f)(*value).await?;
let new_value_type = new_value.data_type();

Ok(Self {
offsets,
bitmap,
value: Box::new(new_value),
value_type: new_value_type,
})
}

// Used for testing purposes
pub fn from_iter(
values: impl IntoIterator<Item = Option<ArrayImpl>>,
Expand Down
7 changes: 7 additions & 0 deletions src/expr/src/expr/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use risingwave_pb::expr::expr_node::{PbType, RexNode};
use risingwave_pb::expr::ExprNode;

use super::expr_array_concat::ArrayConcatExpression;
use super::expr_array_transform::ArrayTransformExpression;
use super::expr_case::CaseExpression;
use super::expr_coalesce::CoalesceExpression;
use super::expr_concat_ws::ConcatWsExpression;
Expand Down Expand Up @@ -92,6 +93,12 @@ pub fn build_func(
ret_type: DataType,
children: Vec<BoxedExpression>,
) -> Result<BoxedExpression> {
if func == PbType::ArrayTransform {
// TODO: The function framework can't handle the lambda arg now.
let [array, lambda] = <[BoxedExpression; 2]>::try_from(children).unwrap();
return Ok(ArrayTransformExpression { array, lambda }.boxed());
}

let args = children
.iter()
.map(|c| c.return_type().into())
Expand Down
67 changes: 67 additions & 0 deletions src/expr/src/expr/expr_array_transform.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Copyright 2023 RisingWave Labs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use std::sync::Arc;

use async_trait::async_trait;
use risingwave_common::array::{ArrayRef, DataChunk, Vis};
use risingwave_common::row::OwnedRow;
use risingwave_common::types::{DataType, Datum, ListValue, ScalarImpl};

use super::{BoxedExpression, Expression};
use crate::Result;

#[derive(Debug)]
pub struct ArrayTransformExpression {
pub(super) array: BoxedExpression,
pub(super) lambda: BoxedExpression,
}

#[async_trait]
impl Expression for ArrayTransformExpression {
fn return_type(&self) -> DataType {
DataType::List(Box::new(self.lambda.return_type()))
}

async fn eval(&self, input: &DataChunk) -> Result<ArrayRef> {
let lambda_input = self.array.eval_checked(input).await?;
let lambda_input = Arc::unwrap_or_clone(lambda_input).into_list();
let new_list = lambda_input
.map_inner(|flatten_input| async move {
let flatten_len = flatten_input.len();
let chunk =
DataChunk::new(vec![Arc::new(flatten_input)], Vis::Compact(flatten_len));
self.lambda.eval(&chunk).await.map(Arc::unwrap_or_clone)
})
.await?;
Ok(Arc::new(new_list.into()))
}

async fn eval_row(&self, input: &OwnedRow) -> Result<Datum> {
let lambda_input = self.array.eval_row(input).await?;
let lambda_input = lambda_input.map(ScalarImpl::into_list);
if let Some(lambda_input) = lambda_input {
let mut new_vals = Vec::with_capacity(lambda_input.values().len());
for val in lambda_input.values() {
let row = OwnedRow::new(vec![val.clone()]);
let res = self.lambda.eval_row(&row).await?;
new_vals.push(res);
}
let new_list = ListValue::new(new_vals);
Ok(Some(new_list.into()))
} else {
Ok(None)
}
}
}
1 change: 1 addition & 0 deletions src/expr/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
// These modules define concrete expression structures.
mod expr_array_concat;
mod expr_array_to_string;
mod expr_array_transform;
mod expr_binary_nonnull;
mod expr_binary_nullable;
mod expr_case;
Expand Down
1 change: 1 addition & 0 deletions src/expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#![feature(round_ties_even)]
#![feature(generators)]
#![feature(test)]
#![feature(arc_unwrap_or_clone)]

pub mod agg;
mod error;
Expand Down
3 changes: 3 additions & 0 deletions src/frontend/src/binder/bind_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use std::rc::Rc;
use parse_display::Display;
use risingwave_common::catalog::Field;
use risingwave_common::error::{ErrorCode, Result};
use risingwave_common::types::DataType;
use risingwave_sqlparser::ast::TableAlias;

type LiteResult<T> = std::result::Result<T, ErrorCode>;
Expand Down Expand Up @@ -79,6 +80,8 @@ pub struct BindContext {
/// Map the cte's name to its Relation::Subquery.
/// The `ShareId` of the value is used to help the planner identify the share plan.
pub cte_to_relation: HashMap<String, Rc<(ShareId, BoundQuery, TableAlias)>>,
/// Current lambda functions's arguments
pub lambda_args: Option<HashMap<String, (usize, DataType)>>,
}

/// Holds the context for the `BindContext`'s `ColumnGroup`s.
Expand Down
3 changes: 3 additions & 0 deletions src/frontend/src/binder/bind_param.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ impl ExprRewriter for ParamRewriter {
ExprImpl::InputRef(inner) => self.rewrite_input_ref(*inner),
ExprImpl::Literal(inner) => self.rewrite_literal(*inner),
ExprImpl::FunctionCall(inner) => self.rewrite_function_call(*inner),
ExprImpl::FunctionCallWithLambda(inner) => {
self.rewrite_function_call_with_lambda(*inner)
}
ExprImpl::AggCall(inner) => self.rewrite_agg_call(*inner),
ExprImpl::Subquery(inner) => self.rewrite_subquery(*inner),
ExprImpl::CorrelatedInputRef(inner) => self.rewrite_correlated_input_ref(*inner),
Expand Down
80 changes: 77 additions & 3 deletions src/frontend/src/binder/expr/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@ use risingwave_expr::function::window::{
Frame, FrameBound, FrameBounds, FrameExclusion, WindowFuncKind,
};
use risingwave_sqlparser::ast::{
Function, FunctionArg, FunctionArgExpr, WindowFrameBound, WindowFrameExclusion,
self, Function, FunctionArg, FunctionArgExpr, Ident, WindowFrameBound, WindowFrameExclusion,
WindowFrameUnits, WindowSpec,
};

use crate::binder::bind_context::Clause;
use crate::binder::{Binder, BoundQuery, BoundSetExpr};
use crate::expr::{
AggCall, Expr, ExprImpl, ExprType, FunctionCall, Literal, Now, OrderBy, Subquery, SubqueryKind,
TableFunction, TableFunctionType, UserDefinedFunction, WindowFunction,
AggCall, Expr, ExprImpl, ExprType, FunctionCall, FunctionCallWithLambda, Literal, Now, OrderBy,
Subquery, SubqueryKind, TableFunction, TableFunctionType, UserDefinedFunction, WindowFunction,
};
use crate::utils::Condition;

Expand Down Expand Up @@ -101,6 +101,11 @@ impl Binder {
return Ok(ExprImpl::literal_varchar("".to_string()));
}

if function_name == "array_transform" {
// For type inference, we need to bind the array type first.
return self.bind_array_transform(f);
}

let inputs = f
.args
.into_iter()
Expand Down Expand Up @@ -154,6 +159,75 @@ impl Binder {
self.bind_builtin_scalar_function(function_name.as_str(), inputs)
}

fn bind_array_transform(&mut self, f: Function) -> Result<ExprImpl> {
let [array, lambda] = <[FunctionArg; 2]>::try_from(f.args).map_err(|args| -> RwError {
ErrorCode::BindError(format!(
"`array_transform` expect two inputs `array` and `lambda`, but {} were given",
args.len()
))
.into()
})?;

let bound_array = self.bind_function_arg(array)?;
let [bound_array] = <[ExprImpl; 1]>::try_from(bound_array).map_err(|bound_array| -> RwError {
ErrorCode::BindError(format!("The `array` argument for `array_transform` should be bound to one argument, but {} were got", bound_array.len()))
.into()
})?;

let inner_ty = match bound_array.return_type() {
DataType::List(ty) => *ty,
real_type => {
return Err(ErrorCode::BindError(format!(
"The `array` argument for `array_transform` should be an array, but {} were got",
real_type
))
.into())
}
};

let ast::FunctionArgExpr::Expr(ast::Expr::LambdaFunction { args: lambda_args, body: lambda_body }) = lambda.get_expr() else {
return Err(ErrorCode::BindError(
"The `lambda` argument for `array_transform` should be a lambda function".to_string()
).into());
};

let [lambda_arg] = <[Ident; 1]>::try_from(lambda_args).map_err(|args| -> RwError {
ErrorCode::BindError(format!(
"The `lambda` argument for `array_transform` should be a lambda function with one argument, but {} were given",
args.len()
))
.into()
})?;

let bound_lambda = self.bind_unary_lambda_function(inner_ty, lambda_arg, *lambda_body)?;

let lambda_ret_type = bound_lambda.return_type();
let transform_ret_type = DataType::List(Box::new(lambda_ret_type));

Ok(ExprImpl::FunctionCallWithLambda(Box::new(
FunctionCallWithLambda::new_unchecked(
ExprType::ArrayTransform,
vec![bound_array],
bound_lambda,
transform_ret_type,
),
)))
}

fn bind_unary_lambda_function(
&mut self,
input_ty: DataType,
arg: Ident,
body: ast::Expr,
) -> Result<ExprImpl> {
let lambda_args = HashMap::from([(arg.real_value(), (0usize, input_ty))]);
let orig_lambda_args = self.context.lambda_args.replace(lambda_args);
let body = self.bind_expr_inner(body)?;
self.context.lambda_args = orig_lambda_args;

Ok(body)
}

pub(super) fn bind_agg(&mut self, f: Function, kind: AggKind) -> Result<ExprImpl> {
self.ensure_aggregate_allowed()?;

Expand Down
13 changes: 12 additions & 1 deletion src/frontend/src/binder/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use risingwave_sqlparser::ast::{

use crate::binder::expr::function::SYS_FUNCTION_WITHOUT_ARGS;
use crate::binder::Binder;
use crate::expr::{Expr as _, ExprImpl, ExprType, FunctionCall, Parameter, SubqueryKind};
use crate::expr::{Expr as _, ExprImpl, ExprType, FunctionCall, InputRef, Parameter, SubqueryKind};

mod binary_op;
mod column;
Expand Down Expand Up @@ -97,6 +97,17 @@ impl Binder {
// NOTE: Here we don't 100% follow the behavior of Postgres, as it doesn't
// allow `session_user()` while we do.
self.bind_function(Function::no_arg(ObjectName(vec![ident])))
} else if let Some(ref lambda_args) = self.context.lambda_args {
// We don't support capture, so if the expression is in the lambda context,
// we'll not bind it for table columns.
if let Some((arg_idx, arg_type)) = lambda_args.get(&ident.real_value()) {
Ok(InputRef::new(*arg_idx, arg_type.clone()).into())
} else {
Err(
ErrorCode::ItemNotFound(format!("Unknown arg: {}", ident.real_value()))
.into(),
)
}
} else {
self.bind_column(&[ident])
}
Expand Down
Loading

0 comments on commit c8fdb9d

Please sign in to comment.