diff --git a/e2e_test/batch/basic/case_when_optimization.slt.part b/e2e_test/batch/basic/case_when_optimization.slt.part new file mode 100644 index 000000000000..7e01be030a91 --- /dev/null +++ b/e2e_test/batch/basic/case_when_optimization.slt.part @@ -0,0 +1,242 @@ +statement ok +SET RW_IMPLICIT_FLUSH TO true; + +statement ok +CREATE TABLE t1 (c1 INT, c2 INT, c3 INT); + +statement ok +INSERT INTO t1 VALUES (1, 1, 1), (2, 2, 2), (3, 3, 3), (4, 4, 4), (5, 5, 5), + (6, 6, 6), (7, 7, 7), (8, 8, 8), (9, 9, 9), (10, 10, 10), + (11, 11, 11), (12, 12, 12), (13, 13, 13), (14, 14, 14), (15, 15, 15), + (16, 16, 16), (17, 17, 17), (18, 18, 18), (19, 19, 19), (20, 20, 20), + (21, 21, 21), (22, 22, 22), (23, 23, 23), (24, 24, 24), (25, 25, 25), + (26, 26, 26), (27, 27, 27), (28, 28, 28), (29, 29, 29), (30, 30, 30), + (31, 31, 31), (32, 32, 32), (33, 33, 33), (34, 34, 34), (35, 35, 35), + (36, 36, 36), (37, 37, 37), (38, 38, 38), (39, 39, 39), (40, 40, 40), + (41, 41, 41), (42, 42, 42), (43, 43, 43), (44, 44, 44), (45, 45, 45), + (46, 46, 46), (47, 47, 47), (48, 48, 48), (49, 49, 49), (50, 50, 50), + (51, 51, 51), (52, 52, 52), (53, 53, 53), (54, 54, 54), (55, 55, 55), + (56, 56, 56), (57, 57, 57), (58, 58, 58), (59, 59, 59), (60, 60, 60), + (61, 61, 61), (62, 62, 62), (63, 63, 63), (64, 64, 64), (65, 65, 65), + (66, 66, 66), (67, 67, 67), (68, 68, 68), (69, 69, 69), (70, 70, 70), + (71, 71, 71), (72, 72, 72), (73, 73, 73), (74, 74, 74), (75, 75, 75), + (76, 76, 76), (77, 77, 77), (78, 78, 78), (79, 79, 79), (80, 80, 80), + (81, 81, 81), (82, 82, 82), (83, 83, 83), (84, 84, 84), (85, 85, 85), + (86, 86, 86), (87, 87, 87), (88, 88, 88), (89, 89, 89), (90, 90, 90), + (91, 91, 91), (92, 92, 92), (93, 93, 93), (94, 94, 94), (95, 95, 95), + (96, 96, 96), (97, 97, 97), (98, 98, 98), (99, 99, 99), (100, 100, 100); + + +# 101 arms case-when expression, with optimizable pattern +query I +SELECT + CASE c1 + WHEN 1 THEN 'one' + WHEN 2 THEN 'two' + WHEN 3 THEN 'three' + WHEN 4 THEN 'four' + WHEN 5 THEN 'five' + WHEN 6 THEN 'six' + WHEN 7 THEN 'seven' + WHEN 8 THEN 'eight' + WHEN 9 THEN 'nine' + WHEN 10 THEN 'ten' + WHEN 11 THEN 'eleven' + WHEN 12 THEN 'twelve' + WHEN 13 THEN 'thirteen' + WHEN 14 THEN 'fourteen' + WHEN 15 THEN 'fifteen' + WHEN 16 THEN 'sixteen' + WHEN 17 THEN 'seventeen' + WHEN 18 THEN 'eighteen' + WHEN 19 THEN 'nineteen' + WHEN 20 THEN 'twenty' + WHEN 21 THEN 'twenty-one' + WHEN 22 THEN 'twenty-two' + WHEN 23 THEN 'twenty-three' + WHEN 24 THEN 'twenty-four' + WHEN 25 THEN 'twenty-five' + WHEN 26 THEN 'twenty-six' + WHEN 27 THEN 'twenty-seven' + WHEN 28 THEN 'twenty-eight' + WHEN 29 THEN 'twenty-nine' + WHEN 30 THEN 'thirty' + WHEN 31 THEN 'thirty-one' + WHEN 32 THEN 'thirty-two' + WHEN 33 THEN 'thirty-three' + WHEN 34 THEN 'thirty-four' + WHEN 35 THEN 'thirty-five' + WHEN 36 THEN 'thirty-six' + WHEN 37 THEN 'thirty-seven' + WHEN 38 THEN 'thirty-eight' + WHEN 39 THEN 'thirty-nine' + WHEN 40 THEN 'forty' + WHEN 41 THEN 'forty-one' + WHEN 42 THEN 'forty-two' + WHEN 43 THEN 'forty-three' + WHEN 44 THEN 'forty-four' + WHEN 45 THEN 'forty-five' + WHEN 46 THEN 'forty-six' + WHEN 47 THEN 'forty-seven' + WHEN 48 THEN 'forty-eight' + WHEN 49 THEN 'forty-nine' + WHEN 50 THEN 'fifty' + WHEN 51 THEN 'fifty-one' + WHEN 52 THEN 'fifty-two' + WHEN 53 THEN 'fifty-three' + WHEN 54 THEN 'fifty-four' + WHEN 55 THEN 'fifty-five' + WHEN 56 THEN 'fifty-six' + WHEN 57 THEN 'fifty-seven' + WHEN 58 THEN 'fifty-eight' + WHEN 59 THEN 'fifty-nine' + WHEN 60 THEN 'sixty' + WHEN 61 THEN 'sixty-one' + WHEN 62 THEN 'sixty-two' + WHEN 63 THEN 'sixty-three' + WHEN 64 THEN 'sixty-four' + WHEN 65 THEN 'sixty-five' + WHEN 66 THEN 'sixty-six' + WHEN 67 THEN 'sixty-seven' + WHEN 68 THEN 'sixty-eight' + WHEN 69 THEN 'sixty-nine' + WHEN 70 THEN 'seventy' + WHEN 71 THEN 'seventy-one' + WHEN 72 THEN 'seventy-two' + WHEN 73 THEN 'seventy-three' + WHEN 74 THEN 'seventy-four' + WHEN 75 THEN 'seventy-five' + WHEN 76 THEN 'seventy-six' + WHEN 77 THEN 'seventy-seven' + WHEN 78 THEN 'seventy-eight' + WHEN 79 THEN 'seventy-nine' + WHEN 80 THEN 'eighty' + WHEN 81 THEN 'eighty-one' + WHEN 82 THEN 'eighty-two' + WHEN 83 THEN 'eighty-three' + WHEN 84 THEN 'eighty-four' + WHEN 85 THEN 'eighty-five' + WHEN 86 THEN 'eighty-six' + WHEN 87 THEN 'eighty-seven' + WHEN 88 THEN 'eighty-eight' + WHEN 89 THEN 'eighty-nine' + WHEN 90 THEN 'ninety' + WHEN 91 THEN 'ninety-one' + WHEN 92 THEN 'ninety-two' + WHEN 93 THEN 'ninety-three' + WHEN 94 THEN 'ninety-four' + WHEN 95 THEN 'ninety-five' + WHEN 96 THEN 'ninety-six' + WHEN 97 THEN 'ninety-seven' + WHEN 98 THEN 'ninety-eight' + WHEN 99 THEN 'ninety-nine' + WHEN 100 THEN 'one hundred' + ELSE + '114514' + END +FROM t1 +ORDER BY c1 ASC; +---- +one +two +three +four +five +six +seven +eight +nine +ten +eleven +twelve +thirteen +fourteen +fifteen +sixteen +seventeen +eighteen +nineteen +twenty +twenty-one +twenty-two +twenty-three +twenty-four +twenty-five +twenty-six +twenty-seven +twenty-eight +twenty-nine +thirty +thirty-one +thirty-two +thirty-three +thirty-four +thirty-five +thirty-six +thirty-seven +thirty-eight +thirty-nine +forty +forty-one +forty-two +forty-three +forty-four +forty-five +forty-six +forty-seven +forty-eight +forty-nine +fifty +fifty-one +fifty-two +fifty-three +fifty-four +fifty-five +fifty-six +fifty-seven +fifty-eight +fifty-nine +sixty +sixty-one +sixty-two +sixty-three +sixty-four +sixty-five +sixty-six +sixty-seven +sixty-eight +sixty-nine +seventy +seventy-one +seventy-two +seventy-three +seventy-four +seventy-five +seventy-six +seventy-seven +seventy-eight +seventy-nine +eighty +eighty-one +eighty-two +eighty-three +eighty-four +eighty-five +eighty-six +eighty-seven +eighty-eight +eighty-nine +ninety +ninety-one +ninety-two +ninety-three +ninety-four +ninety-five +ninety-six +ninety-seven +ninety-eight +ninety-nine +one hundred + +statement ok +drop table t1; \ No newline at end of file diff --git a/proto/expr.proto b/proto/expr.proto index f62ee2936d11..7ab48a405d13 100644 --- a/proto/expr.proto +++ b/proto/expr.proto @@ -82,6 +82,9 @@ message ExprNode { LTRIM = 210; RTRIM = 211; CASE = 212; + // Optimize case-when expression to constant lookup + // when arms are in a large scale with simple form + CONSTANT_LOOKUP = 624; // ROUND(numeric, integer) -> numeric ROUND_DIGIT = 213; // ROUND(numeric) -> numeric diff --git a/src/expr/impl/src/scalar/case.rs b/src/expr/impl/src/scalar/case.rs index d950cd60af55..64a68b987f86 100644 --- a/src/expr/impl/src/scalar/case.rs +++ b/src/expr/impl/src/scalar/case.rs @@ -12,12 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::HashMap; use std::sync::Arc; use risingwave_common::array::{ArrayRef, DataChunk}; use risingwave_common::bail; -use risingwave_common::row::OwnedRow; -use risingwave_common::types::{DataType, Datum}; +use risingwave_common::row::{OwnedRow, Row}; +use risingwave_common::types::{DataType, Datum, ScalarImpl}; use risingwave_expr::expr::{BoxedExpression, Expression}; use risingwave_expr::{build_function, Result}; @@ -106,6 +107,132 @@ impl Expression for CaseExpression { } } +/// With large scale of simple form match arms in case-when expression, +/// we could optimize the `CaseExpression` to `ConstantLookupExpression`, +/// which could significantly facilitate the evaluation of case-when. +#[derive(Debug)] +struct ConstantLookupExpression { + return_type: DataType, + arms: HashMap, + fallback: Option, + /// `operand` must exist at present + operand: BoxedExpression, +} + +impl ConstantLookupExpression { + fn new( + return_type: DataType, + arms: HashMap, + fallback: Option, + operand: BoxedExpression, + ) -> Self { + Self { + return_type, + arms, + fallback, + operand, + } + } +} + +#[async_trait::async_trait] +impl Expression for ConstantLookupExpression { + fn return_type(&self) -> DataType { + self.return_type.clone() + } + + async fn eval(&self, input: &DataChunk) -> Result { + let input_len = input.capacity(); + let mut builder = self.return_type().create_array_builder(input_len); + + // Evaluate the input DataChunk at first + let eval_result = self.operand.eval(input).await?; + + for i in 0..input_len { + let datum = eval_result.datum_at(i); + let (row, vis) = input.row_at(i); + + // Check for visibility + if !vis { + builder.append_null(); + continue; + } + + // Note that the `owned_row` here is extracted from input + // rather than from `eval_result` + let owned_row = row.into_owned_row(); + + if let Some(expr) = self.arms.get(datum.as_ref().unwrap()) { + builder.append(expr.eval_row(&owned_row).await.unwrap().as_ref()); + } else { + // Otherwise this should goes to the fallback arm + // The fallback arm should also be const + if let Some(ref fallback) = self.fallback { + builder.append(fallback.eval_row(&owned_row).await.unwrap().as_ref()); + } else { + builder.append_null(); + } + } + } + + Ok(Arc::new(builder.finish())) + } + + async fn eval_row(&self, input: &OwnedRow) -> Result { + let datum = self.operand.eval_row(input).await?; + + if let Some(expr) = self.arms.get(datum.as_ref().unwrap()) { + expr.eval_row(input).await + } else { + let Some(ref expr) = self.fallback else { + return Ok(None); + }; + expr.eval_row(input).await + } + } +} + +#[build_function("constant_lookup(...) -> any", type_infer = "panic")] +fn build_constant_lookup_expr( + return_type: DataType, + children: Vec, +) -> Result { + if children.is_empty() { + bail!("children expression must not be empty for constant lookup expression"); + } + + let mut children = children; + + let operand = children.remove(0); + + let mut arms = HashMap::new(); + + // Build the `arms` with iterating over `when` & `then` clauses + let mut iter = children.into_iter().array_chunks(); + for [when, then] in iter.by_ref() { + let Ok(Some(s)) = when.eval_const() else { + bail!("expect when expression to be const"); + }; + arms.insert(s, then); + } + + let fallback = if let Some(else_clause) = iter.into_remainder().unwrap().next() { + if else_clause.return_type() != return_type { + bail!("Type mismatched between else and case."); + } + Some(else_clause) + } else { + None + }; + + Ok(Box::new(ConstantLookupExpression::new( + return_type, + arms, + fallback, + operand, + ))) +} + #[build_function("case(...) -> any", type_infer = "panic")] fn build_case_expr( return_type: DataType, @@ -132,6 +259,7 @@ fn build_case_expr( } else { None }; + Ok(Box::new(CaseExpression::new( return_type, when_clauses, diff --git a/src/frontend/planner_test/tests/testdata/input/case_when_optimization.yaml b/src/frontend/planner_test/tests/testdata/input/case_when_optimization.yaml new file mode 100644 index 000000000000..d37f3395e02d --- /dev/null +++ b/src/frontend/planner_test/tests/testdata/input/case_when_optimization.yaml @@ -0,0 +1,117 @@ +- id: create_table + sql: | + CREATE TABLE t1 (c1 INT, c2 INT, c3 INT); + expected_outputs: [] +- id: basic_optimization_pattern + before: + - create_table + sql: | + SELECT + CASE c1 + WHEN 1 THEN 'one' + WHEN 2 THEN 'two' + WHEN 3 THEN 'three' + WHEN 4 THEN 'four' + WHEN 5 THEN 'five' + WHEN 6 THEN 'six' + WHEN 7 THEN 'seven' + WHEN 8 THEN 'eight' + WHEN 9 THEN 'nine' + WHEN 10 THEN 'ten' + WHEN 11 THEN 'eleven' + WHEN 12 THEN 'twelve' + WHEN 13 THEN 'thirteen' + WHEN 14 THEN 'fourteen' + WHEN 15 THEN 'fifteen' + WHEN 16 THEN 'sixteen' + WHEN 17 THEN 'seventeen' + WHEN 18 THEN 'eighteen' + WHEN 19 THEN 'nineteen' + WHEN 20 THEN 'twenty' + WHEN 21 THEN 'twenty-one' + WHEN 22 THEN 'twenty-two' + WHEN 23 THEN 'twenty-three' + WHEN 24 THEN 'twenty-four' + WHEN 25 THEN 'twenty-five' + WHEN 26 THEN 'twenty-six' + WHEN 27 THEN 'twenty-seven' + WHEN 28 THEN 'twenty-eight' + WHEN 29 THEN 'twenty-nine' + WHEN 30 THEN 'thirty' + WHEN 31 THEN 'thirty-one' + WHEN 32 THEN 'thirty-two' + WHEN 33 THEN 'thirty-three' + WHEN 34 THEN 'thirty-four' + WHEN 35 THEN 'thirty-five' + WHEN 36 THEN 'thirty-six' + WHEN 37 THEN 'thirty-seven' + WHEN 38 THEN 'thirty-eight' + WHEN 39 THEN 'thirty-nine' + WHEN 40 THEN 'forty' + WHEN 41 THEN 'forty-one' + WHEN 42 THEN 'forty-two' + WHEN 43 THEN 'forty-three' + WHEN 44 THEN 'forty-four' + WHEN 45 THEN 'forty-five' + WHEN 46 THEN 'forty-six' + WHEN 47 THEN 'forty-seven' + WHEN 48 THEN 'forty-eight' + WHEN 49 THEN 'forty-nine' + WHEN 50 THEN 'fifty' + WHEN 51 THEN 'fifty-one' + WHEN 52 THEN 'fifty-two' + WHEN 53 THEN 'fifty-three' + WHEN 54 THEN 'fifty-four' + WHEN 55 THEN 'fifty-five' + WHEN 56 THEN 'fifty-six' + WHEN 57 THEN 'fifty-seven' + WHEN 58 THEN 'fifty-eight' + WHEN 59 THEN 'fifty-nine' + WHEN 60 THEN 'sixty' + WHEN 61 THEN 'sixty-one' + WHEN 62 THEN 'sixty-two' + WHEN 63 THEN 'sixty-three' + WHEN 64 THEN 'sixty-four' + WHEN 65 THEN 'sixty-five' + WHEN 66 THEN 'sixty-six' + WHEN 67 THEN 'sixty-seven' + WHEN 68 THEN 'sixty-eight' + WHEN 69 THEN 'sixty-nine' + WHEN 70 THEN 'seventy' + WHEN 71 THEN 'seventy-one' + WHEN 72 THEN 'seventy-two' + WHEN 73 THEN 'seventy-three' + WHEN 74 THEN 'seventy-four' + WHEN 75 THEN 'seventy-five' + WHEN 76 THEN 'seventy-six' + WHEN 77 THEN 'seventy-seven' + WHEN 78 THEN 'seventy-eight' + WHEN 79 THEN 'seventy-nine' + WHEN 80 THEN 'eighty' + WHEN 81 THEN 'eighty-one' + WHEN 82 THEN 'eighty-two' + WHEN 83 THEN 'eighty-three' + WHEN 84 THEN 'eighty-four' + WHEN 85 THEN 'eighty-five' + WHEN 86 THEN 'eighty-six' + WHEN 87 THEN 'eighty-seven' + WHEN 88 THEN 'eighty-eight' + WHEN 89 THEN 'eighty-nine' + WHEN 90 THEN 'ninety' + WHEN 91 THEN 'ninety-one' + WHEN 92 THEN 'ninety-two' + WHEN 93 THEN 'ninety-three' + WHEN 94 THEN 'ninety-four' + WHEN 95 THEN 'ninety-five' + WHEN 96 THEN 'ninety-six' + WHEN 97 THEN 'ninety-seven' + WHEN 98 THEN 'ninety-eight' + WHEN 99 THEN 'ninety-nine' + WHEN 100 THEN 'one hundred' + ELSE + '114514' + END + FROM t1; + expected_outputs: + - logical_plan + - batch_plan \ No newline at end of file diff --git a/src/frontend/planner_test/tests/testdata/output/case_when_optimization.yaml b/src/frontend/planner_test/tests/testdata/output/case_when_optimization.yaml new file mode 100644 index 000000000000..82f0ee33aaeb --- /dev/null +++ b/src/frontend/planner_test/tests/testdata/output/case_when_optimization.yaml @@ -0,0 +1,123 @@ +# This file is automatically generated. See `src/frontend/planner_test/README.md` for more information. +- id: create_table + sql: | + CREATE TABLE t1 (c1 INT, c2 INT, c3 INT); +- id: basic_optimization_pattern + before: + - create_table + sql: | + SELECT + CASE c1 + WHEN 1 THEN 'one' + WHEN 2 THEN 'two' + WHEN 3 THEN 'three' + WHEN 4 THEN 'four' + WHEN 5 THEN 'five' + WHEN 6 THEN 'six' + WHEN 7 THEN 'seven' + WHEN 8 THEN 'eight' + WHEN 9 THEN 'nine' + WHEN 10 THEN 'ten' + WHEN 11 THEN 'eleven' + WHEN 12 THEN 'twelve' + WHEN 13 THEN 'thirteen' + WHEN 14 THEN 'fourteen' + WHEN 15 THEN 'fifteen' + WHEN 16 THEN 'sixteen' + WHEN 17 THEN 'seventeen' + WHEN 18 THEN 'eighteen' + WHEN 19 THEN 'nineteen' + WHEN 20 THEN 'twenty' + WHEN 21 THEN 'twenty-one' + WHEN 22 THEN 'twenty-two' + WHEN 23 THEN 'twenty-three' + WHEN 24 THEN 'twenty-four' + WHEN 25 THEN 'twenty-five' + WHEN 26 THEN 'twenty-six' + WHEN 27 THEN 'twenty-seven' + WHEN 28 THEN 'twenty-eight' + WHEN 29 THEN 'twenty-nine' + WHEN 30 THEN 'thirty' + WHEN 31 THEN 'thirty-one' + WHEN 32 THEN 'thirty-two' + WHEN 33 THEN 'thirty-three' + WHEN 34 THEN 'thirty-four' + WHEN 35 THEN 'thirty-five' + WHEN 36 THEN 'thirty-six' + WHEN 37 THEN 'thirty-seven' + WHEN 38 THEN 'thirty-eight' + WHEN 39 THEN 'thirty-nine' + WHEN 40 THEN 'forty' + WHEN 41 THEN 'forty-one' + WHEN 42 THEN 'forty-two' + WHEN 43 THEN 'forty-three' + WHEN 44 THEN 'forty-four' + WHEN 45 THEN 'forty-five' + WHEN 46 THEN 'forty-six' + WHEN 47 THEN 'forty-seven' + WHEN 48 THEN 'forty-eight' + WHEN 49 THEN 'forty-nine' + WHEN 50 THEN 'fifty' + WHEN 51 THEN 'fifty-one' + WHEN 52 THEN 'fifty-two' + WHEN 53 THEN 'fifty-three' + WHEN 54 THEN 'fifty-four' + WHEN 55 THEN 'fifty-five' + WHEN 56 THEN 'fifty-six' + WHEN 57 THEN 'fifty-seven' + WHEN 58 THEN 'fifty-eight' + WHEN 59 THEN 'fifty-nine' + WHEN 60 THEN 'sixty' + WHEN 61 THEN 'sixty-one' + WHEN 62 THEN 'sixty-two' + WHEN 63 THEN 'sixty-three' + WHEN 64 THEN 'sixty-four' + WHEN 65 THEN 'sixty-five' + WHEN 66 THEN 'sixty-six' + WHEN 67 THEN 'sixty-seven' + WHEN 68 THEN 'sixty-eight' + WHEN 69 THEN 'sixty-nine' + WHEN 70 THEN 'seventy' + WHEN 71 THEN 'seventy-one' + WHEN 72 THEN 'seventy-two' + WHEN 73 THEN 'seventy-three' + WHEN 74 THEN 'seventy-four' + WHEN 75 THEN 'seventy-five' + WHEN 76 THEN 'seventy-six' + WHEN 77 THEN 'seventy-seven' + WHEN 78 THEN 'seventy-eight' + WHEN 79 THEN 'seventy-nine' + WHEN 80 THEN 'eighty' + WHEN 81 THEN 'eighty-one' + WHEN 82 THEN 'eighty-two' + WHEN 83 THEN 'eighty-three' + WHEN 84 THEN 'eighty-four' + WHEN 85 THEN 'eighty-five' + WHEN 86 THEN 'eighty-six' + WHEN 87 THEN 'eighty-seven' + WHEN 88 THEN 'eighty-eight' + WHEN 89 THEN 'eighty-nine' + WHEN 90 THEN 'ninety' + WHEN 91 THEN 'ninety-one' + WHEN 92 THEN 'ninety-two' + WHEN 93 THEN 'ninety-three' + WHEN 94 THEN 'ninety-four' + WHEN 95 THEN 'ninety-five' + WHEN 96 THEN 'ninety-six' + WHEN 97 THEN 'ninety-seven' + WHEN 98 THEN 'ninety-eight' + WHEN 99 THEN 'ninety-nine' + WHEN 100 THEN 'one hundred' + ELSE + '114514' + END + FROM t1; + logical_plan: |- + LogicalProject + ├─exprs: ConstantLookup(t1.c1, 1:Int32, 'one':Varchar, 2:Int32, 'two':Varchar, 3:Int32, 'three':Varchar, 4:Int32, 'four':Varchar, 5:Int32, 'five':Varchar, 6:Int32, 'six':Varchar, 7:Int32, 'seven':Varchar, 8:Int32, 'eight':Varchar, 9:Int32, 'nine':Varchar, 10:Int32, 'ten':Varchar, 11:Int32, 'eleven':Varchar, 12:Int32, 'twelve':Varchar, 13:Int32, 'thirteen':Varchar, 14:Int32, 'fourteen':Varchar, 15:Int32, 'fifteen':Varchar, 16:Int32, 'sixteen':Varchar, 17:Int32, 'seventeen':Varchar, 18:Int32, 'eighteen':Varchar, 19:Int32, 'nineteen':Varchar, 20:Int32, 'twenty':Varchar, 21:Int32, 'twenty-one':Varchar, 22:Int32, 'twenty-two':Varchar, 23:Int32, 'twenty-three':Varchar, 24:Int32, 'twenty-four':Varchar, 25:Int32, 'twenty-five':Varchar, 26:Int32, 'twenty-six':Varchar, 27:Int32, 'twenty-seven':Varchar, 28:Int32, 'twenty-eight':Varchar, 29:Int32, 'twenty-nine':Varchar, 30:Int32, 'thirty':Varchar, 31:Int32, 'thirty-one':Varchar, 32:Int32, 'thirty-two':Varchar, 33:Int32, 'thirty-three':Varchar, 34:Int32, 'thirty-four':Varchar, 35:Int32, 'thirty-five':Varchar, 36:Int32, 'thirty-six':Varchar, 37:Int32, 'thirty-seven':Varchar, 38:Int32, 'thirty-eight':Varchar, 39:Int32, 'thirty-nine':Varchar, 40:Int32, 'forty':Varchar, 41:Int32, 'forty-one':Varchar, 42:Int32, 'forty-two':Varchar, 43:Int32, 'forty-three':Varchar, 44:Int32, 'forty-four':Varchar, 45:Int32, 'forty-five':Varchar, 46:Int32, 'forty-six':Varchar, 47:Int32, 'forty-seven':Varchar, 48:Int32, 'forty-eight':Varchar, 49:Int32, 'forty-nine':Varchar, 50:Int32, 'fifty':Varchar, 51:Int32, 'fifty-one':Varchar, 52:Int32, 'fifty-two':Varchar, 53:Int32, 'fifty-three':Varchar, 54:Int32, 'fifty-four':Varchar, 55:Int32, 'fifty-five':Varchar, 56:Int32, 'fifty-six':Varchar, 57:Int32, 'fifty-seven':Varchar, 58:Int32, 'fifty-eight':Varchar, 59:Int32, 'fifty-nine':Varchar, 60:Int32, 'sixty':Varchar, 61:Int32, 'sixty-one':Varchar, 62:Int32, 'sixty-two':Varchar, 63:Int32, 'sixty-three':Varchar, 64:Int32, 'sixty-four':Varchar, 65:Int32, 'sixty-five':Varchar, 66:Int32, 'sixty-six':Varchar, 67:Int32, 'sixty-seven':Varchar, 68:Int32, 'sixty-eight':Varchar, 69:Int32, 'sixty-nine':Varchar, 70:Int32, 'seventy':Varchar, 71:Int32, 'seventy-one':Varchar, 72:Int32, 'seventy-two':Varchar, 73:Int32, 'seventy-three':Varchar, 74:Int32, 'seventy-four':Varchar, 75:Int32, 'seventy-five':Varchar, 76:Int32, 'seventy-six':Varchar, 77:Int32, 'seventy-seven':Varchar, 78:Int32, 'seventy-eight':Varchar, 79:Int32, 'seventy-nine':Varchar, 80:Int32, 'eighty':Varchar, 81:Int32, 'eighty-one':Varchar, 82:Int32, 'eighty-two':Varchar, 83:Int32, 'eighty-three':Varchar, 84:Int32, 'eighty-four':Varchar, 85:Int32, 'eighty-five':Varchar, 86:Int32, 'eighty-six':Varchar, 87:Int32, 'eighty-seven':Varchar, 88:Int32, 'eighty-eight':Varchar, 89:Int32, 'eighty-nine':Varchar, 90:Int32, 'ninety':Varchar, 91:Int32, 'ninety-one':Varchar, 92:Int32, 'ninety-two':Varchar, 93:Int32, 'ninety-three':Varchar, 94:Int32, 'ninety-four':Varchar, 95:Int32, 'ninety-five':Varchar, 96:Int32, 'ninety-six':Varchar, 97:Int32, 'ninety-seven':Varchar, 98:Int32, 'ninety-eight':Varchar, 99:Int32, 'ninety-nine':Varchar, 100:Int32, 'one hundred':Varchar, '114514':Varchar) as $expr1 + └─LogicalScan { table: t1, columns: [t1.c1, t1.c2, t1.c3, t1._row_id] } + batch_plan: |- + BatchExchange { order: [], dist: Single } + └─BatchProject + ├─exprs: ConstantLookup(t1.c1, 1:Int32, 'one':Varchar, 2:Int32, 'two':Varchar, 3:Int32, 'three':Varchar, 4:Int32, 'four':Varchar, 5:Int32, 'five':Varchar, 6:Int32, 'six':Varchar, 7:Int32, 'seven':Varchar, 8:Int32, 'eight':Varchar, 9:Int32, 'nine':Varchar, 10:Int32, 'ten':Varchar, 11:Int32, 'eleven':Varchar, 12:Int32, 'twelve':Varchar, 13:Int32, 'thirteen':Varchar, 14:Int32, 'fourteen':Varchar, 15:Int32, 'fifteen':Varchar, 16:Int32, 'sixteen':Varchar, 17:Int32, 'seventeen':Varchar, 18:Int32, 'eighteen':Varchar, 19:Int32, 'nineteen':Varchar, 20:Int32, 'twenty':Varchar, 21:Int32, 'twenty-one':Varchar, 22:Int32, 'twenty-two':Varchar, 23:Int32, 'twenty-three':Varchar, 24:Int32, 'twenty-four':Varchar, 25:Int32, 'twenty-five':Varchar, 26:Int32, 'twenty-six':Varchar, 27:Int32, 'twenty-seven':Varchar, 28:Int32, 'twenty-eight':Varchar, 29:Int32, 'twenty-nine':Varchar, 30:Int32, 'thirty':Varchar, 31:Int32, 'thirty-one':Varchar, 32:Int32, 'thirty-two':Varchar, 33:Int32, 'thirty-three':Varchar, 34:Int32, 'thirty-four':Varchar, 35:Int32, 'thirty-five':Varchar, 36:Int32, 'thirty-six':Varchar, 37:Int32, 'thirty-seven':Varchar, 38:Int32, 'thirty-eight':Varchar, 39:Int32, 'thirty-nine':Varchar, 40:Int32, 'forty':Varchar, 41:Int32, 'forty-one':Varchar, 42:Int32, 'forty-two':Varchar, 43:Int32, 'forty-three':Varchar, 44:Int32, 'forty-four':Varchar, 45:Int32, 'forty-five':Varchar, 46:Int32, 'forty-six':Varchar, 47:Int32, 'forty-seven':Varchar, 48:Int32, 'forty-eight':Varchar, 49:Int32, 'forty-nine':Varchar, 50:Int32, 'fifty':Varchar, 51:Int32, 'fifty-one':Varchar, 52:Int32, 'fifty-two':Varchar, 53:Int32, 'fifty-three':Varchar, 54:Int32, 'fifty-four':Varchar, 55:Int32, 'fifty-five':Varchar, 56:Int32, 'fifty-six':Varchar, 57:Int32, 'fifty-seven':Varchar, 58:Int32, 'fifty-eight':Varchar, 59:Int32, 'fifty-nine':Varchar, 60:Int32, 'sixty':Varchar, 61:Int32, 'sixty-one':Varchar, 62:Int32, 'sixty-two':Varchar, 63:Int32, 'sixty-three':Varchar, 64:Int32, 'sixty-four':Varchar, 65:Int32, 'sixty-five':Varchar, 66:Int32, 'sixty-six':Varchar, 67:Int32, 'sixty-seven':Varchar, 68:Int32, 'sixty-eight':Varchar, 69:Int32, 'sixty-nine':Varchar, 70:Int32, 'seventy':Varchar, 71:Int32, 'seventy-one':Varchar, 72:Int32, 'seventy-two':Varchar, 73:Int32, 'seventy-three':Varchar, 74:Int32, 'seventy-four':Varchar, 75:Int32, 'seventy-five':Varchar, 76:Int32, 'seventy-six':Varchar, 77:Int32, 'seventy-seven':Varchar, 78:Int32, 'seventy-eight':Varchar, 79:Int32, 'seventy-nine':Varchar, 80:Int32, 'eighty':Varchar, 81:Int32, 'eighty-one':Varchar, 82:Int32, 'eighty-two':Varchar, 83:Int32, 'eighty-three':Varchar, 84:Int32, 'eighty-four':Varchar, 85:Int32, 'eighty-five':Varchar, 86:Int32, 'eighty-six':Varchar, 87:Int32, 'eighty-seven':Varchar, 88:Int32, 'eighty-eight':Varchar, 89:Int32, 'eighty-nine':Varchar, 90:Int32, 'ninety':Varchar, 91:Int32, 'ninety-one':Varchar, 92:Int32, 'ninety-two':Varchar, 93:Int32, 'ninety-three':Varchar, 94:Int32, 'ninety-four':Varchar, 95:Int32, 'ninety-five':Varchar, 96:Int32, 'ninety-six':Varchar, 97:Int32, 'ninety-seven':Varchar, 98:Int32, 'ninety-eight':Varchar, 99:Int32, 'ninety-nine':Varchar, 100:Int32, 'one hundred':Varchar, '114514':Varchar) as $expr1 + └─BatchScan { table: t1, columns: [t1.c1], distribution: SomeShard } diff --git a/src/frontend/src/binder/expr/mod.rs b/src/frontend/src/binder/expr/mod.rs index 93356cd1efec..7da1518ceb55 100644 --- a/src/frontend/src/binder/expr/mod.rs +++ b/src/frontend/src/binder/expr/mod.rs @@ -35,6 +35,13 @@ mod order_by; mod subquery; mod value; +/// The limit arms for case-when expression +/// When the number of condition arms exceed +/// this limit, we will try optimize the case-when +/// expression to `ConstantLookupExpression` +/// Check `case.rs` for details. +const CASE_WHEN_ARMS_OPTIMIZE_LIMIT: usize = 30; + impl Binder { /// Bind an expression with `bind_expr_inner`, attach the original expression /// to the error message. @@ -462,6 +469,56 @@ impl Binder { Ok(func_call.into()) } + /// The helper function to check if the current case-when + /// expression in `bind_case` could be optimized + /// into `ConstantLookupExpression` + fn check_bind_case_optimization( + &mut self, + conditions: Vec, + results_expr: Vec, + operand: Option>, + fallback: Option, + constant_lookup_inputs: &mut Vec, + ) -> bool { + if conditions.len() < CASE_WHEN_ARMS_OPTIMIZE_LIMIT { + return false; + } + + // TODO(Zihao): we could possibly optimize some simple cases when + // `operand` is None in the future, the current choice is not conducting the optimization. + // e.g., select case when c1 = 1 then (...) when (same pattern) then (... ) [else (...)] end from t1; + if let Some(operand) = operand { + let Ok(operand) = self.bind_expr_inner(*operand) else { + return false; + }; + constant_lookup_inputs.push(operand); + } else { + return false; + } + + for (condition, result) in zip_eq_fast(conditions, results_expr) { + if let Expr::Value(_) = condition.clone() { + let Ok(input) = self.bind_expr_inner(condition.clone()) else { + return false; + }; + constant_lookup_inputs.push(input); + } else { + // If at least one condition is not in the simple form / not constant, + // we can NOT do the subsequent optimization then + return false; + } + + constant_lookup_inputs.push(result); + } + + // The fallback arm for case-when expression + if let Some(expr) = fallback { + constant_lookup_inputs.push(expr); + } + + true + } + pub(super) fn bind_case( &mut self, operand: Option>, @@ -478,6 +535,21 @@ impl Binder { .map(|expr| self.bind_expr_inner(*expr)) .transpose()?; + let mut constant_lookup_inputs = Vec::new(); + + // See if the case-when expression can be optimized + let optimize_flag = self.check_bind_case_optimization( + conditions.clone(), + results_expr.clone(), + operand.clone(), + else_result_expr.clone(), + &mut constant_lookup_inputs, + ); + + if optimize_flag { + return Ok(FunctionCall::new(ExprType::ConstantLookup, constant_lookup_inputs)?.into()); + } + for (condition, result) in zip_eq_fast(conditions, results_expr) { let condition = match operand { Some(ref t) => Expr::BinaryOp { @@ -493,14 +565,18 @@ impl Binder { ); inputs.push(result); } + + // The fallback arm for case-when expression if let Some(expr) = else_result_expr { inputs.push(expr); } + if inputs.iter().any(ExprImpl::has_table_function) { return Err( ErrorCode::BindError("table functions are not allowed in CASE".into()).into(), ); } + Ok(FunctionCall::new(ExprType::Case, inputs)?.into()) } diff --git a/src/frontend/src/expr/pure.rs b/src/frontend/src/expr/pure.rs index 7e7378a65752..b87c6d18feb1 100644 --- a/src/frontend/src/expr/pure.rs +++ b/src/frontend/src/expr/pure.rs @@ -82,6 +82,7 @@ impl ExprVisitor for ImpureAnalyzer { | expr_node::Type::Ltrim | expr_node::Type::Rtrim | expr_node::Type::Case + | expr_node::Type::ConstantLookup | expr_node::Type::RoundDigit | expr_node::Type::Round | expr_node::Type::Ascii diff --git a/src/frontend/src/expr/type_inference/func.rs b/src/frontend/src/expr/type_inference/func.rs index 1ccbaa28e9da..337c901a47ed 100644 --- a/src/frontend/src/expr/type_inference/func.rs +++ b/src/frontend/src/expr/type_inference/func.rs @@ -327,6 +327,21 @@ fn infer_type_for_special( .map(Some) .map_err(Into::into) } + ExprType::ConstantLookup => { + let len = inputs.len(); + align_types(inputs.iter_mut().enumerate().filter_map(|(i, e)| { + // This optimized `ConstantLookup` organize `inputs` as + // [dummy_expression] (cond, res) [else / fallback]? pairs. + // So we align exprs at even indices as well as the last one + // when length is odd. + match i != 0 && i.is_even() || i == len - 1 { + true => Some(e), + false => None, + } + })) + .map(Some) + .map_err(Into::into) + } ExprType::In => { align_types(inputs.iter_mut())?; Ok(Some(DataType::Boolean))