Skip to content

Commit

Permalink
impl wrapper
Browse files Browse the repository at this point in the history
Signed-off-by: Bugen Zhao <[email protected]>
  • Loading branch information
BugenZhao committed Oct 17, 2023
1 parent 7a1f371 commit c8df1ef
Show file tree
Hide file tree
Showing 18 changed files with 144 additions and 67 deletions.
11 changes: 7 additions & 4 deletions src/expr/core/src/expr/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use super::expr_some_all::SomeAllExpression;
use super::expr_udf::UdfExpression;
use super::expr_vnode::VnodeExpression;
use super::wrapper::{Checked, EvalErrorReport, NonStrict};
use super::InfallibleExpression;
use crate::expr::{BoxedExpression, Expression, InputRefExpression, LiteralExpression};
use crate::sig::FUNCTION_REGISTRY;
use crate::{bail, ExprError, Result};
Expand All @@ -41,8 +42,10 @@ pub fn build_from_prost(prost: &ExprNode) -> Result<BoxedExpression> {
pub fn build_non_strict_from_prost(
prost: &ExprNode,
error_report: impl EvalErrorReport + 'static,
) -> Result<BoxedExpression> {
ExprBuilder::new_non_strict(error_report).build(prost)
) -> Result<InfallibleExpression> {
ExprBuilder::new_non_strict(error_report)
.build(prost)
.map(InfallibleExpression)
}

/// Build an expression from protobuf with possibly some wrappers attached to each node.
Expand Down Expand Up @@ -217,9 +220,9 @@ pub fn build_func_non_strict(
ret_type: DataType,
children: Vec<BoxedExpression>,
error_report: impl EvalErrorReport + 'static,
) -> Result<BoxedExpression> {
) -> Result<InfallibleExpression> {
let expr = build_func(func, ret_type, children)?;
let wrapped = ExprBuilder::new_non_strict(error_report).wrap(expr);
let wrapped = InfallibleExpression(ExprBuilder::new_non_strict(error_report).wrap(expr));

Ok(wrapped)
}
Expand Down
44 changes: 44 additions & 0 deletions src/expr/core/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,50 @@ impl Expression for BoxedExpression {
}
}

#[derive(Debug)]
pub struct InfallibleExpression<E = BoxedExpression>(E);

impl InfallibleExpression {
pub fn for_test(inner: impl Expression + 'static) -> Self {
Self(inner.boxed())
}
}

impl<E> InfallibleExpression<E>
where
E: Expression,
{
/// Get the return data type.
pub fn return_type(&self) -> DataType {
self.0.return_type()
}

/// Evaluate the expression in vectorized execution and assert it succeeds. Returns an array.
///
/// Use with expressions built in non-strict mode.
pub async fn eval_infallible(&self, input: &DataChunk) -> ArrayRef {
self.0.eval(input).await.expect("evaluation failed")
}

/// Evaluate the expression in row-based execution and assert it succeeds. Returns a nullable
/// scalar.
///
/// Use with expressions built in non-strict mode.
pub async fn eval_row_infallible(&self, input: &OwnedRow) -> Datum {
self.0.eval_row(input).await.expect("evaluation failed")
}

/// Unwrap the inner expression.
pub fn into_inner(self) -> E {
self.0
}

/// Get a reference to the inner expression.
pub fn inner(&self) -> &E {
&self.0
}
}

