diff --git a/src/frontend/src/expr/expr_visitor.rs b/src/frontend/src/expr/expr_visitor.rs index 5bc827b43aba8..54da009871ca4 100644 --- a/src/frontend/src/expr/expr_visitor.rs +++ b/src/frontend/src/expr/expr_visitor.rs @@ -26,14 +26,7 @@ use super::{ /// Note: The default implementation for `visit_subquery` is a no-op, i.e., expressions inside /// subqueries are not traversed. pub trait ExprVisitor { - type Result: Default; - - /// This merge function is used to reduce results of expr inputs. - /// In order to always remind users to implement themselves, we don't provide an default - /// implementation. - fn merge(a: Self::Result, b: Self::Result) -> Self::Result; - - fn visit_expr(&mut self, expr: &ExprImpl) -> Self::Result { + fn visit_expr(&mut self, expr: &ExprImpl) { match expr { ExprImpl::InputRef(inner) => self.visit_input_ref(inner), ExprImpl::Literal(inner) => self.visit_literal(inner), @@ -49,71 +42,37 @@ pub trait ExprVisitor { ExprImpl::Now(inner) => self.visit_now(inner), } } - fn visit_function_call(&mut self, func_call: &FunctionCall) -> Self::Result { + fn visit_function_call(&mut self, func_call: &FunctionCall) { func_call .inputs() .iter() - .map(|expr| self.visit_expr(expr)) - .reduce(Self::merge) - .unwrap_or_default() + .for_each(|expr| self.visit_expr(expr)); } - fn visit_function_call_with_lambda( - &mut self, - func_call: &FunctionCallWithLambda, - ) -> Self::Result { + fn visit_function_call_with_lambda(&mut self, func_call: &FunctionCallWithLambda) { self.visit_function_call(func_call.base()) } - fn visit_agg_call(&mut self, agg_call: &AggCall) -> Self::Result { - let mut r = agg_call + fn visit_agg_call(&mut self, agg_call: &AggCall) { + agg_call .args() .iter() - .map(|expr| self.visit_expr(expr)) - .reduce(Self::merge) - .unwrap_or_default(); - r = Self::merge(r, agg_call.order_by().visit_expr(self)); - r = Self::merge(r, agg_call.filter().visit_expr(self)); - r - } - fn visit_parameter(&mut self, _: &Parameter) -> Self::Result { - Self::Result::default() - } - fn visit_literal(&mut self, _: &Literal) -> Self::Result { - Self::Result::default() + .for_each(|expr| self.visit_expr(expr)); + agg_call.order_by().visit_expr(self); + agg_call.filter().visit_expr(self); } - fn visit_input_ref(&mut self, _: &InputRef) -> Self::Result { - Self::Result::default() - } - fn visit_subquery(&mut self, _: &Subquery) -> Self::Result { - Self::Result::default() - } - fn visit_correlated_input_ref(&mut self, _: &CorrelatedInputRef) -> Self::Result { - Self::Result::default() - } - fn visit_table_function(&mut self, func_call: &TableFunction) -> Self::Result { - func_call - .args - .iter() - .map(|expr| self.visit_expr(expr)) - .reduce(Self::merge) - .unwrap_or_default() - } - fn visit_window_function(&mut self, func_call: &WindowFunction) -> Self::Result { - func_call - .args - .iter() - .map(|expr| self.visit_expr(expr)) - .reduce(Self::merge) - .unwrap_or_default() + fn visit_parameter(&mut self, _: &Parameter) {} + fn visit_literal(&mut self, _: &Literal) {} + fn visit_input_ref(&mut self, _: &InputRef) {} + fn visit_subquery(&mut self, _: &Subquery) {} + fn visit_correlated_input_ref(&mut self, _: &CorrelatedInputRef) {} + + fn visit_table_function(&mut self, func_call: &TableFunction) { + func_call.args.iter().for_each(|expr| self.visit_expr(expr)); } - fn visit_user_defined_function(&mut self, func_call: &UserDefinedFunction) -> Self::Result { - func_call - .args - .iter() - .map(|expr| self.visit_expr(expr)) - .reduce(Self::merge) - .unwrap_or_default() + fn visit_window_function(&mut self, func_call: &WindowFunction) { + func_call.args.iter().for_each(|expr| self.visit_expr(expr)); } - fn visit_now(&mut self, _: &Now) -> Self::Result { - Self::Result::default() + fn visit_user_defined_function(&mut self, func_call: &UserDefinedFunction) { + func_call.args.iter().for_each(|expr| self.visit_expr(expr)); } + fn visit_now(&mut self, _: &Now) {} } diff --git a/src/frontend/src/expr/function_call.rs b/src/frontend/src/expr/function_call.rs index 1dfc86226d78f..ad0ddc8fc08a5 100644 --- a/src/frontend/src/expr/function_call.rs +++ b/src/frontend/src/expr/function_call.rs @@ -278,8 +278,9 @@ impl FunctionCall { } pub fn is_pure(&self) -> bool { - let mut a = ImpureAnalyzer {}; - !a.visit_function_call(self) + let mut a = ImpureAnalyzer { impure: false }; + a.visit_function_call(self); + !a.impure } } diff --git a/src/frontend/src/expr/mod.rs b/src/frontend/src/expr/mod.rs index 27e781a88690b..c3aecb060afc6 100644 --- a/src/frontend/src/expr/mod.rs +++ b/src/frontend/src/expr/mod.rs @@ -215,7 +215,8 @@ impl ExprImpl { /// Count `Now`s in the expression. pub fn count_nows(&self) -> usize { let mut visitor = CountNow::default(); - visitor.visit_expr(self) + visitor.visit_expr(self); + visitor.count() } /// Check whether self is literal NULL. @@ -349,23 +350,17 @@ macro_rules! impl_has_variant { impl ExprImpl { $( pub fn [](&self) -> bool { - struct Has {} + struct Has { has: bool } impl ExprVisitor for Has { - - type Result = bool; - - fn merge(a: bool, b: bool) -> bool { - a | b - } - - fn [](&mut self, _: &$variant) -> bool { - true + fn [](&mut self, _: &$variant) { + self.has = true; } } - let mut visitor = Has {}; - visitor.visit_expr(self) + let mut visitor = Has { has: false }; + visitor.visit_expr(self); + visitor.has } )* } @@ -422,110 +417,96 @@ impl ExprImpl { pub fn has_correlated_input_ref_by_depth(&self, depth: Depth) -> bool { struct Has { depth: usize, + has: bool, } impl ExprVisitor for Has { - type Result = bool; - - fn merge(a: bool, b: bool) -> bool { - a | b - } - - fn visit_correlated_input_ref( - &mut self, - correlated_input_ref: &CorrelatedInputRef, - ) -> bool { - correlated_input_ref.depth() == self.depth + fn visit_correlated_input_ref(&mut self, correlated_input_ref: &CorrelatedInputRef) { + if correlated_input_ref.depth() == self.depth { + self.has = true; + } } - fn visit_subquery(&mut self, subquery: &Subquery) -> bool { + fn visit_subquery(&mut self, subquery: &Subquery) { self.depth += 1; - let has = self.visit_bound_set_expr(&subquery.query.body); + self.visit_bound_set_expr(&subquery.query.body); self.depth -= 1; - - has } } impl Has { - fn visit_bound_set_expr(&mut self, set_expr: &BoundSetExpr) -> bool { - let mut has = false; + fn visit_bound_set_expr(&mut self, set_expr: &BoundSetExpr) { match set_expr { BoundSetExpr::Select(select) => { - select.exprs().for_each(|expr| has |= self.visit_expr(expr)); - has |= match select.from.as_ref() { + select.exprs().for_each(|expr| self.visit_expr(expr)); + match select.from.as_ref() { Some(from) => from.is_correlated(self.depth), None => false, }; } BoundSetExpr::Values(values) => { - values.exprs().for_each(|expr| has |= self.visit_expr(expr)) + values.exprs().for_each(|expr| self.visit_expr(expr)) } BoundSetExpr::Query(query) => { self.depth += 1; - has = self.visit_bound_set_expr(&query.body); + self.visit_bound_set_expr(&query.body); self.depth -= 1; } BoundSetExpr::SetOperation { left, right, .. } => { - has |= self.visit_bound_set_expr(left); - has |= self.visit_bound_set_expr(right); + self.visit_bound_set_expr(left); + self.visit_bound_set_expr(right); } }; - has } } - let mut visitor = Has { depth }; - visitor.visit_expr(self) + let mut visitor = Has { depth, has: false }; + visitor.visit_expr(self); + visitor.has } pub fn has_correlated_input_ref_by_correlated_id(&self, correlated_id: CorrelatedId) -> bool { struct Has { correlated_id: CorrelatedId, + has: bool, } impl ExprVisitor for Has { - type Result = bool; - - fn merge(a: bool, b: bool) -> bool { - a | b - } - - fn visit_correlated_input_ref( - &mut self, - correlated_input_ref: &CorrelatedInputRef, - ) -> bool { - correlated_input_ref.correlated_id() == self.correlated_id + fn visit_correlated_input_ref(&mut self, correlated_input_ref: &CorrelatedInputRef) { + if correlated_input_ref.correlated_id() == self.correlated_id { + self.has = true; + } } - fn visit_subquery(&mut self, subquery: &Subquery) -> bool { - self.visit_bound_set_expr(&subquery.query.body) + fn visit_subquery(&mut self, subquery: &Subquery) { + self.visit_bound_set_expr(&subquery.query.body); } } impl Has { - fn visit_bound_set_expr(&mut self, set_expr: &BoundSetExpr) -> bool { + fn visit_bound_set_expr(&mut self, set_expr: &BoundSetExpr) { match set_expr { - BoundSetExpr::Select(select) => select - .exprs() - .map(|expr| self.visit_expr(expr)) - .reduce(Self::merge) - .unwrap_or_default(), - BoundSetExpr::Values(values) => values - .exprs() - .map(|expr| self.visit_expr(expr)) - .reduce(Self::merge) - .unwrap_or_default(), + BoundSetExpr::Select(select) => { + select.exprs().for_each(|expr| self.visit_expr(expr)) + } + BoundSetExpr::Values(values) => { + values.exprs().for_each(|expr| self.visit_expr(expr)); + } BoundSetExpr::Query(query) => self.visit_bound_set_expr(&query.body), BoundSetExpr::SetOperation { left, right, .. } => { - self.visit_bound_set_expr(left) | self.visit_bound_set_expr(right) + self.visit_bound_set_expr(left); + self.visit_bound_set_expr(right); } } } } - let mut visitor = Has { correlated_id }; - visitor.visit_expr(self) + let mut visitor = Has { + correlated_id, + has: false, + }; + visitor.visit_expr(self); + visitor.has } /// Collect `CorrelatedInputRef`s in `ExprImpl` by relative `depth`, return their indices, and @@ -607,10 +588,6 @@ impl ExprImpl { has_others: bool, } impl ExprVisitor for HasOthers { - type Result = (); - - fn merge(_: (), _: ()) {} - fn visit_expr(&mut self, expr: &ExprImpl) { match expr { ExprImpl::CorrelatedInputRef(_) diff --git a/src/frontend/src/expr/order_by_expr.rs b/src/frontend/src/expr/order_by_expr.rs index e7fe005256640..413b8de016445 100644 --- a/src/frontend/src/expr/order_by_expr.rs +++ b/src/frontend/src/expr/order_by_expr.rs @@ -71,12 +71,10 @@ impl OrderBy { } } - pub fn visit_expr(&self, visitor: &mut V) -> V::Result { + pub fn visit_expr(&self, visitor: &mut V) { self.sort_exprs .iter() - .map(|expr| visitor.visit_expr(&expr.expr)) - .reduce(V::merge) - .unwrap_or_default() + .for_each(|expr| visitor.visit_expr(&expr.expr)); } pub fn visit_expr_mut(&mut self, mutator: &mut (impl ExprMutator + ?Sized)) { diff --git a/src/frontend/src/expr/pure.rs b/src/frontend/src/expr/pure.rs index 5c1cc37b0b99e..ee4c659782731 100644 --- a/src/frontend/src/expr/pure.rs +++ b/src/frontend/src/expr/pure.rs @@ -16,25 +16,22 @@ use risingwave_pb::expr::expr_node; use super::{ExprImpl, ExprVisitor}; use crate::expr::FunctionCall; -pub(crate) struct ImpureAnalyzer {} -impl ExprVisitor for ImpureAnalyzer { - type Result = bool; - - fn merge(a: bool, b: bool) -> bool { - // the expr will be impure if any of its input is impure - a || b - } +#[derive(Default)] +pub(crate) struct ImpureAnalyzer { + pub(crate) impure: bool, +} - fn visit_user_defined_function(&mut self, _func_call: &super::UserDefinedFunction) -> bool { - true +impl ExprVisitor for ImpureAnalyzer { + fn visit_user_defined_function(&mut self, _func_call: &super::UserDefinedFunction) { + self.impure = true; } - fn visit_now(&mut self, _: &super::Now) -> bool { - true + fn visit_now(&mut self, _: &super::Now) { + self.impure = true; } - fn visit_function_call(&mut self, func_call: &super::FunctionCall) -> bool { + fn visit_function_call(&mut self, func_call: &super::FunctionCall) { match func_call.func_type() { expr_node::Type::Unspecified => unreachable!(), expr_node::Type::Add @@ -224,13 +221,10 @@ impl ExprVisitor for ImpureAnalyzer { | expr_node::Type::Least => // expression output is deterministic(same result for the same input) { - let x = func_call + func_call .inputs() .iter() - .map(|expr| self.visit_expr(expr)) - .reduce(Self::merge) - .unwrap_or_default(); - x + .for_each(|expr| self.visit_expr(expr)); } // expression output is not deterministic expr_node::Type::Vnode @@ -240,7 +234,7 @@ impl ExprVisitor for ImpureAnalyzer { | expr_node::Type::PgSleepUntil | expr_node::Type::ColDescription | expr_node::Type::CastRegclass - | expr_node::Type::MakeTimestamptz => true, + | expr_node::Type::MakeTimestamptz => self.impure = true, } } } @@ -250,13 +244,15 @@ pub fn is_pure(expr: &ExprImpl) -> bool { } pub fn is_impure(expr: &ExprImpl) -> bool { - let mut a = ImpureAnalyzer {}; - a.visit_expr(expr) + let mut a = ImpureAnalyzer::default(); + a.visit_expr(expr); + a.impure } pub fn is_impure_func_call(func_call: &FunctionCall) -> bool { - let mut a = ImpureAnalyzer {}; - a.visit_function_call(func_call) + let mut a = ImpureAnalyzer::default(); + a.visit_function_call(func_call); + a.impure } #[cfg(test)] diff --git a/src/frontend/src/expr/utils.rs b/src/frontend/src/expr/utils.rs index 39064d1680359..2e8dcb19d9fc3 100644 --- a/src/frontend/src/expr/utils.rs +++ b/src/frontend/src/expr/utils.rs @@ -354,10 +354,6 @@ pub struct CollectInputRef { } impl ExprVisitor for CollectInputRef { - type Result = (); - - fn merge(_: (), _: ()) {} - fn visit_input_ref(&mut self, expr: &InputRef) { self.input_bits.insert(expr.index()); } @@ -408,17 +404,19 @@ pub fn collect_input_refs<'a>( /// Count `Now`s in the expression. #[derive(Clone, Default)] -pub struct CountNow {} - -impl ExprVisitor for CountNow { - type Result = usize; +pub struct CountNow { + count: usize, +} - fn merge(a: usize, b: usize) -> usize { - a + b +impl CountNow { + pub fn count(&self) -> usize { + self.count } +} - fn visit_now(&mut self, _: &super::Now) -> usize { - 1 +impl ExprVisitor for CountNow { + fn visit_now(&mut self, _: &super::Now) { + self.count += 1; } } diff --git a/src/frontend/src/optimizer/plan_expr_visitor/expr_counter.rs b/src/frontend/src/optimizer/plan_expr_visitor/expr_counter.rs index c664016b779da..55d6898bdc384 100644 --- a/src/frontend/src/optimizer/plan_expr_visitor/expr_counter.rs +++ b/src/frontend/src/optimizer/plan_expr_visitor/expr_counter.rs @@ -24,10 +24,6 @@ pub struct CseExprCounter { } impl ExprVisitor for CseExprCounter { - type Result = (); - - fn merge(_: (), _: ()) {} - fn visit_expr(&mut self, expr: &ExprImpl) { // Considering this sql, `In` expression needs to ensure its in-clauses to be const. // If we extract it into a common sub-expression (finally be a `InputRef`) which will @@ -88,8 +84,6 @@ impl ExprVisitor for CseExprCounter { func_call .inputs() .iter() - .map(|expr| self.visit_expr(expr)) - .reduce(Self::merge) - .unwrap_or_default() + .for_each(|expr| self.visit_expr(expr)); } } diff --git a/src/frontend/src/optimizer/plan_expr_visitor/input_ref_counter.rs b/src/frontend/src/optimizer/plan_expr_visitor/input_ref_counter.rs index 382dc74222c9f..f7500c9272686 100644 --- a/src/frontend/src/optimizer/plan_expr_visitor/input_ref_counter.rs +++ b/src/frontend/src/optimizer/plan_expr_visitor/input_ref_counter.rs @@ -23,10 +23,6 @@ pub struct InputRefCounter { } impl ExprVisitor for InputRefCounter { - type Result = (); - - fn merge(_: (), _: ()) {} - fn visit_input_ref(&mut self, input_ref: &InputRef) { self.counter .entry(input_ref.index) diff --git a/src/frontend/src/optimizer/plan_node/logical_over_window.rs b/src/frontend/src/optimizer/plan_node/logical_over_window.rs index 665cee6f178a0..5549ae71d3d7e 100644 --- a/src/frontend/src/optimizer/plan_node/logical_over_window.rs +++ b/src/frontend/src/optimizer/plan_node/logical_over_window.rs @@ -325,10 +325,6 @@ impl<'a> OverWindowProjectBuilder<'a> { } impl<'a> ExprVisitor for OverWindowProjectBuilder<'a> { - type Result = (); - - fn merge(_a: (), _b: ()) {} - fn visit_window_function(&mut self, window_function: &WindowFunction) { if let Err(e) = self.try_visit_window_function(window_function) { self.error = Some(e); diff --git a/src/frontend/src/optimizer/plan_node/logical_scan.rs b/src/frontend/src/optimizer/plan_node/logical_scan.rs index 8c4aedf524920..7cc8c2d872010 100644 --- a/src/frontend/src/optimizer/plan_node/logical_scan.rs +++ b/src/frontend/src/optimizer/plan_node/logical_scan.rs @@ -387,21 +387,25 @@ impl PredicatePushdown for LogicalScan { ) -> PlanRef { // If the predicate contains `CorrelatedInputRef` or `now()`. We don't push down. // This case could come from the predicate push down before the subquery unnesting. - struct HasCorrelated {} + struct HasCorrelated { + has: bool, + } impl ExprVisitor for HasCorrelated { - type Result = bool; - - fn merge(a: bool, b: bool) -> bool { - a | b - } - - fn visit_correlated_input_ref(&mut self, _: &CorrelatedInputRef) -> bool { - true + fn visit_correlated_input_ref(&mut self, _: &CorrelatedInputRef) { + self.has = true; } } let non_pushable_predicate: Vec<_> = predicate .conjunctions - .extract_if(|expr| expr.count_nows() > 0 || HasCorrelated {}.visit_expr(expr)) + .extract_if(|expr| { + if expr.count_nows() > 0 { + true + } else { + let mut visitor = HasCorrelated { has: false }; + visitor.visit_expr(expr); + visitor.has + } + }) .collect(); let predicate = predicate.rewrite_expr(&mut ColIndexMapping::new( self.output_col_idx().iter().map(|i| Some(*i)).collect(), diff --git a/src/frontend/src/optimizer/plan_node/logical_sys_scan.rs b/src/frontend/src/optimizer/plan_node/logical_sys_scan.rs index 56985d81a5c27..ac1e872015a73 100644 --- a/src/frontend/src/optimizer/plan_node/logical_sys_scan.rs +++ b/src/frontend/src/optimizer/plan_node/logical_sys_scan.rs @@ -273,21 +273,25 @@ impl PredicatePushdown for LogicalSysScan { ) -> PlanRef { // If the predicate contains `CorrelatedInputRef` or `now()`. We don't push down. // This case could come from the predicate push down before the subquery unnesting. - struct HasCorrelated {} + struct HasCorrelated { + has: bool, + } impl ExprVisitor for HasCorrelated { - type Result = bool; - - fn merge(a: bool, b: bool) -> bool { - a | b - } - - fn visit_correlated_input_ref(&mut self, _: &CorrelatedInputRef) -> bool { - true + fn visit_correlated_input_ref(&mut self, _: &CorrelatedInputRef) { + self.has = true; } } let non_pushable_predicate: Vec<_> = predicate .conjunctions - .extract_if(|expr| expr.count_nows() > 0 || HasCorrelated {}.visit_expr(expr)) + .extract_if(|expr| { + if expr.count_nows() > 0 { + true + } else { + let mut visitor = HasCorrelated { has: false }; + visitor.visit_expr(expr); + visitor.has + } + }) .collect(); let predicate = predicate.rewrite_expr(&mut ColIndexMapping::new( self.output_col_idx().iter().map(|i| Some(*i)).collect(), diff --git a/src/frontend/src/optimizer/plan_visitor/input_ref_validator.rs b/src/frontend/src/optimizer/plan_visitor/input_ref_validator.rs index 6911c6e8ce89a..ece836f82e17b 100644 --- a/src/frontend/src/optimizer/plan_visitor/input_ref_validator.rs +++ b/src/frontend/src/optimizer/plan_visitor/input_ref_validator.rs @@ -23,25 +23,18 @@ use crate::optimizer::plan_visitor::PlanVisitor; struct ExprVis<'a> { schema: &'a Schema, + string: Option, } impl ExprVisitor for ExprVis<'_> { - type Result = Option; - - fn visit_input_ref(&mut self, input_ref: &crate::expr::InputRef) -> Option { + fn visit_input_ref(&mut self, input_ref: &crate::expr::InputRef) { if input_ref.data_type != self.schema[input_ref.index].data_type { - Some(format!( + self.string.replace(format!( "InputRef#{} has type {}, but its type is {} in the input schema", input_ref.index, input_ref.data_type, self.schema[input_ref.index].data_type - )) - } else { - None + )); } } - - fn merge(a: Option, b: Option) -> Option { - a.or(b) - } } /// Validates that input references are consistent with the input schema. @@ -71,8 +64,10 @@ macro_rules! visit_filter { let input = plan.input(); let mut vis = ExprVis { schema: input.schema(), + string: None, }; - plan.predicate().visit_expr(&mut vis).or_else(|| { + plan.predicate().visit_expr(&mut vis); + vis.string.or_else(|| { self.visit(input) }) } @@ -89,11 +84,12 @@ macro_rules! visit_project { let input = plan.input(); let mut vis = ExprVis { schema: input.schema(), + string: None, }; for expr in plan.exprs() { - let res = vis.visit_expr(expr); - if res.is_some() { - return res; + vis.visit_expr(expr); + if vis.string.is_some() { + return vis.string; } } self.visit(input) @@ -129,8 +125,10 @@ impl PlanVisitor for InputRefValidator { let input_schema = Schema { fields }; let mut vis = ExprVis { schema: &input_schema, + string: None, }; - plan.predicate().visit_expr(&mut vis) + plan.predicate().visit_expr(&mut vis); + vis.string } // TODO: add more checks diff --git a/src/frontend/src/optimizer/plan_visitor/plan_correlated_id_finder.rs b/src/frontend/src/optimizer/plan_visitor/plan_correlated_id_finder.rs index 7ff5d0adb7c0a..e1d213ba90cff 100644 --- a/src/frontend/src/optimizer/plan_visitor/plan_correlated_id_finder.rs +++ b/src/frontend/src/optimizer/plan_visitor/plan_correlated_id_finder.rs @@ -136,10 +136,6 @@ impl ExprCorrelatedIdFinder { } impl ExprVisitor for ExprCorrelatedIdFinder { - type Result = (); - - fn merge(_: (), _: ()) {} - fn visit_correlated_input_ref(&mut self, correlated_input_ref: &CorrelatedInputRef) { self.correlated_id_set .insert(correlated_input_ref.correlated_id()); diff --git a/src/frontend/src/optimizer/rule/index_selection_rule.rs b/src/frontend/src/optimizer/rule/index_selection_rule.rs index 3920924d8146e..6cce59fbf030c 100644 --- a/src/frontend/src/optimizer/rule/index_selection_rule.rs +++ b/src/frontend/src/optimizer/rule/index_selection_rule.rs @@ -708,6 +708,7 @@ impl IndexSelectionRule { struct TableScanIoEstimator<'a> { table_scan: &'a LogicalScan, row_size: usize, + cost: Option, } impl<'a> TableScanIoEstimator<'a> { @@ -715,6 +716,7 @@ impl<'a> TableScanIoEstimator<'a> { Self { table_scan, row_size, + cost: None, } } @@ -766,7 +768,8 @@ impl<'a> TableScanIoEstimator<'a> { pub fn estimate(&mut self, predicate: &Condition) -> IndexCost { // try to deal with OR condition if predicate.conjunctions.len() == 1 { - self.visit_expr(&predicate.conjunctions[0]) + self.visit_expr(&predicate.conjunctions[0]); + self.cost.take().unwrap_or_default() } else { self.estimate_conjunctions(&predicate.conjunctions) } @@ -914,14 +917,16 @@ impl IndexCost { } impl ExprVisitor for TableScanIoEstimator<'_> { - type Result = IndexCost; - - fn visit_function_call(&mut self, func_call: &FunctionCall) -> IndexCost { - match func_call.func_type() { + fn visit_function_call(&mut self, func_call: &FunctionCall) { + let cost = match func_call.func_type() { ExprType::Or => func_call .inputs() .iter() - .map(|x| self.visit_expr(x)) + .map(|x| { + let mut estimator = TableScanIoEstimator::new(self.table_scan, self.row_size); + estimator.visit_expr(x); + estimator.cost.take().unwrap_or_default() + }) .reduce(|x, y| x.add(&y)) .unwrap(), ExprType::And => self.estimate_conjunctions(func_call.inputs()), @@ -929,11 +934,8 @@ impl ExprVisitor for TableScanIoEstimator<'_> { let single = vec![ExprImpl::FunctionCall(func_call.clone().into())]; self.estimate_conjunctions(&single) } - } - } - - fn merge(a: IndexCost, b: IndexCost) -> IndexCost { - a.add(&b) + }; + self.cost = Some(cost); } } @@ -943,10 +945,6 @@ struct ExprInputRefFinder { } impl ExprVisitor for ExprInputRefFinder { - type Result = (); - - fn merge(_: (), _: ()) {} - fn visit_input_ref(&mut self, input_ref: &InputRef) { self.input_ref_index_set.insert(input_ref.index); } diff --git a/src/frontend/src/optimizer/rule/rewrite_like_expr_rule.rs b/src/frontend/src/optimizer/rule/rewrite_like_expr_rule.rs index facad4a8da07c..6453b743a4f47 100644 --- a/src/frontend/src/optimizer/rule/rewrite_like_expr_rule.rs +++ b/src/frontend/src/optimizer/rule/rewrite_like_expr_rule.rs @@ -31,13 +31,11 @@ pub struct RewriteLikeExprRule {} impl Rule for RewriteLikeExprRule { fn apply(&self, plan: PlanRef) -> Option { let filter: &LogicalFilter = plan.as_logical_filter()?; - let mut has_like = HasLikeExprVisitor {}; - if filter - .predicate() - .conjunctions - .iter() - .any(|expr| has_like.visit_expr(expr)) - { + if filter.predicate().conjunctions.iter().any(|expr| { + let mut has_like = HasLikeExprVisitor { has: false }; + has_like.visit_expr(expr); + has_like.has + }) { let mut rewriter = LikeExprRewriter {}; Some(filter.rewrite_exprs(&mut rewriter)) } else { @@ -46,28 +44,22 @@ impl Rule for RewriteLikeExprRule { } } -struct HasLikeExprVisitor {} +struct HasLikeExprVisitor { + has: bool, +} impl ExprVisitor for HasLikeExprVisitor { - type Result = bool; - - fn merge(a: bool, b: bool) -> bool { - a | b - } - - fn visit_function_call(&mut self, func_call: &FunctionCall) -> bool { + fn visit_function_call(&mut self, func_call: &FunctionCall) { if func_call.func_type() == ExprType::Like && let (_, ExprImpl::InputRef(_), ExprImpl::Literal(_)) = func_call.clone().decompose_as_binary() { - true + self.has = true; } else { func_call .inputs() .iter() - .map(|expr| self.visit_expr(expr)) - .reduce(Self::merge) - .unwrap_or_default() + .for_each(|expr| self.visit_expr(expr)); } } } diff --git a/src/frontend/src/utils/condition.rs b/src/frontend/src/utils/condition.rs index 332dda739bb3a..d078e06bcea6c 100644 --- a/src/frontend/src/utils/condition.rs +++ b/src/frontend/src/utils/condition.rs @@ -844,12 +844,10 @@ impl Condition { .simplify() } - pub fn visit_expr(&self, visitor: &mut V) -> V::Result { + pub fn visit_expr(&self, visitor: &mut V) { self.conjunctions .iter() - .map(|expr| visitor.visit_expr(expr)) - .reduce(V::merge) - .unwrap_or_default() + .for_each(|expr| visitor.visit_expr(expr)); } pub fn visit_expr_mut(&mut self, mutator: &mut (impl ExprMutator + ?Sized)) {