Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(expr): don't fallback to evaluation by row on error #14174

Merged
merged 3 commits into from
Dec 28, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 43 additions & 2 deletions src/expr/core/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,17 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use risingwave_common::array::ArrayError;
use std::fmt::Display;

use risingwave_common::array::{ArrayError, ArrayRef};
use risingwave_common::error::{ErrorCode, RwError};
use risingwave_common::types::DataType;
use risingwave_pb::PbFieldNotFound;
use thiserror::Error;
use thiserror_ext::AsReport;

/// A specialized Result type for expression operations.
pub type Result<T> = std::result::Result<T, ExprError>;
pub type Result<T, E = ExprError> = std::result::Result<T, E>;

pub struct ContextUnavailable(&'static str);

Expand All @@ -39,6 +41,10 @@ impl From<ContextUnavailable> for ExprError {
/// The error type for expression operations.
#[derive(Error, Debug)]
pub enum ExprError {
/// A collection of multiple errors in batch evaluation.
#[error("multiple errors:\n{1}")]
Multiple(ArrayRef, MultiExprError),

// Ideally "Unsupported" errors are caught by frontend. But when the match arms between
// frontend and backend are inconsistent, we do not panic with `unreachable!`.
#[error("Unsupported function: {0}")]
Expand Down Expand Up @@ -135,3 +141,38 @@ impl From<PbFieldNotFound> for ExprError {
))
}
}

/// A collection of multiple errors.
#[derive(Error, Debug)]
pub struct MultiExprError(Box<[ExprError]>);

impl MultiExprError {
/// Returns the first error.
pub fn first(self) -> ExprError {
wangrunji0408 marked this conversation as resolved.
Show resolved Hide resolved
self.0.into_vec().into_iter().next().expect("first error")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.0.into_vec().into_iter().next().expect("first error")
self.0.into_iter().next().expect("first error")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If removed, the item would be &ExprError. Seems to be a historical problem rust-lang/rust#59878

}
}

impl Display for MultiExprError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
for (i, e) in self.0.iter().enumerate() {
writeln!(f, "{i}: {e}")?;
}
Ok(())
}
}

impl From<Vec<ExprError>> for MultiExprError {
fn from(v: Vec<ExprError>) -> Self {
Self(v.into_boxed_slice())
}
}

impl IntoIterator for MultiExprError {
type IntoIter = std::vec::IntoIter<ExprError>;
type Item = ExprError;

fn into_iter(self) -> Self::IntoIter {
self.0.into_vec().into_iter()
}
}
19 changes: 7 additions & 12 deletions src/expr/core/src/expr/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use risingwave_pb::expr::ExprNode;

use super::expr_some_all::SomeAllExpression;
use super::expr_udf::UdfExpression;
use super::non_strict::NonStrictNoFallback;
use super::strict::Strict;
use super::wrapper::checked::Checked;
use super::wrapper::non_strict::NonStrict;
use super::wrapper::EvalErrorReport;
Expand All @@ -34,7 +34,8 @@ use crate::{bail, ExprError, Result};

/// Build an expression from protobuf.
pub fn build_from_prost(prost: &ExprNode) -> Result<BoxedExpression> {
ExprBuilder::new_strict().build(prost)
let expr = ExprBuilder::new_strict().build(prost)?;
Ok(Strict::new(expr).boxed())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move this step into the ExprBuilder?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This wrapper is only needed at the top level. So I think no need to move it into the builder?

}

/// Build an expression from protobuf in non-strict mode.
Expand Down Expand Up @@ -76,15 +77,11 @@ where