/// An optional context that can be used in a function.
///
/// # Example
Expand Down
4 changes: 2 additions & 2 deletions src/stream/src/executor/dynamic_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use risingwave_common::row::{self, once, OwnedRow, OwnedRow as RowData, Row};
use risingwave_common::types::{DataType, Datum, DefaultOrd, ScalarImpl, ToDatumRef, ToOwnedDatum};
use risingwave_common::util::iter_util::ZipEqDebug;
use risingwave_expr::expr::{
build_func_non_strict, BoxedExpression, InputRefExpression, LiteralExpression,
build_func_non_strict, InfallibleExpression, InputRefExpression, LiteralExpression,
};
use risingwave_pb::expr::expr_node::Type as ExprNodeType;
use risingwave_pb::expr::expr_node::Type::{
Expand Down Expand Up @@ -97,7 +97,7 @@ impl<S: StateStore, const USE_WATERMARK_CACHE: bool> DynamicFilterExecutor<S, US
async fn apply_batch(
&mut self,
chunk: &StreamChunk,
condition: Option<BoxedExpression>,
condition: Option<InfallibleExpression>,
) -> Result<(Vec<Op>, Bitmap), StreamExecutorError> {
let mut new_ops = Vec::with_capacity(chunk.capacity());
let mut new_visibility = BitmapBuilder::with_capacity(chunk.capacity());
Expand Down
8 changes: 4 additions & 4 deletions src/stream/src/executor/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use risingwave_common::array::{Array, ArrayImpl, Op, StreamChunk};
use risingwave_common::buffer::BitmapBuilder;
use risingwave_common::catalog::Schema;
use risingwave_common::util::iter_util::ZipEqFast;
use risingwave_expr::expr::BoxedExpression;
use risingwave_expr::expr::InfallibleExpression;

use super::*;

Expand All @@ -34,14 +34,14 @@ pub struct FilterExecutor {

/// Expression of the current filter, note that the filter must always have the same output for
/// the same input.
expr: BoxedExpression,
expr: InfallibleExpression,
}

impl FilterExecutor {
pub fn new(
ctx: ActorContextRef,
input: Box<dyn Executor>,
expr: BoxedExpression,
expr: InfallibleExpression,
executor_id: u64,
) -> Self {
let input_info = input.info();
Expand Down Expand Up @@ -190,8 +190,8 @@ mod tests {
use risingwave_common::array::StreamChunk;
use risingwave_common::catalog::{Field, Schema};
use risingwave_common::types::DataType;
use risingwave_expr::expr::build_from_pretty;

use super::super::test_utils::expr::build_from_pretty;
use super::super::test_utils::MockSource;
use super::super::*;
use super::*;
Expand Down
20 changes: 10 additions & 10 deletions src/stream/src/executor/hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use risingwave_common::row::{OwnedRow, Row};
use risingwave_common::types::{DataType, DefaultOrd, ToOwnedDatum};
use risingwave_common::util::epoch::EpochPair;
use risingwave_common::util::iter_util::ZipEqDebug;
use risingwave_expr::expr::BoxedExpression;
use risingwave_expr::expr::InfallibleExpression;
use risingwave_expr::ExprError;
use risingwave_storage::StateStore;
use tokio::time::Instant;
Expand Down Expand Up @@ -242,9 +242,9 @@ pub struct HashJoinExecutor<K: HashKey, S: StateStore, const T: JoinTypePrimitiv
/// The parameters of the right join executor
side_r: JoinSide<K, S>,
/// Optional non-equi join conditions
cond: Option<BoxedExpression>,
cond: Option<InfallibleExpression>,
/// Column indices of watermark output and offset expression of each inequality, respectively.
inequality_pairs: Vec<(Vec<usize>, Option<BoxedExpression>)>,
inequality_pairs: Vec<(Vec<usize>, Option<InfallibleExpression>)>,
/// The output watermark of each inequality condition and its value is the minimum of the
/// calculation result of both side. It will be used to generate watermark into downstream
/// and do state cleaning if `clean_state` field of that inequality is `true`.
Expand Down Expand Up @@ -313,7 +313,7 @@ struct EqJoinArgs<'a, K: HashKey, S: StateStore> {
side_l: &'a mut JoinSide<K, S>,
side_r: &'a mut JoinSide<K, S>,
actual_output_data_types: &'a [DataType],
cond: &'a mut Option<BoxedExpression>,
cond: &'a mut Option<InfallibleExpression>,
inequality_watermarks: &'a [Option<Watermark>],
chunk: StreamChunk,
append_only_optimize: bool,
Expand Down Expand Up @@ -448,8 +448,8 @@ impl<K: HashKey, S: StateStore, const T: JoinTypePrimitive> HashJoinExecutor<K,
pk_indices: PkIndices,
output_indices: Vec<usize>,
executor_id: u64,
cond: Option<BoxedExpression>,
inequality_pairs: Vec<(usize, usize, bool, Option<BoxedExpression>)>,
cond: Option<InfallibleExpression>,
inequality_pairs: Vec<(usize, usize, bool, Option<InfallibleExpression>)>,
op_info: String,
state_table_l: StateTable<S>,
degree_state_table_l: StateTable<S>,
Expand Down Expand Up @@ -912,7 +912,7 @@ impl<K: HashKey, S: StateStore, const T: JoinTypePrimitive> HashJoinExecutor<K,
// allow since we will handle error manually.
#[allow(clippy::disallowed_methods)]
let eval_result = delta_expression
.eval_row(&OwnedRow::new(vec![Some(input_watermark.val)]))
.inner().eval_row(&OwnedRow::new(vec![Some(input_watermark.val)]))
.await;
match eval_result {
Ok(value) => input_watermark.val = value.unwrap(),
Expand Down Expand Up @@ -1275,11 +1275,11 @@ mod tests {
use risingwave_common::hash::{Key128, Key64};
use risingwave_common::types::ScalarImpl;
use risingwave_common::util::sort_util::OrderType;
use risingwave_expr::expr::build_from_pretty;
use risingwave_storage::memory::MemoryStateStore;

use super::*;
use crate::common::table::state_table::StateTable;
use crate::executor::test_utils::expr::build_from_pretty;
use crate::executor::test_utils::{MessageSender, MockSource, StreamExecutorTestExt};
use crate::executor::{ActorContext, Barrier, EpochPair};

Expand Down Expand Up @@ -1327,7 +1327,7 @@ mod tests {
(state_table, degree_state_table)
}

fn create_cond(condition_text: Option<String>) -> BoxedExpression {
fn create_cond(condition_text: Option<String>) -> InfallibleExpression {
build_from_pretty(
condition_text
.as_deref()
Expand All @@ -1339,7 +1339,7 @@ mod tests {
with_condition: bool,
null_safe: bool,
condition_text: Option<String>,
inequality_pairs: Vec<(usize, usize, bool, Option<BoxedExpression>)>,
inequality_pairs: Vec<(usize, usize, bool, Option<InfallibleExpression>)>,
) -> (MessageSender, MessageSender, BoxedMessageStream) {
let schema = Schema {
fields: vec![
Expand Down
21 changes: 14 additions & 7 deletions src/stream/src/executor/hop_window.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use futures_async_stream::try_stream;
use itertools::Itertools;
use risingwave_common::array::{DataChunk, Op};
use risingwave_common::types::Interval;
use risingwave_expr::expr::BoxedExpression;
use risingwave_expr::expr::InfallibleExpression;
use risingwave_expr::ExprError;

use super::error::StreamExecutorError;
Expand All @@ -33,8 +33,8 @@ pub struct HopWindowExecutor {
pub time_col_idx: usize,
pub window_slide: Interval,
pub window_size: Interval,
window_start_exprs: Vec<BoxedExpression>,
window_end_exprs: Vec<BoxedExpression>,
window_start_exprs: Vec<InfallibleExpression>,
window_end_exprs: Vec<InfallibleExpression>,
pub output_indices: Vec<usize>,
chunk_size: usize,
}
Expand All @@ -48,8 +48,8 @@ impl HopWindowExecutor {
time_col_idx: usize,
window_slide: Interval,
window_size: Interval,
window_start_exprs: Vec<BoxedExpression>,
window_end_exprs: Vec<BoxedExpression>,
window_start_exprs: Vec<InfallibleExpression>,
window_end_exprs: Vec<InfallibleExpression>,
output_indices: Vec<usize>,
chunk_size: usize,
) -> Self {
Expand Down Expand Up @@ -251,6 +251,7 @@ mod tests {
use risingwave_common::types::test_utils::IntervalTestExt;
use risingwave_common::types::{DataType, Interval};
use risingwave_expr::expr::test_utils::make_hop_window_expression;
use risingwave_expr::expr::InfallibleExpression;

use crate::executor::test_utils::MockSource;
use crate::executor::{ActorContext, Executor, ExecutorInfo, StreamChunk};
Expand Down Expand Up @@ -302,8 +303,14 @@ mod tests {
2,
window_slide,
window_size,
window_start_exprs,
window_end_exprs,
window_start_exprs
.into_iter()
.map(InfallibleExpression::for_test)
.collect(),
window_end_exprs
.into_iter()
.map(InfallibleExpression::for_test)
.collect(),
output_indices,
CHUNK_SIZE,
)
Expand Down
2 changes: 1 addition & 1 deletion src/stream/src/executor/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ async fn test_merger_sum_aggr() {
vec![],
vec![
// TODO: use the new streaming_if_null expression here, and add `None` tests
Box::new(InputRefExpression::new(DataType::Int64, 1)),
InfallibleExpression::for_test(InputRefExpression::new(DataType::Int64, 1)),
],
3,
MultiMap::new(),
Expand Down
6 changes: 3 additions & 3 deletions src/stream/src/executor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use risingwave_common::util::epoch::{Epoch, EpochPair};
use risingwave_common::util::tracing::TracingContext;
use risingwave_common::util::value_encoding::{DatumFromProtoExt, DatumToProtoExt};
use risingwave_connector::source::SplitImpl;
use risingwave_expr::expr::BoxedExpression;
use risingwave_expr::expr::{InfallibleExpression, Expression};
use risingwave_pb::data::PbEpoch;
use risingwave_pb::expr::PbInputRef;
use risingwave_pb::stream_plan::barrier::{BarrierKind, PbMutation};
Expand Down Expand Up @@ -641,7 +641,7 @@ impl Watermark {

pub async fn transform_with_expr(
self,
expr: &BoxedExpression,
expr: &InfallibleExpression,
new_col_idx: usize,
) -> Option<Self> {
let Self { col_idx, val, .. } = self;
Expand All @@ -651,7 +651,7 @@ impl Watermark {
OwnedRow::new(row)
};
let val = expr.eval_row_infallible(&row).await?;
Some(Self::new(new_col_idx, expr.return_type(), val))
Some(Self::new(new_col_idx, expr.inner().return_type(), val))
}

/// Transform the watermark with the given output indices. If this watermark is not in the
Expand Down
11 changes: 6 additions & 5 deletions src/stream/src/executor/project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use risingwave_common::catalog::{Field, Schema};
use risingwave_common::row::{Row, RowExt};
use risingwave_common::types::ToOwnedDatum;
use risingwave_common::util::iter_util::ZipEqFast;
use risingwave_expr::expr::BoxedExpression;
use risingwave_expr::expr::InfallibleExpression;

use super::*;

Expand All @@ -38,7 +38,7 @@ struct Inner {
info: ExecutorInfo,

/// Expressions of the current projection.
exprs: Vec<BoxedExpression>,
exprs: Vec<InfallibleExpression>,
/// All the watermark derivations, (input_column_index, output_column_index). And the
/// derivation expression is the project's expression itself.
watermark_derivations: MultiMap<usize, usize>,
Expand All @@ -58,7 +58,7 @@ impl ProjectExecutor {
ctx: ActorContextRef,
input: Box<dyn Executor>,
pk_indices: PkIndices,
exprs: Vec<BoxedExpression>,
exprs: Vec<InfallibleExpression>,
executor_id: u64,
watermark_derivations: MultiMap<usize, usize>,
nondecreasing_expr_indices: Vec<usize>,
Expand Down Expand Up @@ -233,11 +233,12 @@ mod tests {
use risingwave_common::array::{DataChunk, StreamChunk};
use risingwave_common::catalog::{Field, Schema};
use risingwave_common::types::{DataType, Datum};
use risingwave_expr::expr::{self, build_from_pretty, Expression, ValueImpl};
use risingwave_expr::expr::{self, Expression, ValueImpl};

use super::super::test_utils::MockSource;
use super::super::*;
use super::*;
use crate::executor::test_utils::expr::build_from_pretty;
use crate::executor::test_utils::StreamExecutorTestExt;

#[tokio::test]
Expand Down Expand Up @@ -345,7 +346,7 @@ mod tests {

let a_expr = build_from_pretty("(add:int8 $0:int8 1:int8)");
let b_expr = build_from_pretty("(subtract:int8 $0:int8 1:int8)");
let c_expr = DummyNondecreasingExpr.boxed();
let c_expr = InfallibleExpression::for_test(DummyNondecreasingExpr);

let project = Box::new(ProjectExecutor::new(
ActorContext::create(123),
Expand Down
7 changes: 6 additions & 1 deletion src/stream/src/executor/project_set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@ use risingwave_common::catalog::{Field, Schema};
use risingwave_common::row::{Row, RowExt};
use risingwave_common::types::{DataType, Datum, DatumRef, ToOwnedDatum};
use risingwave_common::util::iter_util::ZipEqFast;
use risingwave_expr::expr::{BoxedExpression, InfallibleExpression};
use risingwave_expr::table_function::ProjectSetSelectItem;

use super::error::StreamExecutorError;
use super::test_utils::expr::build_from_pretty;
use super::{
ActorContextRef, BoxedExecutor, Executor, ExecutorInfo, Message, PkIndices, PkIndicesRef,
StreamExecutorResult, Watermark,
Expand Down Expand Up @@ -260,7 +262,10 @@ impl Inner {
ProjectSetSelectItem::Expr(expr) => {
watermark
.clone()
.transform_with_expr(expr, expr_idx + PROJ_ROW_ID_OFFSET)
.transform_with_expr(
&build_from_pretty(""), // TODO
expr_idx + PROJ_ROW_ID_OFFSET,
)
.await
}
ProjectSetSelectItem::TableFunction(_) => {
Expand Down
Loading

0 comments on commit c8df1ef

Please sign in to comment.