From 5596207f4d28e2f4b1ebbff5128f8b7ce6d23cca Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Sat, 9 Dec 2023 18:43:20 +0800 Subject: [PATCH] fix(expr): fix a critical bug in case expression (#13890) Signed-off-by: Runji Wang --- src/expr/core/src/expr/build.rs | 2 - src/expr/core/src/expr/mod.rs | 1 - .../expr_case.rs => impl/src/scalar/case.rs} | 119 +++++++----------- src/expr/impl/src/scalar/mod.rs | 1 + 4 files changed, 49 insertions(+), 74 deletions(-) rename src/expr/{core/src/expr/expr_case.rs => impl/src/scalar/case.rs} (64%) diff --git a/src/expr/core/src/expr/build.rs b/src/expr/core/src/expr/build.rs index f0fd3397c4fa..46c672d6da52 100644 --- a/src/expr/core/src/expr/build.rs +++ b/src/expr/core/src/expr/build.rs @@ -20,7 +20,6 @@ use risingwave_pb::expr::expr_node::{PbType, RexNode}; use risingwave_pb::expr::ExprNode; use super::expr_array_transform::ArrayTransformExpression; -use super::expr_case::CaseExpression; use super::expr_coalesce::CoalesceExpression; use super::expr_field::FieldExpression; use super::expr_in::InExpression; @@ -114,7 +113,6 @@ where // Dedicated types E::All | E::Some => SomeAllExpression::build_boxed(prost, build_child), E::In => InExpression::build_boxed(prost, build_child), - E::Case => CaseExpression::build_boxed(prost, build_child), E::Coalesce => CoalesceExpression::build_boxed(prost, build_child), E::Field => FieldExpression::build_boxed(prost, build_child), E::Vnode => VnodeExpression::build_boxed(prost, build_child), diff --git a/src/expr/core/src/expr/mod.rs b/src/expr/core/src/expr/mod.rs index 48a46f640bf7..efbba9e66846 100644 --- a/src/expr/core/src/expr/mod.rs +++ b/src/expr/core/src/expr/mod.rs @@ -34,7 +34,6 @@ // These modules define concrete expression structures. mod and_or; mod expr_array_transform; -mod expr_case; mod expr_coalesce; mod expr_field; mod expr_in; diff --git a/src/expr/core/src/expr/expr_case.rs b/src/expr/impl/src/scalar/case.rs similarity index 64% rename from src/expr/core/src/expr/expr_case.rs rename to src/expr/impl/src/scalar/case.rs index 49f11298d3e0..f7b1dc677d04 100644 --- a/src/expr/core/src/expr/expr_case.rs +++ b/src/expr/impl/src/scalar/case.rs @@ -15,31 +15,27 @@ use std::sync::Arc; use risingwave_common::array::{ArrayRef, DataChunk}; +use risingwave_common::bail; use risingwave_common::row::OwnedRow; use risingwave_common::types::{DataType, Datum}; -use risingwave_common::{bail, ensure}; -use risingwave_pb::expr::expr_node::{PbType, RexNode}; -use risingwave_pb::expr::ExprNode; - -use super::Build; -use crate::expr::{BoxedExpression, Expression}; -use crate::Result; +use risingwave_expr::expr::{BoxedExpression, Expression}; +use risingwave_expr::{build_function, Result}; #[derive(Debug)] -pub struct WhenClause { +struct WhenClause { when: BoxedExpression, then: BoxedExpression, } #[derive(Debug)] -pub struct CaseExpression { +struct CaseExpression { return_type: DataType, when_clauses: Vec, else_clause: Option, } impl CaseExpression { - pub fn new( + fn new( return_type: DataType, when_clauses: Vec, else_clause: Option, @@ -65,8 +61,10 @@ impl Expression for CaseExpression { let when_len = self.when_clauses.len(); let mut result_array = Vec::with_capacity(when_len + 1); for (when_idx, WhenClause { when, then }) in self.when_clauses.iter().enumerate() { - let calc_then_vis = when.eval(&input).await?.as_bool().to_bitmap(); let input_vis = input.visibility().clone(); + // note that evaluated result from when clause may contain bits that are not visible, + // so we need to mask it with input visibility. + let calc_then_vis = when.eval(&input).await?.as_bool().to_bitmap() & &input_vis; input.set_visibility(calc_then_vis.clone()); let then_res = then.eval(&input).await?; calc_then_vis @@ -108,49 +106,37 @@ impl Expression for CaseExpression { } } -impl Build for CaseExpression { - fn build( - prost: &ExprNode, - build_child: impl Fn(&ExprNode) -> Result, - ) -> Result { - ensure!(prost.get_function_type().unwrap() == PbType::Case); - - let ret_type = DataType::from(prost.get_return_type().unwrap()); - let RexNode::FuncCall(func_call_node) = prost.get_rex_node().unwrap() else { - bail!("Expected RexNode::FuncCall"); - }; - let children = &func_call_node.children; - // children: (when, then)+, (else_clause)? - let len = children.len(); - let else_clause = if len % 2 == 1 { - let else_clause = build_child(&children[len - 1])?; - if else_clause.return_type() != ret_type { - bail!("Type mismatched between else and case."); - } - Some(else_clause) - } else { - None - }; - let mut when_clauses = vec![]; - for i in 0..len / 2 { - let when_index = i * 2; - let then_index = i * 2 + 1; - let when_expr = build_child(&children[when_index])?; - let then_expr = build_child(&children[then_index])?; - if when_expr.return_type() != DataType::Boolean { - bail!("Type mismatched between when clause and condition"); - } - if then_expr.return_type() != ret_type { - bail!("Type mismatched between then clause and case"); - } - let when_clause = WhenClause { - when: when_expr, - then: then_expr, - }; - when_clauses.push(when_clause); +#[build_function("case(...) -> any", type_infer = "panic")] +fn build_case_expr( + return_type: DataType, + children: Vec, +) -> Result { + // children: (when, then)+, (else_clause)? + let len = children.len(); + let mut when_clauses = Vec::with_capacity(len / 2); + let mut iter = children.into_iter().array_chunks(); + for [when, then] in iter.by_ref() { + if when.return_type() != DataType::Boolean { + bail!("Type mismatched between when clause and condition"); } - Ok(CaseExpression::new(ret_type, when_clauses, else_clause)) + if then.return_type() != return_type { + bail!("Type mismatched between then clause and case"); + } + when_clauses.push(WhenClause { when, then }); } + let else_clause = if let Some(else_clause) = iter.into_remainder().unwrap().next() { + if else_clause.return_type() != return_type { + bail!("Type mismatched between else and case."); + } + Some(else_clause) + } else { + None + }; + Ok(Box::new(CaseExpression::new( + return_type, + when_clauses, + else_clause, + ))) } #[cfg(test)] @@ -159,19 +145,14 @@ mod tests { use risingwave_common::test_prelude::DataChunkTestExt; use risingwave_common::types::ToOwnedDatum; use risingwave_common::util::iter_util::ZipEqDebug; + use risingwave_expr::expr::build_from_pretty; use super::*; - use crate::expr::build_from_pretty; #[tokio::test] async fn test_eval_searched_case() { // when x then 1 else 2 - let when_clauses = vec![WhenClause { - when: build_from_pretty("$0:boolean"), - then: build_from_pretty("1:int4"), - }]; - let els = build_from_pretty("2:int4"); - let case = CaseExpression::new(DataType::Int32, when_clauses, Some(els)); + let case = build_from_pretty("(case:int4 $0:boolean 1:int4 2:int4)"); let (input, expected) = DataChunk::from_pretty( "B i t 1 @@ -195,20 +176,16 @@ mod tests { #[tokio::test] async fn test_eval_without_else() { - // when x then 1 - let when_clauses = vec![WhenClause { - when: build_from_pretty("$0:boolean"), - then: build_from_pretty("1:int4"), - }]; - let case = CaseExpression::new(DataType::Int32, when_clauses, None); + // when x then 1 when y then 2 + let case = build_from_pretty("(case:int4 $0:boolean 1:int4 $1:boolean 2:int4)"); let (input, expected) = DataChunk::from_pretty( - "B i - t 1 - f . - t 1 - f .", + "B B i + f f . + f t 2 + t f 1 + t t 1", ) - .split_column_at(1); + .split_column_at(2); // test eval let output = case.eval(&input).await.unwrap(); diff --git a/src/expr/impl/src/scalar/mod.rs b/src/expr/impl/src/scalar/mod.rs index 21620d5a6a4e..900e38cdd9ce 100644 --- a/src/expr/impl/src/scalar/mod.rs +++ b/src/expr/impl/src/scalar/mod.rs @@ -30,6 +30,7 @@ mod array_to_string; mod ascii; mod bitwise_op; mod cardinality; +mod case; mod cast; mod cmp; mod concat_op;