/// Attach wrappers to an expression.
#[expect(clippy::let_and_return)]
fn wrap(&self, expr: impl Expression + 'static, no_fallback: bool) -> BoxedExpression {
fn wrap(&self, expr: impl Expression + 'static) -> BoxedExpression {
let checked = Checked(expr);

let may_non_strict = if let Some(error_report) = &self.error_report {
if no_fallback {
NonStrictNoFallback::new(checked, error_report.clone()).boxed()
} else {
NonStrict::new(checked, error_report.clone()).boxed()
}
NonStrict::new(checked, error_report.clone()).boxed()
} else {
checked.boxed()
};
Expand All @@ -95,9 +92,7 @@ where
/// Build an expression with `build_inner` and attach some wrappers.
fn build(&self, prost: &ExprNode) -> Result<BoxedExpression> {
let expr = self.build_inner(prost)?;
// no fallback to row-based evaluation for UDF
let no_fallback = matches!(prost.get_rex_node().unwrap(), RexNode::Udf(_));
Ok(self.wrap(expr, no_fallback))
Ok(self.wrap(expr))
}

/// Build an expression from protobuf.
Expand Down Expand Up @@ -216,7 +211,7 @@ pub fn build_func_non_strict(
error_report: impl EvalErrorReport + 'static,
) -> Result<NonStrictExpression> {
let expr = build_func(func, ret_type, children)?;
let wrapped = NonStrictExpression(ExprBuilder::new_non_strict(error_report).wrap(expr, false));
let wrapped = NonStrictExpression(ExprBuilder::new_non_strict(error_report).wrap(expr));

Ok(wrapped)
}
Expand Down
1 change: 1 addition & 0 deletions src/expr/core/src/expr/wrapper/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@

pub(crate) mod checked;
pub(crate) mod non_strict;
pub(crate) mod strict;

pub use non_strict::{EvalErrorReport, LogReport};
120 changes: 20 additions & 100 deletions src/expr/core/src/expr/wrapper/non_strict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use async_trait::async_trait;
use auto_impl::auto_impl;
use risingwave_common::array::{ArrayRef, DataChunk};
use risingwave_common::log::LogSuppresser;
use risingwave_common::row::{OwnedRow, Row};
use risingwave_common::row::OwnedRow;
use risingwave_common::types::{DataType, Datum};
use thiserror_ext::AsReport;

Expand Down Expand Up @@ -60,8 +60,7 @@ impl EvalErrorReport for LogReport {
}

/// A wrapper of [`Expression`] that evaluates in a non-strict way. Basically...
/// - When an error occurs during chunk-level evaluation, recompute in row-based execution and pad
/// with NULL for each failed row.
/// - When an error occurs during chunk-level evaluation, pad with NULL for each failed row.
/// - Report all error occurred during row-level evaluation to the [`EvalErrorReport`].
pub(crate) struct NonStrict<E, R> {
inner: E,
Expand All @@ -88,31 +87,6 @@ where
pub fn new(inner: E, report: R) -> Self {
Self { inner, report }
}

/// Evaluate expression in row-based execution with `eval_row_infallible`.
async fn eval_chunk_infallible_by_row(&self, input: &DataChunk) -> ArrayRef {
let mut array_builder = self.return_type().create_array_builder(input.capacity());
for row in input.rows_with_holes() {
if let Some(row) = row {
let datum = self.eval_row_infallible(&row.into_owned_row()).await; // TODO: use `Row` trait
array_builder.append(&datum);
} else {
array_builder.append_null();
}
}
array_builder.finish().into()
}

/// Evaluate expression on a single row, report error and return NULL if failed.
async fn eval_row_infallible(&self, input: &OwnedRow) -> Datum {
match self.inner.eval_row(input).await {
Ok(datum) => datum,
Err(error) => {
self.report.report(error);
None // NULL
}
}
}
}

// TODO: avoid the overhead of extra boxing.
Expand All @@ -129,75 +103,14 @@ where
async fn eval(&self, input: &DataChunk) -> Result<ArrayRef> {
Ok(match self.inner.eval(input).await {
Ok(array) => array,
Err(_e) => self.eval_chunk_infallible_by_row(input).await,
})
}

async fn eval_v2(&self, input: &DataChunk) -> Result<ValueImpl> {
Ok(match self.inner.eval_v2(input).await {
Ok(value) => value,
Err(_e) => self.eval_chunk_infallible_by_row(input).await.into(),
})
}

async fn eval_row(&self, input: &OwnedRow) -> Result<Datum> {
Ok(self.eval_row_infallible(input).await)
}

fn eval_const(&self) -> Result<Datum> {
self.inner.eval_const() // do not handle error
}

fn input_ref_index(&self) -> Option<usize> {
self.inner.input_ref_index()
}
}

/// Similar to [`NonStrict`] wrapper, but does not fallback to row-based evaluation when an error occurs.
pub(crate) struct NonStrictNoFallback<E, R> {
inner: E,
report: R,
}

impl<E, R> std::fmt::Debug for NonStrictNoFallback<E, R>
where
E: std::fmt::Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("NonStrictNoFallback")
.field("inner", &self.inner)
.field("report", &std::any::type_name::<R>())
.finish()
}
}

impl<E, R> NonStrictNoFallback<E, R>
where
E: Expression,
R: EvalErrorReport,
{
pub fn new(inner: E, report: R) -> Self {
Self { inner, report }
}
}

