diff --git a/src/frontend/planner_test/tests/testdata/input/case_when_optimization.yaml b/src/frontend/planner_test/tests/testdata/input/case_when_optimization.yaml index 272e1ded6e364..a0183ba7907ba 100644 --- a/src/frontend/planner_test/tests/testdata/input/case_when_optimization.yaml +++ b/src/frontend/planner_test/tests/testdata/input/case_when_optimization.yaml @@ -249,4 +249,47 @@ end; expected_outputs: - logical_plan - - batch_plan \ No newline at end of file + - batch_plan +- id: case_when_optimizable_pattern_basic + before: + - create_table + sql: | + select + case + when c1 = 1 then 'one' + when c1 = 2 then 'two' + when c1 = 3 then 'three' + when c1 = 4 then 'four' + when c1 = 5 then 'five' + when c1 = 6 then 'six' + when c1 = 7 then 'seven' + when c1 = 8 then 'eight' + when c1 = 9 then 'nine' + when c1 = 10 then 'ten' + when c1 = 11 then 'eleven' + when c1 = 12 then 'twelve' + when c1 = 13 then 'thirteen' + when c1 = 14 then 'fourteen' + when c1 = 15 then 'fifteen' + when c1 = 16 then 'sixteen' + when c1 = 17 then 'seventeen' + when c1 = 18 then 'eighteen' + when c1 = 19 then 'nineteen' + when c1 = 20 then 'twenty' + when c1 = 21 then 'twenty-one' + when c1 = 22 then 'twenty-two' + when c1 = 23 then 'twenty-three' + when c1 = 24 then 'twenty-four' + when c1 = 25 then 'twenty-five' + when c1 = 26 then 'twenty-six' + when c1 = 27 then 'twenty-seven' + when c1 = 28 then 'twenty-eight' + when c1 = 29 then 'twenty-nine' + when c1 = 30 then 'thirty' + when c1 = 31 then 'thirty-one' + else 'other' + end + from t1; + expected_outputs: + - logical_plan + - batch_plan diff --git a/src/frontend/planner_test/tests/testdata/output/case_when_optimization.yaml b/src/frontend/planner_test/tests/testdata/output/case_when_optimization.yaml index 60d4c56f99517..14156d6f5b494 100644 --- a/src/frontend/planner_test/tests/testdata/output/case_when_optimization.yaml +++ b/src/frontend/planner_test/tests/testdata/output/case_when_optimization.yaml @@ -274,3 +274,50 @@ LogicalProject { exprs: [1919810:Int32] } └─LogicalValues { rows: [[]], schema: Schema { fields: [] } } batch_plan: 'BatchValues { rows: [[1919810:Int32]] }' +- id: case_when_optimizable_pattern_basic + before: + - create_table + sql: | + select + case + when c1 = 1 then 'one' + when c1 = 2 then 'two' + when c1 = 3 then 'three' + when c1 = 4 then 'four' + when c1 = 5 then 'five' + when c1 = 6 then 'six' + when c1 = 7 then 'seven' + when c1 = 8 then 'eight' + when c1 = 9 then 'nine' + when c1 = 10 then 'ten' + when c1 = 11 then 'eleven' + when c1 = 12 then 'twelve' + when c1 = 13 then 'thirteen' + when c1 = 14 then 'fourteen' + when c1 = 15 then 'fifteen' + when c1 = 16 then 'sixteen' + when c1 = 17 then 'seventeen' + when c1 = 18 then 'eighteen' + when c1 = 19 then 'nineteen' + when c1 = 20 then 'twenty' + when c1 = 21 then 'twenty-one' + when c1 = 22 then 'twenty-two' + when c1 = 23 then 'twenty-three' + when c1 = 24 then 'twenty-four' + when c1 = 25 then 'twenty-five' + when c1 = 26 then 'twenty-six' + when c1 = 27 then 'twenty-seven' + when c1 = 28 then 'twenty-eight' + when c1 = 29 then 'twenty-nine' + when c1 = 30 then 'thirty' + when c1 = 31 then 'thirty-one' + else 'other' + end + from t1; + logical_plan: |- + LogicalProject { exprs: [ConstantLookup(t1.c1, 1:Int32, 'one':Varchar, 2:Int32, 'two':Varchar, 3:Int32, 'three':Varchar, 4:Int32, 'four':Varchar, 5:Int32, 'five':Varchar, 6:Int32, 'six':Varchar, 7:Int32, 'seven':Varchar, 8:Int32, 'eight':Varchar, 9:Int32, 'nine':Varchar, 10:Int32, 'ten':Varchar, 11:Int32, 'eleven':Varchar, 12:Int32, 'twelve':Varchar, 13:Int32, 'thirteen':Varchar, 14:Int32, 'fourteen':Varchar, 15:Int32, 'fifteen':Varchar, 16:Int32, 'sixteen':Varchar, 17:Int32, 'seventeen':Varchar, 18:Int32, 'eighteen':Varchar, 19:Int32, 'nineteen':Varchar, 20:Int32, 'twenty':Varchar, 21:Int32, 'twenty-one':Varchar, 22:Int32, 'twenty-two':Varchar, 23:Int32, 'twenty-three':Varchar, 24:Int32, 'twenty-four':Varchar, 25:Int32, 'twenty-five':Varchar, 26:Int32, 'twenty-six':Varchar, 27:Int32, 'twenty-seven':Varchar, 28:Int32, 'twenty-eight':Varchar, 29:Int32, 'twenty-nine':Varchar, 30:Int32, 'thirty':Varchar, 31:Int32, 'thirty-one':Varchar, 'other':Varchar) as $expr1] } + └─LogicalScan { table: t1, columns: [t1.c1, t1.c2, t1.c3, t1._row_id] } + batch_plan: |- + BatchExchange { order: [], dist: Single } + └─BatchProject { exprs: [ConstantLookup(t1.c1, 1:Int32, 'one':Varchar, 2:Int32, 'two':Varchar, 3:Int32, 'three':Varchar, 4:Int32, 'four':Varchar, 5:Int32, 'five':Varchar, 6:Int32, 'six':Varchar, 7:Int32, 'seven':Varchar, 8:Int32, 'eight':Varchar, 9:Int32, 'nine':Varchar, 10:Int32, 'ten':Varchar, 11:Int32, 'eleven':Varchar, 12:Int32, 'twelve':Varchar, 13:Int32, 'thirteen':Varchar, 14:Int32, 'fourteen':Varchar, 15:Int32, 'fifteen':Varchar, 16:Int32, 'sixteen':Varchar, 17:Int32, 'seventeen':Varchar, 18:Int32, 'eighteen':Varchar, 19:Int32, 'nineteen':Varchar, 20:Int32, 'twenty':Varchar, 21:Int32, 'twenty-one':Varchar, 22:Int32, 'twenty-two':Varchar, 23:Int32, 'twenty-three':Varchar, 24:Int32, 'twenty-four':Varchar, 25:Int32, 'twenty-five':Varchar, 26:Int32, 'twenty-six':Varchar, 27:Int32, 'twenty-seven':Varchar, 28:Int32, 'twenty-eight':Varchar, 29:Int32, 'twenty-nine':Varchar, 30:Int32, 'thirty':Varchar, 31:Int32, 'thirty-one':Varchar, 'other':Varchar) as $expr1] } + └─BatchScan { table: t1, columns: [t1.c1], distribution: SomeShard } diff --git a/src/frontend/src/binder/expr/mod.rs b/src/frontend/src/binder/expr/mod.rs index 382c528de1c75..8e45170f36f2d 100644 --- a/src/frontend/src/binder/expr/mod.rs +++ b/src/frontend/src/binder/expr/mod.rs @@ -532,6 +532,127 @@ impl Binder { true } + /// Helper function to compare or set column identifier + /// used in `check_convert_simple_form` + fn compare_or_set(col_expr: &mut Option, test_expr: Expr) -> bool { + let Expr::Identifier(test_ident) = test_expr else { + return false; + }; + if let Some(expr) = col_expr { + let Expr::Identifier(ident) = expr else { + return false; + }; + if ident.real_value() != test_ident.real_value() { + return false; + } + } else { + *col_expr = Some(Expr::Identifier(test_ident)); + } + true + } + + /// left expression and right expression must be either: + /// ` ` or ` ` + /// used in `check_convert_simple_form` + fn check_invariant(left: Expr, op: BinaryOperator, right: Expr) -> bool { + if op != BinaryOperator::Eq { + return false; + } + if let Expr::Identifier(_) = left { + // + let Expr::Value(_) = right else { + return false; + }; + } else { + // + let Expr::Value(_) = left else { + return false; + }; + let Expr::Identifier(_) = right else { + return false; + }; + } + true + } + + /// Helper function to extract expression out and insert + /// the corresponding bound version to `inputs` + /// used in `check_convert_simple_form` + /// Note: this function will be invoked per arm + fn try_extract_simple_form( + &mut self, + ident_expr: Expr, + constant_expr: Expr, + column_expr: &mut Option, + inputs: &mut Vec, + ) -> bool { + if !Self::compare_or_set(column_expr, ident_expr) { + return false; + } + let Ok(bound_expr) = self.bind_expr_inner(constant_expr) else { + return false; + }; + inputs.push(bound_expr); + true + } + + /// See if the case when expression in form + /// `select case when (...with same pattern...) else end;` + /// If so, this expression could also be converted to constant lookup + fn check_convert_simple_form( + &mut self, + conditions: Vec, + results_expr: Vec, + fallback: Option, + constant_lookup_inputs: &mut Vec, + ) -> bool { + let mut column_expr = None; + + for (condition, result) in zip_eq_fast(conditions, results_expr) { + if let Expr::BinaryOp { left, op, right } = condition { + if !Self::check_invariant(*(left.clone()), op.clone(), *(right.clone())) { + return false; + } + if let Expr::Identifier(_) = *(left.clone()) { + if !self.try_extract_simple_form( + *left, + *right, + &mut column_expr, + constant_lookup_inputs, + ) { + return false; + } + } else if !self.try_extract_simple_form( + *right, + *left, + &mut column_expr, + constant_lookup_inputs, + ) { + return false; + } + constant_lookup_inputs.push(result); + } else { + return false; + } + } + + // Insert operand first + let Some(operand) = column_expr else { + return false; + }; + let Ok(bound_operand) = self.bind_expr_inner(operand) else { + return false; + }; + constant_lookup_inputs.insert(0, bound_operand); + + // fallback insertion + if let Some(expr) = fallback { + constant_lookup_inputs.push(expr); + } + + true + } + /// The helper function to check if the current case-when /// expression in `bind_case` could be optimized /// into `ConstantLookupExpression` @@ -547,9 +668,6 @@ impl Binder { return false; } - // TODO(Zihao): we could possibly optimize some simple cases when - // `operand` is None in the future, the current choice is not conducting the optimization. - // e.g., select case when c1 = 1 then (...) when (same pattern) then (... ) [else (...)] end from t1; if let Some(operand) = operand { let Ok(operand) = self.bind_expr_inner(*operand) else { return false; @@ -562,7 +680,14 @@ impl Binder { } constant_lookup_inputs.push(operand); } else { - return false; + // Try converting to simple form + // see the example as illustrated in `check_convert_simple_form` + return self.check_convert_simple_form( + conditions, + results_expr, + fallback, + constant_lookup_inputs, + ); } for (condition, result) in zip_eq_fast(conditions, results_expr) {