From 33f49b8e9ba4a222c6050df3e3bc197750994f97 Mon Sep 17 00:00:00 2001 From: Michael Xu Date: Fri, 2 Feb 2024 18:33:23 -0500 Subject: [PATCH] add constant case when evaluation pass --- proto/expr.proto | 3 + src/frontend/src/binder/expr/mod.rs | 68 +++++++++++++++++++ src/frontend/src/expr/pure.rs | 1 + .../src/optimizer/logical_optimization.rs | 12 ++++ .../const_case_when_rewriter.rs | 13 +++- .../src/optimizer/plan_expr_rewriter/mod.rs | 2 + .../rule/const_case_when_eval_rule.rs | 5 ++ src/frontend/src/optimizer/rule/mod.rs | 3 + 8 files changed, 106 insertions(+), 1 deletion(-) diff --git a/proto/expr.proto b/proto/expr.proto index e26b725b49737..e991e7f017a74 100644 --- a/proto/expr.proto +++ b/proto/expr.proto @@ -88,6 +88,9 @@ message ExprNode { // Optimize case-when expression to constant lookup // when arms are in a large scale with simple form CONSTANT_LOOKUP = 624; + // Optimized in frontend + // should be invisible to backend + CONSTANT_CASE_WHEN_EVAL = 625; // ROUND(numeric, integer) -> numeric ROUND_DIGIT = 213; // ROUND(numeric) -> numeric diff --git a/src/frontend/src/binder/expr/mod.rs b/src/frontend/src/binder/expr/mod.rs index cfb28dee2487d..2de6db034eec4 100644 --- a/src/frontend/src/binder/expr/mod.rs +++ b/src/frontend/src/binder/expr/mod.rs @@ -469,6 +469,55 @@ impl Binder { Ok(func_call.into()) } + fn check_constant_case_when_optimization( + &mut self, + conditions: Vec, + results_expr: Vec, + operand: Option>, + fallback: Option, + constant_case_when_eval_inputs: &mut Vec, + ) -> bool { + // The operand value to be compared later + let operand_value; + + if let Some(operand) = operand { + let Ok(operand) = self.bind_expr_inner(*operand) else { + return false; + }; + if !operand.is_const() { + return false; + } + operand_value = operand; + } else { + return false; + } + + for (condition, result) in zip_eq_fast(conditions, results_expr) { + if let Expr::Value(_) = condition.clone() { + let Ok(res) = self.bind_expr_inner(condition.clone()) else { + return false; + }; + // Found a match + if res == operand_value { + constant_case_when_eval_inputs.push(result); + return true; + } + } else { + return false; + } + } + + // Otherwise this will eventually go through fallback arm + debug_assert!(constant_case_when_eval_inputs.is_empty(), "expect `inputs` to be empty"); + + let Some(fallback) = fallback else { + return false; + }; + + constant_case_when_eval_inputs.push(fallback); + true + } + /// The helper function to check if the current case-when /// expression in `bind_case` could be optimized /// into `ConstantLookupExpression` @@ -491,6 +540,12 @@ impl Binder { let Ok(operand) = self.bind_expr_inner(*operand) else { return false; }; + // This optimization should be done in subsequent optimization phase + // if the operand is const + // e.g., select case 1 when 1 then 114514 else 1919810 end; + if operand.is_const() { + return false; + } constant_lookup_inputs.push(operand); } else { return false; @@ -536,6 +591,19 @@ impl Binder { .transpose()?; let mut constant_lookup_inputs = Vec::new(); + let mut constant_case_when_eval_inputs = Vec::new(); + + let constant_case_when_flag = self.check_constant_case_when_optimization( + conditions.clone(), + results_expr.clone(), + operand.clone(), + else_result_expr.clone(), + &mut constant_case_when_eval_inputs, + ); + + if constant_case_when_flag { + return Ok(FunctionCall::new(ExprType::ConstantCaseWhenEval, constant_case_when_eval_inputs)?.into()); + } // See if the case-when expression can be optimized let optimize_flag = self.check_bind_case_optimization( diff --git a/src/frontend/src/expr/pure.rs b/src/frontend/src/expr/pure.rs index c50f1cc2460b8..53209eedfdacb 100644 --- a/src/frontend/src/expr/pure.rs +++ b/src/frontend/src/expr/pure.rs @@ -86,6 +86,7 @@ impl ExprVisitor for ImpureAnalyzer { | expr_node::Type::Rtrim | expr_node::Type::Case | expr_node::Type::ConstantLookup + | expr_node::Type::ConstantCaseWhenEval | expr_node::Type::RoundDigit | expr_node::Type::Round | expr_node::Type::Ascii diff --git a/src/frontend/src/optimizer/logical_optimization.rs b/src/frontend/src/optimizer/logical_optimization.rs index db5dc8ceca7d2..48062d8a52893 100644 --- a/src/frontend/src/optimizer/logical_optimization.rs +++ b/src/frontend/src/optimizer/logical_optimization.rs @@ -415,6 +415,14 @@ static COMMON_SUB_EXPR_EXTRACT: LazyLock = LazyLock::new(|| { ) }); +static CONST_CASE_WHEN_EVAL: LazyLock = LazyLock::new(|| { + OptimizationStage::new( + "Const Case When Evaluation", + vec![ConstCaseWhenEvalRule::create()], + ApplyOrder::TopDown, + ) +}); + impl LogicalOptimizer { pub fn predicate_pushdown( plan: PlanRef, @@ -623,6 +631,8 @@ impl LogicalOptimizer { plan = plan.optimize_by_rules(&COMMON_SUB_EXPR_EXTRACT); + plan = plan.optimize_by_rules(&CONST_CASE_WHEN_EVAL); + #[cfg(debug_assertions)] InputRefValidator.validate(plan.clone()); @@ -711,6 +721,8 @@ impl LogicalOptimizer { plan = plan.optimize_by_rules(&COMMON_SUB_EXPR_EXTRACT); + plan = plan.optimize_by_rules(&CONST_CASE_WHEN_EVAL); + plan = plan.optimize_by_rules(&PULL_UP_HOP); plan = plan.optimize_by_rules(&TOP_N_AGG_ON_INDEX); diff --git a/src/frontend/src/optimizer/plan_expr_rewriter/const_case_when_rewriter.rs b/src/frontend/src/optimizer/plan_expr_rewriter/const_case_when_rewriter.rs index affe23c064000..10c2cc3e1188d 100644 --- a/src/frontend/src/optimizer/plan_expr_rewriter/const_case_when_rewriter.rs +++ b/src/frontend/src/optimizer/plan_expr_rewriter/const_case_when_rewriter.rs @@ -10,4 +10,15 @@ // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and -// limitations under the License. \ No newline at end of file +// limitations under the License. + +use crate::expr::{ExprImpl, ExprRewriter, FunctionCall}; + +pub struct ConstCaseWhenRewriter {} + +impl ExprRewriter for ConstCaseWhenRewriter { + fn rewrite_function_call(&mut self, func_call: FunctionCall) -> ExprImpl { + println!("Current func_call: {:#?}", func_call); + todo!() + } +} \ No newline at end of file diff --git a/src/frontend/src/optimizer/plan_expr_rewriter/mod.rs b/src/frontend/src/optimizer/plan_expr_rewriter/mod.rs index c847ece1ff048..49c214bb56883 100644 --- a/src/frontend/src/optimizer/plan_expr_rewriter/mod.rs +++ b/src/frontend/src/optimizer/plan_expr_rewriter/mod.rs @@ -14,6 +14,8 @@ mod const_eval_rewriter; mod cse_rewriter; +mod const_case_when_rewriter; pub(crate) use const_eval_rewriter::ConstEvalRewriter; pub(crate) use cse_rewriter::CseRewriter; +pub(crate) use const_case_when_rewriter::ConstCaseWhenRewriter; diff --git a/src/frontend/src/optimizer/rule/const_case_when_eval_rule.rs b/src/frontend/src/optimizer/rule/const_case_when_eval_rule.rs index 5c3ae0f42ab5a..60ec95b746d84 100644 --- a/src/frontend/src/optimizer/rule/const_case_when_eval_rule.rs +++ b/src/frontend/src/optimizer/rule/const_case_when_eval_rule.rs @@ -14,10 +14,15 @@ use super::super::plan_node::*; use super::{BoxedRule, Rule}; +use crate::optimizer::plan_expr_rewriter::ConstCaseWhenRewriter; pub struct ConstCaseWhenEvalRule {} impl Rule for ConstCaseWhenEvalRule { fn apply(&self, plan: PlanRef) -> Option { + println!("Current plan: {:#?}", plan); + let values: &LogicalValues = plan.as_logical_values()?; + println!("Current values: {:#?}", values); + let _const_case_when_rewriter = ConstCaseWhenRewriter {}; todo!() } } diff --git a/src/frontend/src/optimizer/rule/mod.rs b/src/frontend/src/optimizer/rule/mod.rs index acde2f7b72eb6..b4cf5a58c6e82 100644 --- a/src/frontend/src/optimizer/rule/mod.rs +++ b/src/frontend/src/optimizer/rule/mod.rs @@ -157,6 +157,8 @@ pub use agg_call_merge_rule::*; mod values_extract_project_rule; pub use batch::batch_push_limit_to_scan_rule::*; pub use values_extract_project_rule::*; +mod const_case_when_eval_rule; +pub use const_case_when_eval_rule::*; #[macro_export] macro_rules! for_all_rules { @@ -227,6 +229,7 @@ macro_rules! for_all_rules { , { AggCallMergeRule } , { ValuesExtractProjectRule } , { BatchPushLimitToScanRule } + , { ConstCaseWhenEvalRule } } }; }