// TODO: avoid the overhead of extra boxing.
#[async_trait]
impl<E, R> Expression for NonStrictNoFallback<E, R>
where
E: Expression,
R: EvalErrorReport,
{
fn return_type(&self) -> DataType {
self.inner.return_type()
}

async fn eval(&self, input: &DataChunk) -> Result<ArrayRef> {
Ok(match self.inner.eval(input).await {
Ok(array) => array,
Err(error) => {
self.report.report(error);
// no fallback and return NULL for each row
Err(ExprError::Multiple(array, errors)) => {
for error in errors {
self.report.report(error);
}
array
}
Err(e) => {
self.report.report(e);
Comment on lines +106 to +113
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we ensure that (almost) all Expressions return Multiple? Otherwise, all rows will be padded with NULL once there's an error.

Copy link
Contributor Author

@wangrunji0408 wangrunji0408 Dec 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I have checked that. The remaining expressions not generated by #[function] are not evaluated row by row, except for UDF, which I will fix later.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks so fragile to me. 🥵

let mut builder = self.return_type().create_array_builder(input.capacity());
builder.append_n_null(input.capacity());
builder.finish().into()
Expand All @@ -207,9 +120,15 @@ where

async fn eval_v2(&self, input: &DataChunk) -> Result<ValueImpl> {
Ok(match self.inner.eval_v2(input).await {
Ok(value) => value,
Err(error) => {
self.report.report(error);
Ok(array) => array,
Err(ExprError::Multiple(array, errors)) => {
for error in errors {
self.report.report(error);
}
array.into()
}
Err(e) => {
self.report.report(e);
ValueImpl::Scalar {
value: None,
capacity: input.capacity(),
Expand All @@ -218,6 +137,7 @@ where
})
}

/// Evaluate expression on a single row, report error and return NULL if failed.
async fn eval_row(&self, input: &OwnedRow) -> Result<Datum> {
Ok(match self.inner.eval_row(input).await {
Ok(datum) => datum,
Expand Down
83 changes: 83 additions & 0 deletions src/expr/core/src/expr/wrapper/strict.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// Copyright 2023 RisingWave Labs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use async_trait::async_trait;
use risingwave_common::array::{ArrayRef, DataChunk};
use risingwave_common::row::OwnedRow;
use risingwave_common::types::{DataType, Datum};

use crate::error::Result;
use crate::expr::{Expression, ValueImpl};
use crate::ExprError;

/// A wrapper of [`Expression`] that only keeps the first error if multiple errors are returned.
pub(crate) struct Strict<E> {
inner: E,
}

impl<E> std::fmt::Debug for Strict<E>
where
E: std::fmt::Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Strict")
.field("inner", &self.inner)
.finish()
}
}

impl<E> Strict<E>
where
E: Expression,
{
pub fn new(inner: E) -> Self {
Self { inner }
}
}

#[async_trait]
impl<E> Expression for Strict<E>
where
E: Expression,
{
fn return_type(&self) -> DataType {
self.inner.return_type()
}

async fn eval(&self, input: &DataChunk) -> Result<ArrayRef> {
match self.inner.eval(input).await {
Err(ExprError::Multiple(_, errors)) => Err(errors.first()),
wangrunji0408 marked this conversation as resolved.
Show resolved Hide resolved
res => res,
}
}

async fn eval_v2(&self, input: &DataChunk) -> Result<ValueImpl> {
match self.inner.eval_v2(input).await {
Err(ExprError::Multiple(_, errors)) => Err(errors.first()),
wangrunji0408 marked this conversation as resolved.
Show resolved Hide resolved
res => res,
}
}

async fn eval_row(&self, input: &OwnedRow) -> Result<Datum> {
self.inner.eval_row(input).await
}

fn eval_const(&self) -> Result<Datum> {
self.inner.eval_const()
}

fn input_ref_index(&self) -> Option<usize> {
self.inner.input_ref_index()
}
}
4 changes: 2 additions & 2 deletions src/expr/impl/src/scalar/arithmetic_op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,8 @@ where
}

#[function("abs(decimal) -> decimal")]
pub fn decimal_abs(decimal: Decimal) -> Result<Decimal> {
Ok(Decimal::abs(&decimal))
pub fn decimal_abs(decimal: Decimal) -> Decimal {
Decimal::abs(&decimal)
}

fn err_pow_zero_negative() -> ExprError {
Expand Down
Loading
Loading