Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(binder): add const case-when evaluation optimization during binding #14965

Merged
merged 15 commits into from
Feb 23, 2024
79 changes: 78 additions & 1 deletion src/frontend/src/binder/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,60 @@ impl Binder {
Ok(func_call.into())
}

/// The optimization check for the following case-when expression pattern
/// e.g., select case 1 when (...) then (...) else (...) end;
fn check_constant_case_when_optimization(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason I create another check function is that we can't check all the optimizable case-when patterns within only one pass, and they indeed have subtle but significant difference, which may not be generalized.

&mut self,
conditions: Vec<Expr>,
results_expr: Vec<ExprImpl>,
operand: Option<Box<Expr>>,
fallback: Option<ExprImpl>,
constant_case_when_eval_inputs: &mut Vec<ExprImpl>,
) -> bool {
// The operand value to be compared later
let operand_value;

if let Some(operand) = operand {
let Ok(operand) = self.bind_expr_inner(*operand) else {
return false;
};
if !operand.is_const() {
return false;
}
operand_value = operand;
} else {
return false;
}

for (condition, result) in zip_eq_fast(conditions, results_expr) {
if let Expr::Value(_) = condition.clone() {
let Ok(res) = self.bind_expr_inner(condition.clone()) else {
return false;
};
// Found a match
if res == operand_value {
constant_case_when_eval_inputs.push(result);
return true;
}
} else {
return false;
}
}

// Otherwise this will eventually go through fallback arm
debug_assert!(
constant_case_when_eval_inputs.is_empty(),
"expect `inputs` to be empty"
);

let Some(fallback) = fallback else {
return false;
};

constant_case_when_eval_inputs.push(fallback);
true
}

/// The helper function to check if the current case-when
/// expression in `bind_case` could be optimized
/// into `ConstantLookupExpression`
Expand All @@ -491,6 +545,12 @@ impl Binder {
let Ok(operand) = self.bind_expr_inner(*operand) else {
return false;
};
// This optimization should be done in subsequent optimization phase
// if the operand is const
// e.g., select case 1 when 1 then 114514 else 1919810 end;
if operand.is_const() {
return false;
}
constant_lookup_inputs.push(operand);
} else {
return false;
Expand All @@ -504,7 +564,7 @@ impl Binder {
constant_lookup_inputs.push(input);
} else {
// If at least one condition is not in the simple form / not constant,
// we can NOT do the subsequent optimization then
// we can NOT do the subsequent optimization pass
return false;
}

Expand Down Expand Up @@ -536,6 +596,23 @@ impl Binder {
.transpose()?;

let mut constant_lookup_inputs = Vec::new();
let mut constant_case_when_eval_inputs = Vec::new();

let constant_case_when_flag = self.check_constant_case_when_optimization(
conditions.clone(),
results_expr.clone(),
operand.clone(),
else_result_expr.clone(),
&mut constant_case_when_eval_inputs,
);

if constant_case_when_flag {
return Ok(FunctionCall::new(
ExprType::ConstantLookup,
constant_case_when_eval_inputs,
)?
.into());
}

// See if the case-when expression can be optimized
let optimize_flag = self.check_bind_case_optimization(
Expand Down
12 changes: 12 additions & 0 deletions src/frontend/src/optimizer/logical_optimization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,14 @@ static COMMON_SUB_EXPR_EXTRACT: LazyLock<OptimizationStage> = LazyLock::new(|| {
)
});

static CONST_CASE_WHEN_EVAL: LazyLock<OptimizationStage> = LazyLock::new(|| {
OptimizationStage::new(
"Const Case When Evaluation",
vec![ConstCaseWhenEvalRule::create()],
ApplyOrder::TopDown,
)
});

impl LogicalOptimizer {
pub fn predicate_pushdown(
plan: PlanRef,
Expand Down Expand Up @@ -623,6 +631,8 @@ impl LogicalOptimizer {

plan = plan.optimize_by_rules(&COMMON_SUB_EXPR_EXTRACT);

plan = plan.optimize_by_rules(&CONST_CASE_WHEN_EVAL);

#[cfg(debug_assertions)]
InputRefValidator.validate(plan.clone());

Expand Down Expand Up @@ -711,6 +721,8 @@ impl LogicalOptimizer {

plan = plan.optimize_by_rules(&COMMON_SUB_EXPR_EXTRACT);

plan = plan.optimize_by_rules(&CONST_CASE_WHEN_EVAL);

plan = plan.optimize_by_rules(&PULL_UP_HOP);

plan = plan.optimize_by_rules(&TOP_N_AGG_ON_INDEX);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright 2024 RisingWave Labs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use risingwave_common::error::RwError;

use crate::expr::{ExprImpl, ExprRewriter, ExprType, FunctionCall};

pub struct ConstCaseWhenRewriter {
pub error: Option<RwError>,
}

impl ExprRewriter for ConstCaseWhenRewriter {
fn rewrite_function_call(&mut self, func_call: FunctionCall) -> ExprImpl {
if func_call.func_type() != ExprType::ConstantLookup {
return func_call.into();
}
if func_call.inputs().len() != 1 {
// Normal constant lookup pass
return func_call.into();
}
func_call.inputs()[0].clone().into()
}
}
2 changes: 2 additions & 0 deletions src/frontend/src/optimizer/plan_expr_rewriter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.

mod const_case_when_rewriter;
mod const_eval_rewriter;
mod cse_rewriter;

pub(crate) use const_case_when_rewriter::ConstCaseWhenRewriter;
pub(crate) use const_eval_rewriter::ConstEvalRewriter;
pub(crate) use cse_rewriter::CseRewriter;
32 changes: 32 additions & 0 deletions src/frontend/src/optimizer/rule/const_case_when_eval_rule.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright 2024 RisingWave Labs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use super::super::plan_node::*;
use super::{BoxedRule, Rule};
use crate::optimizer::plan_expr_rewriter::ConstCaseWhenRewriter;

pub struct ConstCaseWhenEvalRule {}
impl Rule for ConstCaseWhenEvalRule {
fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
let values: &LogicalValues = plan.as_logical_values()?;
let mut const_case_when_rewriter = ConstCaseWhenRewriter { error: None };
Some(values.rewrite_exprs(&mut const_case_when_rewriter))
}
}

impl ConstCaseWhenEvalRule {
pub fn create() -> BoxedRule {
Box::new(ConstCaseWhenEvalRule {})
}
}
3 changes: 3 additions & 0 deletions src/frontend/src/optimizer/rule/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ pub use agg_call_merge_rule::*;
mod values_extract_project_rule;
pub use batch::batch_push_limit_to_scan_rule::*;
pub use values_extract_project_rule::*;
mod const_case_when_eval_rule;
pub use const_case_when_eval_rule::*;

#[macro_export]
macro_rules! for_all_rules {
Expand Down Expand Up @@ -227,6 +229,7 @@ macro_rules! for_all_rules {
, { AggCallMergeRule }
, { ValuesExtractProjectRule }
, { BatchPushLimitToScanRule }
, { ConstCaseWhenEvalRule }
}
};
}
Expand Down
Loading