diff --git a/src/expr/core/src/expr/build.rs b/src/expr/core/src/expr/build.rs index 46c672d6da521..19854f47e0501 100644 --- a/src/expr/core/src/expr/build.rs +++ b/src/expr/core/src/expr/build.rs @@ -26,6 +26,7 @@ use super::expr_in::InExpression; use super::expr_some_all::SomeAllExpression; use super::expr_udf::UdfExpression; use super::expr_vnode::VnodeExpression; +use super::non_strict::NonStrictNoFallback; use super::wrapper::checked::Checked; use super::wrapper::non_strict::NonStrict; use super::wrapper::EvalErrorReport; @@ -80,11 +81,15 @@ where /// Attach wrappers to an expression. #[expect(clippy::let_and_return)] - fn wrap(&self, expr: impl Expression + 'static) -> BoxedExpression { + fn wrap(&self, expr: impl Expression + 'static, no_fallback: bool) -> BoxedExpression { let checked = Checked(expr); let may_non_strict = if let Some(error_report) = &self.error_report { - NonStrict::new(checked, error_report.clone()).boxed() + if no_fallback { + NonStrictNoFallback::new(checked, error_report.clone()).boxed() + } else { + NonStrict::new(checked, error_report.clone()).boxed() + } } else { checked.boxed() }; @@ -95,7 +100,9 @@ where /// Build an expression with `build_inner` and attach some wrappers. fn build(&self, prost: &ExprNode) -> Result { let expr = self.build_inner(prost)?; - Ok(self.wrap(expr)) + // no fallback to row-based evaluation for UDF + let no_fallback = matches!(prost.get_rex_node().unwrap(), RexNode::Udf(_)); + Ok(self.wrap(expr, no_fallback)) } /// Build an expression from protobuf. @@ -224,7 +231,7 @@ pub fn build_func_non_strict( error_report: impl EvalErrorReport + 'static, ) -> Result { let expr = build_func(func, ret_type, children)?; - let wrapped = NonStrictExpression(ExprBuilder::new_non_strict(error_report).wrap(expr)); + let wrapped = NonStrictExpression(ExprBuilder::new_non_strict(error_report).wrap(expr, false)); Ok(wrapped) } diff --git a/src/expr/core/src/expr/wrapper/non_strict.rs b/src/expr/core/src/expr/wrapper/non_strict.rs index 782456023cdf7..c965b2a4bf5ce 100644 --- a/src/expr/core/src/expr/wrapper/non_strict.rs +++ b/src/expr/core/src/expr/wrapper/non_strict.rs @@ -16,7 +16,7 @@ use async_trait::async_trait; use auto_impl::auto_impl; use risingwave_common::array::{ArrayRef, DataChunk}; use risingwave_common::row::{OwnedRow, Row}; -use risingwave_common::types::{DataType, Datum}; +use risingwave_common::types::{DataType, Datum, ScalarImpl}; use crate::error::Result; use crate::expr::{Expression, ValueImpl}; @@ -141,3 +141,83 @@ where self.inner.eval_const() // do not handle error } } + +/// Similar to [`NonStrict`] wrapper, but does not fallback to row-based evaluation when an error occurs. +pub(crate) struct NonStrictNoFallback { + inner: E, + report: R, +} + +impl std::fmt::Debug for NonStrictNoFallback +where + E: std::fmt::Debug, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("NonStrictNoFallback") + .field("inner", &self.inner) + .field("report", &std::any::type_name::()) + .finish() + } +} + +impl NonStrictNoFallback +where + E: Expression, + R: EvalErrorReport, +{ + pub fn new(inner: E, report: R) -> Self { + Self { inner, report } + } +} + +// TODO: avoid the overhead of extra boxing. +#[async_trait] +impl Expression for NonStrictNoFallback +where + E: Expression, + R: EvalErrorReport, +{ + fn return_type(&self) -> DataType { + self.inner.return_type() + } + + async fn eval(&self, input: &DataChunk) -> Result { + Ok(match self.inner.eval(input).await { + Ok(array) => array, + Err(error) => { + self.report.report(error); + // no fallback and return NULL for each row + let mut builder = self.return_type().create_array_builder(input.capacity()); + builder.append_n(input.capacity(), Option::::None); + builder.finish().into() + } + }) + } + + async fn eval_v2(&self, input: &DataChunk) -> Result { + Ok(match self.inner.eval_v2(input).await { + Ok(value) => value, + Err(error) => { + self.report.report(error); + ValueImpl::Scalar { + value: None, + capacity: input.capacity(), + } + } + }) + } + + async fn eval_row(&self, input: &OwnedRow) -> Result { + Ok(match self.inner.eval_row(input).await { + Ok(datum) => datum, + Err(error) => { + self.report.report(error); + None // NULL + } + }) + } + + fn eval_const(&self) -> Result { + self.inner.eval_const() // do not handle error + } +}