Skip to content

Commit

Permalink
fix(expr): fix a critical bug in case expression (#13890)
Browse files Browse the repository at this point in the history
Signed-off-by: Runji Wang <[email protected]>
  • Loading branch information
wangrunji0408 authored Dec 9, 2023
1 parent 8e6818e commit 5596207
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 74 deletions.
2 changes: 0 additions & 2 deletions src/expr/core/src/expr/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ use risingwave_pb::expr::expr_node::{PbType, RexNode};
use risingwave_pb::expr::ExprNode;

use super::expr_array_transform::ArrayTransformExpression;
use super::expr_case::CaseExpression;
use super::expr_coalesce::CoalesceExpression;
use super::expr_field::FieldExpression;
use super::expr_in::InExpression;
Expand Down Expand Up @@ -114,7 +113,6 @@ where
// Dedicated types
E::All | E::Some => SomeAllExpression::build_boxed(prost, build_child),
E::In => InExpression::build_boxed(prost, build_child),
E::Case => CaseExpression::build_boxed(prost, build_child),
E::Coalesce => CoalesceExpression::build_boxed(prost, build_child),
E::Field => FieldExpression::build_boxed(prost, build_child),
E::Vnode => VnodeExpression::build_boxed(prost, build_child),
Expand Down
1 change: 0 additions & 1 deletion src/expr/core/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
// These modules define concrete expression structures.
mod and_or;
mod expr_array_transform;
mod expr_case;
mod expr_coalesce;
mod expr_field;
mod expr_in;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,31 +15,27 @@
use std::sync::Arc;

use risingwave_common::array::{ArrayRef, DataChunk};
use risingwave_common::bail;
use risingwave_common::row::OwnedRow;
use risingwave_common::types::{DataType, Datum};
use risingwave_common::{bail, ensure};
use risingwave_pb::expr::expr_node::{PbType, RexNode};
use risingwave_pb::expr::ExprNode;

use super::Build;
use crate::expr::{BoxedExpression, Expression};
use crate::Result;
use risingwave_expr::expr::{BoxedExpression, Expression};
use risingwave_expr::{build_function, Result};

#[derive(Debug)]
pub struct WhenClause {
struct WhenClause {
when: BoxedExpression,
then: BoxedExpression,
}

#[derive(Debug)]
pub struct CaseExpression {
struct CaseExpression {
return_type: DataType,
when_clauses: Vec<WhenClause>,
else_clause: Option<BoxedExpression>,
}

impl CaseExpression {
pub fn new(
fn new(
return_type: DataType,
when_clauses: Vec<WhenClause>,
else_clause: Option<BoxedExpression>,
Expand All @@ -65,8 +61,10 @@ impl Expression for CaseExpression {
let when_len = self.when_clauses.len();
let mut result_array = Vec::with_capacity(when_len + 1);
for (when_idx, WhenClause { when, then }) in self.when_clauses.iter().enumerate() {
let calc_then_vis = when.eval(&input).await?.as_bool().to_bitmap();
let input_vis = input.visibility().clone();
// note that evaluated result from when clause may contain bits that are not visible,
// so we need to mask it with input visibility.
let calc_then_vis = when.eval(&input).await?.as_bool().to_bitmap() & &input_vis;
input.set_visibility(calc_then_vis.clone());
let then_res = then.eval(&input).await?;
calc_then_vis
Expand Down Expand Up @@ -108,49 +106,37 @@ impl Expression for CaseExpression {
}
}

impl Build for CaseExpression {
fn build(
prost: &ExprNode,
build_child: impl Fn(&ExprNode) -> Result<BoxedExpression>,
) -> Result<Self> {
ensure!(prost.get_function_type().unwrap() == PbType::Case);

let ret_type = DataType::from(prost.get_return_type().unwrap());
let RexNode::FuncCall(func_call_node) = prost.get_rex_node().unwrap() else {
bail!("Expected RexNode::FuncCall");
};
let children = &func_call_node.children;
// children: (when, then)+, (else_clause)?
let len = children.len();
let else_clause = if len % 2 == 1 {
let else_clause = build_child(&children[len - 1])?;
if else_clause.return_type() != ret_type {
bail!("Type mismatched between else and case.");
}
Some(else_clause)
} else {
None
};
let mut when_clauses = vec![];
for i in 0..len / 2 {
let when_index = i * 2;
let then_index = i * 2 + 1;
let when_expr = build_child(&children[when_index])?;
let then_expr = build_child(&children[then_index])?;
if when_expr.return_type() != DataType::Boolean {
bail!("Type mismatched between when clause and condition");
}
if then_expr.return_type() != ret_type {
bail!("Type mismatched between then clause and case");
}
let when_clause = WhenClause {
when: when_expr,
then: then_expr,
};
when_clauses.push(when_clause);
#[build_function("case(...) -> any", type_infer = "panic")]
fn build_case_expr(
return_type: DataType,
children: Vec<BoxedExpression>,
) -> Result<BoxedExpression> {
// children: (when, then)+, (else_clause)?
let len = children.len();
let mut when_clauses = Vec::with_capacity(len / 2);
let mut iter = children.into_iter().array_chunks();
for [when, then] in iter.by_ref() {
if when.return_type() != DataType::Boolean {
bail!("Type mismatched between when clause and condition");
}
Ok(CaseExpression::new(ret_type, when_clauses, else_clause))
if then.return_type() != return_type {
bail!("Type mismatched between then clause and case");
}
when_clauses.push(WhenClause { when, then });
}
let else_clause = if let Some(else_clause) = iter.into_remainder().unwrap().next() {
if else_clause.return_type() != return_type {
bail!("Type mismatched between else and case.");
}
Some(else_clause)
} else {
None
};
Ok(Box::new(CaseExpression::new(
return_type,
when_clauses,
else_clause,
)))
}

#[cfg(test)]
Expand All @@ -159,19 +145,14 @@ mod tests {
use risingwave_common::test_prelude::DataChunkTestExt;
use risingwave_common::types::ToOwnedDatum;
use risingwave_common::util::iter_util::ZipEqDebug;
use risingwave_expr::expr::build_from_pretty;

use super::*;
use crate::expr::build_from_pretty;

#[tokio::test]
async fn test_eval_searched_case() {
// when x then 1 else 2
let when_clauses = vec![WhenClause {
when: build_from_pretty("$0:boolean"),
then: build_from_pretty("1:int4"),
}];
let els = build_from_pretty("2:int4");
let case = CaseExpression::new(DataType::Int32, when_clauses, Some(els));
let case = build_from_pretty("(case:int4 $0:boolean 1:int4 2:int4)");
let (input, expected) = DataChunk::from_pretty(
"B i
t 1
Expand All @@ -195,20 +176,16 @@ mod tests {

#[tokio::test]
async fn test_eval_without_else() {
// when x then 1
let when_clauses = vec![WhenClause {
when: build_from_pretty("$0:boolean"),
then: build_from_pretty("1:int4"),
}];
let case = CaseExpression::new(DataType::Int32, when_clauses, None);
// when x then 1 when y then 2
let case = build_from_pretty("(case:int4 $0:boolean 1:int4 $1:boolean 2:int4)");
let (input, expected) = DataChunk::from_pretty(
"B i
t 1
f .
t 1
f .",
"B B i
f f .
f t 2
t f 1
t t 1",
)
.split_column_at(1);
.split_column_at(2);

// test eval
let output = case.eval(&input).await.unwrap();
Expand Down
1 change: 1 addition & 0 deletions src/expr/impl/src/scalar/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ mod array_to_string;
mod ascii;
mod bitwise_op;
mod cardinality;
mod case;
mod cast;
mod cmp;
mod concat_op;
Expand Down

0 comments on commit 5596207

Please sign in to comment.