From c8df1eff684c97b4b10f2d8a6b078823431fc62c Mon Sep 17 00:00:00 2001 From: Bugen Zhao Date: Tue, 17 Oct 2023 20:07:28 +0800 Subject: [PATCH 1/8] impl wrapper Signed-off-by: Bugen Zhao --- src/expr/core/src/expr/build.rs | 11 +++-- src/expr/core/src/expr/mod.rs | 44 +++++++++++++++++++ src/stream/src/executor/dynamic_filter.rs | 4 +- src/stream/src/executor/filter.rs | 8 ++-- src/stream/src/executor/hash_join.rs | 20 ++++----- src/stream/src/executor/hop_window.rs | 21 ++++++--- src/stream/src/executor/integration_tests.rs | 2 +- src/stream/src/executor/mod.rs | 6 +-- src/stream/src/executor/project.rs | 11 ++--- src/stream/src/executor/project_set.rs | 7 ++- src/stream/src/executor/temporal_join.rs | 6 +-- src/stream/src/executor/test_utils.rs | 10 ++++- src/stream/src/executor/values.rs | 20 +++++---- src/stream/src/executor/watermark_filter.rs | 10 ++--- src/stream/src/from_proto/hash_join.rs | 8 ++-- src/stream/src/from_proto/temporal_join.rs | 4 +- .../tests/integration_tests/hop_window.rs | 11 ++++- .../tests/integration_tests/project_set.rs | 8 ++-- 18 files changed, 144 insertions(+), 67 deletions(-) diff --git a/src/expr/core/src/expr/build.rs b/src/expr/core/src/expr/build.rs index 1ea03bd36f42..a11883e9db78 100644 --- a/src/expr/core/src/expr/build.rs +++ b/src/expr/core/src/expr/build.rs @@ -28,6 +28,7 @@ use super::expr_some_all::SomeAllExpression; use super::expr_udf::UdfExpression; use super::expr_vnode::VnodeExpression; use super::wrapper::{Checked, EvalErrorReport, NonStrict}; +use super::InfallibleExpression; use crate::expr::{BoxedExpression, Expression, InputRefExpression, LiteralExpression}; use crate::sig::FUNCTION_REGISTRY; use crate::{bail, ExprError, Result}; @@ -41,8 +42,10 @@ pub fn build_from_prost(prost: &ExprNode) -> Result { pub fn build_non_strict_from_prost( prost: &ExprNode, error_report: impl EvalErrorReport + 'static, -) -> Result { - ExprBuilder::new_non_strict(error_report).build(prost) +) -> Result { + ExprBuilder::new_non_strict(error_report) + .build(prost) + .map(InfallibleExpression) } /// Build an expression from protobuf with possibly some wrappers attached to each node. @@ -217,9 +220,9 @@ pub fn build_func_non_strict( ret_type: DataType, children: Vec, error_report: impl EvalErrorReport + 'static, -) -> Result { +) -> Result { let expr = build_func(func, ret_type, children)?; - let wrapped = ExprBuilder::new_non_strict(error_report).wrap(expr); + let wrapped = InfallibleExpression(ExprBuilder::new_non_strict(error_report).wrap(expr)); Ok(wrapped) } diff --git a/src/expr/core/src/expr/mod.rs b/src/expr/core/src/expr/mod.rs index 37e0104371a3..04ab40fca1e8 100644 --- a/src/expr/core/src/expr/mod.rs +++ b/src/expr/core/src/expr/mod.rs @@ -160,6 +160,50 @@ impl Expression for BoxedExpression { } } +#[derive(Debug)] +pub struct InfallibleExpression(E); + +impl InfallibleExpression { + pub fn for_test(inner: impl Expression + 'static) -> Self { + Self(inner.boxed()) + } +} + +impl InfallibleExpression +where + E: Expression, +{ + /// Get the return data type. + pub fn return_type(&self) -> DataType { + self.0.return_type() + } + + /// Evaluate the expression in vectorized execution and assert it succeeds. Returns an array. + /// + /// Use with expressions built in non-strict mode. + pub async fn eval_infallible(&self, input: &DataChunk) -> ArrayRef { + self.0.eval(input).await.expect("evaluation failed") + } + + /// Evaluate the expression in row-based execution and assert it succeeds. Returns a nullable + /// scalar. + /// + /// Use with expressions built in non-strict mode. + pub async fn eval_row_infallible(&self, input: &OwnedRow) -> Datum { + self.0.eval_row(input).await.expect("evaluation failed") + } + + /// Unwrap the inner expression. + pub fn into_inner(self) -> E { + self.0 + } + + /// Get a reference to the inner expression. + pub fn inner(&self) -> &E { + &self.0 + } +} + /// An optional context that can be used in a function. /// /// # Example diff --git a/src/stream/src/executor/dynamic_filter.rs b/src/stream/src/executor/dynamic_filter.rs index e8eb4da545f2..b86a2a3d67c1 100644 --- a/src/stream/src/executor/dynamic_filter.rs +++ b/src/stream/src/executor/dynamic_filter.rs @@ -26,7 +26,7 @@ use risingwave_common::row::{self, once, OwnedRow, OwnedRow as RowData, Row}; use risingwave_common::types::{DataType, Datum, DefaultOrd, ScalarImpl, ToDatumRef, ToOwnedDatum}; use risingwave_common::util::iter_util::ZipEqDebug; use risingwave_expr::expr::{ - build_func_non_strict, BoxedExpression, InputRefExpression, LiteralExpression, + build_func_non_strict, InfallibleExpression, InputRefExpression, LiteralExpression, }; use risingwave_pb::expr::expr_node::Type as ExprNodeType; use risingwave_pb::expr::expr_node::Type::{ @@ -97,7 +97,7 @@ impl DynamicFilterExecutor, + condition: Option, ) -> Result<(Vec, Bitmap), StreamExecutorError> { let mut new_ops = Vec::with_capacity(chunk.capacity()); let mut new_visibility = BitmapBuilder::with_capacity(chunk.capacity()); diff --git a/src/stream/src/executor/filter.rs b/src/stream/src/executor/filter.rs index ef593f873428..236f203dcbe1 100644 --- a/src/stream/src/executor/filter.rs +++ b/src/stream/src/executor/filter.rs @@ -19,7 +19,7 @@ use risingwave_common::array::{Array, ArrayImpl, Op, StreamChunk}; use risingwave_common::buffer::BitmapBuilder; use risingwave_common::catalog::Schema; use risingwave_common::util::iter_util::ZipEqFast; -use risingwave_expr::expr::BoxedExpression; +use risingwave_expr::expr::InfallibleExpression; use super::*; @@ -34,14 +34,14 @@ pub struct FilterExecutor { /// Expression of the current filter, note that the filter must always have the same output for /// the same input. - expr: BoxedExpression, + expr: InfallibleExpression, } impl FilterExecutor { pub fn new( ctx: ActorContextRef, input: Box, - expr: BoxedExpression, + expr: InfallibleExpression, executor_id: u64, ) -> Self { let input_info = input.info(); @@ -190,8 +190,8 @@ mod tests { use risingwave_common::array::StreamChunk; use risingwave_common::catalog::{Field, Schema}; use risingwave_common::types::DataType; - use risingwave_expr::expr::build_from_pretty; + use super::super::test_utils::expr::build_from_pretty; use super::super::test_utils::MockSource; use super::super::*; use super::*; diff --git a/src/stream/src/executor/hash_join.rs b/src/stream/src/executor/hash_join.rs index 7aed840679c8..cca953022369 100644 --- a/src/stream/src/executor/hash_join.rs +++ b/src/stream/src/executor/hash_join.rs @@ -28,7 +28,7 @@ use risingwave_common::row::{OwnedRow, Row}; use risingwave_common::types::{DataType, DefaultOrd, ToOwnedDatum}; use risingwave_common::util::epoch::EpochPair; use risingwave_common::util::iter_util::ZipEqDebug; -use risingwave_expr::expr::BoxedExpression; +use risingwave_expr::expr::InfallibleExpression; use risingwave_expr::ExprError; use risingwave_storage::StateStore; use tokio::time::Instant; @@ -242,9 +242,9 @@ pub struct HashJoinExecutor, /// Optional non-equi join conditions - cond: Option, + cond: Option, /// Column indices of watermark output and offset expression of each inequality, respectively. - inequality_pairs: Vec<(Vec, Option)>, + inequality_pairs: Vec<(Vec, Option)>, /// The output watermark of each inequality condition and its value is the minimum of the /// calculation result of both side. It will be used to generate watermark into downstream /// and do state cleaning if `clean_state` field of that inequality is `true`. @@ -313,7 +313,7 @@ struct EqJoinArgs<'a, K: HashKey, S: StateStore> { side_l: &'a mut JoinSide, side_r: &'a mut JoinSide, actual_output_data_types: &'a [DataType], - cond: &'a mut Option, + cond: &'a mut Option, inequality_watermarks: &'a [Option], chunk: StreamChunk, append_only_optimize: bool, @@ -448,8 +448,8 @@ impl HashJoinExecutor, executor_id: u64, - cond: Option, - inequality_pairs: Vec<(usize, usize, bool, Option)>, + cond: Option, + inequality_pairs: Vec<(usize, usize, bool, Option)>, op_info: String, state_table_l: StateTable, degree_state_table_l: StateTable, @@ -912,7 +912,7 @@ impl HashJoinExecutor input_watermark.val = value.unwrap(), @@ -1275,11 +1275,11 @@ mod tests { use risingwave_common::hash::{Key128, Key64}; use risingwave_common::types::ScalarImpl; use risingwave_common::util::sort_util::OrderType; - use risingwave_expr::expr::build_from_pretty; use risingwave_storage::memory::MemoryStateStore; use super::*; use crate::common::table::state_table::StateTable; + use crate::executor::test_utils::expr::build_from_pretty; use crate::executor::test_utils::{MessageSender, MockSource, StreamExecutorTestExt}; use crate::executor::{ActorContext, Barrier, EpochPair}; @@ -1327,7 +1327,7 @@ mod tests { (state_table, degree_state_table) } - fn create_cond(condition_text: Option) -> BoxedExpression { + fn create_cond(condition_text: Option) -> InfallibleExpression { build_from_pretty( condition_text .as_deref() @@ -1339,7 +1339,7 @@ mod tests { with_condition: bool, null_safe: bool, condition_text: Option, - inequality_pairs: Vec<(usize, usize, bool, Option)>, + inequality_pairs: Vec<(usize, usize, bool, Option)>, ) -> (MessageSender, MessageSender, BoxedMessageStream) { let schema = Schema { fields: vec![ diff --git a/src/stream/src/executor/hop_window.rs b/src/stream/src/executor/hop_window.rs index c6fffcd94896..7c8c44330c86 100644 --- a/src/stream/src/executor/hop_window.rs +++ b/src/stream/src/executor/hop_window.rs @@ -19,7 +19,7 @@ use futures_async_stream::try_stream; use itertools::Itertools; use risingwave_common::array::{DataChunk, Op}; use risingwave_common::types::Interval; -use risingwave_expr::expr::BoxedExpression; +use risingwave_expr::expr::InfallibleExpression; use risingwave_expr::ExprError; use super::error::StreamExecutorError; @@ -33,8 +33,8 @@ pub struct HopWindowExecutor { pub time_col_idx: usize, pub window_slide: Interval, pub window_size: Interval, - window_start_exprs: Vec, - window_end_exprs: Vec, + window_start_exprs: Vec, + window_end_exprs: Vec, pub output_indices: Vec, chunk_size: usize, } @@ -48,8 +48,8 @@ impl HopWindowExecutor { time_col_idx: usize, window_slide: Interval, window_size: Interval, - window_start_exprs: Vec, - window_end_exprs: Vec, + window_start_exprs: Vec, + window_end_exprs: Vec, output_indices: Vec, chunk_size: usize, ) -> Self { @@ -251,6 +251,7 @@ mod tests { use risingwave_common::types::test_utils::IntervalTestExt; use risingwave_common::types::{DataType, Interval}; use risingwave_expr::expr::test_utils::make_hop_window_expression; + use risingwave_expr::expr::InfallibleExpression; use crate::executor::test_utils::MockSource; use crate::executor::{ActorContext, Executor, ExecutorInfo, StreamChunk}; @@ -302,8 +303,14 @@ mod tests { 2, window_slide, window_size, - window_start_exprs, - window_end_exprs, + window_start_exprs + .into_iter() + .map(InfallibleExpression::for_test) + .collect(), + window_end_exprs + .into_iter() + .map(InfallibleExpression::for_test) + .collect(), output_indices, CHUNK_SIZE, ) diff --git a/src/stream/src/executor/integration_tests.rs b/src/stream/src/executor/integration_tests.rs index a9c219a25641..1ae92bf3dcd1 100644 --- a/src/stream/src/executor/integration_tests.rs +++ b/src/stream/src/executor/integration_tests.rs @@ -152,7 +152,7 @@ async fn test_merger_sum_aggr() { vec![], vec![ // TODO: use the new streaming_if_null expression here, and add `None` tests - Box::new(InputRefExpression::new(DataType::Int64, 1)), + InfallibleExpression::for_test(InputRefExpression::new(DataType::Int64, 1)), ], 3, MultiMap::new(), diff --git a/src/stream/src/executor/mod.rs b/src/stream/src/executor/mod.rs index 99b090e21a24..2d3957740c4c 100644 --- a/src/stream/src/executor/mod.rs +++ b/src/stream/src/executor/mod.rs @@ -31,7 +31,7 @@ use risingwave_common::util::epoch::{Epoch, EpochPair}; use risingwave_common::util::tracing::TracingContext; use risingwave_common::util::value_encoding::{DatumFromProtoExt, DatumToProtoExt}; use risingwave_connector::source::SplitImpl; -use risingwave_expr::expr::BoxedExpression; +use risingwave_expr::expr::{InfallibleExpression, Expression}; use risingwave_pb::data::PbEpoch; use risingwave_pb::expr::PbInputRef; use risingwave_pb::stream_plan::barrier::{BarrierKind, PbMutation}; @@ -641,7 +641,7 @@ impl Watermark { pub async fn transform_with_expr( self, - expr: &BoxedExpression, + expr: &InfallibleExpression, new_col_idx: usize, ) -> Option { let Self { col_idx, val, .. } = self; @@ -651,7 +651,7 @@ impl Watermark { OwnedRow::new(row) }; let val = expr.eval_row_infallible(&row).await?; - Some(Self::new(new_col_idx, expr.return_type(), val)) + Some(Self::new(new_col_idx, expr.inner().return_type(), val)) } /// Transform the watermark with the given output indices. If this watermark is not in the diff --git a/src/stream/src/executor/project.rs b/src/stream/src/executor/project.rs index 56a31bde901b..2e40568a7167 100644 --- a/src/stream/src/executor/project.rs +++ b/src/stream/src/executor/project.rs @@ -21,7 +21,7 @@ use risingwave_common::catalog::{Field, Schema}; use risingwave_common::row::{Row, RowExt}; use risingwave_common::types::ToOwnedDatum; use risingwave_common::util::iter_util::ZipEqFast; -use risingwave_expr::expr::BoxedExpression; +use risingwave_expr::expr::InfallibleExpression; use super::*; @@ -38,7 +38,7 @@ struct Inner { info: ExecutorInfo, /// Expressions of the current projection. - exprs: Vec, + exprs: Vec, /// All the watermark derivations, (input_column_index, output_column_index). And the /// derivation expression is the project's expression itself. watermark_derivations: MultiMap, @@ -58,7 +58,7 @@ impl ProjectExecutor { ctx: ActorContextRef, input: Box, pk_indices: PkIndices, - exprs: Vec, + exprs: Vec, executor_id: u64, watermark_derivations: MultiMap, nondecreasing_expr_indices: Vec, @@ -233,11 +233,12 @@ mod tests { use risingwave_common::array::{DataChunk, StreamChunk}; use risingwave_common::catalog::{Field, Schema}; use risingwave_common::types::{DataType, Datum}; - use risingwave_expr::expr::{self, build_from_pretty, Expression, ValueImpl}; + use risingwave_expr::expr::{self, Expression, ValueImpl}; use super::super::test_utils::MockSource; use super::super::*; use super::*; + use crate::executor::test_utils::expr::build_from_pretty; use crate::executor::test_utils::StreamExecutorTestExt; #[tokio::test] @@ -345,7 +346,7 @@ mod tests { let a_expr = build_from_pretty("(add:int8 $0:int8 1:int8)"); let b_expr = build_from_pretty("(subtract:int8 $0:int8 1:int8)"); - let c_expr = DummyNondecreasingExpr.boxed(); + let c_expr = InfallibleExpression::for_test(DummyNondecreasingExpr); let project = Box::new(ProjectExecutor::new( ActorContext::create(123), diff --git a/src/stream/src/executor/project_set.rs b/src/stream/src/executor/project_set.rs index 6867e3d55bfd..7e722d8dd363 100644 --- a/src/stream/src/executor/project_set.rs +++ b/src/stream/src/executor/project_set.rs @@ -24,9 +24,11 @@ use risingwave_common::catalog::{Field, Schema}; use risingwave_common::row::{Row, RowExt}; use risingwave_common::types::{DataType, Datum, DatumRef, ToOwnedDatum}; use risingwave_common::util::iter_util::ZipEqFast; +use risingwave_expr::expr::{BoxedExpression, InfallibleExpression}; use risingwave_expr::table_function::ProjectSetSelectItem; use super::error::StreamExecutorError; +use super::test_utils::expr::build_from_pretty; use super::{ ActorContextRef, BoxedExecutor, Executor, ExecutorInfo, Message, PkIndices, PkIndicesRef, StreamExecutorResult, Watermark, @@ -260,7 +262,10 @@ impl Inner { ProjectSetSelectItem::Expr(expr) => { watermark .clone() - .transform_with_expr(expr, expr_idx + PROJ_ROW_ID_OFFSET) + .transform_with_expr( + &build_from_pretty(""), // TODO + expr_idx + PROJ_ROW_ID_OFFSET, + ) .await } ProjectSetSelectItem::TableFunction(_) => { diff --git a/src/stream/src/executor/temporal_join.rs b/src/stream/src/executor/temporal_join.rs index 3c8cde63c4ca..bb456ed570d6 100644 --- a/src/stream/src/executor/temporal_join.rs +++ b/src/stream/src/executor/temporal_join.rs @@ -32,7 +32,7 @@ use risingwave_common::hash::{HashKey, NullBitmap}; use risingwave_common::row::{OwnedRow, Row, RowExt}; use risingwave_common::types::DataType; use risingwave_common::util::iter_util::ZipEqDebug; -use risingwave_expr::expr::BoxedExpression; +use risingwave_expr::expr::InfallibleExpression; use risingwave_hummock_sdk::{HummockEpoch, HummockReadEpoch}; use risingwave_storage::store::PrefetchOptions; use risingwave_storage::table::batch_table::storage_table::StorageTable; @@ -57,7 +57,7 @@ pub struct TemporalJoinExecutor, right_join_keys: Vec, null_safe: Vec, - condition: Option, + condition: Option, output_indices: Vec, pk_indices: PkIndices, schema: Schema, @@ -338,7 +338,7 @@ impl TemporalJoinExecutor left_join_keys: Vec, right_join_keys: Vec, null_safe: Vec, - condition: Option, + condition: Option, pk_indices: PkIndices, output_indices: Vec, table_output_indices: Vec, diff --git a/src/stream/src/executor/test_utils.rs b/src/stream/src/executor/test_utils.rs index bb4864ac04ef..7d3c48de3a5e 100644 --- a/src/stream/src/executor/test_utils.rs +++ b/src/stream/src/executor/test_utils.rs @@ -34,11 +34,11 @@ pub mod prelude { pub use risingwave_common::test_prelude::StreamChunkTestExt; pub use risingwave_common::types::DataType; pub use risingwave_common::util::sort_util::OrderType; - pub use risingwave_expr::expr::build_from_pretty; pub use risingwave_storage::memory::MemoryStateStore; pub use risingwave_storage::StateStore; pub use crate::common::table::state_table::StateTable; + pub use crate::executor::test_utils::expr::build_from_pretty; pub use crate::executor::test_utils::{MessageSender, MockSource, StreamExecutorTestExt}; pub use crate::executor::{ActorContext, BoxedMessageStream, Executor, PkIndices}; } @@ -263,6 +263,14 @@ pub trait StreamExecutorTestExt: MessageStream + Unpin { // FIXME: implement on any `impl MessageStream` if the analyzer works well. impl StreamExecutorTestExt for BoxedMessageStream {} +pub mod expr { + use risingwave_expr::expr::InfallibleExpression; + + pub fn build_from_pretty(s: impl AsRef) -> InfallibleExpression { + InfallibleExpression::for_test(risingwave_expr::expr::build_from_pretty(s)) + } +} + pub mod agg_executor { use std::sync::atomic::AtomicU64; use std::sync::Arc; diff --git a/src/stream/src/executor/values.rs b/src/stream/src/executor/values.rs index 624b2531bf7b..73dc0f8f08c8 100644 --- a/src/stream/src/executor/values.rs +++ b/src/stream/src/executor/values.rs @@ -21,7 +21,7 @@ use risingwave_common::array::{DataChunk, Op, StreamChunk}; use risingwave_common::catalog::Schema; use risingwave_common::ensure; use risingwave_common::util::iter_util::ZipEqFast; -use risingwave_expr::expr::BoxedExpression; +use risingwave_expr::expr::InfallibleExpression; use tokio::sync::mpsc::UnboundedReceiver; use super::{ @@ -40,7 +40,7 @@ pub struct ValuesExecutor { barrier_receiver: UnboundedReceiver, progress: CreateMviewProgress, - rows: vec::IntoIter>, + rows: vec::IntoIter>, pk_indices: PkIndices, identity: String, schema: Schema, @@ -51,7 +51,7 @@ impl ValuesExecutor { pub fn new( ctx: ActorContextRef, progress: CreateMviewProgress, - rows: Vec>, + rows: Vec>, schema: Schema, barrier_receiver: UnboundedReceiver, executor_id: u64, @@ -167,7 +167,7 @@ mod tests { }; use risingwave_common::catalog::{Field, Schema}; use risingwave_common::types::{DataType, ScalarImpl, StructType}; - use risingwave_expr::expr::{BoxedExpression, LiteralExpression}; + use risingwave_expr::expr::{BoxedExpression, InfallibleExpression, LiteralExpression}; use tokio::sync::mpsc::unbounded_channel; use super::ValuesExecutor; @@ -183,8 +183,7 @@ mod tests { let actor_id = progress.actor_id(); let (tx, barrier_receiver) = unbounded_channel(); let value = StructValue::new(vec![Some(1.into()), Some(2.into()), Some(3.into())]); - let exprs = vec![ - Box::new(LiteralExpression::new( + let exprs = vec![Box::new(LiteralExpression::new( DataType::Int16, Some(ScalarImpl::Int16(1)), )) as BoxedExpression, @@ -202,11 +201,11 @@ mod tests { vec![], ), Some(ScalarImpl::Struct(value)), - )) as BoxedExpression, + )), Box::new(LiteralExpression::new( DataType::Int64, Some(ScalarImpl::Int64(0)), - )) as BoxedExpression, + )), ]; let fields = exprs .iter() // for each column @@ -215,7 +214,10 @@ mod tests { let values_executor_struct = ValuesExecutor::new( ActorContext::create(actor_id), progress, - vec![exprs], + vec![exprs + .into_iter() + .map(InfallibleExpression::for_test) + .collect()], Schema { fields }, barrier_receiver, 10005, diff --git a/src/stream/src/executor/watermark_filter.rs b/src/stream/src/executor/watermark_filter.rs index ad332112ef26..e2ff36b69c78 100644 --- a/src/stream/src/executor/watermark_filter.rs +++ b/src/stream/src/executor/watermark_filter.rs @@ -23,7 +23,7 @@ use risingwave_common::row::{OwnedRow, Row}; use risingwave_common::types::{DataType, DefaultOrd, ScalarImpl}; use risingwave_common::{bail, row}; use risingwave_expr::expr::{ - build_func_non_strict, BoxedExpression, Expression, InputRefExpression, LiteralExpression, + build_func_non_strict, Expression, InfallibleExpression, InputRefExpression, LiteralExpression, }; use risingwave_expr::Result as ExprResult; use risingwave_pb::expr::expr_node::Type; @@ -44,7 +44,7 @@ use crate::task::ActorEvalErrorReport; pub struct WatermarkFilterExecutor { input: BoxedExecutor, /// The expression used to calculate the watermark value. - watermark_expr: BoxedExpression, + watermark_expr: InfallibleExpression, /// The column we should generate watermark and filter on. event_time_col_idx: usize, ctx: ActorContextRef, @@ -55,7 +55,7 @@ pub struct WatermarkFilterExecutor { impl WatermarkFilterExecutor { pub fn new( input: BoxedExecutor, - watermark_expr: BoxedExpression, + watermark_expr: InfallibleExpression, event_time_col_idx: usize, ctx: ActorContextRef, table: StateTable, @@ -298,7 +298,7 @@ impl WatermarkFilterExecutor { event_time_col_idx: usize, watermark: ScalarImpl, eval_error_report: ActorEvalErrorReport, - ) -> ExprResult { + ) -> ExprResult { build_func_non_strict( Type::GreaterThanOrEqual, DataType::Boolean, @@ -350,11 +350,11 @@ mod tests { use risingwave_common::test_prelude::StreamChunkTestExt; use risingwave_common::types::Date; use risingwave_common::util::sort_util::OrderType; - use risingwave_expr::expr::build_from_pretty; use risingwave_storage::memory::MemoryStateStore; use risingwave_storage::table::Distribution; use super::*; + use crate::executor::test_utils::expr::build_from_pretty; use crate::executor::test_utils::{MessageSender, MockSource}; use crate::executor::ActorContext; diff --git a/src/stream/src/from_proto/hash_join.rs b/src/stream/src/from_proto/hash_join.rs index 44799af9405c..32bbfae40889 100644 --- a/src/stream/src/from_proto/hash_join.rs +++ b/src/stream/src/from_proto/hash_join.rs @@ -18,7 +18,7 @@ use std::sync::Arc; use risingwave_common::hash::{HashKey, HashKeyDispatcher}; use risingwave_common::types::DataType; use risingwave_expr::expr::{ - build_func_non_strict, build_non_strict_from_prost, BoxedExpression, InputRefExpression, + build_func_non_strict, build_non_strict_from_prost, InfallibleExpression, InputRefExpression, }; pub use risingwave_pb::expr::expr_node::Type as ExprType; use risingwave_pb::plan_common::JoinType as JoinTypeProto; @@ -109,7 +109,7 @@ impl ExecutorBuilder for HashJoinExecutorBuilder { build_non_strict_from_prost( delta_expression.delta.as_ref().unwrap(), params.eval_error_report.clone(), - )?, + )?.into_inner(), ], params.eval_error_report.clone(), )?) @@ -175,8 +175,8 @@ struct HashJoinExecutorDispatcherArgs { pk_indices: PkIndices, output_indices: Vec, executor_id: u64, - cond: Option, - inequality_pairs: Vec<(usize, usize, bool, Option)>, + cond: Option, + inequality_pairs: Vec<(usize, usize, bool, Option)>, op_info: String, state_table_l: StateTable, degree_state_table_l: StateTable, diff --git a/src/stream/src/from_proto/temporal_join.rs b/src/stream/src/from_proto/temporal_join.rs index 8b7b3b6af133..758f1aac96b5 100644 --- a/src/stream/src/from_proto/temporal_join.rs +++ b/src/stream/src/from_proto/temporal_join.rs @@ -18,7 +18,7 @@ use risingwave_common::catalog::{ColumnDesc, TableId, TableOption}; use risingwave_common::hash::{HashKey, HashKeyDispatcher}; use risingwave_common::types::DataType; use risingwave_common::util::sort_util::OrderType; -use risingwave_expr::expr::{build_non_strict_from_prost, BoxedExpression}; +use risingwave_expr::expr::{build_non_strict_from_prost, InfallibleExpression}; use risingwave_pb::plan_common::{JoinType as JoinTypeProto, StorageTableDesc}; use risingwave_storage::table::batch_table::storage_table::StorageTable; use risingwave_storage::table::Distribution; @@ -190,7 +190,7 @@ struct TemporalJoinExecutorDispatcherArgs { left_join_keys: Vec, right_join_keys: Vec, null_safe: Vec, - condition: Option, + condition: Option, pk_indices: PkIndices, output_indices: Vec, table_output_indices: Vec, diff --git a/src/stream/tests/integration_tests/hop_window.rs b/src/stream/tests/integration_tests/hop_window.rs index 167857cc7d9f..04ccd8dbf51b 100644 --- a/src/stream/tests/integration_tests/hop_window.rs +++ b/src/stream/tests/integration_tests/hop_window.rs @@ -15,6 +15,7 @@ use risingwave_common::types::test_utils::IntervalTestExt; use risingwave_common::types::{Interval, Timestamp}; use risingwave_expr::expr::test_utils::make_hop_window_expression; +use risingwave_expr::expr::InfallibleExpression; use risingwave_stream::executor::{ExecutorInfo, HopWindowExecutor}; use crate::prelude::*; @@ -55,8 +56,14 @@ fn create_executor(output_indices: Vec) -> (MessageSender, BoxedMessageSt TIME_COL_IDX, window_slide, window_size, - window_start_exprs, - window_end_exprs, + window_start_exprs + .into_iter() + .map(InfallibleExpression::for_test) + .collect(), + window_end_exprs + .into_iter() + .map(InfallibleExpression::for_test) + .collect(), output_indices, CHUNK_SIZE, ) diff --git a/src/stream/tests/integration_tests/project_set.rs b/src/stream/tests/integration_tests/project_set.rs index bf1354c25b83..61a879256108 100644 --- a/src/stream/tests/integration_tests/project_set.rs +++ b/src/stream/tests/integration_tests/project_set.rs @@ -29,10 +29,10 @@ fn create_executor() -> (MessageSender, BoxedMessageStream) { }; let (tx, source) = MockSource::channel(schema, PkIndices::new()); - let test_expr = build_from_pretty("(add:int8 $0:int8 $1:int8)"); - let test_expr_watermark = build_from_pretty("(add:int8 $0:int8 1:int8)"); - let tf1 = repeat(build_from_pretty("1:int4"), 1); - let tf2 = repeat(build_from_pretty("2:int4"), 2); + let test_expr = build_from_pretty("(add:int8 $0:int8 $1:int8)").into_inner(); + let test_expr_watermark = build_from_pretty("(add:int8 $0:int8 1:int8)").into_inner(); + let tf1 = repeat(build_from_pretty("1:int4").into_inner(), 1); + let tf2 = repeat(build_from_pretty("2:int4").into_inner(), 2); let project_set = Box::new(ProjectSetExecutor::new( ActorContext::create(123), From 0058837ddcdc1665459c675b897811614ecd8733 Mon Sep 17 00:00:00 2001 From: Bugen Zhao Date: Tue, 17 Oct 2023 20:41:09 +0800 Subject: [PATCH 2/8] split traits and fix todos Signed-off-by: Bugen Zhao --- src/batch/src/executor/aggregation/filter.rs | 2 +- src/batch/src/executor/project_set.rs | 2 +- src/expr/core/src/expr/build.rs | 6 +- src/expr/core/src/expr/mod.rs | 75 +++++-------------- src/expr/impl/src/scalar/cast.rs | 4 +- .../src/table_function/generate_series.rs | 2 +- src/storage/src/row_serde/value_serde.rs | 3 +- src/stream/src/executor/aggregation/mod.rs | 7 +- src/stream/src/executor/mod.rs | 4 +- src/stream/src/executor/project_set.rs | 5 +- src/stream/src/executor/values.rs | 3 +- src/stream/src/executor/watermark_filter.rs | 3 +- src/stream/src/from_proto/hash_join.rs | 3 +- 13 files changed, 46 insertions(+), 73 deletions(-) diff --git a/src/batch/src/executor/aggregation/filter.rs b/src/batch/src/executor/aggregation/filter.rs index 2db2320ed353..9cfbeabffe41 100644 --- a/src/batch/src/executor/aggregation/filter.rs +++ b/src/batch/src/executor/aggregation/filter.rs @@ -75,7 +75,7 @@ impl AggregateFunction for Filter { mod tests { use risingwave_common::test_prelude::StreamChunkTestExt; use risingwave_expr::aggregate::{build_append_only, AggCall}; - use risingwave_expr::expr::{build_from_pretty, Expression, LiteralExpression}; + use risingwave_expr::expr::{build_from_pretty, ExpressionBoxExt, LiteralExpression}; use super::*; diff --git a/src/batch/src/executor/project_set.rs b/src/batch/src/executor/project_set.rs index 670933a6bb50..fa3dfac917e8 100644 --- a/src/batch/src/executor/project_set.rs +++ b/src/batch/src/executor/project_set.rs @@ -171,7 +171,7 @@ mod tests { use risingwave_common::catalog::{Field, Schema}; use risingwave_common::test_prelude::*; use risingwave_common::types::DataType; - use risingwave_expr::expr::{Expression, InputRefExpression, LiteralExpression}; + use risingwave_expr::expr::{ExpressionBoxExt, InputRefExpression, LiteralExpression}; use risingwave_expr::table_function::repeat; use super::*; diff --git a/src/expr/core/src/expr/build.rs b/src/expr/core/src/expr/build.rs index a11883e9db78..2748b575f76a 100644 --- a/src/expr/core/src/expr/build.rs +++ b/src/expr/core/src/expr/build.rs @@ -29,7 +29,9 @@ use super::expr_udf::UdfExpression; use super::expr_vnode::VnodeExpression; use super::wrapper::{Checked, EvalErrorReport, NonStrict}; use super::InfallibleExpression; -use crate::expr::{BoxedExpression, Expression, InputRefExpression, LiteralExpression}; +use crate::expr::{ + BoxedExpression, Expression, ExpressionBoxExt, InputRefExpression, LiteralExpression, +}; use crate::sig::FUNCTION_REGISTRY; use crate::{bail, ExprError, Result}; @@ -156,7 +158,7 @@ impl BuildBoxed for E { prost: &ExprNode, build_child: impl Fn(&ExprNode) -> Result, ) -> Result { - Self::build(prost, build_child).map(Expression::boxed) + Self::build(prost, build_child).map(ExpressionBoxExt::boxed) } } diff --git a/src/expr/core/src/expr/mod.rs b/src/expr/core/src/expr/mod.rs index 04ab40fca1e8..78955e7e871a 100644 --- a/src/expr/core/src/expr/mod.rs +++ b/src/expr/core/src/expr/mod.rs @@ -67,6 +67,7 @@ pub use super::{ExprError, Result}; /// should be implemented. Prefer calling and implementing `eval_v2` instead of `eval` if possible, /// to gain the performance benefit of scalar expression. #[async_trait::async_trait] +#[auto_impl::auto_impl(&, Box)] pub trait Expression: std::fmt::Debug + Sync + Send { /// Get the return data type. fn return_type(&self) -> DataType; @@ -101,78 +102,38 @@ pub trait Expression: std::fmt::Debug + Sync + Send { fn eval_const(&self) -> Result { Err(ExprError::NotConstant) } - - /// Wrap the expression in a Box. - fn boxed(self) -> BoxedExpression - where - Self: Sized + Send + 'static, - { - Box::new(self) - } -} - -// TODO: make this an extension, or implement it on a `NonStrict` newtype. -impl dyn Expression { - /// Evaluate the expression in vectorized execution and assert it succeeds. Returns an array. - /// - /// Use with expressions built in non-strict mode. - pub async fn eval_infallible(&self, input: &DataChunk) -> ArrayRef { - self.eval(input).await.expect("evaluation failed") - } - - /// Evaluate the expression in row-based execution and assert it succeeds. Returns a nullable - /// scalar. - /// - /// Use with expressions built in non-strict mode. - pub async fn eval_row_infallible(&self, input: &OwnedRow) -> Datum { - self.eval_row(input).await.expect("evaluation failed") - } } /// An owned dynamically typed [`Expression`]. pub type BoxedExpression = Box; -// TODO: avoid the overhead of extra boxing. -#[async_trait::async_trait] -impl Expression for BoxedExpression { - fn return_type(&self) -> DataType { - (**self).return_type() - } - - async fn eval(&self, input: &DataChunk) -> Result { - (**self).eval(input).await - } - - async fn eval_v2(&self, input: &DataChunk) -> Result { - (**self).eval_v2(input).await - } - - async fn eval_row(&self, input: &OwnedRow) -> Result { - (**self).eval_row(input).await - } - - fn eval_const(&self) -> Result { - (**self).eval_const() - } - - fn boxed(self) -> BoxedExpression { - self +/// Extension trait for boxing expressions. +#[easy_ext::ext(ExpressionBoxExt)] +impl E { + /// Wrap the expression in a Box. + pub fn boxed(self) -> BoxedExpression { + Box::new(self) } } #[derive(Debug)] pub struct InfallibleExpression(E); -impl InfallibleExpression { - pub fn for_test(inner: impl Expression + 'static) -> Self { - Self(inner.boxed()) - } -} - impl InfallibleExpression where E: Expression, { + pub fn for_test(inner: E) -> InfallibleExpression + where + E: 'static, + { + InfallibleExpression(inner.boxed()) + } + + pub fn todo(inner: E) -> Self { + Self(inner) + } + /// Get the return data type. pub fn return_type(&self) -> DataType { self.0.return_type() diff --git a/src/expr/impl/src/scalar/cast.rs b/src/expr/impl/src/scalar/cast.rs index 889cc43fe6b1..c173c76c330c 100644 --- a/src/expr/impl/src/scalar/cast.rs +++ b/src/expr/impl/src/scalar/cast.rs @@ -22,7 +22,9 @@ use risingwave_common::cast; use risingwave_common::row::OwnedRow; use risingwave_common::types::{DataType, Int256, IntoOrdered, JsonbRef, ToText, F64}; use risingwave_common::util::iter_util::ZipEqFast; -use risingwave_expr::expr::{build_func, Context, Expression, InputRefExpression}; +use risingwave_expr::expr::{ + build_func, Context, Expression, ExpressionBoxExt, InputRefExpression, +}; use risingwave_expr::{function, ExprError, Result}; use risingwave_pb::expr::expr_node::PbType; diff --git a/src/expr/impl/src/table_function/generate_series.rs b/src/expr/impl/src/table_function/generate_series.rs index 586fa60de02c..dfa09b0e215b 100644 --- a/src/expr/impl/src/table_function/generate_series.rs +++ b/src/expr/impl/src/table_function/generate_series.rs @@ -159,7 +159,7 @@ mod tests { use risingwave_common::array::DataChunk; use risingwave_common::types::test_utils::IntervalTestExt; use risingwave_common::types::{DataType, Decimal, Interval, ScalarImpl, Timestamp}; - use risingwave_expr::expr::{BoxedExpression, Expression, LiteralExpression}; + use risingwave_expr::expr::{BoxedExpression, ExpressionBoxExt, LiteralExpression}; use risingwave_expr::table_function::build; use risingwave_expr::ExprError; use risingwave_pb::expr::table_function::PbType; diff --git a/src/storage/src/row_serde/value_serde.rs b/src/storage/src/row_serde/value_serde.rs index 5d56cdba2d96..9048b90c23a5 100644 --- a/src/storage/src/row_serde/value_serde.rs +++ b/src/storage/src/row_serde/value_serde.rs @@ -114,9 +114,10 @@ impl ValueRowSerdeNew for ColumnAwareSerde { // It's okay since we previously banned impure expressions in default columns. build_from_prost(&expr.expect("expr should not be none")) .expect("build_from_prost error") - .eval_row_infallible(&OwnedRow::empty()) + .eval_row(&OwnedRow::empty()) .now_or_never() .expect("constant expression should not be async") + .expect("eval_row failed") }; Some((i, value)) } else { diff --git a/src/stream/src/executor/aggregation/mod.rs b/src/stream/src/executor/aggregation/mod.rs index dd0ce9d01c54..b4d69152c291 100644 --- a/src/stream/src/executor/aggregation/mod.rs +++ b/src/stream/src/executor/aggregation/mod.rs @@ -21,6 +21,7 @@ use risingwave_common::bail; use risingwave_common::buffer::Bitmap; use risingwave_common::catalog::{Field, Schema}; use risingwave_expr::aggregate::{AggCall, AggKind}; +use risingwave_expr::expr::InfallibleExpression; use risingwave_storage::StateStore; use crate::common::table::state_table::StateTable; @@ -74,7 +75,11 @@ pub async fn agg_call_filter_res( } if let Some(ref filter) = agg_call.filter { - if let Bool(filter_res) = filter.eval_infallible(chunk).await.as_ref() { + if let Bool(filter_res) = InfallibleExpression::todo(&**filter) + .eval_infallible(chunk) + .await + .as_ref() + { vis &= filter_res.to_bitmap(); } else { bail!("Filter can only receive bool array"); diff --git a/src/stream/src/executor/mod.rs b/src/stream/src/executor/mod.rs index 2d3957740c4c..e3a341ba3a5d 100644 --- a/src/stream/src/executor/mod.rs +++ b/src/stream/src/executor/mod.rs @@ -31,7 +31,7 @@ use risingwave_common::util::epoch::{Epoch, EpochPair}; use risingwave_common::util::tracing::TracingContext; use risingwave_common::util::value_encoding::{DatumFromProtoExt, DatumToProtoExt}; use risingwave_connector::source::SplitImpl; -use risingwave_expr::expr::{InfallibleExpression, Expression}; +use risingwave_expr::expr::{Expression, InfallibleExpression}; use risingwave_pb::data::PbEpoch; use risingwave_pb::expr::PbInputRef; use risingwave_pb::stream_plan::barrier::{BarrierKind, PbMutation}; @@ -641,7 +641,7 @@ impl Watermark { pub async fn transform_with_expr( self, - expr: &InfallibleExpression, + expr: &InfallibleExpression, new_col_idx: usize, ) -> Option { let Self { col_idx, val, .. } = self; diff --git a/src/stream/src/executor/project_set.rs b/src/stream/src/executor/project_set.rs index 7e722d8dd363..8bf94b86b69e 100644 --- a/src/stream/src/executor/project_set.rs +++ b/src/stream/src/executor/project_set.rs @@ -24,11 +24,10 @@ use risingwave_common::catalog::{Field, Schema}; use risingwave_common::row::{Row, RowExt}; use risingwave_common::types::{DataType, Datum, DatumRef, ToOwnedDatum}; use risingwave_common::util::iter_util::ZipEqFast; -use risingwave_expr::expr::{BoxedExpression, InfallibleExpression}; +use risingwave_expr::expr::InfallibleExpression; use risingwave_expr::table_function::ProjectSetSelectItem; use super::error::StreamExecutorError; -use super::test_utils::expr::build_from_pretty; use super::{ ActorContextRef, BoxedExecutor, Executor, ExecutorInfo, Message, PkIndices, PkIndicesRef, StreamExecutorResult, Watermark, @@ -263,7 +262,7 @@ impl Inner { watermark .clone() .transform_with_expr( - &build_from_pretty(""), // TODO + &InfallibleExpression::todo(expr), expr_idx + PROJ_ROW_ID_OFFSET, ) .await diff --git a/src/stream/src/executor/values.rs b/src/stream/src/executor/values.rs index 73dc0f8f08c8..bb6017ae1441 100644 --- a/src/stream/src/executor/values.rs +++ b/src/stream/src/executor/values.rs @@ -183,7 +183,8 @@ mod tests { let actor_id = progress.actor_id(); let (tx, barrier_receiver) = unbounded_channel(); let value = StructValue::new(vec![Some(1.into()), Some(2.into()), Some(3.into())]); - let exprs = vec![Box::new(LiteralExpression::new( + let exprs = vec![ + Box::new(LiteralExpression::new( DataType::Int16, Some(ScalarImpl::Int16(1)), )) as BoxedExpression, diff --git a/src/stream/src/executor/watermark_filter.rs b/src/stream/src/executor/watermark_filter.rs index e2ff36b69c78..01a6744addaf 100644 --- a/src/stream/src/executor/watermark_filter.rs +++ b/src/stream/src/executor/watermark_filter.rs @@ -23,7 +23,8 @@ use risingwave_common::row::{OwnedRow, Row}; use risingwave_common::types::{DataType, DefaultOrd, ScalarImpl}; use risingwave_common::{bail, row}; use risingwave_expr::expr::{ - build_func_non_strict, Expression, InfallibleExpression, InputRefExpression, LiteralExpression, + build_func_non_strict, ExpressionBoxExt, InfallibleExpression, InputRefExpression, + LiteralExpression, }; use risingwave_expr::Result as ExprResult; use risingwave_pb::expr::expr_node::Type; diff --git a/src/stream/src/from_proto/hash_join.rs b/src/stream/src/from_proto/hash_join.rs index 32bbfae40889..c3eed40b1839 100644 --- a/src/stream/src/from_proto/hash_join.rs +++ b/src/stream/src/from_proto/hash_join.rs @@ -109,7 +109,8 @@ impl ExecutorBuilder for HashJoinExecutorBuilder { build_non_strict_from_prost( delta_expression.delta.as_ref().unwrap(), params.eval_error_report.clone(), - )?.into_inner(), + )? + .into_inner(), ], params.eval_error_report.clone(), )?) From 0de57921986bc53ba32952e91add28f6b39c5222 Mon Sep 17 00:00:00 2001 From: Bugen Zhao Date: Tue, 17 Oct 2023 20:44:35 +0800 Subject: [PATCH 3/8] rename to non strict expression Signed-off-by: Bugen Zhao --- src/expr/core/src/expr/build.rs | 14 ++++++++------ src/expr/core/src/expr/mod.rs | 8 ++++---- src/expr/core/src/expr/wrapper/checked.rs | 2 +- src/expr/core/src/expr/wrapper/mod.rs | 7 +++---- src/expr/core/src/expr/wrapper/non_strict.rs | 2 +- src/stream/src/executor/aggregation/mod.rs | 4 ++-- src/stream/src/executor/dynamic_filter.rs | 4 ++-- src/stream/src/executor/filter.rs | 6 +++--- src/stream/src/executor/hash_join.rs | 16 ++++++++-------- src/stream/src/executor/hop_window.rs | 16 ++++++++-------- src/stream/src/executor/integration_tests.rs | 2 +- src/stream/src/executor/mod.rs | 4 ++-- src/stream/src/executor/project.rs | 8 ++++---- src/stream/src/executor/project_set.rs | 4 ++-- src/stream/src/executor/temporal_join.rs | 6 +++--- src/stream/src/executor/test_utils.rs | 6 +++--- src/stream/src/executor/values.rs | 10 +++++----- src/stream/src/executor/watermark_filter.rs | 10 +++++----- src/stream/src/from_proto/hash_join.rs | 6 +++--- src/stream/src/from_proto/temporal_join.rs | 4 ++-- src/stream/tests/integration_tests/hop_window.rs | 6 +++--- 21 files changed, 73 insertions(+), 72 deletions(-) diff --git a/src/expr/core/src/expr/build.rs b/src/expr/core/src/expr/build.rs index 2748b575f76a..7dffbcd42d66 100644 --- a/src/expr/core/src/expr/build.rs +++ b/src/expr/core/src/expr/build.rs @@ -27,8 +27,10 @@ use super::expr_in::InExpression; use super::expr_some_all::SomeAllExpression; use super::expr_udf::UdfExpression; use super::expr_vnode::VnodeExpression; -use super::wrapper::{Checked, EvalErrorReport, NonStrict}; -use super::InfallibleExpression; +use super::wrapper::checked::Checked; +use super::wrapper::non_strict::NonStrict; +use super::wrapper::EvalErrorReport; +use super::NonStrictExpression; use crate::expr::{ BoxedExpression, Expression, ExpressionBoxExt, InputRefExpression, LiteralExpression, }; @@ -44,10 +46,10 @@ pub fn build_from_prost(prost: &ExprNode) -> Result { pub fn build_non_strict_from_prost( prost: &ExprNode, error_report: impl EvalErrorReport + 'static, -) -> Result { +) -> Result { ExprBuilder::new_non_strict(error_report) .build(prost) - .map(InfallibleExpression) + .map(NonStrictExpression) } /// Build an expression from protobuf with possibly some wrappers attached to each node. @@ -222,9 +224,9 @@ pub fn build_func_non_strict( ret_type: DataType, children: Vec, error_report: impl EvalErrorReport + 'static, -) -> Result { +) -> Result { let expr = build_func(func, ret_type, children)?; - let wrapped = InfallibleExpression(ExprBuilder::new_non_strict(error_report).wrap(expr)); + let wrapped = NonStrictExpression(ExprBuilder::new_non_strict(error_report).wrap(expr)); Ok(wrapped) } diff --git a/src/expr/core/src/expr/mod.rs b/src/expr/core/src/expr/mod.rs index 78955e7e871a..7b65eb1fe814 100644 --- a/src/expr/core/src/expr/mod.rs +++ b/src/expr/core/src/expr/mod.rs @@ -117,17 +117,17 @@ impl E { } #[derive(Debug)] -pub struct InfallibleExpression(E); +pub struct NonStrictExpression(E); -impl InfallibleExpression +impl NonStrictExpression where E: Expression, { - pub fn for_test(inner: E) -> InfallibleExpression + pub fn for_test(inner: E) -> NonStrictExpression where E: 'static, { - InfallibleExpression(inner.boxed()) + NonStrictExpression(inner.boxed()) } pub fn todo(inner: E) -> Self { diff --git a/src/expr/core/src/expr/wrapper/checked.rs b/src/expr/core/src/expr/wrapper/checked.rs index 1e049ad48101..b3b1375c4fa8 100644 --- a/src/expr/core/src/expr/wrapper/checked.rs +++ b/src/expr/core/src/expr/wrapper/checked.rs @@ -22,7 +22,7 @@ use crate::expr::{Expression, ValueImpl}; /// A wrapper of [`Expression`] that does extra checks after evaluation. #[derive(Debug)] -pub struct Checked(pub E); +pub(crate) struct Checked(pub E); // TODO: avoid the overhead of extra boxing. #[async_trait] diff --git a/src/expr/core/src/expr/wrapper/mod.rs b/src/expr/core/src/expr/wrapper/mod.rs index 48241d05de45..c93da021c882 100644 --- a/src/expr/core/src/expr/wrapper/mod.rs +++ b/src/expr/core/src/expr/wrapper/mod.rs @@ -12,8 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -mod checked; -mod non_strict; +pub(crate) mod checked; +pub(crate) mod non_strict; -pub use checked::Checked; -pub use non_strict::{EvalErrorReport, NonStrict}; +pub use non_strict::EvalErrorReport; diff --git a/src/expr/core/src/expr/wrapper/non_strict.rs b/src/expr/core/src/expr/wrapper/non_strict.rs index 0859cea27aa4..04213c194e62 100644 --- a/src/expr/core/src/expr/wrapper/non_strict.rs +++ b/src/expr/core/src/expr/wrapper/non_strict.rs @@ -46,7 +46,7 @@ impl EvalErrorReport for ! { /// - When an error occurs during chunk-level evaluation, recompute in row-based execution and pad /// with NULL for each failed row. /// - Report all error occurred during row-level evaluation to the [`EvalErrorReport`]. -pub struct NonStrict { +pub(crate) struct NonStrict { inner: E, report: R, } diff --git a/src/stream/src/executor/aggregation/mod.rs b/src/stream/src/executor/aggregation/mod.rs index b4d69152c291..b25d95855b42 100644 --- a/src/stream/src/executor/aggregation/mod.rs +++ b/src/stream/src/executor/aggregation/mod.rs @@ -21,7 +21,7 @@ use risingwave_common::bail; use risingwave_common::buffer::Bitmap; use risingwave_common::catalog::{Field, Schema}; use risingwave_expr::aggregate::{AggCall, AggKind}; -use risingwave_expr::expr::InfallibleExpression; +use risingwave_expr::expr::NonStrictExpression; use risingwave_storage::StateStore; use crate::common::table::state_table::StateTable; @@ -75,7 +75,7 @@ pub async fn agg_call_filter_res( } if let Some(ref filter) = agg_call.filter { - if let Bool(filter_res) = InfallibleExpression::todo(&**filter) + if let Bool(filter_res) = NonStrictExpression::todo(&**filter) .eval_infallible(chunk) .await .as_ref() diff --git a/src/stream/src/executor/dynamic_filter.rs b/src/stream/src/executor/dynamic_filter.rs index b86a2a3d67c1..5828ae1e15c1 100644 --- a/src/stream/src/executor/dynamic_filter.rs +++ b/src/stream/src/executor/dynamic_filter.rs @@ -26,7 +26,7 @@ use risingwave_common::row::{self, once, OwnedRow, OwnedRow as RowData, Row}; use risingwave_common::types::{DataType, Datum, DefaultOrd, ScalarImpl, ToDatumRef, ToOwnedDatum}; use risingwave_common::util::iter_util::ZipEqDebug; use risingwave_expr::expr::{ - build_func_non_strict, InfallibleExpression, InputRefExpression, LiteralExpression, + build_func_non_strict, InputRefExpression, LiteralExpression, NonStrictExpression, }; use risingwave_pb::expr::expr_node::Type as ExprNodeType; use risingwave_pb::expr::expr_node::Type::{ @@ -97,7 +97,7 @@ impl DynamicFilterExecutor, + condition: Option, ) -> Result<(Vec, Bitmap), StreamExecutorError> { let mut new_ops = Vec::with_capacity(chunk.capacity()); let mut new_visibility = BitmapBuilder::with_capacity(chunk.capacity()); diff --git a/src/stream/src/executor/filter.rs b/src/stream/src/executor/filter.rs index 236f203dcbe1..1a1e645e44e6 100644 --- a/src/stream/src/executor/filter.rs +++ b/src/stream/src/executor/filter.rs @@ -19,7 +19,7 @@ use risingwave_common::array::{Array, ArrayImpl, Op, StreamChunk}; use risingwave_common::buffer::BitmapBuilder; use risingwave_common::catalog::Schema; use risingwave_common::util::iter_util::ZipEqFast; -use risingwave_expr::expr::InfallibleExpression; +use risingwave_expr::expr::NonStrictExpression; use super::*; @@ -34,14 +34,14 @@ pub struct FilterExecutor { /// Expression of the current filter, note that the filter must always have the same output for /// the same input. - expr: InfallibleExpression, + expr: NonStrictExpression, } impl FilterExecutor { pub fn new( ctx: ActorContextRef, input: Box, - expr: InfallibleExpression, + expr: NonStrictExpression, executor_id: u64, ) -> Self { let input_info = input.info(); diff --git a/src/stream/src/executor/hash_join.rs b/src/stream/src/executor/hash_join.rs index cca953022369..4178012cb9d9 100644 --- a/src/stream/src/executor/hash_join.rs +++ b/src/stream/src/executor/hash_join.rs @@ -28,7 +28,7 @@ use risingwave_common::row::{OwnedRow, Row}; use risingwave_common::types::{DataType, DefaultOrd, ToOwnedDatum}; use risingwave_common::util::epoch::EpochPair; use risingwave_common::util::iter_util::ZipEqDebug; -use risingwave_expr::expr::InfallibleExpression; +use risingwave_expr::expr::NonStrictExpression; use risingwave_expr::ExprError; use risingwave_storage::StateStore; use tokio::time::Instant; @@ -242,9 +242,9 @@ pub struct HashJoinExecutor, /// Optional non-equi join conditions - cond: Option, + cond: Option, /// Column indices of watermark output and offset expression of each inequality, respectively. - inequality_pairs: Vec<(Vec, Option)>, + inequality_pairs: Vec<(Vec, Option)>, /// The output watermark of each inequality condition and its value is the minimum of the /// calculation result of both side. It will be used to generate watermark into downstream /// and do state cleaning if `clean_state` field of that inequality is `true`. @@ -313,7 +313,7 @@ struct EqJoinArgs<'a, K: HashKey, S: StateStore> { side_l: &'a mut JoinSide, side_r: &'a mut JoinSide, actual_output_data_types: &'a [DataType], - cond: &'a mut Option, + cond: &'a mut Option, inequality_watermarks: &'a [Option], chunk: StreamChunk, append_only_optimize: bool, @@ -448,8 +448,8 @@ impl HashJoinExecutor, executor_id: u64, - cond: Option, - inequality_pairs: Vec<(usize, usize, bool, Option)>, + cond: Option, + inequality_pairs: Vec<(usize, usize, bool, Option)>, op_info: String, state_table_l: StateTable, degree_state_table_l: StateTable, @@ -1327,7 +1327,7 @@ mod tests { (state_table, degree_state_table) } - fn create_cond(condition_text: Option) -> InfallibleExpression { + fn create_cond(condition_text: Option) -> NonStrictExpression { build_from_pretty( condition_text .as_deref() @@ -1339,7 +1339,7 @@ mod tests { with_condition: bool, null_safe: bool, condition_text: Option, - inequality_pairs: Vec<(usize, usize, bool, Option)>, + inequality_pairs: Vec<(usize, usize, bool, Option)>, ) -> (MessageSender, MessageSender, BoxedMessageStream) { let schema = Schema { fields: vec![ diff --git a/src/stream/src/executor/hop_window.rs b/src/stream/src/executor/hop_window.rs index 7c8c44330c86..42d13d790da8 100644 --- a/src/stream/src/executor/hop_window.rs +++ b/src/stream/src/executor/hop_window.rs @@ -19,7 +19,7 @@ use futures_async_stream::try_stream; use itertools::Itertools; use risingwave_common::array::{DataChunk, Op}; use risingwave_common::types::Interval; -use risingwave_expr::expr::InfallibleExpression; +use risingwave_expr::expr::NonStrictExpression; use risingwave_expr::ExprError; use super::error::StreamExecutorError; @@ -33,8 +33,8 @@ pub struct HopWindowExecutor { pub time_col_idx: usize, pub window_slide: Interval, pub window_size: Interval, - window_start_exprs: Vec, - window_end_exprs: Vec, + window_start_exprs: Vec, + window_end_exprs: Vec, pub output_indices: Vec, chunk_size: usize, } @@ -48,8 +48,8 @@ impl HopWindowExecutor { time_col_idx: usize, window_slide: Interval, window_size: Interval, - window_start_exprs: Vec, - window_end_exprs: Vec, + window_start_exprs: Vec, + window_end_exprs: Vec, output_indices: Vec, chunk_size: usize, ) -> Self { @@ -251,7 +251,7 @@ mod tests { use risingwave_common::types::test_utils::IntervalTestExt; use risingwave_common::types::{DataType, Interval}; use risingwave_expr::expr::test_utils::make_hop_window_expression; - use risingwave_expr::expr::InfallibleExpression; + use risingwave_expr::expr::NonStrictExpression; use crate::executor::test_utils::MockSource; use crate::executor::{ActorContext, Executor, ExecutorInfo, StreamChunk}; @@ -305,11 +305,11 @@ mod tests { window_size, window_start_exprs .into_iter() - .map(InfallibleExpression::for_test) + .map(NonStrictExpression::for_test) .collect(), window_end_exprs .into_iter() - .map(InfallibleExpression::for_test) + .map(NonStrictExpression::for_test) .collect(), output_indices, CHUNK_SIZE, diff --git a/src/stream/src/executor/integration_tests.rs b/src/stream/src/executor/integration_tests.rs index 1ae92bf3dcd1..cd505093294f 100644 --- a/src/stream/src/executor/integration_tests.rs +++ b/src/stream/src/executor/integration_tests.rs @@ -152,7 +152,7 @@ async fn test_merger_sum_aggr() { vec![], vec![ // TODO: use the new streaming_if_null expression here, and add `None` tests - InfallibleExpression::for_test(InputRefExpression::new(DataType::Int64, 1)), + NonStrictExpression::for_test(InputRefExpression::new(DataType::Int64, 1)), ], 3, MultiMap::new(), diff --git a/src/stream/src/executor/mod.rs b/src/stream/src/executor/mod.rs index e3a341ba3a5d..c28d6ec8564d 100644 --- a/src/stream/src/executor/mod.rs +++ b/src/stream/src/executor/mod.rs @@ -31,7 +31,7 @@ use risingwave_common::util::epoch::{Epoch, EpochPair}; use risingwave_common::util::tracing::TracingContext; use risingwave_common::util::value_encoding::{DatumFromProtoExt, DatumToProtoExt}; use risingwave_connector::source::SplitImpl; -use risingwave_expr::expr::{Expression, InfallibleExpression}; +use risingwave_expr::expr::{Expression, NonStrictExpression}; use risingwave_pb::data::PbEpoch; use risingwave_pb::expr::PbInputRef; use risingwave_pb::stream_plan::barrier::{BarrierKind, PbMutation}; @@ -641,7 +641,7 @@ impl Watermark { pub async fn transform_with_expr( self, - expr: &InfallibleExpression, + expr: &NonStrictExpression, new_col_idx: usize, ) -> Option { let Self { col_idx, val, .. } = self; diff --git a/src/stream/src/executor/project.rs b/src/stream/src/executor/project.rs index 2e40568a7167..8cfebfecd3f3 100644 --- a/src/stream/src/executor/project.rs +++ b/src/stream/src/executor/project.rs @@ -21,7 +21,7 @@ use risingwave_common::catalog::{Field, Schema}; use risingwave_common::row::{Row, RowExt}; use risingwave_common::types::ToOwnedDatum; use risingwave_common::util::iter_util::ZipEqFast; -use risingwave_expr::expr::InfallibleExpression; +use risingwave_expr::expr::NonStrictExpression; use super::*; @@ -38,7 +38,7 @@ struct Inner { info: ExecutorInfo, /// Expressions of the current projection. - exprs: Vec, + exprs: Vec, /// All the watermark derivations, (input_column_index, output_column_index). And the /// derivation expression is the project's expression itself. watermark_derivations: MultiMap, @@ -58,7 +58,7 @@ impl ProjectExecutor { ctx: ActorContextRef, input: Box, pk_indices: PkIndices, - exprs: Vec, + exprs: Vec, executor_id: u64, watermark_derivations: MultiMap, nondecreasing_expr_indices: Vec, @@ -346,7 +346,7 @@ mod tests { let a_expr = build_from_pretty("(add:int8 $0:int8 1:int8)"); let b_expr = build_from_pretty("(subtract:int8 $0:int8 1:int8)"); - let c_expr = InfallibleExpression::for_test(DummyNondecreasingExpr); + let c_expr = NonStrictExpression::for_test(DummyNondecreasingExpr); let project = Box::new(ProjectExecutor::new( ActorContext::create(123), diff --git a/src/stream/src/executor/project_set.rs b/src/stream/src/executor/project_set.rs index 8bf94b86b69e..53e1d7a6277b 100644 --- a/src/stream/src/executor/project_set.rs +++ b/src/stream/src/executor/project_set.rs @@ -24,7 +24,7 @@ use risingwave_common::catalog::{Field, Schema}; use risingwave_common::row::{Row, RowExt}; use risingwave_common::types::{DataType, Datum, DatumRef, ToOwnedDatum}; use risingwave_common::util::iter_util::ZipEqFast; -use risingwave_expr::expr::InfallibleExpression; +use risingwave_expr::expr::NonStrictExpression; use risingwave_expr::table_function::ProjectSetSelectItem; use super::error::StreamExecutorError; @@ -262,7 +262,7 @@ impl Inner { watermark .clone() .transform_with_expr( - &InfallibleExpression::todo(expr), + &NonStrictExpression::todo(expr), expr_idx + PROJ_ROW_ID_OFFSET, ) .await diff --git a/src/stream/src/executor/temporal_join.rs b/src/stream/src/executor/temporal_join.rs index bb456ed570d6..82c1e5664967 100644 --- a/src/stream/src/executor/temporal_join.rs +++ b/src/stream/src/executor/temporal_join.rs @@ -32,7 +32,7 @@ use risingwave_common::hash::{HashKey, NullBitmap}; use risingwave_common::row::{OwnedRow, Row, RowExt}; use risingwave_common::types::DataType; use risingwave_common::util::iter_util::ZipEqDebug; -use risingwave_expr::expr::InfallibleExpression; +use risingwave_expr::expr::NonStrictExpression; use risingwave_hummock_sdk::{HummockEpoch, HummockReadEpoch}; use risingwave_storage::store::PrefetchOptions; use risingwave_storage::table::batch_table::storage_table::StorageTable; @@ -57,7 +57,7 @@ pub struct TemporalJoinExecutor, right_join_keys: Vec, null_safe: Vec, - condition: Option, + condition: Option, output_indices: Vec, pk_indices: PkIndices, schema: Schema, @@ -338,7 +338,7 @@ impl TemporalJoinExecutor left_join_keys: Vec, right_join_keys: Vec, null_safe: Vec, - condition: Option, + condition: Option, pk_indices: PkIndices, output_indices: Vec, table_output_indices: Vec, diff --git a/src/stream/src/executor/test_utils.rs b/src/stream/src/executor/test_utils.rs index 7d3c48de3a5e..13a9237cf015 100644 --- a/src/stream/src/executor/test_utils.rs +++ b/src/stream/src/executor/test_utils.rs @@ -264,10 +264,10 @@ pub trait StreamExecutorTestExt: MessageStream + Unpin { impl StreamExecutorTestExt for BoxedMessageStream {} pub mod expr { - use risingwave_expr::expr::InfallibleExpression; + use risingwave_expr::expr::NonStrictExpression; - pub fn build_from_pretty(s: impl AsRef) -> InfallibleExpression { - InfallibleExpression::for_test(risingwave_expr::expr::build_from_pretty(s)) + pub fn build_from_pretty(s: impl AsRef) -> NonStrictExpression { + NonStrictExpression::for_test(risingwave_expr::expr::build_from_pretty(s)) } } diff --git a/src/stream/src/executor/values.rs b/src/stream/src/executor/values.rs index bb6017ae1441..8c09b56aa355 100644 --- a/src/stream/src/executor/values.rs +++ b/src/stream/src/executor/values.rs @@ -21,7 +21,7 @@ use risingwave_common::array::{DataChunk, Op, StreamChunk}; use risingwave_common::catalog::Schema; use risingwave_common::ensure; use risingwave_common::util::iter_util::ZipEqFast; -use risingwave_expr::expr::InfallibleExpression; +use risingwave_expr::expr::NonStrictExpression; use tokio::sync::mpsc::UnboundedReceiver; use super::{ @@ -40,7 +40,7 @@ pub struct ValuesExecutor { barrier_receiver: UnboundedReceiver, progress: CreateMviewProgress, - rows: vec::IntoIter>, + rows: vec::IntoIter>, pk_indices: PkIndices, identity: String, schema: Schema, @@ -51,7 +51,7 @@ impl ValuesExecutor { pub fn new( ctx: ActorContextRef, progress: CreateMviewProgress, - rows: Vec>, + rows: Vec>, schema: Schema, barrier_receiver: UnboundedReceiver, executor_id: u64, @@ -167,7 +167,7 @@ mod tests { }; use risingwave_common::catalog::{Field, Schema}; use risingwave_common::types::{DataType, ScalarImpl, StructType}; - use risingwave_expr::expr::{BoxedExpression, InfallibleExpression, LiteralExpression}; + use risingwave_expr::expr::{BoxedExpression, LiteralExpression, NonStrictExpression}; use tokio::sync::mpsc::unbounded_channel; use super::ValuesExecutor; @@ -217,7 +217,7 @@ mod tests { progress, vec![exprs .into_iter() - .map(InfallibleExpression::for_test) + .map(NonStrictExpression::for_test) .collect()], Schema { fields }, barrier_receiver, diff --git a/src/stream/src/executor/watermark_filter.rs b/src/stream/src/executor/watermark_filter.rs index 01a6744addaf..5e5454cecff9 100644 --- a/src/stream/src/executor/watermark_filter.rs +++ b/src/stream/src/executor/watermark_filter.rs @@ -23,8 +23,8 @@ use risingwave_common::row::{OwnedRow, Row}; use risingwave_common::types::{DataType, DefaultOrd, ScalarImpl}; use risingwave_common::{bail, row}; use risingwave_expr::expr::{ - build_func_non_strict, ExpressionBoxExt, InfallibleExpression, InputRefExpression, - LiteralExpression, + build_func_non_strict, ExpressionBoxExt, InputRefExpression, LiteralExpression, + NonStrictExpression, }; use risingwave_expr::Result as ExprResult; use risingwave_pb::expr::expr_node::Type; @@ -45,7 +45,7 @@ use crate::task::ActorEvalErrorReport; pub struct WatermarkFilterExecutor { input: BoxedExecutor, /// The expression used to calculate the watermark value. - watermark_expr: InfallibleExpression, + watermark_expr: NonStrictExpression, /// The column we should generate watermark and filter on. event_time_col_idx: usize, ctx: ActorContextRef, @@ -56,7 +56,7 @@ pub struct WatermarkFilterExecutor { impl WatermarkFilterExecutor { pub fn new( input: BoxedExecutor, - watermark_expr: InfallibleExpression, + watermark_expr: NonStrictExpression, event_time_col_idx: usize, ctx: ActorContextRef, table: StateTable, @@ -299,7 +299,7 @@ impl WatermarkFilterExecutor { event_time_col_idx: usize, watermark: ScalarImpl, eval_error_report: ActorEvalErrorReport, - ) -> ExprResult { + ) -> ExprResult { build_func_non_strict( Type::GreaterThanOrEqual, DataType::Boolean, diff --git a/src/stream/src/from_proto/hash_join.rs b/src/stream/src/from_proto/hash_join.rs index c3eed40b1839..87174282e517 100644 --- a/src/stream/src/from_proto/hash_join.rs +++ b/src/stream/src/from_proto/hash_join.rs @@ -18,7 +18,7 @@ use std::sync::Arc; use risingwave_common::hash::{HashKey, HashKeyDispatcher}; use risingwave_common::types::DataType; use risingwave_expr::expr::{ - build_func_non_strict, build_non_strict_from_prost, InfallibleExpression, InputRefExpression, + build_func_non_strict, build_non_strict_from_prost, InputRefExpression, NonStrictExpression, }; pub use risingwave_pb::expr::expr_node::Type as ExprType; use risingwave_pb::plan_common::JoinType as JoinTypeProto; @@ -176,8 +176,8 @@ struct HashJoinExecutorDispatcherArgs { pk_indices: PkIndices, output_indices: Vec, executor_id: u64, - cond: Option, - inequality_pairs: Vec<(usize, usize, bool, Option)>, + cond: Option, + inequality_pairs: Vec<(usize, usize, bool, Option)>, op_info: String, state_table_l: StateTable, degree_state_table_l: StateTable, diff --git a/src/stream/src/from_proto/temporal_join.rs b/src/stream/src/from_proto/temporal_join.rs index 758f1aac96b5..58699089e8c2 100644 --- a/src/stream/src/from_proto/temporal_join.rs +++ b/src/stream/src/from_proto/temporal_join.rs @@ -18,7 +18,7 @@ use risingwave_common::catalog::{ColumnDesc, TableId, TableOption}; use risingwave_common::hash::{HashKey, HashKeyDispatcher}; use risingwave_common::types::DataType; use risingwave_common::util::sort_util::OrderType; -use risingwave_expr::expr::{build_non_strict_from_prost, InfallibleExpression}; +use risingwave_expr::expr::{build_non_strict_from_prost, NonStrictExpression}; use risingwave_pb::plan_common::{JoinType as JoinTypeProto, StorageTableDesc}; use risingwave_storage::table::batch_table::storage_table::StorageTable; use risingwave_storage::table::Distribution; @@ -190,7 +190,7 @@ struct TemporalJoinExecutorDispatcherArgs { left_join_keys: Vec, right_join_keys: Vec, null_safe: Vec, - condition: Option, + condition: Option, pk_indices: PkIndices, output_indices: Vec, table_output_indices: Vec, diff --git a/src/stream/tests/integration_tests/hop_window.rs b/src/stream/tests/integration_tests/hop_window.rs index 04ccd8dbf51b..9d6d879240fc 100644 --- a/src/stream/tests/integration_tests/hop_window.rs +++ b/src/stream/tests/integration_tests/hop_window.rs @@ -15,7 +15,7 @@ use risingwave_common::types::test_utils::IntervalTestExt; use risingwave_common::types::{Interval, Timestamp}; use risingwave_expr::expr::test_utils::make_hop_window_expression; -use risingwave_expr::expr::InfallibleExpression; +use risingwave_expr::expr::NonStrictExpression; use risingwave_stream::executor::{ExecutorInfo, HopWindowExecutor}; use crate::prelude::*; @@ -58,11 +58,11 @@ fn create_executor(output_indices: Vec) -> (MessageSender, BoxedMessageSt window_size, window_start_exprs .into_iter() - .map(InfallibleExpression::for_test) + .map(NonStrictExpression::for_test) .collect(), window_end_exprs .into_iter() - .map(InfallibleExpression::for_test) + .map(NonStrictExpression::for_test) .collect(), output_indices, CHUNK_SIZE, From 95084fda96c410ff0be14f23c6e4c29c52e9a31f Mon Sep 17 00:00:00 2001 From: Bugen Zhao Date: Tue, 17 Oct 2023 21:06:59 +0800 Subject: [PATCH 4/8] refine todo behavior Signed-off-by: Bugen Zhao --- src/expr/core/src/expr/mod.rs | 11 +++++++++-- src/expr/core/src/expr/wrapper/mod.rs | 2 +- src/expr/core/src/expr/wrapper/non_strict.rs | 10 ++++++++++ 3 files changed, 20 insertions(+), 3 deletions(-) diff --git a/src/expr/core/src/expr/mod.rs b/src/expr/core/src/expr/mod.rs index 7b65eb1fe814..1b23bb85fd40 100644 --- a/src/expr/core/src/expr/mod.rs +++ b/src/expr/core/src/expr/mod.rs @@ -123,6 +123,9 @@ impl NonStrictExpression where E: Expression, { + /// Create a non-strict expression directly wrapping the given expression. + /// + /// Should only be used in tests as evaluation may panic. pub fn for_test(inner: E) -> NonStrictExpression where E: 'static, @@ -130,8 +133,12 @@ where NonStrictExpression(inner.boxed()) } - pub fn todo(inner: E) -> Self { - Self(inner) + /// Create a non-strict expression from the given expression, where only the evaluation of the + /// top-level expression is non-strict (which is subtly different from + /// [`crate::expr::build_non_strict_from_prost`]), and error will only be simply logged. + pub fn todo(inner: E) -> NonStrictExpression { + let inner = wrapper::non_strict::NonStrict::new(inner, wrapper::LogReport); + NonStrictExpression(inner) } /// Get the return data type. diff --git a/src/expr/core/src/expr/wrapper/mod.rs b/src/expr/core/src/expr/wrapper/mod.rs index c93da021c882..16988a050ad8 100644 --- a/src/expr/core/src/expr/wrapper/mod.rs +++ b/src/expr/core/src/expr/wrapper/mod.rs @@ -15,4 +15,4 @@ pub(crate) mod checked; pub(crate) mod non_strict; -pub use non_strict::EvalErrorReport; +pub use non_strict::{EvalErrorReport, LogReport}; diff --git a/src/expr/core/src/expr/wrapper/non_strict.rs b/src/expr/core/src/expr/wrapper/non_strict.rs index 04213c194e62..0afbe5a0afc9 100644 --- a/src/expr/core/src/expr/wrapper/non_strict.rs +++ b/src/expr/core/src/expr/wrapper/non_strict.rs @@ -42,6 +42,16 @@ impl EvalErrorReport for ! { } } +/// Log the error to report an error during evaluation. +#[derive(Clone)] +pub struct LogReport; + +impl EvalErrorReport for LogReport { + fn report(&self, error: ExprError) { + tracing::error!(%error, "failed to evaluate expression"); + } +} + /// A wrapper of [`Expression`] that evaluates in a non-strict way. Basically... /// - When an error occurs during chunk-level evaluation, recompute in row-based execution and pad /// with NULL for each failed row. From 9edb4cc577c06994a3fa89756080faf10493fc3f Mon Sep 17 00:00:00 2001 From: Bugen Zhao Date: Tue, 17 Oct 2023 21:26:28 +0800 Subject: [PATCH 5/8] pass in error report Signed-off-by: Bugen Zhao --- src/expr/core/src/expr/mod.rs | 13 +++++++++---- src/stream/src/executor/aggregation/mod.rs | 4 ++-- src/stream/src/executor/dynamic_filter.rs | 2 +- src/stream/src/executor/project_set.rs | 8 ++++---- 4 files changed, 16 insertions(+), 11 deletions(-) diff --git a/src/expr/core/src/expr/mod.rs b/src/expr/core/src/expr/mod.rs index 1b23bb85fd40..528a303fdab9 100644 --- a/src/expr/core/src/expr/mod.rs +++ b/src/expr/core/src/expr/mod.rs @@ -58,7 +58,7 @@ pub use self::build::*; pub use self::expr_input_ref::InputRefExpression; pub use self::expr_literal::LiteralExpression; pub use self::value::{ValueImpl, ValueRef}; -pub use self::wrapper::EvalErrorReport; +pub use self::wrapper::*; pub use super::{ExprError, Result}; /// Interface of an expression. @@ -135,9 +135,14 @@ where /// Create a non-strict expression from the given expression, where only the evaluation of the /// top-level expression is non-strict (which is subtly different from - /// [`crate::expr::build_non_strict_from_prost`]), and error will only be simply logged. - pub fn todo(inner: E) -> NonStrictExpression { - let inner = wrapper::non_strict::NonStrict::new(inner, wrapper::LogReport); + /// [`crate::expr::build_non_strict_from_prost`]). + /// + /// This should be used as a "TODO". + pub fn todo( + inner: E, + error_report: impl EvalErrorReport, + ) -> NonStrictExpression { + let inner = wrapper::non_strict::NonStrict::new(inner, error_report); NonStrictExpression(inner) } diff --git a/src/stream/src/executor/aggregation/mod.rs b/src/stream/src/executor/aggregation/mod.rs index b25d95855b42..9e05b050d1e0 100644 --- a/src/stream/src/executor/aggregation/mod.rs +++ b/src/stream/src/executor/aggregation/mod.rs @@ -21,7 +21,7 @@ use risingwave_common::bail; use risingwave_common::buffer::Bitmap; use risingwave_common::catalog::{Field, Schema}; use risingwave_expr::aggregate::{AggCall, AggKind}; -use risingwave_expr::expr::NonStrictExpression; +use risingwave_expr::expr::{LogReport, NonStrictExpression}; use risingwave_storage::StateStore; use crate::common::table::state_table::StateTable; @@ -75,7 +75,7 @@ pub async fn agg_call_filter_res( } if let Some(ref filter) = agg_call.filter { - if let Bool(filter_res) = NonStrictExpression::todo(&**filter) + if let Bool(filter_res) = NonStrictExpression::todo(&**filter, LogReport) .eval_infallible(chunk) .await .as_ref() diff --git a/src/stream/src/executor/dynamic_filter.rs b/src/stream/src/executor/dynamic_filter.rs index 5828ae1e15c1..ccb55b75c24f 100644 --- a/src/stream/src/executor/dynamic_filter.rs +++ b/src/stream/src/executor/dynamic_filter.rs @@ -265,7 +265,7 @@ impl DynamicFilterExecutor, chunk_size: usize, @@ -84,7 +84,7 @@ impl ProjectSetExecutor { let inner = Inner { info, - _ctx: ctx, + ctx, select_list, chunk_size, watermark_derivations, @@ -262,7 +262,7 @@ impl Inner { watermark .clone() .transform_with_expr( - &NonStrictExpression::todo(expr), + &NonStrictExpression::todo(expr, LogReport), expr_idx + PROJ_ROW_ID_OFFSET, ) .await From b347510adf3eaa88eb08c630df85b8f8fcd69d10 Mon Sep 17 00:00:00 2001 From: Bugen Zhao Date: Wed, 18 Oct 2023 11:41:48 +0800 Subject: [PATCH 6/8] rename todo to topmost Signed-off-by: Bugen Zhao --- src/expr/core/src/expr/mod.rs | 8 ++++---- src/expr/core/src/expr/wrapper/non_strict.rs | 2 +- src/stream/src/executor/aggregation/mod.rs | 3 ++- src/stream/src/executor/project_set.rs | 7 ++++--- 4 files changed, 11 insertions(+), 9 deletions(-) diff --git a/src/expr/core/src/expr/mod.rs b/src/expr/core/src/expr/mod.rs index 528a303fdab9..b80bf5d4a2c6 100644 --- a/src/expr/core/src/expr/mod.rs +++ b/src/expr/core/src/expr/mod.rs @@ -134,11 +134,11 @@ where } /// Create a non-strict expression from the given expression, where only the evaluation of the - /// top-level expression is non-strict (which is subtly different from - /// [`crate::expr::build_non_strict_from_prost`]). + /// topmost level expression node is non-strict (which is subtly different from + /// [`crate::expr::build_non_strict_from_prost`] where every node is non-strict). /// - /// This should be used as a "TODO". - pub fn todo( + /// This should be used as a TODO. + pub fn new_topmost( inner: E, error_report: impl EvalErrorReport, ) -> NonStrictExpression { diff --git a/src/expr/core/src/expr/wrapper/non_strict.rs b/src/expr/core/src/expr/wrapper/non_strict.rs index 0afbe5a0afc9..782456023cdf 100644 --- a/src/expr/core/src/expr/wrapper/non_strict.rs +++ b/src/expr/core/src/expr/wrapper/non_strict.rs @@ -23,7 +23,7 @@ use crate::expr::{Expression, ValueImpl}; use crate::ExprError; /// Report an error during evaluation. -#[auto_impl(Arc)] +#[auto_impl(&, Arc)] pub trait EvalErrorReport: Clone + Send + Sync { /// Perform the error reporting. /// diff --git a/src/stream/src/executor/aggregation/mod.rs b/src/stream/src/executor/aggregation/mod.rs index 9e05b050d1e0..9bb111315296 100644 --- a/src/stream/src/executor/aggregation/mod.rs +++ b/src/stream/src/executor/aggregation/mod.rs @@ -75,7 +75,8 @@ pub async fn agg_call_filter_res( } if let Some(ref filter) = agg_call.filter { - if let Bool(filter_res) = NonStrictExpression::todo(&**filter, LogReport) + // TODO: should we build `filter` in non-strict mode? + if let Bool(filter_res) = NonStrictExpression::new_topmost(&**filter, LogReport) .eval_infallible(chunk) .await .as_ref() diff --git a/src/stream/src/executor/project_set.rs b/src/stream/src/executor/project_set.rs index 13e2ae30f8f3..ff3214db88ea 100644 --- a/src/stream/src/executor/project_set.rs +++ b/src/stream/src/executor/project_set.rs @@ -46,7 +46,7 @@ pub struct ProjectSetExecutor { struct Inner { info: ExecutorInfo, - ctx: ActorContextRef, + _ctx: ActorContextRef, /// Expressions of the current project_section. select_list: Vec, chunk_size: usize, @@ -84,7 +84,7 @@ impl ProjectSetExecutor { let inner = Inner { info, - ctx, + _ctx: ctx, select_list, chunk_size, watermark_derivations, @@ -262,7 +262,8 @@ impl Inner { watermark .clone() .transform_with_expr( - &NonStrictExpression::todo(expr, LogReport), + // TODO: should we build `expr` in non-strict mode? + &NonStrictExpression::new_topmost(expr, LogReport), expr_idx + PROJ_ROW_ID_OFFSET, ) .await From 55019f83a18fb1a093119f521fe2d2dea61af69b Mon Sep 17 00:00:00 2001 From: Bugen Zhao Date: Wed, 18 Oct 2023 11:56:34 +0800 Subject: [PATCH 7/8] refine docs Signed-off-by: Bugen Zhao --- src/expr/core/src/expr/mod.rs | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/expr/core/src/expr/mod.rs b/src/expr/core/src/expr/mod.rs index b80bf5d4a2c6..48a46f640bf7 100644 --- a/src/expr/core/src/expr/mod.rs +++ b/src/expr/core/src/expr/mod.rs @@ -108,6 +108,10 @@ pub trait Expression: std::fmt::Debug + Sync + Send { pub type BoxedExpression = Box; /// Extension trait for boxing expressions. +/// +/// This is not directly made into [`Expression`] trait because... +/// - an expression does not have to be `'static`, +/// - and for the ease of `auto_impl`. #[easy_ext::ext(ExpressionBoxExt)] impl E { /// Wrap the expression in a Box. @@ -116,6 +120,19 @@ impl E { } } +/// An type-safe wrapper that indicates the inner expression can be evaluated in a non-strict +/// manner, i.e., developers can directly call `eval_infallible` and `eval_row_infallible` without +/// checking the result. +/// +/// This is usually created by non-strict build functions like [`crate::expr::build_non_strict_from_prost`] +/// and [`crate::expr::build_func_non_strict`]. It can also be created directly by +/// [`NonStrictExpression::new_topmost`], where only the evaluation of the topmost level expression +/// node is non-strict and should be treated as a TODO. +/// +/// Compared to [`crate::expr::wrapper::non_strict::NonStrict`], this is more like an indicator +/// applied on the root of an expression tree, while the latter is a wrapper that can be applied on +/// each node of the tree and actually changes the behavior. As a result, [`NonStrictExpression`] +/// does not implement [`Expression`] trait and instead deals directly with developers. #[derive(Debug)] pub struct NonStrictExpression(E); From 38e3973be65064af3852fb300d59f7653d8a36ec Mon Sep 17 00:00:00 2001 From: Bugen Zhao Date: Wed, 18 Oct 2023 12:00:35 +0800 Subject: [PATCH 8/8] update disallowed methods Signed-off-by: Bugen Zhao --- src/stream/clippy.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/stream/clippy.toml b/src/stream/clippy.toml index a6969d5bd607..b7257c4acb98 100644 --- a/src/stream/clippy.toml +++ b/src/stream/clippy.toml @@ -3,8 +3,8 @@ disallowed-methods = [ { path = "risingwave_expr::expr::build_from_prost", reason = "Expressions in streaming must be in non-strict mode. Please use `build_non_strict_from_prost` instead." }, { path = "risingwave_expr::expr::build_func", reason = "Expressions in streaming must be in non-strict mode. Please use `build_func_non_strict` instead." }, - { path = "risingwave_expr::expr::Expression::eval", reason = "Please use `Expression::eval_infallible` instead." }, - { path = "risingwave_expr::expr::Expression::eval_row", reason = "Please use `Expression::eval_row_infallible` instead." }, + { path = "risingwave_expr::expr::Expression::eval", reason = "Please use `NonStrictExpression::eval_infallible` instead." }, + { path = "risingwave_expr::expr::Expression::eval_row", reason = "Please use `NonStrictExpression::eval_row_infallible` instead." }, { path = "risingwave_common::error::internal_err", reason = "Please use per-crate error type instead." }, { path = "risingwave_common::error::internal_error", reason = "Please use per-crate error type instead." },