diff --git a/src/batch/src/executor/aggregation/filter.rs b/src/batch/src/executor/aggregation/filter.rs index 2db2320ed3534..9cfbeabffe417 100644 --- a/src/batch/src/executor/aggregation/filter.rs +++ b/src/batch/src/executor/aggregation/filter.rs @@ -75,7 +75,7 @@ impl AggregateFunction for Filter { mod tests { use risingwave_common::test_prelude::StreamChunkTestExt; use risingwave_expr::aggregate::{build_append_only, AggCall}; - use risingwave_expr::expr::{build_from_pretty, Expression, LiteralExpression}; + use risingwave_expr::expr::{build_from_pretty, ExpressionBoxExt, LiteralExpression}; use super::*; diff --git a/src/batch/src/executor/project_set.rs b/src/batch/src/executor/project_set.rs index 670933a6bb50c..fa3dfac917e8a 100644 --- a/src/batch/src/executor/project_set.rs +++ b/src/batch/src/executor/project_set.rs @@ -171,7 +171,7 @@ mod tests { use risingwave_common::catalog::{Field, Schema}; use risingwave_common::test_prelude::*; use risingwave_common::types::DataType; - use risingwave_expr::expr::{Expression, InputRefExpression, LiteralExpression}; + use risingwave_expr::expr::{ExpressionBoxExt, InputRefExpression, LiteralExpression}; use risingwave_expr::table_function::repeat; use super::*; diff --git a/src/expr/core/src/expr/build.rs b/src/expr/core/src/expr/build.rs index a11883e9db781..2748b575f76ae 100644 --- a/src/expr/core/src/expr/build.rs +++ b/src/expr/core/src/expr/build.rs @@ -29,7 +29,9 @@ use super::expr_udf::UdfExpression; use super::expr_vnode::VnodeExpression; use super::wrapper::{Checked, EvalErrorReport, NonStrict}; use super::InfallibleExpression; -use crate::expr::{BoxedExpression, Expression, InputRefExpression, LiteralExpression}; +use crate::expr::{ + BoxedExpression, Expression, ExpressionBoxExt, InputRefExpression, LiteralExpression, +}; use crate::sig::FUNCTION_REGISTRY; use crate::{bail, ExprError, Result}; @@ -156,7 +158,7 @@ impl BuildBoxed for E { prost: &ExprNode, build_child: impl Fn(&ExprNode) -> Result, ) -> Result { - Self::build(prost, build_child).map(Expression::boxed) + Self::build(prost, build_child).map(ExpressionBoxExt::boxed) } } diff --git a/src/expr/core/src/expr/mod.rs b/src/expr/core/src/expr/mod.rs index 04ab40fca1e86..78955e7e871ab 100644 --- a/src/expr/core/src/expr/mod.rs +++ b/src/expr/core/src/expr/mod.rs @@ -67,6 +67,7 @@ pub use super::{ExprError, Result}; /// should be implemented. Prefer calling and implementing `eval_v2` instead of `eval` if possible, /// to gain the performance benefit of scalar expression. #[async_trait::async_trait] +#[auto_impl::auto_impl(&, Box)] pub trait Expression: std::fmt::Debug + Sync + Send { /// Get the return data type. fn return_type(&self) -> DataType; @@ -101,78 +102,38 @@ pub trait Expression: std::fmt::Debug + Sync + Send { fn eval_const(&self) -> Result { Err(ExprError::NotConstant) } - - /// Wrap the expression in a Box. - fn boxed(self) -> BoxedExpression - where - Self: Sized + Send + 'static, - { - Box::new(self) - } -} - -// TODO: make this an extension, or implement it on a `NonStrict` newtype. -impl dyn Expression { - /// Evaluate the expression in vectorized execution and assert it succeeds. Returns an array. - /// - /// Use with expressions built in non-strict mode. - pub async fn eval_infallible(&self, input: &DataChunk) -> ArrayRef { - self.eval(input).await.expect("evaluation failed") - } - - /// Evaluate the expression in row-based execution and assert it succeeds. Returns a nullable - /// scalar. - /// - /// Use with expressions built in non-strict mode. - pub async fn eval_row_infallible(&self, input: &OwnedRow) -> Datum { - self.eval_row(input).await.expect("evaluation failed") - } } /// An owned dynamically typed [`Expression`]. pub type BoxedExpression = Box; -// TODO: avoid the overhead of extra boxing. -#[async_trait::async_trait] -impl Expression for BoxedExpression { - fn return_type(&self) -> DataType { - (**self).return_type() - } - - async fn eval(&self, input: &DataChunk) -> Result { - (**self).eval(input).await - } - - async fn eval_v2(&self, input: &DataChunk) -> Result { - (**self).eval_v2(input).await - } - - async fn eval_row(&self, input: &OwnedRow) -> Result { - (**self).eval_row(input).await - } - - fn eval_const(&self) -> Result { - (**self).eval_const() - } - - fn boxed(self) -> BoxedExpression { - self +/// Extension trait for boxing expressions. +#[easy_ext::ext(ExpressionBoxExt)] +impl E { + /// Wrap the expression in a Box. + pub fn boxed(self) -> BoxedExpression { + Box::new(self) } } #[derive(Debug)] pub struct InfallibleExpression(E); -impl InfallibleExpression { - pub fn for_test(inner: impl Expression + 'static) -> Self { - Self(inner.boxed()) - } -} - impl InfallibleExpression where E: Expression, { + pub fn for_test(inner: E) -> InfallibleExpression + where + E: 'static, + { + InfallibleExpression(inner.boxed()) + } + + pub fn todo(inner: E) -> Self { + Self(inner) + } + /// Get the return data type. pub fn return_type(&self) -> DataType { self.0.return_type() diff --git a/src/expr/impl/src/scalar/cast.rs b/src/expr/impl/src/scalar/cast.rs index 889cc43fe6b18..c173c76c330c5 100644 --- a/src/expr/impl/src/scalar/cast.rs +++ b/src/expr/impl/src/scalar/cast.rs @@ -22,7 +22,9 @@ use risingwave_common::cast; use risingwave_common::row::OwnedRow; use risingwave_common::types::{DataType, Int256, IntoOrdered, JsonbRef, ToText, F64}; use risingwave_common::util::iter_util::ZipEqFast; -use risingwave_expr::expr::{build_func, Context, Expression, InputRefExpression}; +use risingwave_expr::expr::{ + build_func, Context, Expression, ExpressionBoxExt, InputRefExpression, +}; use risingwave_expr::{function, ExprError, Result}; use risingwave_pb::expr::expr_node::PbType; diff --git a/src/expr/impl/src/table_function/generate_series.rs b/src/expr/impl/src/table_function/generate_series.rs index 586fa60de02c2..dfa09b0e215b8 100644 --- a/src/expr/impl/src/table_function/generate_series.rs +++ b/src/expr/impl/src/table_function/generate_series.rs @@ -159,7 +159,7 @@ mod tests { use risingwave_common::array::DataChunk; use risingwave_common::types::test_utils::IntervalTestExt; use risingwave_common::types::{DataType, Decimal, Interval, ScalarImpl, Timestamp}; - use risingwave_expr::expr::{BoxedExpression, Expression, LiteralExpression}; + use risingwave_expr::expr::{BoxedExpression, ExpressionBoxExt, LiteralExpression}; use risingwave_expr::table_function::build; use risingwave_expr::ExprError; use risingwave_pb::expr::table_function::PbType; diff --git a/src/storage/src/row_serde/value_serde.rs b/src/storage/src/row_serde/value_serde.rs index 5d56cdba2d96d..9048b90c23a53 100644 --- a/src/storage/src/row_serde/value_serde.rs +++ b/src/storage/src/row_serde/value_serde.rs @@ -114,9 +114,10 @@ impl ValueRowSerdeNew for ColumnAwareSerde { // It's okay since we previously banned impure expressions in default columns. build_from_prost(&expr.expect("expr should not be none")) .expect("build_from_prost error") - .eval_row_infallible(&OwnedRow::empty()) + .eval_row(&OwnedRow::empty()) .now_or_never() .expect("constant expression should not be async") + .expect("eval_row failed") }; Some((i, value)) } else { diff --git a/src/stream/src/executor/aggregation/mod.rs b/src/stream/src/executor/aggregation/mod.rs index dd0ce9d01c544..b4d69152c2911 100644 --- a/src/stream/src/executor/aggregation/mod.rs +++ b/src/stream/src/executor/aggregation/mod.rs @@ -21,6 +21,7 @@ use risingwave_common::bail; use risingwave_common::buffer::Bitmap; use risingwave_common::catalog::{Field, Schema}; use risingwave_expr::aggregate::{AggCall, AggKind}; +use risingwave_expr::expr::InfallibleExpression; use risingwave_storage::StateStore; use crate::common::table::state_table::StateTable; @@ -74,7 +75,11 @@ pub async fn agg_call_filter_res( } if let Some(ref filter) = agg_call.filter { - if let Bool(filter_res) = filter.eval_infallible(chunk).await.as_ref() { + if let Bool(filter_res) = InfallibleExpression::todo(&**filter) + .eval_infallible(chunk) + .await + .as_ref() + { vis &= filter_res.to_bitmap(); } else { bail!("Filter can only receive bool array"); diff --git a/src/stream/src/executor/mod.rs b/src/stream/src/executor/mod.rs index 2d3957740c4c4..e3a341ba3a5d0 100644 --- a/src/stream/src/executor/mod.rs +++ b/src/stream/src/executor/mod.rs @@ -31,7 +31,7 @@ use risingwave_common::util::epoch::{Epoch, EpochPair}; use risingwave_common::util::tracing::TracingContext; use risingwave_common::util::value_encoding::{DatumFromProtoExt, DatumToProtoExt}; use risingwave_connector::source::SplitImpl; -use risingwave_expr::expr::{InfallibleExpression, Expression}; +use risingwave_expr::expr::{Expression, InfallibleExpression}; use risingwave_pb::data::PbEpoch; use risingwave_pb::expr::PbInputRef; use risingwave_pb::stream_plan::barrier::{BarrierKind, PbMutation}; @@ -641,7 +641,7 @@ impl Watermark { pub async fn transform_with_expr( self, - expr: &InfallibleExpression, + expr: &InfallibleExpression, new_col_idx: usize, ) -> Option { let Self { col_idx, val, .. } = self; diff --git a/src/stream/src/executor/project_set.rs b/src/stream/src/executor/project_set.rs index 7e722d8dd363e..8bf94b86b69ea 100644 --- a/src/stream/src/executor/project_set.rs +++ b/src/stream/src/executor/project_set.rs @@ -24,11 +24,10 @@ use risingwave_common::catalog::{Field, Schema}; use risingwave_common::row::{Row, RowExt}; use risingwave_common::types::{DataType, Datum, DatumRef, ToOwnedDatum}; use risingwave_common::util::iter_util::ZipEqFast; -use risingwave_expr::expr::{BoxedExpression, InfallibleExpression}; +use risingwave_expr::expr::InfallibleExpression; use risingwave_expr::table_function::ProjectSetSelectItem; use super::error::StreamExecutorError; -use super::test_utils::expr::build_from_pretty; use super::{ ActorContextRef, BoxedExecutor, Executor, ExecutorInfo, Message, PkIndices, PkIndicesRef, StreamExecutorResult, Watermark, @@ -263,7 +262,7 @@ impl Inner { watermark .clone() .transform_with_expr( - &build_from_pretty(""), // TODO + &InfallibleExpression::todo(expr), expr_idx + PROJ_ROW_ID_OFFSET, ) .await diff --git a/src/stream/src/executor/values.rs b/src/stream/src/executor/values.rs index 73dc0f8f08c8c..bb6017ae14410 100644 --- a/src/stream/src/executor/values.rs +++ b/src/stream/src/executor/values.rs @@ -183,7 +183,8 @@ mod tests { let actor_id = progress.actor_id(); let (tx, barrier_receiver) = unbounded_channel(); let value = StructValue::new(vec![Some(1.into()), Some(2.into()), Some(3.into())]); - let exprs = vec![Box::new(LiteralExpression::new( + let exprs = vec![ + Box::new(LiteralExpression::new( DataType::Int16, Some(ScalarImpl::Int16(1)), )) as BoxedExpression, diff --git a/src/stream/src/executor/watermark_filter.rs b/src/stream/src/executor/watermark_filter.rs index e2ff36b69c782..01a6744addafd 100644 --- a/src/stream/src/executor/watermark_filter.rs +++ b/src/stream/src/executor/watermark_filter.rs @@ -23,7 +23,8 @@ use risingwave_common::row::{OwnedRow, Row}; use risingwave_common::types::{DataType, DefaultOrd, ScalarImpl}; use risingwave_common::{bail, row}; use risingwave_expr::expr::{ - build_func_non_strict, Expression, InfallibleExpression, InputRefExpression, LiteralExpression, + build_func_non_strict, ExpressionBoxExt, InfallibleExpression, InputRefExpression, + LiteralExpression, }; use risingwave_expr::Result as ExprResult; use risingwave_pb::expr::expr_node::Type; diff --git a/src/stream/src/from_proto/hash_join.rs b/src/stream/src/from_proto/hash_join.rs index 32bbfae408893..c3eed40b18391 100644 --- a/src/stream/src/from_proto/hash_join.rs +++ b/src/stream/src/from_proto/hash_join.rs @@ -109,7 +109,8 @@ impl ExecutorBuilder for HashJoinExecutorBuilder { build_non_strict_from_prost( delta_expression.delta.as_ref().unwrap(), params.eval_error_report.clone(), - )?.into_inner(), + )? + .into_inner(), ], params.eval_error_report.clone(), )?)