Skip to content

Commit

Permalink
split traits and fix todos
Browse files Browse the repository at this point in the history
Signed-off-by: Bugen Zhao <[email protected]>
  • Loading branch information
BugenZhao committed Oct 17, 2023
1 parent c8df1ef commit 0058837
Show file tree
Hide file tree
Showing 13 changed files with 46 additions and 73 deletions.
2 changes: 1 addition & 1 deletion src/batch/src/executor/aggregation/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ impl AggregateFunction for Filter {
mod tests {
use risingwave_common::test_prelude::StreamChunkTestExt;
use risingwave_expr::aggregate::{build_append_only, AggCall};
use risingwave_expr::expr::{build_from_pretty, Expression, LiteralExpression};
use risingwave_expr::expr::{build_from_pretty, ExpressionBoxExt, LiteralExpression};

use super::*;

Expand Down
2 changes: 1 addition & 1 deletion src/batch/src/executor/project_set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ mod tests {
use risingwave_common::catalog::{Field, Schema};
use risingwave_common::test_prelude::*;
use risingwave_common::types::DataType;
use risingwave_expr::expr::{Expression, InputRefExpression, LiteralExpression};
use risingwave_expr::expr::{ExpressionBoxExt, InputRefExpression, LiteralExpression};
use risingwave_expr::table_function::repeat;

use super::*;
Expand Down
6 changes: 4 additions & 2 deletions src/expr/core/src/expr/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ use super::expr_udf::UdfExpression;
use super::expr_vnode::VnodeExpression;
use super::wrapper::{Checked, EvalErrorReport, NonStrict};
use super::InfallibleExpression;
use crate::expr::{BoxedExpression, Expression, InputRefExpression, LiteralExpression};
use crate::expr::{
BoxedExpression, Expression, ExpressionBoxExt, InputRefExpression, LiteralExpression,
};
use crate::sig::FUNCTION_REGISTRY;
use crate::{bail, ExprError, Result};

Expand Down Expand Up @@ -156,7 +158,7 @@ impl<E: Build + 'static> BuildBoxed for E {
prost: &ExprNode,
build_child: impl Fn(&ExprNode) -> Result<BoxedExpression>,
) -> Result<BoxedExpression> {
Self::build(prost, build_child).map(Expression::boxed)
Self::build(prost, build_child).map(ExpressionBoxExt::boxed)
}
}

Expand Down
75 changes: 18 additions & 57 deletions src/expr/core/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ pub use super::{ExprError, Result};
/// should be implemented. Prefer calling and implementing `eval_v2` instead of `eval` if possible,
/// to gain the performance benefit of scalar expression.
#[async_trait::async_trait]
#[auto_impl::auto_impl(&, Box)]
pub trait Expression: std::fmt::Debug + Sync + Send {
/// Get the return data type.
fn return_type(&self) -> DataType;
Expand Down Expand Up @@ -101,78 +102,38 @@ pub trait Expression: std::fmt::Debug + Sync + Send {
fn eval_const(&self) -> Result<Datum> {
Err(ExprError::NotConstant)
}

/// Wrap the expression in a Box.
fn boxed(self) -> BoxedExpression
where
Self: Sized + Send + 'static,
{
Box::new(self)
}
}

// TODO: make this an extension, or implement it on a `NonStrict` newtype.
impl dyn Expression {
/// Evaluate the expression in vectorized execution and assert it succeeds. Returns an array.
///
/// Use with expressions built in non-strict mode.
pub async fn eval_infallible(&self, input: &DataChunk) -> ArrayRef {
self.eval(input).await.expect("evaluation failed")
}

/// Evaluate the expression in row-based execution and assert it succeeds. Returns a nullable
/// scalar.
///
/// Use with expressions built in non-strict mode.
pub async fn eval_row_infallible(&self, input: &OwnedRow) -> Datum {
self.eval_row(input).await.expect("evaluation failed")
}
}

/// An owned dynamically typed [`Expression`].
pub type BoxedExpression = Box<dyn Expression>;

// TODO: avoid the overhead of extra boxing.
#[async_trait::async_trait]
impl Expression for BoxedExpression {
fn return_type(&self) -> DataType {
(**self).return_type()
}

async fn eval(&self, input: &DataChunk) -> Result<ArrayRef> {
(**self).eval(input).await
}

async fn eval_v2(&self, input: &DataChunk) -> Result<ValueImpl> {
(**self).eval_v2(input).await
}

async fn eval_row(&self, input: &OwnedRow) -> Result<Datum> {
(**self).eval_row(input).await
}

fn eval_const(&self) -> Result<Datum> {
(**self).eval_const()
}

fn boxed(self) -> BoxedExpression {
self
/// Extension trait for boxing expressions.
#[easy_ext::ext(ExpressionBoxExt)]
impl<E: Expression + 'static> E {
/// Wrap the expression in a Box.
pub fn boxed(self) -> BoxedExpression {
Box::new(self)
}
}

#[derive(Debug)]
pub struct InfallibleExpression<E = BoxedExpression>(E);

impl InfallibleExpression {
pub fn for_test(inner: impl Expression + 'static) -> Self {
Self(inner.boxed())
}
}

impl<E> InfallibleExpression<E>
where
E: Expression,
{
pub fn for_test(inner: E) -> InfallibleExpression
where
E: 'static,
{
InfallibleExpression(inner.boxed())
}

pub fn todo(inner: E) -> Self {
Self(inner)
}

/// Get the return data type.
pub fn return_type(&self) -> DataType {
self.0.return_type()
Expand Down
4 changes: 3 additions & 1 deletion src/expr/impl/src/scalar/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ use risingwave_common::cast;
use risingwave_common::row::OwnedRow;
use risingwave_common::types::{DataType, Int256, IntoOrdered, JsonbRef, ToText, F64};
use risingwave_common::util::iter_util::ZipEqFast;
use risingwave_expr::expr::{build_func, Context, Expression, InputRefExpression};
use risingwave_expr::expr::{
build_func, Context, Expression, ExpressionBoxExt, InputRefExpression,
};
use risingwave_expr::{function, ExprError, Result};
use risingwave_pb::expr::expr_node::PbType;

Expand Down
2 changes: 1 addition & 1 deletion src/expr/impl/src/table_function/generate_series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ mod tests {
use risingwave_common::array::DataChunk;
use risingwave_common::types::test_utils::IntervalTestExt;
use risingwave_common::types::{DataType, Decimal, Interval, ScalarImpl, Timestamp};
use risingwave_expr::expr::{BoxedExpression, Expression, LiteralExpression};
use risingwave_expr::expr::{BoxedExpression, ExpressionBoxExt, LiteralExpression};
use risingwave_expr::table_function::build;
use risingwave_expr::ExprError;
use risingwave_pb::expr::table_function::PbType;
Expand Down
3 changes: 2 additions & 1 deletion src/storage/src/row_serde/value_serde.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,10 @@ impl ValueRowSerdeNew for ColumnAwareSerde {
// It's okay since we previously banned impure expressions in default columns.
build_from_prost(&expr.expect("expr should not be none"))
.expect("build_from_prost error")
.eval_row_infallible(&OwnedRow::empty())
.eval_row(&OwnedRow::empty())
.now_or_never()
.expect("constant expression should not be async")
.expect("eval_row failed")
};
Some((i, value))
} else {
Expand Down
7 changes: 6 additions & 1 deletion src/stream/src/executor/aggregation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use risingwave_common::bail;
use risingwave_common::buffer::Bitmap;
use risingwave_common::catalog::{Field, Schema};
use risingwave_expr::aggregate::{AggCall, AggKind};
use risingwave_expr::expr::InfallibleExpression;
use risingwave_storage::StateStore;

use crate::common::table::state_table::StateTable;
Expand Down Expand Up @@ -74,7 +75,11 @@ pub async fn agg_call_filter_res(
}

if let Some(ref filter) = agg_call.filter {
if let Bool(filter_res) = filter.eval_infallible(chunk).await.as_ref() {
if let Bool(filter_res) = InfallibleExpression::todo(&**filter)
.eval_infallible(chunk)
.await
.as_ref()
{
vis &= filter_res.to_bitmap();
} else {
bail!("Filter can only receive bool array");
Expand Down
4 changes: 2 additions & 2 deletions src/stream/src/executor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use risingwave_common::util::epoch::{Epoch, EpochPair};
use risingwave_common::util::tracing::TracingContext;
use risingwave_common::util::value_encoding::{DatumFromProtoExt, DatumToProtoExt};
use risingwave_connector::source::SplitImpl;
use risingwave_expr::expr::{InfallibleExpression, Expression};
use risingwave_expr::expr::{Expression, InfallibleExpression};
use risingwave_pb::data::PbEpoch;
use risingwave_pb::expr::PbInputRef;
use risingwave_pb::stream_plan::barrier::{BarrierKind, PbMutation};
Expand Down Expand Up @@ -641,7 +641,7 @@ impl Watermark {

pub async fn transform_with_expr(
self,
expr: &InfallibleExpression,
expr: &InfallibleExpression<impl Expression>,
new_col_idx: usize,
) -> Option<Self> {
let Self { col_idx, val, .. } = self;
Expand Down
5 changes: 2 additions & 3 deletions src/stream/src/executor/project_set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,10 @@ use risingwave_common::catalog::{Field, Schema};
use risingwave_common::row::{Row, RowExt};
use risingwave_common::types::{DataType, Datum, DatumRef, ToOwnedDatum};
use risingwave_common::util::iter_util::ZipEqFast;
use risingwave_expr::expr::{BoxedExpression, InfallibleExpression};
use risingwave_expr::expr::InfallibleExpression;
use risingwave_expr::table_function::ProjectSetSelectItem;

use super::error::StreamExecutorError;
use super::test_utils::expr::build_from_pretty;
use super::{
ActorContextRef, BoxedExecutor, Executor, ExecutorInfo, Message, PkIndices, PkIndicesRef,
StreamExecutorResult, Watermark,
Expand Down Expand Up @@ -263,7 +262,7 @@ impl Inner {
watermark
.clone()
.transform_with_expr(
&build_from_pretty(""), // TODO
&InfallibleExpression::todo(expr),
expr_idx + PROJ_ROW_ID_OFFSET,
)
.await
Expand Down
3 changes: 2 additions & 1 deletion src/stream/src/executor/values.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,8 @@ mod tests {
let actor_id = progress.actor_id();
let (tx, barrier_receiver) = unbounded_channel();
let value = StructValue::new(vec![Some(1.into()), Some(2.into()), Some(3.into())]);
let exprs = vec![Box::new(LiteralExpression::new(
let exprs = vec![
Box::new(LiteralExpression::new(
DataType::Int16,
Some(ScalarImpl::Int16(1)),
)) as BoxedExpression,
Expand Down
3 changes: 2 additions & 1 deletion src/stream/src/executor/watermark_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ use risingwave_common::row::{OwnedRow, Row};
use risingwave_common::types::{DataType, DefaultOrd, ScalarImpl};
use risingwave_common::{bail, row};
use risingwave_expr::expr::{
build_func_non_strict, Expression, InfallibleExpression, InputRefExpression, LiteralExpression,
build_func_non_strict, ExpressionBoxExt, InfallibleExpression, InputRefExpression,
LiteralExpression,
};
use risingwave_expr::Result as ExprResult;
use risingwave_pb::expr::expr_node::Type;
Expand Down
3 changes: 2 additions & 1 deletion src/stream/src/from_proto/hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ impl ExecutorBuilder for HashJoinExecutorBuilder {
build_non_strict_from_prost(
delta_expression.delta.as_ref().unwrap(),
params.eval_error_report.clone(),
)?.into_inner(),
)?
.into_inner(),
],
params.eval_error_report.clone(),
)?)
Expand Down

0 comments on commit 0058837

Please sign in to comment.