Skip to content

Commit

Permalink
feat(optimizer): grow stack for complicated plan in more places (#17224)
Browse files Browse the repository at this point in the history
Signed-off-by: Bugen Zhao <[email protected]>
  • Loading branch information
BugenZhao authored Jun 18, 2024
1 parent 4e5a731 commit ba80ab6
Show file tree
Hide file tree
Showing 9 changed files with 128 additions and 92 deletions.
19 changes: 2 additions & 17 deletions src/frontend/src/binder/bind_param.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use risingwave_common::types::{Datum, ScalarImpl};
use super::statement::RewriteExprsRecursive;
use super::BoundStatement;
use crate::error::{ErrorCode, Result};
use crate::expr::{Expr, ExprImpl, ExprRewriter, Literal};
use crate::expr::{default_rewrite_expr, Expr, ExprImpl, ExprRewriter, Literal};

/// Rewrites parameter expressions to literals.
pub(crate) struct ParamRewriter {
Expand All @@ -47,22 +47,7 @@ impl ExprRewriter for ParamRewriter {
if self.error.is_some() {
return expr;
}
match expr {
ExprImpl::InputRef(inner) => self.rewrite_input_ref(*inner),
ExprImpl::Literal(inner) => self.rewrite_literal(*inner),
ExprImpl::FunctionCall(inner) => self.rewrite_function_call(*inner),
ExprImpl::FunctionCallWithLambda(inner) => {
self.rewrite_function_call_with_lambda(*inner)
}
ExprImpl::AggCall(inner) => self.rewrite_agg_call(*inner),
ExprImpl::Subquery(inner) => self.rewrite_subquery(*inner),
ExprImpl::CorrelatedInputRef(inner) => self.rewrite_correlated_input_ref(*inner),
ExprImpl::TableFunction(inner) => self.rewrite_table_function(*inner),
ExprImpl::WindowFunction(inner) => self.rewrite_window_function(*inner),
ExprImpl::UserDefinedFunction(inner) => self.rewrite_user_defined_function(*inner),
ExprImpl::Parameter(inner) => self.rewrite_parameter(*inner),
ExprImpl::Now(inner) => self.rewrite_now(*inner),
}
default_rewrite_expr(self, expr)
}

fn rewrite_subquery(&mut self, mut subquery: crate::expr::Subquery) -> ExprImpl {
Expand Down
60 changes: 43 additions & 17 deletions src/frontend/src/expr/expr_rewriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,33 +12,59 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use risingwave_common::util::recursive::{tracker, Recurse};

use super::{
AggCall, CorrelatedInputRef, ExprImpl, FunctionCall, FunctionCallWithLambda, InputRef, Literal,
Parameter, Subquery, TableFunction, UserDefinedFunction, WindowFunction,
Parameter, Subquery, TableFunction, UserDefinedFunction, WindowFunction, EXPR_DEPTH_THRESHOLD,
EXPR_TOO_DEEP_NOTICE,
};
use crate::expr::Now;
use crate::session::current::notice_to_user;

/// The default implementation of [`ExprRewriter::rewrite_expr`] that simply dispatches to other
/// methods based on the type of the expression.
///
/// You can use this function as a helper to reduce boilerplate code when implementing the trait.
// TODO: This is essentially a mimic of `super` pattern from OO languages. Ideally, we should
// adopt the style proposed in https://github.com/risingwavelabs/risingwave/issues/13477.
pub fn default_rewrite_expr<R: ExprRewriter + ?Sized>(
rewriter: &mut R,
expr: ExprImpl,
) -> ExprImpl {
// TODO: Implementors may choose to not use this function at all, in which case we will fail
// to track the recursion and grow the stack as necessary. The current approach is only a
// best-effort attempt to prevent stack overflow.
tracker!().recurse(|t| {
if t.depth_reaches(EXPR_DEPTH_THRESHOLD) {
notice_to_user(EXPR_TOO_DEEP_NOTICE);
}

match expr {
ExprImpl::InputRef(inner) => rewriter.rewrite_input_ref(*inner),
ExprImpl::Literal(inner) => rewriter.rewrite_literal(*inner),
ExprImpl::FunctionCall(inner) => rewriter.rewrite_function_call(*inner),
ExprImpl::FunctionCallWithLambda(inner) => {
rewriter.rewrite_function_call_with_lambda(*inner)
}
ExprImpl::AggCall(inner) => rewriter.rewrite_agg_call(*inner),
ExprImpl::Subquery(inner) => rewriter.rewrite_subquery(*inner),
ExprImpl::CorrelatedInputRef(inner) => rewriter.rewrite_correlated_input_ref(*inner),
ExprImpl::TableFunction(inner) => rewriter.rewrite_table_function(*inner),
ExprImpl::WindowFunction(inner) => rewriter.rewrite_window_function(*inner),
ExprImpl::UserDefinedFunction(inner) => rewriter.rewrite_user_defined_function(*inner),
ExprImpl::Parameter(inner) => rewriter.rewrite_parameter(*inner),
ExprImpl::Now(inner) => rewriter.rewrite_now(*inner),
}
})
}

/// By default, `ExprRewriter` simply traverses the expression tree and leaves nodes unchanged.
/// Implementations can override a subset of methods and perform transformation on some particular
/// types of expression.
pub trait ExprRewriter {
fn rewrite_expr(&mut self, expr: ExprImpl) -> ExprImpl {
match expr {
ExprImpl::InputRef(inner) => self.rewrite_input_ref(*inner),
ExprImpl::Literal(inner) => self.rewrite_literal(*inner),
ExprImpl::FunctionCall(inner) => self.rewrite_function_call(*inner),
ExprImpl::FunctionCallWithLambda(inner) => {
self.rewrite_function_call_with_lambda(*inner)
}
ExprImpl::AggCall(inner) => self.rewrite_agg_call(*inner),
ExprImpl::Subquery(inner) => self.rewrite_subquery(*inner),
ExprImpl::CorrelatedInputRef(inner) => self.rewrite_correlated_input_ref(*inner),
ExprImpl::TableFunction(inner) => self.rewrite_table_function(*inner),
ExprImpl::WindowFunction(inner) => self.rewrite_window_function(*inner),
ExprImpl::UserDefinedFunction(inner) => self.rewrite_user_defined_function(*inner),
ExprImpl::Parameter(inner) => self.rewrite_parameter(*inner),
ExprImpl::Now(inner) => self.rewrite_now(*inner),
}
default_rewrite_expr(self, expr)
}

fn rewrite_function_call(&mut self, func_call: FunctionCall) -> ExprImpl {
Expand Down
53 changes: 39 additions & 14 deletions src/frontend/src/expr/expr_visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,48 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use risingwave_common::util::recursive::{tracker, Recurse};

use super::{
AggCall, CorrelatedInputRef, ExprImpl, FunctionCall, FunctionCallWithLambda, InputRef, Literal,
Now, Parameter, Subquery, TableFunction, UserDefinedFunction, WindowFunction,
EXPR_DEPTH_THRESHOLD, EXPR_TOO_DEEP_NOTICE,
};
use crate::session::current::notice_to_user;

/// The default implementation of [`ExprVisitor::visit_expr`] that simply dispatches to other
/// methods based on the type of the expression.
///
/// You can use this function as a helper to reduce boilerplate code when implementing the trait.
// TODO: This is essentially a mimic of `super` pattern from OO languages. Ideally, we should
// adopt the style proposed in https://github.com/risingwavelabs/risingwave/issues/13477.
pub fn default_visit_expr<V: ExprVisitor + ?Sized>(visitor: &mut V, expr: &ExprImpl) {
// TODO: Implementors may choose to not use this function at all, in which case we will fail
// to track the recursion and grow the stack as necessary. The current approach is only a
// best-effort attempt to prevent stack overflow.
tracker!().recurse(|t| {
if t.depth_reaches(EXPR_DEPTH_THRESHOLD) {
notice_to_user(EXPR_TOO_DEEP_NOTICE);
}

match expr {
ExprImpl::InputRef(inner) => visitor.visit_input_ref(inner),
ExprImpl::Literal(inner) => visitor.visit_literal(inner),
ExprImpl::FunctionCall(inner) => visitor.visit_function_call(inner),
ExprImpl::FunctionCallWithLambda(inner) => {
visitor.visit_function_call_with_lambda(inner)
}
ExprImpl::AggCall(inner) => visitor.visit_agg_call(inner),
ExprImpl::Subquery(inner) => visitor.visit_subquery(inner),
ExprImpl::CorrelatedInputRef(inner) => visitor.visit_correlated_input_ref(inner),
ExprImpl::TableFunction(inner) => visitor.visit_table_function(inner),
ExprImpl::WindowFunction(inner) => visitor.visit_window_function(inner),
ExprImpl::UserDefinedFunction(inner) => visitor.visit_user_defined_function(inner),
ExprImpl::Parameter(inner) => visitor.visit_parameter(inner),
ExprImpl::Now(inner) => visitor.visit_now(inner),
}
})
}

/// Traverse an expression tree.
///
Expand All @@ -27,20 +65,7 @@ use super::{
/// subqueries are not traversed.
pub trait ExprVisitor {
fn visit_expr(&mut self, expr: &ExprImpl) {
match expr {
ExprImpl::InputRef(inner) => self.visit_input_ref(inner),
ExprImpl::Literal(inner) => self.visit_literal(inner),
ExprImpl::FunctionCall(inner) => self.visit_function_call(inner),
ExprImpl::FunctionCallWithLambda(inner) => self.visit_function_call_with_lambda(inner),
ExprImpl::AggCall(inner) => self.visit_agg_call(inner),
ExprImpl::Subquery(inner) => self.visit_subquery(inner),
ExprImpl::CorrelatedInputRef(inner) => self.visit_correlated_input_ref(inner),
ExprImpl::TableFunction(inner) => self.visit_table_function(inner),
ExprImpl::WindowFunction(inner) => self.visit_window_function(inner),
ExprImpl::UserDefinedFunction(inner) => self.visit_user_defined_function(inner),
ExprImpl::Parameter(inner) => self.visit_parameter(inner),
ExprImpl::Now(inner) => self.visit_now(inner),
}
default_visit_expr(self, expr)
}
fn visit_function_call(&mut self, func_call: &FunctionCall) {
func_call
Expand Down
8 changes: 6 additions & 2 deletions src/frontend/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ mod utils;
pub use agg_call::AggCall;
pub use correlated_input_ref::{CorrelatedId, CorrelatedInputRef, Depth};
pub use expr_mutator::ExprMutator;
pub use expr_rewriter::ExprRewriter;
pub use expr_visitor::ExprVisitor;
pub use expr_rewriter::{default_rewrite_expr, ExprRewriter};
pub use expr_visitor::{default_visit_expr, ExprVisitor};
pub use function_call::{is_row_function, FunctionCall, FunctionCallDisplay};
pub use function_call_with_lambda::FunctionCallWithLambda;
pub use input_ref::{input_ref_to_column_indices, InputRef, InputRefDisplay};
Expand All @@ -74,6 +74,10 @@ pub use user_defined_function::UserDefinedFunction;
pub use utils::*;
pub use window_function::WindowFunction;

const EXPR_DEPTH_THRESHOLD: usize = 30;
const EXPR_TOO_DEEP_NOTICE: &str = "Some expression is too complicated. \
Consider simplifying or splitting the query if you encounter any issues.";

/// the trait of bound expressions
pub trait Expr: Into<ExprImpl> {
/// Get the return type of the expr
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
// limitations under the License.

use crate::error::RwError;
use crate::expr::{Expr, ExprImpl, ExprRewriter, Literal};
use crate::expr::{default_rewrite_expr, Expr, ExprImpl, ExprRewriter, Literal};

pub(crate) struct ConstEvalRewriter {
pub(crate) error: Option<RwError>,
Expand All @@ -31,21 +31,10 @@ impl ExprRewriter for ConstEvalRewriter {
expr
}
}
} else if let ExprImpl::Parameter(_) = expr {
unreachable!("Parameter should not appear here. It will be replaced by a literal before this step.")
} else {
match expr {
ExprImpl::InputRef(inner) => self.rewrite_input_ref(*inner),
ExprImpl::Literal(inner) => self.rewrite_literal(*inner),
ExprImpl::FunctionCall(inner) => self.rewrite_function_call(*inner),
ExprImpl::FunctionCallWithLambda(inner) => self.rewrite_function_call_with_lambda(*inner),
ExprImpl::AggCall(inner) => self.rewrite_agg_call(*inner),
ExprImpl::Subquery(inner) => self.rewrite_subquery(*inner),
ExprImpl::CorrelatedInputRef(inner) => self.rewrite_correlated_input_ref(*inner),
ExprImpl::TableFunction(inner) => self.rewrite_table_function(*inner),
ExprImpl::WindowFunction(inner) => self.rewrite_window_function(*inner),
ExprImpl::UserDefinedFunction(inner) => self.rewrite_user_defined_function(*inner),
ExprImpl::Parameter(_) => unreachable!("Parameter should not appear here. It will be replaced by a literal before this step."),
ExprImpl::Now(inner) => self.rewrite_now(*inner),
}
default_rewrite_expr(self, expr)
}
}
}
17 changes: 2 additions & 15 deletions src/frontend/src/optimizer/plan_expr_visitor/expr_counter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

use std::collections::HashMap;

use crate::expr::{ExprImpl, ExprType, ExprVisitor, FunctionCall};
use crate::expr::{default_visit_expr, ExprImpl, ExprType, ExprVisitor, FunctionCall};

/// `ExprCounter` is used by `CseRewriter`.
#[derive(Default)]
Expand All @@ -35,20 +35,7 @@ impl ExprVisitor for CseExprCounter {
return;
}

match expr {
ExprImpl::InputRef(inner) => self.visit_input_ref(inner),
ExprImpl::Literal(inner) => self.visit_literal(inner),
ExprImpl::FunctionCall(inner) => self.visit_function_call(inner),
ExprImpl::FunctionCallWithLambda(inner) => self.visit_function_call_with_lambda(inner),
ExprImpl::AggCall(inner) => self.visit_agg_call(inner),
ExprImpl::Subquery(inner) => self.visit_subquery(inner),
ExprImpl::CorrelatedInputRef(inner) => self.visit_correlated_input_ref(inner),
ExprImpl::TableFunction(inner) => self.visit_table_function(inner),
ExprImpl::WindowFunction(inner) => self.visit_window_function(inner),
ExprImpl::UserDefinedFunction(inner) => self.visit_user_defined_function(inner),
ExprImpl::Parameter(inner) => self.visit_parameter(inner),
ExprImpl::Now(inner) => self.visit_now(inner),
}
default_visit_expr(self, expr);
}

fn visit_function_call(&mut self, func_call: &FunctionCall) {
Expand Down
6 changes: 4 additions & 2 deletions src/frontend/src/optimizer/plan_node/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -696,8 +696,10 @@ impl dyn PlanNode {
}
}

const PLAN_DEPTH_THRESHOLD: usize = 30;
const PLAN_TOO_DEEP_NOTICE: &str = "The plan is too deep. \
/// Recursion depth threshold for plan node visitor to send notice to user.
pub const PLAN_DEPTH_THRESHOLD: usize = 30;
/// Notice message for plan node visitor to send to user when the depth threshold is reached.
pub const PLAN_TOO_DEEP_NOTICE: &str = "The plan is too deep. \
Consider simplifying or splitting the query if you encounter any issues.";

impl dyn PlanNode {
Expand Down
19 changes: 14 additions & 5 deletions src/frontend/src/optimizer/plan_rewriter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,20 @@ macro_rules! def_rewriter {
pub trait PlanRewriter {
paste! {
fn rewrite(&mut self, plan: PlanRef) -> PlanRef{
match plan.node_type() {
$(
PlanNodeType::[<$convention $name>] => self.[<rewrite_ $convention:snake _ $name:snake>](plan.downcast_ref::<[<$convention $name>]>().unwrap()),
)*
}
use risingwave_common::util::recursive::{tracker, Recurse};
use crate::session::current::notice_to_user;

tracker!().recurse(|t| {
if t.depth_reaches(PLAN_DEPTH_THRESHOLD) {
notice_to_user(PLAN_TOO_DEEP_NOTICE);
}

match plan.node_type() {
$(
PlanNodeType::[<$convention $name>] => self.[<rewrite_ $convention:snake _ $name:snake>](plan.downcast_ref::<[<$convention $name>]>().unwrap()),
)*
}
})
}

$(
Expand Down
19 changes: 14 additions & 5 deletions src/frontend/src/optimizer/plan_visitor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,20 @@ macro_rules! def_visitor {

paste! {
fn visit(&mut self, plan: PlanRef) -> Self::Result {
match plan.node_type() {
$(
PlanNodeType::[<$convention $name>] => self.[<visit_ $convention:snake _ $name:snake>](plan.downcast_ref::<[<$convention $name>]>().unwrap()),
)*
}
use risingwave_common::util::recursive::{tracker, Recurse};
use crate::session::current::notice_to_user;

tracker!().recurse(|t| {
if t.depth_reaches(PLAN_DEPTH_THRESHOLD) {
notice_to_user(PLAN_TOO_DEEP_NOTICE);
}

match plan.node_type() {
$(
PlanNodeType::[<$convention $name>] => self.[<visit_ $convention:snake _ $name:snake>](plan.downcast_ref::<[<$convention $name>]>().unwrap()),
)*
}
})
}

$(
Expand Down

0 comments on commit ba80ab6

Please sign in to comment.