Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(expr): type-safe infallible evaluation #12921

Merged
merged 9 commits into from
Oct 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/batch/src/executor/aggregation/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ impl AggregateFunction for Filter {
mod tests {
use risingwave_common::test_prelude::StreamChunkTestExt;
use risingwave_expr::aggregate::{build_append_only, AggCall};
use risingwave_expr::expr::{build_from_pretty, Expression, LiteralExpression};
use risingwave_expr::expr::{build_from_pretty, ExpressionBoxExt, LiteralExpression};

use super::*;

Expand Down
2 changes: 1 addition & 1 deletion src/batch/src/executor/project_set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ mod tests {
use risingwave_common::catalog::{Field, Schema};
use risingwave_common::test_prelude::*;
use risingwave_common::types::DataType;
use risingwave_expr::expr::{Expression, InputRefExpression, LiteralExpression};
use risingwave_expr::expr::{ExpressionBoxExt, InputRefExpression, LiteralExpression};
use risingwave_expr::table_function::repeat;

use super::*;
Expand Down
21 changes: 14 additions & 7 deletions src/expr/core/src/expr/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,13 @@ use super::expr_in::InExpression;
use super::expr_some_all::SomeAllExpression;
use super::expr_udf::UdfExpression;
use super::expr_vnode::VnodeExpression;
use super::wrapper::{Checked, EvalErrorReport, NonStrict};
use crate::expr::{BoxedExpression, Expression, InputRefExpression, LiteralExpression};
use super::wrapper::checked::Checked;
use super::wrapper::non_strict::NonStrict;
use super::wrapper::EvalErrorReport;
use super::NonStrictExpression;
use crate::expr::{
BoxedExpression, Expression, ExpressionBoxExt, InputRefExpression, LiteralExpression,
};
use crate::sig::FUNCTION_REGISTRY;
use crate::{bail, ExprError, Result};

Expand All @@ -41,8 +46,10 @@ pub fn build_from_prost(prost: &ExprNode) -> Result<BoxedExpression> {
pub fn build_non_strict_from_prost(
prost: &ExprNode,
error_report: impl EvalErrorReport + 'static,
) -> Result<BoxedExpression> {
ExprBuilder::new_non_strict(error_report).build(prost)
) -> Result<NonStrictExpression> {
ExprBuilder::new_non_strict(error_report)
.build(prost)
.map(NonStrictExpression)
}

/// Build an expression from protobuf with possibly some wrappers attached to each node.
Expand Down Expand Up @@ -153,7 +160,7 @@ impl<E: Build + 'static> BuildBoxed for E {
prost: &ExprNode,
build_child: impl Fn(&ExprNode) -> Result<BoxedExpression>,
) -> Result<BoxedExpression> {
Self::build(prost, build_child).map(Expression::boxed)
Self::build(prost, build_child).map(ExpressionBoxExt::boxed)
}
}

Expand Down Expand Up @@ -217,9 +224,9 @@ pub fn build_func_non_strict(
ret_type: DataType,
children: Vec<BoxedExpression>,
error_report: impl EvalErrorReport + 'static,
) -> Result<BoxedExpression> {
) -> Result<NonStrictExpression> {
let expr = build_func(func, ret_type, children)?;
let wrapped = ExprBuilder::new_non_strict(error_report).wrap(expr);
let wrapped = NonStrictExpression(ExprBuilder::new_non_strict(error_report).wrap(expr));

Ok(wrapped)
}
Expand Down
106 changes: 70 additions & 36 deletions src/expr/core/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ pub use self::build::*;
pub use self::expr_input_ref::InputRefExpression;
pub use self::expr_literal::LiteralExpression;
pub use self::value::{ValueImpl, ValueRef};
pub use self::wrapper::EvalErrorReport;
pub use self::wrapper::*;
pub use super::{ExprError, Result};

/// Interface of an expression.
Expand All @@ -67,6 +67,7 @@ pub use super::{ExprError, Result};
/// should be implemented. Prefer calling and implementing `eval_v2` instead of `eval` if possible,
/// to gain the performance benefit of scalar expression.
#[async_trait::async_trait]
#[auto_impl::auto_impl(&, Box)]
pub trait Expression: std::fmt::Debug + Sync + Send {
/// Get the return data type.
fn return_type(&self) -> DataType;
Expand Down Expand Up @@ -101,62 +102,95 @@ pub trait Expression: std::fmt::Debug + Sync + Send {
fn eval_const(&self) -> Result<Datum> {
Err(ExprError::NotConstant)
}
}

/// An owned dynamically typed [`Expression`].
pub type BoxedExpression = Box<dyn Expression>;

/// Extension trait for boxing expressions.
///
/// This is not directly made into [`Expression`] trait because...
/// - an expression does not have to be `'static`,
/// - and for the ease of `auto_impl`.
#[easy_ext::ext(ExpressionBoxExt)]
impl<E: Expression + 'static> E {
/// Wrap the expression in a Box.
fn boxed(self) -> BoxedExpression
where
Self: Sized + Send + 'static,
{
pub fn boxed(self) -> BoxedExpression {
Box::new(self)
}
}

// TODO: make this an extension, or implement it on a `NonStrict` newtype.
impl dyn Expression {
/// An type-safe wrapper that indicates the inner expression can be evaluated in a non-strict
/// manner, i.e., developers can directly call `eval_infallible` and `eval_row_infallible` without
/// checking the result.
///
/// This is usually created by non-strict build functions like [`crate::expr::build_non_strict_from_prost`]
/// and [`crate::expr::build_func_non_strict`]. It can also be created directly by
/// [`NonStrictExpression::new_topmost`], where only the evaluation of the topmost level expression
/// node is non-strict and should be treated as a TODO.
///
/// Compared to [`crate::expr::wrapper::non_strict::NonStrict`], this is more like an indicator
/// applied on the root of an expression tree, while the latter is a wrapper that can be applied on
/// each node of the tree and actually changes the behavior. As a result, [`NonStrictExpression`]
/// does not implement [`Expression`] trait and instead deals directly with developers.
#[derive(Debug)]
pub struct NonStrictExpression<E = BoxedExpression>(E);
stdrc marked this conversation as resolved.
Show resolved Hide resolved

impl<E> NonStrictExpression<E>
where
E: Expression,
{
/// Create a non-strict expression directly wrapping the given expression.
///
/// Should only be used in tests as evaluation may panic.
pub fn for_test(inner: E) -> NonStrictExpression
where
E: 'static,
{
NonStrictExpression(inner.boxed())
}

/// Create a non-strict expression from the given expression, where only the evaluation of the
/// topmost level expression node is non-strict (which is subtly different from
/// [`crate::expr::build_non_strict_from_prost`] where every node is non-strict).
///
/// This should be used as a TODO.
stdrc marked this conversation as resolved.
Show resolved Hide resolved
pub fn new_topmost(
inner: E,
error_report: impl EvalErrorReport,
) -> NonStrictExpression<impl Expression> {
let inner = wrapper::non_strict::NonStrict::new(inner, error_report);
NonStrictExpression(inner)
}

/// Get the return data type.
pub fn return_type(&self) -> DataType {
self.0.return_type()
}

/// Evaluate the expression in vectorized execution and assert it succeeds. Returns an array.
///
/// Use with expressions built in non-strict mode.
pub async fn eval_infallible(&self, input: &DataChunk) -> ArrayRef {
self.eval(input).await.expect("evaluation failed")
self.0.eval(input).await.expect("evaluation failed")
stdrc marked this conversation as resolved.
Show resolved Hide resolved
}

/// Evaluate the expression in row-based execution and assert it succeeds. Returns a nullable
/// scalar.
///
/// Use with expressions built in non-strict mode.
pub async fn eval_row_infallible(&self, input: &OwnedRow) -> Datum {
self.eval_row(input).await.expect("evaluation failed")
}
}

/// An owned dynamically typed [`Expression`].
pub type BoxedExpression = Box<dyn Expression>;

// TODO: avoid the overhead of extra boxing.
#[async_trait::async_trait]
impl Expression for BoxedExpression {
fn return_type(&self) -> DataType {
(**self).return_type()
self.0.eval_row(input).await.expect("evaluation failed")
}

async fn eval(&self, input: &DataChunk) -> Result<ArrayRef> {
(**self).eval(input).await
}

async fn eval_v2(&self, input: &DataChunk) -> Result<ValueImpl> {
(**self).eval_v2(input).await
}

async fn eval_row(&self, input: &OwnedRow) -> Result<Datum> {
(**self).eval_row(input).await
}

fn eval_const(&self) -> Result<Datum> {
(**self).eval_const()
/// Unwrap the inner expression.
pub fn into_inner(self) -> E {
self.0
}

fn boxed(self) -> BoxedExpression {
self
/// Get a reference to the inner expression.
pub fn inner(&self) -> &E {
&self.0
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/expr/core/src/expr/wrapper/checked.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use crate::expr::{Expression, ValueImpl};

/// A wrapper of [`Expression`] that does extra checks after evaluation.
#[derive(Debug)]
pub struct Checked<E>(pub E);
pub(crate) struct Checked<E>(pub E);

// TODO: avoid the overhead of extra boxing.
#[async_trait]
Expand Down
7 changes: 3 additions & 4 deletions src/expr/core/src/expr/wrapper/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

mod checked;
mod non_strict;
pub(crate) mod checked;
pub(crate) mod non_strict;

pub use checked::Checked;
pub use non_strict::{EvalErrorReport, NonStrict};
pub use non_strict::{EvalErrorReport, LogReport};
14 changes: 12 additions & 2 deletions src/expr/core/src/expr/wrapper/non_strict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use crate::expr::{Expression, ValueImpl};
use crate::ExprError;

/// Report an error during evaluation.
#[auto_impl(Arc)]
#[auto_impl(&, Arc)]
pub trait EvalErrorReport: Clone + Send + Sync {
/// Perform the error reporting.
///
Expand All @@ -42,11 +42,21 @@ impl EvalErrorReport for ! {
}
}

/// Log the error to report an error during evaluation.
#[derive(Clone)]
pub struct LogReport;

impl EvalErrorReport for LogReport {
fn report(&self, error: ExprError) {
tracing::error!(%error, "failed to evaluate expression");
}
}

/// A wrapper of [`Expression`] that evaluates in a non-strict way. Basically...
/// - When an error occurs during chunk-level evaluation, recompute in row-based execution and pad
/// with NULL for each failed row.
/// - Report all error occurred during row-level evaluation to the [`EvalErrorReport`].
pub struct NonStrict<E, R> {
pub(crate) struct NonStrict<E, R> {
inner: E,
report: R,
}
Expand Down
4 changes: 3 additions & 1 deletion src/expr/impl/src/scalar/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ use risingwave_common::cast;
use risingwave_common::row::OwnedRow;
use risingwave_common::types::{DataType, Int256, IntoOrdered, JsonbRef, ToText, F64};
use risingwave_common::util::iter_util::ZipEqFast;
use risingwave_expr::expr::{build_func, Context, Expression, InputRefExpression};
use risingwave_expr::expr::{
build_func, Context, Expression, ExpressionBoxExt, InputRefExpression,
};
use risingwave_expr::{function, ExprError, Result};
use risingwave_pb::expr::expr_node::PbType;

Expand Down
2 changes: 1 addition & 1 deletion src/expr/impl/src/table_function/generate_series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ mod tests {
use risingwave_common::array::DataChunk;
use risingwave_common::types::test_utils::IntervalTestExt;
use risingwave_common::types::{DataType, Decimal, Interval, ScalarImpl, Timestamp};
use risingwave_expr::expr::{BoxedExpression, Expression, LiteralExpression};
use risingwave_expr::expr::{BoxedExpression, ExpressionBoxExt, LiteralExpression};
use risingwave_expr::table_function::build;
use risingwave_expr::ExprError;
use risingwave_pb::expr::table_function::PbType;
Expand Down
3 changes: 2 additions & 1 deletion src/storage/src/row_serde/value_serde.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,10 @@ impl ValueRowSerdeNew for ColumnAwareSerde {
// It's okay since we previously banned impure expressions in default columns.
build_from_prost(&expr.expect("expr should not be none"))
.expect("build_from_prost error")
.eval_row_infallible(&OwnedRow::empty())
.eval_row(&OwnedRow::empty())
.now_or_never()
.expect("constant expression should not be async")
.expect("eval_row failed")
stdrc marked this conversation as resolved.
Show resolved Hide resolved
};
Some((i, value))
} else {
Expand Down
4 changes: 2 additions & 2 deletions src/stream/clippy.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ disallowed-methods = [

{ path = "risingwave_expr::expr::build_from_prost", reason = "Expressions in streaming must be in non-strict mode. Please use `build_non_strict_from_prost` instead." },
{ path = "risingwave_expr::expr::build_func", reason = "Expressions in streaming must be in non-strict mode. Please use `build_func_non_strict` instead." },
{ path = "risingwave_expr::expr::Expression::eval", reason = "Please use `Expression::eval_infallible` instead." },
{ path = "risingwave_expr::expr::Expression::eval_row", reason = "Please use `Expression::eval_row_infallible` instead." },
{ path = "risingwave_expr::expr::Expression::eval", reason = "Please use `NonStrictExpression::eval_infallible` instead." },
{ path = "risingwave_expr::expr::Expression::eval_row", reason = "Please use `NonStrictExpression::eval_row_infallible` instead." },

{ path = "risingwave_common::error::internal_err", reason = "Please use per-crate error type instead." },
{ path = "risingwave_common::error::internal_error", reason = "Please use per-crate error type instead." },
Expand Down
8 changes: 7 additions & 1 deletion src/stream/src/executor/aggregation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use risingwave_common::bail;
use risingwave_common::buffer::Bitmap;
use risingwave_common::catalog::{Field, Schema};
use risingwave_expr::aggregate::{AggCall, AggKind};
use risingwave_expr::expr::{LogReport, NonStrictExpression};
use risingwave_storage::StateStore;

use crate::common::table::state_table::StateTable;
Expand Down Expand Up @@ -74,7 +75,12 @@ pub async fn agg_call_filter_res(
}

if let Some(ref filter) = agg_call.filter {
if let Bool(filter_res) = filter.eval_infallible(chunk).await.as_ref() {
// TODO: should we build `filter` in non-strict mode?
if let Bool(filter_res) = NonStrictExpression::new_topmost(&**filter, LogReport)
.eval_infallible(chunk)
.await
.as_ref()
{
vis &= filter_res.to_bitmap();
} else {
bail!("Filter can only receive bool array");
Expand Down
6 changes: 3 additions & 3 deletions src/stream/src/executor/dynamic_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use risingwave_common::row::{self, once, OwnedRow, OwnedRow as RowData, Row};
use risingwave_common::types::{DataType, Datum, DefaultOrd, ScalarImpl, ToDatumRef, ToOwnedDatum};
use risingwave_common::util::iter_util::ZipEqDebug;
use risingwave_expr::expr::{
build_func_non_strict, BoxedExpression, InputRefExpression, LiteralExpression,
build_func_non_strict, InputRefExpression, LiteralExpression, NonStrictExpression,
};
use risingwave_pb::expr::expr_node::Type as ExprNodeType;
use risingwave_pb::expr::expr_node::Type::{
Expand Down Expand Up @@ -97,7 +97,7 @@ impl<S: StateStore, const USE_WATERMARK_CACHE: bool> DynamicFilterExecutor<S, US
async fn apply_batch(
&mut self,
chunk: &StreamChunk,
condition: Option<BoxedExpression>,
condition: Option<NonStrictExpression>,
) -> Result<(Vec<Op>, Bitmap), StreamExecutorError> {
let mut new_ops = Vec::with_capacity(chunk.capacity());
let mut new_visibility = BitmapBuilder::with_capacity(chunk.capacity());
Expand Down Expand Up @@ -265,7 +265,7 @@ impl<S: StateStore, const USE_WATERMARK_CACHE: bool> DynamicFilterExecutor<S, US
let dynamic_cond = {
let eval_error_report = ActorEvalErrorReport {
actor_context: self.ctx.clone(),
identity: self.identity.as_str().into(),
identity: Arc::from(self.identity.as_str()),
};
move |literal: Datum| {
literal.map(|scalar| {
Expand Down
8 changes: 4 additions & 4 deletions src/stream/src/executor/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use risingwave_common::array::{Array, ArrayImpl, Op, StreamChunk};
use risingwave_common::buffer::BitmapBuilder;
use risingwave_common::catalog::Schema;
use risingwave_common::util::iter_util::ZipEqFast;
use risingwave_expr::expr::BoxedExpression;
use risingwave_expr::expr::NonStrictExpression;

use super::*;

Expand All @@ -34,14 +34,14 @@ pub struct FilterExecutor {

/// Expression of the current filter, note that the filter must always have the same output for
/// the same input.
expr: BoxedExpression,
expr: NonStrictExpression,
}

impl FilterExecutor {
pub fn new(
ctx: ActorContextRef,
input: Box<dyn Executor>,
expr: BoxedExpression,
expr: NonStrictExpression,
executor_id: u64,
) -> Self {
let input_info = input.info();
Expand Down Expand Up @@ -190,8 +190,8 @@ mod tests {
use risingwave_common::array::StreamChunk;
use risingwave_common::catalog::{Field, Schema};
use risingwave_common::types::DataType;
use risingwave_expr::expr::build_from_pretty;

use super::super::test_utils::expr::build_from_pretty;
use super::super::test_utils::MockSource;
use super::super::*;
use super::*;
Expand Down
Loading
Loading