Skip to content

Commit

Permalink
fix: cherry-pick #14147 to release-1.5 (#14185)
Browse files Browse the repository at this point in the history
Signed-off-by: Runji Wang <[email protected]>
Co-authored-by: Huangjw <[email protected]>
  • Loading branch information
wangrunji0408 and huangjw806 authored Dec 25, 2023
1 parent af650d7 commit b0b1ab6
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 5 deletions.
15 changes: 11 additions & 4 deletions src/expr/core/src/expr/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use super::expr_in::InExpression;
use super::expr_some_all::SomeAllExpression;
use super::expr_udf::UdfExpression;
use super::expr_vnode::VnodeExpression;
use super::non_strict::NonStrictNoFallback;
use super::wrapper::checked::Checked;
use super::wrapper::non_strict::NonStrict;
use super::wrapper::EvalErrorReport;
Expand Down Expand Up @@ -80,11 +81,15 @@ where

/// Attach wrappers to an expression.
#[expect(clippy::let_and_return)]
fn wrap(&self, expr: impl Expression + 'static) -> BoxedExpression {
fn wrap(&self, expr: impl Expression + 'static, no_fallback: bool) -> BoxedExpression {
let checked = Checked(expr);

let may_non_strict = if let Some(error_report) = &self.error_report {
NonStrict::new(checked, error_report.clone()).boxed()
if no_fallback {
NonStrictNoFallback::new(checked, error_report.clone()).boxed()
} else {
NonStrict::new(checked, error_report.clone()).boxed()
}
} else {
checked.boxed()
};
Expand All @@ -95,7 +100,9 @@ where
/// Build an expression with `build_inner` and attach some wrappers.
fn build(&self, prost: &ExprNode) -> Result<BoxedExpression> {
let expr = self.build_inner(prost)?;
Ok(self.wrap(expr))
// no fallback to row-based evaluation for UDF
let no_fallback = matches!(prost.get_rex_node().unwrap(), RexNode::Udf(_));
Ok(self.wrap(expr, no_fallback))
}

/// Build an expression from protobuf.
Expand Down Expand Up @@ -224,7 +231,7 @@ pub fn build_func_non_strict(
error_report: impl EvalErrorReport + 'static,
) -> Result<NonStrictExpression> {
let expr = build_func(func, ret_type, children)?;
let wrapped = NonStrictExpression(ExprBuilder::new_non_strict(error_report).wrap(expr));
let wrapped = NonStrictExpression(ExprBuilder::new_non_strict(error_report).wrap(expr, false));

Ok(wrapped)
}
Expand Down
82 changes: 81 additions & 1 deletion src/expr/core/src/expr/wrapper/non_strict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use async_trait::async_trait;
use auto_impl::auto_impl;
use risingwave_common::array::{ArrayRef, DataChunk};
use risingwave_common::row::{OwnedRow, Row};
use risingwave_common::types::{DataType, Datum};
use risingwave_common::types::{DataType, Datum, ScalarImpl};

use crate::error::Result;
use crate::expr::{Expression, ValueImpl};
Expand Down Expand Up @@ -141,3 +141,83 @@ where
self.inner.eval_const() // do not handle error
}
}

/// Similar to [`NonStrict`] wrapper, but does not fallback to row-based evaluation when an error occurs.
pub(crate) struct NonStrictNoFallback<E, R> {
inner: E,
report: R,
}

impl<E, R> std::fmt::Debug for NonStrictNoFallback<E, R>
where
E: std::fmt::Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("NonStrictNoFallback")
.field("inner", &self.inner)
.field("report", &std::any::type_name::<R>())
.finish()
}
}

impl<E, R> NonStrictNoFallback<E, R>
where
E: Expression,
R: EvalErrorReport,
{
pub fn new(inner: E, report: R) -> Self {
Self { inner, report }
}
}

// TODO: avoid the overhead of extra boxing.
#[async_trait]
impl<E, R> Expression for NonStrictNoFallback<E, R>
where
E: Expression,
R: EvalErrorReport,
{
fn return_type(&self) -> DataType {
self.inner.return_type()
}

async fn eval(&self, input: &DataChunk) -> Result<ArrayRef> {
Ok(match self.inner.eval(input).await {
Ok(array) => array,
Err(error) => {
self.report.report(error);
// no fallback and return NULL for each row
let mut builder = self.return_type().create_array_builder(input.capacity());
builder.append_n(input.capacity(), Option::<ScalarImpl>::None);
builder.finish().into()
}
})
}

async fn eval_v2(&self, input: &DataChunk) -> Result<ValueImpl> {
Ok(match self.inner.eval_v2(input).await {
Ok(value) => value,
Err(error) => {
self.report.report(error);
ValueImpl::Scalar {
value: None,
capacity: input.capacity(),
}
}
})
}

async fn eval_row(&self, input: &OwnedRow) -> Result<Datum> {
Ok(match self.inner.eval_row(input).await {
Ok(datum) => datum,
Err(error) => {
self.report.report(error);
None // NULL
}
})
}

fn eval_const(&self) -> Result<Datum> {
self.inner.eval_const() // do not handle error
}
}

0 comments on commit b0b1ab6

Please sign in to comment.