diff --git a/Cargo.toml b/Cargo.toml index 2ba02c575f4d2..18adb9ede927c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,6 +16,7 @@ members = [ "src/expr/core", "src/expr/impl", "src/expr/macro", + "src/expr/udf", "src/frontend", "src/frontend/planner_test", "src/java_binding", @@ -48,7 +49,6 @@ members = [ "src/tests/simulation", "src/tests/sqlsmith", "src/tests/state_cleaning_test", - "src/udf", "src/utils/local_stats_alloc", "src/utils/pgwire", "src/utils/runtime", @@ -170,7 +170,7 @@ risingwave_sqlsmith = { path = "./src/tests/sqlsmith" } risingwave_storage = { path = "./src/storage" } risingwave_stream = { path = "./src/stream" } risingwave_test_runner = { path = "./src/test_runner" } -risingwave_udf = { path = "./src/udf" } +risingwave_udf = { path = "./src/expr/udf" } risingwave_variables = { path = "./src/utils/variables" } risingwave_java_binding = { path = "./src/java_binding" } risingwave_jni_core = { path = "src/jni_core" } diff --git a/ci/scripts/run-unit-test.sh b/ci/scripts/run-unit-test.sh index c1b7a1b71782d..6f2093060f370 100755 --- a/ci/scripts/run-unit-test.sh +++ b/ci/scripts/run-unit-test.sh @@ -6,7 +6,7 @@ set -euo pipefail REPO_ROOT=${PWD} echo "+++ Run python UDF SDK unit tests" -cd ${REPO_ROOT}/src/udf/python +cd ${REPO_ROOT}/src/expr/udf/python python3 -m pytest cd ${REPO_ROOT} diff --git a/e2e_test/udf/test.py b/e2e_test/udf/test.py index a0089c2e4b1b0..999c42ec53011 100644 --- a/e2e_test/udf/test.py +++ b/e2e_test/udf/test.py @@ -19,7 +19,7 @@ from typing import Iterator, List, Optional, Tuple, Any from decimal import Decimal -sys.path.append("src/udf/python") # noqa +sys.path.append("src/expr/udf/python") # noqa from risingwave.udf import udf, udtf, UdfServer diff --git a/integration_tests/feature-store/server/udf.py b/integration_tests/feature-store/server/udf.py index 51651b3534330..fc2bc2fc883e0 100644 --- a/integration_tests/feature-store/server/udf.py +++ b/integration_tests/feature-store/server/udf.py @@ -2,22 +2,20 @@ from typing import Iterator, List, Optional, Tuple, Any from decimal import Decimal -sys.path.append("src/udf/python") # noqa +sys.path.append("src/expr/udf/python") # noqa from risingwave.udf import udf, UdfServer - @udf(input_types=["INT", "VARCHAR"], result_type="INT") def udf_sum(x: int, y: str) -> int: - if y=='mfa+': + if y == "mfa+": return x else: return -x - if __name__ == "__main__": server = UdfServer(location="0.0.0.0:8815") server.add_function(udf_sum) - server.serve() \ No newline at end of file + server.serve() diff --git a/src/common/src/array/data_chunk.rs b/src/common/src/array/data_chunk.rs index b8cd84dbcc932..fff5efc22d1f8 100644 --- a/src/common/src/array/data_chunk.rs +++ b/src/common/src/array/data_chunk.rs @@ -733,7 +733,7 @@ pub trait DataChunkTestExt { /// // T: str /// // TS: Timestamp /// // SRL: Serial - /// // {i,f}: struct + /// // : struct /// ``` fn from_pretty(s: &str) -> Self; @@ -783,7 +783,7 @@ impl DataChunkTestExt for DataChunk { "TZ" => DataType::Timestamptz, "T" => DataType::Varchar, "SRL" => DataType::Serial, - array if array.starts_with('{') && array.ends_with('}') => { + array if array.starts_with('<') && array.ends_with('>') => { DataType::Struct(StructType::unnamed( array[1..array.len() - 1] .split(',') diff --git a/src/common/src/array/stream_chunk.rs b/src/common/src/array/stream_chunk.rs index 689d3cd0df1dc..e024d22ec5172 100644 --- a/src/common/src/array/stream_chunk.rs +++ b/src/common/src/array/stream_chunk.rs @@ -604,7 +604,7 @@ impl StreamChunkTestExt for StreamChunk { /// // TZ: Timestamptz /// // SRL: Serial /// // x[]: array of x - /// // {i,f}: struct + /// // : struct /// ``` fn from_pretty(s: &str) -> Self { let mut chunk_str = String::new(); diff --git a/src/common/src/types/mod.rs b/src/common/src/types/mod.rs index 417ea6a2e0caf..75139c6bed5bf 100644 --- a/src/common/src/types/mod.rs +++ b/src/common/src/types/mod.rs @@ -145,7 +145,7 @@ pub enum DataType { #[from_str(regex = "(?i)^interval$")] Interval, #[display("{0}")] - #[from_str(ignore)] + #[from_str(regex = "(?i)^(?P<0>.+)$")] Struct(StructType), #[display("{0}[]")] #[from_str(regex = r"(?i)^(?P<0>.+)\[\]$")] @@ -933,7 +933,7 @@ impl ScalarImpl { Self::List(ListValue::new(builder.finish())) } DataType::Struct(s) => { - if !(str.starts_with('{') && str.ends_with('}')) { + if !(str.starts_with('(') && str.ends_with(')')) { return Err(FromSqlError::from_text(str)); } let mut fields = Vec::with_capacity(s.len()); @@ -1466,5 +1466,17 @@ mod tests { DataType::from_str("interval[]").unwrap(), DataType::List(Box::new(DataType::Interval)) ); + + assert_eq!( + DataType::from_str("record").unwrap(), + DataType::Struct(StructType::unnamed(vec![])) + ); + assert_eq!( + DataType::from_str("struct").unwrap(), + DataType::Struct(StructType::new(vec![ + ("a", DataType::Int32), + ("b", DataType::Varchar) + ])) + ); } } diff --git a/src/common/src/types/struct_type.rs b/src/common/src/types/struct_type.rs index 7aca0e727c82e..91d631b6801b4 100644 --- a/src/common/src/types/struct_type.rs +++ b/src/common/src/types/struct_type.rs @@ -16,6 +16,7 @@ use std::fmt::{Debug, Display, Formatter}; use std::str::FromStr; use std::sync::Arc; +use anyhow::anyhow; use itertools::Itertools; use super::DataType; @@ -153,10 +154,12 @@ impl FromStr for StructType { if s == "record" { return Ok(StructType::unnamed(Vec::new())); } - let s = s.trim_start_matches("struct<").trim_end_matches('>'); + if !(s.starts_with("struct<") && s.ends_with('>')) { + return Err(anyhow!("expect struct<...>")); + }; let mut field_types = Vec::new(); let mut field_names = Vec::new(); - for field in s.split(',') { + for field in s[7..s.len() - 1].split(',') { let field = field.trim(); let mut iter = field.split_whitespace(); let field_name = iter.next().unwrap(); diff --git a/src/connector/src/sink/formatter/debezium_json.rs b/src/connector/src/sink/formatter/debezium_json.rs index 091f63f63f7e3..f3e877c44829e 100644 --- a/src/connector/src/sink/formatter/debezium_json.rs +++ b/src/connector/src/sink/formatter/debezium_json.rs @@ -337,17 +337,17 @@ mod tests { #[test] fn test_chunk_to_json() -> Result<()> { let chunk = StreamChunk::from_pretty( - " i f {i,f} - + 0 0.0 {0,0.0} - + 1 1.0 {1,1.0} - + 2 2.0 {2,2.0} - + 3 3.0 {3,3.0} - + 4 4.0 {4,4.0} - + 5 5.0 {5,5.0} - + 6 6.0 {6,6.0} - + 7 7.0 {7,7.0} - + 8 8.0 {8,8.0} - + 9 9.0 {9,9.0}", + " i f + + 0 0.0 (0,0.0) + + 1 1.0 (1,1.0) + + 2 2.0 (2,2.0) + + 3 3.0 (3,3.0) + + 4 4.0 (4,4.0) + + 5 5.0 (5,5.0) + + 6 6.0 (6,6.0) + + 7 7.0 (7,7.0) + + 8 8.0 (8,8.0) + + 9 9.0 (9,9.0)", ); let schema = Schema::new(vec![ diff --git a/src/expr/core/src/expr/build.rs b/src/expr/core/src/expr/build.rs index 46c672d6da521..8e4cb8439e7df 100644 --- a/src/expr/core/src/expr/build.rs +++ b/src/expr/core/src/expr/build.rs @@ -19,13 +19,8 @@ use risingwave_common::types::{DataType, ScalarImpl}; use risingwave_pb::expr::expr_node::{PbType, RexNode}; use risingwave_pb::expr::ExprNode; -use super::expr_array_transform::ArrayTransformExpression; -use super::expr_coalesce::CoalesceExpression; -use super::expr_field::FieldExpression; -use super::expr_in::InExpression; use super::expr_some_all::SomeAllExpression; use super::expr_udf::UdfExpression; -use super::expr_vnode::VnodeExpression; use super::wrapper::checked::Checked; use super::wrapper::non_strict::NonStrict; use super::wrapper::EvalErrorReport; @@ -112,10 +107,6 @@ where RexNode::FuncCall(_) => match prost.function_type() { // Dedicated types E::All | E::Some => SomeAllExpression::build_boxed(prost, build_child), - E::In => InExpression::build_boxed(prost, build_child), - E::Coalesce => CoalesceExpression::build_boxed(prost, build_child), - E::Field => FieldExpression::build_boxed(prost, build_child), - E::Vnode => VnodeExpression::build_boxed(prost, build_child), // General types, lookup in the function signature map _ => FuncCallBuilder::build_boxed(prost, build_child), @@ -193,12 +184,6 @@ pub fn build_func( ret_type: DataType, children: Vec, ) -> Result { - if func == PbType::ArrayTransform { - // TODO: The function framework can't handle the lambda arg now. - let [array, lambda] = <[BoxedExpression; 2]>::try_from(children).unwrap(); - return Ok(ArrayTransformExpression { array, lambda }.boxed()); - } - let args = children.iter().map(|c| c.return_type()).collect_vec(); let desc = FUNCTION_REGISTRY .get(func, &args, &ret_type) @@ -300,7 +285,7 @@ impl> Parser { assert_eq!(self.tokens.next(), Some(Token::Colon), "Expected a Colon"); let ty = self.parse_type(); let value = match value.as_str() { - "null" => None, + "null" | "NULL" => None, _ => Some( ScalarImpl::from_text(value.as_bytes(), &ty).expect_str("value", &value), ), @@ -313,7 +298,10 @@ impl> Parser { fn parse_type(&mut self) -> DataType { match self.tokens.next().expect("Unexpected end of input") { - Token::Literal(name) => name.parse::().expect_str("type", &name), + Token::Literal(name) => name + .replace('_', " ") + .parse::() + .expect_str("type", &name), t => panic!("Expected a Literal, got {t:?}"), } } diff --git a/src/expr/core/src/expr/expr_coalesce.rs b/src/expr/core/src/expr/expr_coalesce.rs deleted file mode 100644 index b7916f414136c..0000000000000 --- a/src/expr/core/src/expr/expr_coalesce.rs +++ /dev/null @@ -1,202 +0,0 @@ -// Copyright 2023 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::ops::BitAnd; -use std::sync::Arc; - -use risingwave_common::array::{ArrayRef, DataChunk}; -use risingwave_common::row::OwnedRow; -use risingwave_common::types::{DataType, Datum}; -use risingwave_pb::expr::expr_node::{RexNode, Type}; -use risingwave_pb::expr::ExprNode; - -use super::Build; -use crate::expr::{BoxedExpression, Expression}; -use crate::{bail, ensure, Result}; - -#[derive(Debug)] -pub struct CoalesceExpression { - return_type: DataType, - children: Vec, -} - -#[async_trait::async_trait] -impl Expression for CoalesceExpression { - fn return_type(&self) -> DataType { - self.return_type.clone() - } - - async fn eval(&self, input: &DataChunk) -> Result { - let init_vis = input.visibility(); - let mut input = input.clone(); - let len = input.capacity(); - let mut selection: Vec> = vec![None; len]; - let mut children_array = Vec::with_capacity(self.children.len()); - for (child_idx, child) in self.children.iter().enumerate() { - let res = child.eval(&input).await?; - let res_bitmap = res.null_bitmap(); - let orig_vis = input.visibility(); - for pos in orig_vis.bitand(res_bitmap).iter_ones() { - selection[pos] = Some(child_idx); - } - let new_vis = orig_vis & !res_bitmap; - input.set_visibility(new_vis); - children_array.push(res); - } - let mut builder = self.return_type.create_array_builder(len); - for (i, sel) in selection.iter().enumerate() { - if init_vis.is_set(i) - && let Some(child_idx) = sel - { - builder.append(children_array[*child_idx].value_at(i)); - } else { - builder.append_null() - } - } - Ok(Arc::new(builder.finish())) - } - - async fn eval_row(&self, input: &OwnedRow) -> Result { - for child in &self.children { - let datum = child.eval_row(input).await?; - if datum.is_some() { - return Ok(datum); - } - } - Ok(None) - } -} - -impl CoalesceExpression { - pub fn new(return_type: DataType, children: Vec) -> Self { - CoalesceExpression { - return_type, - children, - } - } -} - -impl Build for CoalesceExpression { - fn build( - prost: &ExprNode, - build_child: impl Fn(&ExprNode) -> Result, - ) -> Result { - ensure!(prost.get_function_type().unwrap() == Type::Coalesce); - - let ret_type = DataType::from(prost.get_return_type().unwrap()); - let RexNode::FuncCall(func_call_node) = prost.get_rex_node().unwrap() else { - bail!("Expected RexNode::FuncCall"); - }; - - let children = func_call_node - .children - .to_vec() - .iter() - .map(build_child) - .collect::>>()?; - Ok(CoalesceExpression::new(ret_type, children)) - } -} - -#[cfg(test)] -mod tests { - use risingwave_common::array::DataChunk; - use risingwave_common::row::OwnedRow; - use risingwave_common::test_prelude::DataChunkTestExt; - use risingwave_common::types::{Scalar, ScalarImpl}; - use risingwave_pb::data::data_type::TypeName; - use risingwave_pb::data::PbDataType; - use risingwave_pb::expr::expr_node::RexNode; - use risingwave_pb::expr::expr_node::Type::Coalesce; - use risingwave_pb::expr::{ExprNode, FunctionCall}; - - use crate::expr::expr_coalesce::CoalesceExpression; - use crate::expr::test_utils::make_input_ref; - use crate::expr::{Build, Expression}; - - pub fn make_coalesce_function(children: Vec, ret: TypeName) -> ExprNode { - ExprNode { - function_type: Coalesce as i32, - return_type: Some(PbDataType { - type_name: ret as i32, - ..Default::default() - }), - rex_node: Some(RexNode::FuncCall(FunctionCall { children })), - } - } - - #[tokio::test] - async fn test_coalesce_expr() { - let input_node1 = make_input_ref(0, TypeName::Int32); - let input_node2 = make_input_ref(1, TypeName::Int32); - let input_node3 = make_input_ref(2, TypeName::Int32); - - let data_chunk = DataChunk::from_pretty( - "i i i - 1 . . - . 2 . - . . 3 - . . .", - ); - - let nullif_expr = CoalesceExpression::build_for_test(&make_coalesce_function( - vec![input_node1, input_node2, input_node3], - TypeName::Int32, - )) - .unwrap(); - let res = nullif_expr.eval(&data_chunk).await.unwrap(); - assert_eq!(res.datum_at(0), Some(ScalarImpl::Int32(1))); - assert_eq!(res.datum_at(1), Some(ScalarImpl::Int32(2))); - assert_eq!(res.datum_at(2), Some(ScalarImpl::Int32(3))); - assert_eq!(res.datum_at(3), None); - } - - #[tokio::test] - async fn test_eval_row_coalesce_expr() { - let input_node1 = make_input_ref(0, TypeName::Int32); - let input_node2 = make_input_ref(1, TypeName::Int32); - let input_node3 = make_input_ref(2, TypeName::Int32); - - let nullif_expr = CoalesceExpression::build_for_test(&make_coalesce_function( - vec![input_node1, input_node2, input_node3], - TypeName::Int32, - )) - .unwrap(); - - let row_inputs = vec![ - vec![Some(1), None, None, None], - vec![None, Some(2), None, None], - vec![None, None, Some(3), None], - vec![None, None, None, None], - ]; - - let expected = vec![ - Some(ScalarImpl::Int32(1)), - Some(ScalarImpl::Int32(2)), - Some(ScalarImpl::Int32(3)), - None, - ]; - - for (i, row_input) in row_inputs.iter().enumerate() { - let datum_vec = row_input - .iter() - .map(|o| o.map(|int| int.to_scalar_value())) - .collect(); - let row = OwnedRow::new(datum_vec); - - let result = nullif_expr.eval_row(&row).await.unwrap(); - assert_eq!(result, expected[i]); - } - } -} diff --git a/src/expr/core/src/expr/expr_field.rs b/src/expr/core/src/expr/expr_field.rs deleted file mode 100644 index adddc37633e76..0000000000000 --- a/src/expr/core/src/expr/expr_field.rs +++ /dev/null @@ -1,177 +0,0 @@ -// Copyright 2023 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use anyhow::{anyhow, Context}; -use risingwave_common::array::{ArrayImpl, ArrayRef, DataChunk}; -use risingwave_common::row::OwnedRow; -use risingwave_common::types::{DataType, Datum, ScalarImpl}; -use risingwave_common::util::value_encoding::DatumFromProtoExt; -use risingwave_pb::expr::expr_node::{RexNode, Type}; -use risingwave_pb::expr::ExprNode; - -use super::Build; -use crate::expr::{BoxedExpression, Expression}; -use crate::{bail, ensure, Result}; - -/// `FieldExpression` access a field from a struct. -#[derive(Debug)] -pub struct FieldExpression { - return_type: DataType, - input: BoxedExpression, - index: usize, -} - -#[async_trait::async_trait] -impl Expression for FieldExpression { - fn return_type(&self) -> DataType { - self.return_type.clone() - } - - async fn eval(&self, input: &DataChunk) -> Result { - let array = self.input.eval(input).await?; - if let ArrayImpl::Struct(struct_array) = array.as_ref() { - Ok(struct_array.field_at(self.index).clone()) - } else { - Err(anyhow!("expects a struct array ref").into()) - } - } - - async fn eval_row(&self, input: &OwnedRow) -> Result { - let struct_datum = self.input.eval_row(input).await?; - struct_datum - .map(|s| match s { - ScalarImpl::Struct(v) => Ok(v.fields()[self.index].clone()), - _ => Err(anyhow!("expects a struct array ref").into()), - }) - .transpose() - .map(|x| x.flatten()) - } -} - -impl FieldExpression { - pub fn new(return_type: DataType, input: BoxedExpression, index: usize) -> Self { - FieldExpression { - return_type, - input, - index, - } - } -} - -impl Build for FieldExpression { - fn build( - prost: &ExprNode, - build_child: impl Fn(&ExprNode) -> Result, - ) -> Result { - ensure!(prost.get_function_type().unwrap() == Type::Field); - - let ret_type = DataType::from(prost.get_return_type().unwrap()); - let RexNode::FuncCall(func_call_node) = prost.get_rex_node().unwrap() else { - bail!("Expected RexNode::FuncCall"); - }; - - let children = func_call_node.children.to_vec(); - // Field `func_call_node` have 2 child nodes, the first is Field `FuncCall` or - // `InputRef`, the second is i32 `Literal`. - let [first, second]: [_; 2] = children.try_into().unwrap(); - let input = build_child(&first)?; - let RexNode::Constant(value) = second.get_rex_node().unwrap() else { - bail!("Expected Constant as 1st argument"); - }; - let index = Datum::from_protobuf(value, &DataType::Int32) - .context("Failed to deserialize i32")? - .unwrap() - .as_int32() - .to_owned(); - - Ok(FieldExpression::new(ret_type, input, index as usize)) - } -} - -#[cfg(test)] -mod tests { - - use risingwave_common::array::{Array, DataChunk, F32Array, I32Array, StructArray}; - use risingwave_common::types::{DataType, ScalarImpl, StructType}; - use risingwave_pb::data::data_type::TypeName; - - use crate::expr::expr_field::FieldExpression; - use crate::expr::test_utils::{make_field_function, make_i32_literal, make_input_ref}; - use crate::expr::{Build, Expression}; - - #[tokio::test] - async fn test_field_expr() { - let input_node = make_input_ref(0, TypeName::Struct); - let literal_node = make_i32_literal(0); - let field_expr = FieldExpression::build_for_test(&make_field_function( - vec![input_node, literal_node], - TypeName::Int32, - )) - .unwrap(); - let array = StructArray::new( - StructType::unnamed(vec![DataType::Int32, DataType::Float32]), - vec![ - I32Array::from_iter([1, 2, 3, 4, 5]).into_ref(), - F32Array::from_iter([2.0, 2.0, 2.0, 2.0, 2.0]).into_ref(), - ], - [true].into_iter().collect(), - ); - - let data_chunk = DataChunk::new(vec![array.into_ref()], 1); - let res = field_expr.eval(&data_chunk).await.unwrap(); - assert_eq!(res.datum_at(0), Some(ScalarImpl::Int32(1))); - assert_eq!(res.datum_at(1), Some(ScalarImpl::Int32(2))); - assert_eq!(res.datum_at(2), Some(ScalarImpl::Int32(3))); - assert_eq!(res.datum_at(3), Some(ScalarImpl::Int32(4))); - assert_eq!(res.datum_at(4), Some(ScalarImpl::Int32(5))); - } - - #[tokio::test] - async fn test_nested_field_expr() { - let field_node = make_field_function( - vec![make_input_ref(0, TypeName::Struct), make_i32_literal(0)], - TypeName::Int32, - ); - let field_expr = FieldExpression::build_for_test(&make_field_function( - vec![field_node, make_i32_literal(1)], - TypeName::Int32, - )) - .unwrap(); - - let struct_array = StructArray::new( - StructType::unnamed(vec![DataType::Int32, DataType::Float32]), - vec![ - I32Array::from_iter([1, 2, 3, 4, 5]).into_ref(), - F32Array::from_iter([1.0, 2.0, 3.0, 4.0, 5.0]).into_ref(), - ], - [true].into_iter().collect(), - ); - let array = StructArray::new( - StructType::unnamed(vec![DataType::Int32, DataType::Float32]), - vec![ - struct_array.into_ref(), - F32Array::from_iter([2.0, 2.0, 2.0, 2.0, 2.0]).into_ref(), - ], - [true].into_iter().collect(), - ); - - let data_chunk = DataChunk::new(vec![array.into_ref()], 1); - let res = field_expr.eval(&data_chunk).await.unwrap(); - assert_eq!(res.datum_at(0), Some(ScalarImpl::Float32(1.0.into()))); - assert_eq!(res.datum_at(1), Some(ScalarImpl::Float32(2.0.into()))); - assert_eq!(res.datum_at(2), Some(ScalarImpl::Float32(3.0.into()))); - assert_eq!(res.datum_at(3), Some(ScalarImpl::Float32(4.0.into()))); - assert_eq!(res.datum_at(4), Some(ScalarImpl::Float32(5.0.into()))); - } -} diff --git a/src/expr/core/src/expr/expr_in.rs b/src/expr/core/src/expr/expr_in.rs deleted file mode 100644 index cbc5cd244b528..0000000000000 --- a/src/expr/core/src/expr/expr_in.rs +++ /dev/null @@ -1,279 +0,0 @@ -// Copyright 2023 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::collections::HashSet; -use std::fmt::Debug; -use std::sync::Arc; - -use futures_util::future::FutureExt; -use risingwave_common::array::{ArrayBuilder, ArrayRef, BoolArrayBuilder, DataChunk}; -use risingwave_common::row::OwnedRow; -use risingwave_common::types::{DataType, Datum, Scalar, ToOwnedDatum}; -use risingwave_common::util::iter_util::ZipEqFast; -use risingwave_common::{bail, ensure}; -use risingwave_pb::expr::expr_node::{RexNode, Type}; -use risingwave_pb::expr::ExprNode; - -use super::Build; -use crate::expr::{BoxedExpression, Expression}; -use crate::Result; - -#[derive(Debug)] -pub struct InExpression { - left: BoxedExpression, - set: HashSet, - return_type: DataType, -} - -impl InExpression { - pub fn new( - left: BoxedExpression, - data: impl Iterator, - return_type: DataType, - ) -> Self { - let mut sarg = HashSet::new(); - for datum in data { - sarg.insert(datum); - } - Self { - left, - set: sarg, - return_type, - } - } - - // Returns true if datum exists in set, null if datum is null or datum does not exist in set - // but null does, and false if neither datum nor null exists in set. - fn exists(&self, datum: &Datum) -> Option { - if datum.is_none() { - None - } else if self.set.contains(datum) { - Some(true) - } else if self.set.contains(&None) { - None - } else { - Some(false) - } - } -} - -#[async_trait::async_trait] -impl Expression for InExpression { - fn return_type(&self) -> DataType { - self.return_type.clone() - } - - async fn eval(&self, input: &DataChunk) -> Result { - let input_array = self.left.eval(input).await?; - let mut output_array = BoolArrayBuilder::new(input_array.len()); - for (data, vis) in input_array.iter().zip_eq_fast(input.visibility().iter()) { - if vis { - let ret = self.exists(&data.to_owned_datum()); - output_array.append(ret); - } else { - output_array.append(None); - } - } - Ok(Arc::new(output_array.finish().into())) - } - - async fn eval_row(&self, input: &OwnedRow) -> Result { - let data = self.left.eval_row(input).await?; - let ret = self.exists(&data); - Ok(ret.map(|b| b.to_scalar_value())) - } -} - -impl Build for InExpression { - fn build( - prost: &ExprNode, - build_child: impl Fn(&ExprNode) -> Result, - ) -> Result { - ensure!(prost.get_function_type().unwrap() == Type::In); - - let ret_type = DataType::from(prost.get_return_type().unwrap()); - let RexNode::FuncCall(func_call_node) = prost.get_rex_node().unwrap() else { - bail!("Expected RexNode::FuncCall"); - }; - let children = &func_call_node.children; - - let left_expr = build_child(&children[0])?; - let mut data = Vec::new(); - // Used for const expression below to generate datum. - // Frontend has made sure these can all be folded to constants. - let data_chunk = DataChunk::new_dummy(1); - for child in &children[1..] { - let const_expr = build_child(child)?; - let array = const_expr - .eval(&data_chunk) - .now_or_never() - .expect("constant expression should not be async")?; - let datum = array.value_at(0).to_owned_datum(); - data.push(datum); - } - Ok(InExpression::new(left_expr, data.into_iter(), ret_type)) - } -} - -#[cfg(test)] -mod tests { - use risingwave_common::array::DataChunk; - use risingwave_common::row::OwnedRow; - use risingwave_common::test_prelude::DataChunkTestExt; - use risingwave_common::types::{DataType, Datum, ScalarImpl}; - use risingwave_common::util::value_encoding::DatumToProtoExt; - use risingwave_pb::data::data_type::TypeName; - use risingwave_pb::data::PbDataType; - use risingwave_pb::expr::expr_node::{RexNode, Type}; - use risingwave_pb::expr::{ExprNode, FunctionCall}; - - use crate::expr::expr_in::InExpression; - use crate::expr::{Build, Expression, InputRefExpression}; - - #[test] - fn test_in_expr() { - let input_ref_expr_node = ExprNode { - function_type: Type::Unspecified as i32, - return_type: Some(PbDataType { - type_name: TypeName::Varchar as i32, - ..Default::default() - }), - rex_node: Some(RexNode::InputRef(0)), - }; - let constant_values = vec![ - ExprNode { - function_type: Type::Unspecified as i32, - return_type: Some(PbDataType { - type_name: TypeName::Varchar as i32, - ..Default::default() - }), - rex_node: Some(RexNode::Constant(Datum::Some("ABC".into()).to_protobuf())), - }, - ExprNode { - function_type: Type::Unspecified as i32, - return_type: Some(PbDataType { - type_name: TypeName::Varchar as i32, - ..Default::default() - }), - rex_node: Some(RexNode::Constant(Datum::Some("def".into()).to_protobuf())), - }, - ]; - let mut in_children = vec![input_ref_expr_node]; - in_children.extend(constant_values); - let call = FunctionCall { - children: in_children, - }; - let p = ExprNode { - function_type: Type::In as i32, - return_type: Some(PbDataType { - type_name: TypeName::Boolean as i32, - ..Default::default() - }), - rex_node: Some(RexNode::FuncCall(call)), - }; - assert!(InExpression::build_for_test(&p).is_ok()); - } - - #[tokio::test] - async fn test_eval_search_expr() { - let input_refs = [ - Box::new(InputRefExpression::new(DataType::Varchar, 0)), - Box::new(InputRefExpression::new(DataType::Varchar, 0)), - ]; - let data = [ - vec![ - Some(ScalarImpl::Utf8("abc".into())), - Some(ScalarImpl::Utf8("def".into())), - ], - vec![None, Some(ScalarImpl::Utf8("abc".into()))], - ]; - - let data_chunks = [ - DataChunk::from_pretty( - "T - abc - a - def - abc - .", - ) - .with_invisible_holes(), - DataChunk::from_pretty( - "T - abc - a - .", - ) - .with_invisible_holes(), - ]; - - let expected = vec![ - vec![Some(true), Some(false), Some(true), Some(true), None], - vec![Some(true), None, None], - ]; - - for (i, input_ref) in input_refs.into_iter().enumerate() { - let search_expr = - InExpression::new(input_ref, data[i].clone().into_iter(), DataType::Boolean); - let vis = data_chunks[i].visibility(); - let res = search_expr - .eval(&data_chunks[i]) - .await - .unwrap() - .compact(vis, expected[i].len()); - - for (i, expect) in expected[i].iter().enumerate() { - assert_eq!(res.datum_at(i), expect.map(ScalarImpl::Bool)); - } - } - } - - #[tokio::test] - async fn test_eval_row_search_expr() { - let input_refs = [ - Box::new(InputRefExpression::new(DataType::Varchar, 0)), - Box::new(InputRefExpression::new(DataType::Varchar, 0)), - ]; - - let data = [ - vec![ - Some(ScalarImpl::Utf8("abc".into())), - Some(ScalarImpl::Utf8("def".into())), - ], - vec![None, Some(ScalarImpl::Utf8("abc".into()))], - ]; - - let row_inputs = vec![ - vec![Some("abc"), Some("a"), Some("def"), None], - vec![Some("abc"), Some("a"), None], - ]; - - let expected = [ - vec![Some(true), Some(false), Some(true), None], - vec![Some(true), None, None], - ]; - - for (i, input_ref) in input_refs.into_iter().enumerate() { - let search_expr = - InExpression::new(input_ref, data[i].clone().into_iter(), DataType::Boolean); - - for (j, row_input) in row_inputs[i].iter().enumerate() { - let row_input = vec![row_input.map(|s| s.into())]; - let row = OwnedRow::new(row_input); - let result = search_expr.eval_row(&row).await.unwrap(); - assert_eq!(result, expected[i][j].map(ScalarImpl::Bool)); - } - } - } -} diff --git a/src/expr/core/src/expr/expr_input_ref.rs b/src/expr/core/src/expr/expr_input_ref.rs index 8e8ac4364ba64..377bcf958bfe9 100644 --- a/src/expr/core/src/expr/expr_input_ref.rs +++ b/src/expr/core/src/expr/expr_input_ref.rs @@ -44,6 +44,10 @@ impl Expression for InputRefExpression { let cell = input.index(self.idx).as_ref().cloned(); Ok(cell) } + + fn input_ref_index(&self) -> Option { + Some(self.idx) + } } impl InputRefExpression { diff --git a/src/expr/core/src/expr/expr_vnode.rs b/src/expr/core/src/expr/expr_vnode.rs deleted file mode 100644 index 200c1a2f03fa9..0000000000000 --- a/src/expr/core/src/expr/expr_vnode.rs +++ /dev/null @@ -1,162 +0,0 @@ -// Copyright 2023 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::sync::Arc; - -use risingwave_common::array::{ArrayBuilder, ArrayImpl, ArrayRef, DataChunk, I16ArrayBuilder}; -use risingwave_common::hash::VirtualNode; -use risingwave_common::row::OwnedRow; -use risingwave_common::types::{DataType, Datum}; -use risingwave_pb::expr::expr_node::{RexNode, Type}; -use risingwave_pb::expr::ExprNode; - -use super::{BoxedExpression, Build, Expression}; -use crate::expr::InputRefExpression; -use crate::{bail, ensure, Result}; - -#[derive(Debug)] -pub struct VnodeExpression { - dist_key_indices: Vec, -} - -impl VnodeExpression { - pub fn new(dist_key_indices: Vec) -> Self { - VnodeExpression { dist_key_indices } - } -} - -impl Build for VnodeExpression { - fn build( - prost: &ExprNode, - _build_child: impl Fn(&ExprNode) -> Result, - ) -> Result { - ensure!(prost.get_function_type().unwrap() == Type::Vnode); - ensure!(DataType::from(prost.get_return_type().unwrap()) == DataType::Int16); - - let RexNode::FuncCall(func_call_node) = prost.get_rex_node().unwrap() else { - bail!("Expected RexNode::FuncCall"); - }; - - let dist_key_input_refs = func_call_node - .get_children() - .iter() - .map(InputRefExpression::from_prost) - .map(|input| input.index()) - .collect(); - - Ok(VnodeExpression::new(dist_key_input_refs)) - } -} - -#[async_trait::async_trait] -impl Expression for VnodeExpression { - fn return_type(&self) -> DataType { - DataType::Int16 - } - - async fn eval(&self, input: &DataChunk) -> Result { - let vnodes = VirtualNode::compute_chunk(input, &self.dist_key_indices); - let mut builder = I16ArrayBuilder::new(input.capacity()); - vnodes - .into_iter() - .for_each(|vnode| builder.append(Some(vnode.to_scalar()))); - Ok(Arc::new(ArrayImpl::from(builder.finish()))) - } - - async fn eval_row(&self, input: &OwnedRow) -> Result { - Ok(Some( - VirtualNode::compute_row(input, &self.dist_key_indices) - .to_scalar() - .into(), - )) - } -} - -#[cfg(test)] -mod tests { - use risingwave_common::array::{DataChunk, DataChunkTestExt}; - use risingwave_common::hash::VirtualNode; - use risingwave_common::row::Row; - use risingwave_pb::data::data_type::TypeName; - use risingwave_pb::data::PbDataType; - use risingwave_pb::expr::expr_node::RexNode; - use risingwave_pb::expr::expr_node::Type::Vnode; - use risingwave_pb::expr::{ExprNode, FunctionCall}; - - use super::VnodeExpression; - use crate::expr::test_utils::make_input_ref; - use crate::expr::{Build, Expression}; - - pub fn make_vnode_function(children: Vec) -> ExprNode { - ExprNode { - function_type: Vnode as i32, - return_type: Some(PbDataType { - type_name: TypeName::Int16 as i32, - ..Default::default() - }), - rex_node: Some(RexNode::FuncCall(FunctionCall { children })), - } - } - - #[tokio::test] - async fn test_vnode_expr_eval() { - let input_node1 = make_input_ref(0, TypeName::Int32); - let input_node2 = make_input_ref(0, TypeName::Int64); - let input_node3 = make_input_ref(0, TypeName::Varchar); - let vnode_expr = VnodeExpression::build_for_test(&make_vnode_function(vec![ - input_node1, - input_node2, - input_node3, - ])) - .unwrap(); - let chunk = DataChunk::from_pretty( - "i I T - 1 10 abc - 2 32 def - 3 88 ghi", - ); - let actual = vnode_expr.eval(&chunk).await.unwrap(); - actual.iter().for_each(|vnode| { - let vnode = vnode.unwrap().into_int16(); - assert!(vnode >= 0); - assert!((vnode as usize) < VirtualNode::COUNT); - }); - } - - #[tokio::test] - async fn test_vnode_expr_eval_row() { - let input_node1 = make_input_ref(0, TypeName::Int32); - let input_node2 = make_input_ref(0, TypeName::Int64); - let input_node3 = make_input_ref(0, TypeName::Varchar); - let vnode_expr = VnodeExpression::build_for_test(&make_vnode_function(vec![ - input_node1, - input_node2, - input_node3, - ])) - .unwrap(); - let chunk = DataChunk::from_pretty( - "i I T - 1 10 abc - 2 32 def - 3 88 ghi", - ); - let rows: Vec<_> = chunk.rows().map(|row| row.into_owned_row()).collect(); - for row in rows { - let actual = vnode_expr.eval_row(&row).await.unwrap(); - let vnode = actual.unwrap().into_int16(); - assert!(vnode >= 0); - assert!((vnode as usize) < VirtualNode::COUNT); - } - } -} diff --git a/src/expr/core/src/expr/mod.rs b/src/expr/core/src/expr/mod.rs index efbba9e668469..e393fd9170564 100644 --- a/src/expr/core/src/expr/mod.rs +++ b/src/expr/core/src/expr/mod.rs @@ -33,15 +33,10 @@ // These modules define concrete expression structures. mod and_or; -mod expr_array_transform; -mod expr_coalesce; -mod expr_field; -mod expr_in; mod expr_input_ref; mod expr_literal; mod expr_some_all; pub(crate) mod expr_udf; -mod expr_vnode; pub(crate) mod wrapper; mod build; @@ -101,6 +96,11 @@ pub trait Expression: std::fmt::Debug + Sync + Send { fn eval_const(&self) -> Result { Err(ExprError::NotConstant) } + + /// Get the index if the expression is an `InputRef`. + fn input_ref_index(&self) -> Option { + None + } } /// An owned dynamically typed [`Expression`]. diff --git a/src/expr/core/src/expr/wrapper/checked.rs b/src/expr/core/src/expr/wrapper/checked.rs index b3b1375c4fa82..920a1f0e42cad 100644 --- a/src/expr/core/src/expr/wrapper/checked.rs +++ b/src/expr/core/src/expr/wrapper/checked.rs @@ -50,4 +50,8 @@ impl Expression for Checked { fn eval_const(&self) -> Result { self.0.eval_const() } + + fn input_ref_index(&self) -> Option { + self.0.input_ref_index() + } } diff --git a/src/expr/core/src/expr/wrapper/non_strict.rs b/src/expr/core/src/expr/wrapper/non_strict.rs index 1c6795fd86eb4..f28e74eed14e2 100644 --- a/src/expr/core/src/expr/wrapper/non_strict.rs +++ b/src/expr/core/src/expr/wrapper/non_strict.rs @@ -141,4 +141,8 @@ where fn eval_const(&self) -> Result { self.inner.eval_const() // do not handle error } + + fn input_ref_index(&self) -> Option { + self.inner.input_ref_index() + } } diff --git a/src/expr/core/src/expr/expr_array_transform.rs b/src/expr/impl/src/scalar/array_transform.rs similarity index 59% rename from src/expr/core/src/expr/expr_array_transform.rs rename to src/expr/impl/src/scalar/array_transform.rs index 19b678db87559..3f0ea8f276450 100644 --- a/src/expr/core/src/expr/expr_array_transform.rs +++ b/src/expr/impl/src/scalar/array_transform.rs @@ -18,14 +18,13 @@ use async_trait::async_trait; use risingwave_common::array::{ArrayRef, DataChunk}; use risingwave_common::row::OwnedRow; use risingwave_common::types::{DataType, Datum, ListValue, ScalarImpl}; - -use super::{BoxedExpression, Expression}; -use crate::Result; +use risingwave_expr::expr::{BoxedExpression, Expression}; +use risingwave_expr::{build_function, Result}; #[derive(Debug)] -pub struct ArrayTransformExpression { - pub(super) array: BoxedExpression, - pub(super) lambda: BoxedExpression, +struct ArrayTransformExpression { + array: BoxedExpression, + lambda: BoxedExpression, } #[async_trait] @@ -61,3 +60,39 @@ impl Expression for ArrayTransformExpression { } } } + +#[build_function("array_transform(anyarray, any) -> anyarray")] +fn build(_: DataType, children: Vec) -> Result { + let [array, lambda] = <[BoxedExpression; 2]>::try_from(children).unwrap(); + Ok(Box::new(ArrayTransformExpression { array, lambda })) +} + +#[cfg(test)] +mod tests { + use risingwave_common::array::{DataChunk, DataChunkTestExt}; + use risingwave_common::row::Row; + use risingwave_common::types::ToOwnedDatum; + use risingwave_common::util::iter_util::ZipEqDebug; + use risingwave_expr::expr::build_from_pretty; + + #[tokio::test] + async fn test_array_transform() { + let expr = + build_from_pretty("(array_transform:int4[] $0:int4[] (multiply:int4 $0:int4 2:int4))"); + let (input, expected) = DataChunk::from_pretty( + "i[] i[] + {1,2,3} {2,4,6}", + ) + .split_column_at(1); + + // test eval + let output = expr.eval(&input).await.unwrap(); + assert_eq!(&output, expected.column_at(0)); + + // test eval_row + for (row, expected) in input.rows().zip_eq_debug(expected.rows()) { + let result = expr.eval_row(&row.to_owned_row()).await.unwrap(); + assert_eq!(result, expected.datum_at(0).to_owned_datum()); + } + } +} diff --git a/src/expr/impl/src/scalar/coalesce.rs b/src/expr/impl/src/scalar/coalesce.rs new file mode 100644 index 0000000000000..8e7fb925a03ff --- /dev/null +++ b/src/expr/impl/src/scalar/coalesce.rs @@ -0,0 +1,116 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::ops::BitAnd; +use std::sync::Arc; + +use risingwave_common::array::{ArrayRef, DataChunk}; +use risingwave_common::row::OwnedRow; +use risingwave_common::types::{DataType, Datum}; +use risingwave_expr::expr::{BoxedExpression, Expression}; +use risingwave_expr::{build_function, Result}; + +#[derive(Debug)] +pub struct CoalesceExpression { + return_type: DataType, + children: Vec, +} + +#[async_trait::async_trait] +impl Expression for CoalesceExpression { + fn return_type(&self) -> DataType { + self.return_type.clone() + } + + async fn eval(&self, input: &DataChunk) -> Result { + let init_vis = input.visibility(); + let mut input = input.clone(); + let len = input.capacity(); + let mut selection: Vec> = vec![None; len]; + let mut children_array = Vec::with_capacity(self.children.len()); + for (child_idx, child) in self.children.iter().enumerate() { + let res = child.eval(&input).await?; + let res_bitmap = res.null_bitmap(); + let orig_vis = input.visibility(); + for pos in orig_vis.bitand(res_bitmap).iter_ones() { + selection[pos] = Some(child_idx); + } + let new_vis = orig_vis & !res_bitmap; + input.set_visibility(new_vis); + children_array.push(res); + } + let mut builder = self.return_type.create_array_builder(len); + for (i, sel) in selection.iter().enumerate() { + if init_vis.is_set(i) + && let Some(child_idx) = sel + { + builder.append(children_array[*child_idx].value_at(i)); + } else { + builder.append_null() + } + } + Ok(Arc::new(builder.finish())) + } + + async fn eval_row(&self, input: &OwnedRow) -> Result { + for child in &self.children { + let datum = child.eval_row(input).await?; + if datum.is_some() { + return Ok(datum); + } + } + Ok(None) + } +} + +#[build_function("coalesce(...) -> any", type_infer = "panic")] +fn build(return_type: DataType, children: Vec) -> Result { + Ok(Box::new(CoalesceExpression { + return_type, + children, + })) +} + +#[cfg(test)] +mod tests { + use risingwave_common::array::DataChunk; + use risingwave_common::row::Row; + use risingwave_common::test_prelude::DataChunkTestExt; + use risingwave_common::types::ToOwnedDatum; + use risingwave_common::util::iter_util::ZipEqDebug; + use risingwave_expr::expr::build_from_pretty; + + #[tokio::test] + async fn test_coalesce_expr() { + let expr = build_from_pretty("(coalesce:int4 $0:int4 $1:int4 $2:int4)"); + let (input, expected) = DataChunk::from_pretty( + "i i i i + 1 . . 1 + . 2 . 2 + . . 3 3 + . . . .", + ) + .split_column_at(3); + + // test eval + let output = expr.eval(&input).await.unwrap(); + assert_eq!(&output, expected.column_at(0)); + + // test eval_row + for (row, expected) in input.rows().zip_eq_debug(expected.rows()) { + let result = expr.eval_row(&row.to_owned_row()).await.unwrap(); + assert_eq!(result, expected.datum_at(0).to_owned_datum()); + } + } +} diff --git a/src/expr/impl/src/scalar/field.rs b/src/expr/impl/src/scalar/field.rs new file mode 100644 index 0000000000000..a15b8f59664ff --- /dev/null +++ b/src/expr/impl/src/scalar/field.rs @@ -0,0 +1,99 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use anyhow::anyhow; +use risingwave_common::array::{ArrayImpl, ArrayRef, DataChunk}; +use risingwave_common::row::OwnedRow; +use risingwave_common::types::{DataType, Datum, ScalarImpl}; +use risingwave_expr::expr::{BoxedExpression, Expression}; +use risingwave_expr::{build_function, Result}; + +/// `FieldExpression` access a field from a struct. +#[derive(Debug)] +pub struct FieldExpression { + return_type: DataType, + input: BoxedExpression, + index: usize, +} + +#[async_trait::async_trait] +impl Expression for FieldExpression { + fn return_type(&self) -> DataType { + self.return_type.clone() + } + + async fn eval(&self, input: &DataChunk) -> Result { + let array = self.input.eval(input).await?; + if let ArrayImpl::Struct(struct_array) = array.as_ref() { + Ok(struct_array.field_at(self.index).clone()) + } else { + Err(anyhow!("expects a struct array ref").into()) + } + } + + async fn eval_row(&self, input: &OwnedRow) -> Result { + let struct_datum = self.input.eval_row(input).await?; + struct_datum + .map(|s| match s { + ScalarImpl::Struct(v) => Ok(v.fields()[self.index].clone()), + _ => Err(anyhow!("expects a struct array ref").into()), + }) + .transpose() + .map(|x| x.flatten()) + } +} + +#[build_function("field(struct, int4) -> any", type_infer = "panic")] +fn build(return_type: DataType, children: Vec) -> Result { + // Field `func_call_node` have 2 child nodes, the first is Field `FuncCall` or + // `InputRef`, the second is i32 `Literal`. + let [input, index]: [_; 2] = children.try_into().unwrap(); + let index = index.eval_const()?.unwrap().into_int32() as usize; + Ok(Box::new(FieldExpression { + return_type, + input, + index, + })) +} + +#[cfg(test)] +mod tests { + use risingwave_common::array::{DataChunk, DataChunkTestExt}; + use risingwave_common::row::Row; + use risingwave_common::types::ToOwnedDatum; + use risingwave_common::util::iter_util::ZipEqDebug; + use risingwave_expr::expr::build_from_pretty; + + #[tokio::test] + async fn test_field_expr() { + let expr = build_from_pretty("(field:int4 $0:struct 0:int4)"); + let (input, expected) = DataChunk::from_pretty( + " i + (1,2.0) 1 + (2,2.0) 2 + (3,2.0) 3", + ) + .split_column_at(1); + + // test eval + let output = expr.eval(&input).await.unwrap(); + assert_eq!(&output, expected.column_at(0)); + + // test eval_row + for (row, expected) in input.rows().zip_eq_debug(expected.rows()) { + let result = expr.eval_row(&row.to_owned_row()).await.unwrap(); + assert_eq!(result, expected.datum_at(0).to_owned_datum()); + } + } +} diff --git a/src/expr/impl/src/scalar/in_.rs b/src/expr/impl/src/scalar/in_.rs new file mode 100644 index 0000000000000..be037ecbd9b6e --- /dev/null +++ b/src/expr/impl/src/scalar/in_.rs @@ -0,0 +1,165 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::HashSet; +use std::fmt::Debug; +use std::sync::Arc; + +use futures_util::FutureExt; +use risingwave_common::array::{ArrayBuilder, ArrayRef, BoolArrayBuilder, DataChunk}; +use risingwave_common::row::OwnedRow; +use risingwave_common::types::{DataType, Datum, Scalar, ToOwnedDatum}; +use risingwave_common::util::iter_util::ZipEqFast; +use risingwave_expr::expr::{BoxedExpression, Expression}; +use risingwave_expr::{build_function, Result}; + +#[derive(Debug)] +pub struct InExpression { + left: BoxedExpression, + set: HashSet, + return_type: DataType, +} + +impl InExpression { + pub fn new( + left: BoxedExpression, + data: impl Iterator, + return_type: DataType, + ) -> Self { + Self { + left, + set: data.collect(), + return_type, + } + } + + // Returns true if datum exists in set, null if datum is null or datum does not exist in set + // but null does, and false if neither datum nor null exists in set. + fn exists(&self, datum: &Datum) -> Option { + if datum.is_none() { + None + } else if self.set.contains(datum) { + Some(true) + } else if self.set.contains(&None) { + None + } else { + Some(false) + } + } +} + +#[async_trait::async_trait] +impl Expression for InExpression { + fn return_type(&self) -> DataType { + self.return_type.clone() + } + + async fn eval(&self, input: &DataChunk) -> Result { + let input_array = self.left.eval(input).await?; + let mut output_array = BoolArrayBuilder::new(input_array.len()); + for (data, vis) in input_array.iter().zip_eq_fast(input.visibility().iter()) { + if vis { + // TODO: avoid `to_owned_datum()` + let ret = self.exists(&data.to_owned_datum()); + output_array.append(ret); + } else { + output_array.append(None); + } + } + Ok(Arc::new(output_array.finish().into())) + } + + async fn eval_row(&self, input: &OwnedRow) -> Result { + let data = self.left.eval_row(input).await?; + let ret = self.exists(&data); + Ok(ret.map(|b| b.to_scalar_value())) + } +} + +#[build_function("in(any, ...) -> boolean")] +fn build(return_type: DataType, children: Vec) -> Result { + let mut iter = children.into_iter(); + let left_expr = iter.next().unwrap(); + let mut data = Vec::with_capacity(iter.size_hint().0); + let data_chunk = DataChunk::new_dummy(1); + for child in iter { + let array = child + .eval(&data_chunk) + .now_or_never() + .expect("constant expression should not be async")?; + let datum = array.value_at(0).to_owned_datum(); + data.push(datum); + } + Ok(Box::new(InExpression::new( + left_expr, + data.into_iter(), + return_type, + ))) +} + +#[cfg(test)] +mod tests { + use risingwave_common::array::DataChunk; + use risingwave_common::row::Row; + use risingwave_common::test_prelude::DataChunkTestExt; + use risingwave_common::types::ToOwnedDatum; + use risingwave_common::util::iter_util::ZipEqDebug; + use risingwave_expr::expr::{build_from_pretty, Expression}; + + #[tokio::test] + async fn test_in_expr() { + let expr = build_from_pretty("(in:boolean $0:varchar abc:varchar def:varchar)"); + let (input, expected) = DataChunk::from_pretty( + "T B + abc t + a f + def t + abc t + . .", + ) + .split_column_at(1); + + // test eval + let output = expr.eval(&input).await.unwrap(); + assert_eq!(&output, expected.column_at(0)); + + // test eval_row + for (row, expected) in input.rows().zip_eq_debug(expected.rows()) { + let result = expr.eval_row(&row.to_owned_row()).await.unwrap(); + assert_eq!(result, expected.datum_at(0).to_owned_datum()); + } + } + + #[tokio::test] + async fn test_in_expr_null() { + let expr = build_from_pretty("(in:boolean $0:varchar abc:varchar null:varchar)"); + let (input, expected) = DataChunk::from_pretty( + "T B + abc t + a . + . .", + ) + .split_column_at(1); + + // test eval + let output = expr.eval(&input).await.unwrap(); + assert_eq!(&output, expected.column_at(0)); + + // test eval_row + for (row, expected) in input.rows().zip_eq_debug(expected.rows()) { + let result = expr.eval_row(&row.to_owned_row()).await.unwrap(); + assert_eq!(result, expected.datum_at(0).to_owned_datum()); + } + } +} diff --git a/src/expr/impl/src/scalar/mod.rs b/src/expr/impl/src/scalar/mod.rs index 900e38cdd9cec..1d1f3e4bc1500 100644 --- a/src/expr/impl/src/scalar/mod.rs +++ b/src/expr/impl/src/scalar/mod.rs @@ -27,12 +27,14 @@ mod array_replace; mod array_sort; mod array_sum; mod array_to_string; +mod array_transform; mod ascii; mod bitwise_op; mod cardinality; mod case; mod cast; mod cmp; +mod coalesce; mod concat_op; mod concat_ws; mod conjunction; @@ -41,8 +43,10 @@ mod delay; mod encdec; mod exp; mod extract; +mod field; mod format; mod format_type; +mod in_; mod int256; mod jsonb_access; mod jsonb_build; @@ -71,6 +75,7 @@ mod substr; mod timestamptz; mod to_char; mod to_jsonb; +mod vnode; pub use to_jsonb::*; mod to_timestamp; mod translate; diff --git a/src/expr/impl/src/scalar/vnode.rs b/src/expr/impl/src/scalar/vnode.rs new file mode 100644 index 0000000000000..cedad2d7a05aa --- /dev/null +++ b/src/expr/impl/src/scalar/vnode.rs @@ -0,0 +1,94 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use risingwave_common::array::{ArrayBuilder, ArrayImpl, ArrayRef, DataChunk, I16ArrayBuilder}; +use risingwave_common::hash::VirtualNode; +use risingwave_common::row::OwnedRow; +use risingwave_common::types::{DataType, Datum}; +use risingwave_expr::expr::{BoxedExpression, Expression}; +use risingwave_expr::{build_function, Result}; + +#[derive(Debug)] +struct VnodeExpression { + dist_key_indices: Vec, +} + +#[build_function("vnode(...) -> int2")] +fn build(_: DataType, children: Vec) -> Result { + let dist_key_indices = children + .into_iter() + .map(|child| child.input_ref_index().unwrap()) + .collect(); + + Ok(Box::new(VnodeExpression { dist_key_indices })) +} + +#[async_trait::async_trait] +impl Expression for VnodeExpression { + fn return_type(&self) -> DataType { + DataType::Int16 + } + + async fn eval(&self, input: &DataChunk) -> Result { + let vnodes = VirtualNode::compute_chunk(input, &self.dist_key_indices); + let mut builder = I16ArrayBuilder::new(input.capacity()); + vnodes + .into_iter() + .for_each(|vnode| builder.append(Some(vnode.to_scalar()))); + Ok(Arc::new(ArrayImpl::from(builder.finish()))) + } + + async fn eval_row(&self, input: &OwnedRow) -> Result { + Ok(Some( + VirtualNode::compute_row(input, &self.dist_key_indices) + .to_scalar() + .into(), + )) + } +} + +#[cfg(test)] +mod tests { + use risingwave_common::array::{DataChunk, DataChunkTestExt}; + use risingwave_common::hash::VirtualNode; + use risingwave_common::row::Row; + use risingwave_expr::expr::build_from_pretty; + + #[tokio::test] + async fn test_vnode_expr_eval() { + let expr = build_from_pretty("(vnode:int2 $0:int4 $0:int8 $0:varchar)"); + let input = DataChunk::from_pretty( + "i I T + 1 10 abc + 2 32 def + 3 88 ghi", + ); + + // test eval + let output = expr.eval(&input).await.unwrap(); + for vnode in output.iter() { + let vnode = vnode.unwrap().into_int16(); + assert!((0..VirtualNode::COUNT as i16).contains(&vnode)); + } + + // test eval_row + for row in input.rows() { + let result = expr.eval_row(&row.to_owned_row()).await.unwrap(); + let vnode = result.unwrap().into_int16(); + assert!((0..VirtualNode::COUNT as i16).contains(&vnode)); + } + } +} diff --git a/src/udf/Cargo.toml b/src/expr/udf/Cargo.toml similarity index 100% rename from src/udf/Cargo.toml rename to src/expr/udf/Cargo.toml diff --git a/src/udf/examples/client.rs b/src/expr/udf/examples/client.rs similarity index 100% rename from src/udf/examples/client.rs rename to src/expr/udf/examples/client.rs diff --git a/src/udf/python/.gitignore b/src/expr/udf/python/.gitignore similarity index 100% rename from src/udf/python/.gitignore rename to src/expr/udf/python/.gitignore diff --git a/src/udf/python/CHANGELOG.md b/src/expr/udf/python/CHANGELOG.md similarity index 100% rename from src/udf/python/CHANGELOG.md rename to src/expr/udf/python/CHANGELOG.md diff --git a/src/udf/python/README.md b/src/expr/udf/python/README.md similarity index 100% rename from src/udf/python/README.md rename to src/expr/udf/python/README.md diff --git a/src/udf/python/publish.md b/src/expr/udf/python/publish.md similarity index 100% rename from src/udf/python/publish.md rename to src/expr/udf/python/publish.md diff --git a/src/udf/python/pyproject.toml b/src/expr/udf/python/pyproject.toml similarity index 100% rename from src/udf/python/pyproject.toml rename to src/expr/udf/python/pyproject.toml diff --git a/src/udf/python/risingwave/__init__.py b/src/expr/udf/python/risingwave/__init__.py similarity index 100% rename from src/udf/python/risingwave/__init__.py rename to src/expr/udf/python/risingwave/__init__.py diff --git a/src/udf/python/risingwave/test_udf.py b/src/expr/udf/python/risingwave/test_udf.py similarity index 100% rename from src/udf/python/risingwave/test_udf.py rename to src/expr/udf/python/risingwave/test_udf.py diff --git a/src/udf/python/risingwave/udf.py b/src/expr/udf/python/risingwave/udf.py similarity index 100% rename from src/udf/python/risingwave/udf.py rename to src/expr/udf/python/risingwave/udf.py diff --git a/src/udf/python/risingwave/udf/health_check.py b/src/expr/udf/python/risingwave/udf/health_check.py similarity index 100% rename from src/udf/python/risingwave/udf/health_check.py rename to src/expr/udf/python/risingwave/udf/health_check.py diff --git a/src/udf/src/error.rs b/src/expr/udf/src/error.rs similarity index 100% rename from src/udf/src/error.rs rename to src/expr/udf/src/error.rs diff --git a/src/udf/src/external.rs b/src/expr/udf/src/external.rs similarity index 100% rename from src/udf/src/external.rs rename to src/expr/udf/src/external.rs diff --git a/src/udf/src/lib.rs b/src/expr/udf/src/lib.rs similarity index 100% rename from src/udf/src/lib.rs rename to src/expr/udf/src/lib.rs