diff --git a/src/expr/core/src/expr/build.rs b/src/expr/core/src/expr/build.rs index 8e4cb8439e7df..1e40022fbe170 100644 --- a/src/expr/core/src/expr/build.rs +++ b/src/expr/core/src/expr/build.rs @@ -21,6 +21,7 @@ use risingwave_pb::expr::ExprNode; use super::expr_some_all::SomeAllExpression; use super::expr_udf::UdfExpression; +use super::non_strict::NonStrictNoFallback; use super::wrapper::checked::Checked; use super::wrapper::non_strict::NonStrict; use super::wrapper::EvalErrorReport; @@ -75,11 +76,15 @@ where /// Attach wrappers to an expression. #[expect(clippy::let_and_return)] - fn wrap(&self, expr: impl Expression + 'static) -> BoxedExpression { + fn wrap(&self, expr: impl Expression + 'static, no_fallback: bool) -> BoxedExpression { let checked = Checked(expr); let may_non_strict = if let Some(error_report) = &self.error_report { - NonStrict::new(checked, error_report.clone()).boxed() + if no_fallback { + NonStrictNoFallback::new(checked, error_report.clone()).boxed() + } else { + NonStrict::new(checked, error_report.clone()).boxed() + } } else { checked.boxed() }; @@ -90,7 +95,9 @@ where /// Build an expression with `build_inner` and attach some wrappers. fn build(&self, prost: &ExprNode) -> Result { let expr = self.build_inner(prost)?; - Ok(self.wrap(expr)) + // no fallback to row-based evaluation for UDF + let no_fallback = matches!(prost.get_rex_node().unwrap(), RexNode::Udf(_)); + Ok(self.wrap(expr, no_fallback)) } /// Build an expression from protobuf. @@ -209,7 +216,7 @@ pub fn build_func_non_strict( error_report: impl EvalErrorReport + 'static, ) -> Result { let expr = build_func(func, ret_type, children)?; - let wrapped = NonStrictExpression(ExprBuilder::new_non_strict(error_report).wrap(expr)); + let wrapped = NonStrictExpression(ExprBuilder::new_non_strict(error_report).wrap(expr, false)); Ok(wrapped) } diff --git a/src/expr/core/src/expr/wrapper/non_strict.rs b/src/expr/core/src/expr/wrapper/non_strict.rs index fa261cfabe446..e1ed69a3e3598 100644 --- a/src/expr/core/src/expr/wrapper/non_strict.rs +++ b/src/expr/core/src/expr/wrapper/non_strict.rs @@ -152,3 +152,87 @@ where self.inner.input_ref_index() } } + +/// Similar to [`NonStrict`] wrapper, but does not fallback to row-based evaluation when an error occurs. +pub(crate) struct NonStrictNoFallback { + inner: E, + report: R, +} + +impl std::fmt::Debug for NonStrictNoFallback +where + E: std::fmt::Debug, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("NonStrictNoFallback") + .field("inner", &self.inner) + .field("report", &std::any::type_name::()) + .finish() + } +} + +impl NonStrictNoFallback +where + E: Expression, + R: EvalErrorReport, +{ + pub fn new(inner: E, report: R) -> Self { + Self { inner, report } + } +} + +// TODO: avoid the overhead of extra boxing. +#[async_trait] +impl Expression for NonStrictNoFallback +where + E: Expression, + R: EvalErrorReport, +{ + fn return_type(&self) -> DataType { + self.inner.return_type() + } + + async fn eval(&self, input: &DataChunk) -> Result { + Ok(match self.inner.eval(input).await { + Ok(array) => array, + Err(error) => { + self.report.report(error); + // no fallback and return NULL for each row + let mut builder = self.return_type().create_array_builder(input.capacity()); + builder.append_n_null(input.capacity()); + builder.finish().into() + } + }) + } + + async fn eval_v2(&self, input: &DataChunk) -> Result { + Ok(match self.inner.eval_v2(input).await { + Ok(value) => value, + Err(error) => { + self.report.report(error); + ValueImpl::Scalar { + value: None, + capacity: input.capacity(), + } + } + }) + } + + async fn eval_row(&self, input: &OwnedRow) -> Result { + Ok(match self.inner.eval_row(input).await { + Ok(datum) => datum, + Err(error) => { + self.report.report(error); + None // NULL + } + }) + } + + fn eval_const(&self) -> Result { + self.inner.eval_const() // do not handle error + } + + fn input_ref_index(&self) -> Option { + self.inner.input_ref_index() + } +}