diff --git a/src/frontend/src/binder/bind_param.rs b/src/frontend/src/binder/bind_param.rs index 6c3be04d4ee90..b4bbaf420e0c9 100644 --- a/src/frontend/src/binder/bind_param.rs +++ b/src/frontend/src/binder/bind_param.rs @@ -21,7 +21,7 @@ use risingwave_common::types::{Datum, ScalarImpl}; use super::statement::RewriteExprsRecursive; use super::BoundStatement; use crate::error::{ErrorCode, Result}; -use crate::expr::{Expr, ExprImpl, ExprRewriter, Literal}; +use crate::expr::{default_rewrite_expr, Expr, ExprImpl, ExprRewriter, Literal}; /// Rewrites parameter expressions to literals. pub(crate) struct ParamRewriter { @@ -47,22 +47,7 @@ impl ExprRewriter for ParamRewriter { if self.error.is_some() { return expr; } - match expr { - ExprImpl::InputRef(inner) => self.rewrite_input_ref(*inner), - ExprImpl::Literal(inner) => self.rewrite_literal(*inner), - ExprImpl::FunctionCall(inner) => self.rewrite_function_call(*inner), - ExprImpl::FunctionCallWithLambda(inner) => { - self.rewrite_function_call_with_lambda(*inner) - } - ExprImpl::AggCall(inner) => self.rewrite_agg_call(*inner), - ExprImpl::Subquery(inner) => self.rewrite_subquery(*inner), - ExprImpl::CorrelatedInputRef(inner) => self.rewrite_correlated_input_ref(*inner), - ExprImpl::TableFunction(inner) => self.rewrite_table_function(*inner), - ExprImpl::WindowFunction(inner) => self.rewrite_window_function(*inner), - ExprImpl::UserDefinedFunction(inner) => self.rewrite_user_defined_function(*inner), - ExprImpl::Parameter(inner) => self.rewrite_parameter(*inner), - ExprImpl::Now(inner) => self.rewrite_now(*inner), - } + default_rewrite_expr(self, expr) } fn rewrite_subquery(&mut self, mut subquery: crate::expr::Subquery) -> ExprImpl { diff --git a/src/frontend/src/expr/expr_rewriter.rs b/src/frontend/src/expr/expr_rewriter.rs index 4d5b960d654dc..6300f9d5e8858 100644 --- a/src/frontend/src/expr/expr_rewriter.rs +++ b/src/frontend/src/expr/expr_rewriter.rs @@ -12,33 +12,59 @@ // See the License for the specific language governing permissions and // limitations under the License. +use risingwave_common::util::recursive::{tracker, Recurse}; + use super::{ AggCall, CorrelatedInputRef, ExprImpl, FunctionCall, FunctionCallWithLambda, InputRef, Literal, - Parameter, Subquery, TableFunction, UserDefinedFunction, WindowFunction, + Parameter, Subquery, TableFunction, UserDefinedFunction, WindowFunction, EXPR_DEPTH_THRESHOLD, + EXPR_TOO_DEEP_NOTICE, }; use crate::expr::Now; +use crate::session::current::notice_to_user; + +/// The default implementation of [`ExprRewriter::rewrite_expr`] that simply dispatches to other +/// methods based on the type of the expression. +/// +/// You can use this function as a helper to reduce boilerplate code when implementing the trait. +// TODO: This is essentially a mimic of `super` pattern from OO languages. Ideally, we should +// adopt the style proposed in https://github.com/risingwavelabs/risingwave/issues/13477. +pub fn default_rewrite_expr( + rewriter: &mut R, + expr: ExprImpl, +) -> ExprImpl { + // TODO: Implementors may choose to not use this function at all, in which case we will fail + // to track the recursion and grow the stack as necessary. The current approach is only a + // best-effort attempt to prevent stack overflow. + tracker!().recurse(|t| { + if t.depth_reaches(EXPR_DEPTH_THRESHOLD) { + notice_to_user(EXPR_TOO_DEEP_NOTICE); + } + + match expr { + ExprImpl::InputRef(inner) => rewriter.rewrite_input_ref(*inner), + ExprImpl::Literal(inner) => rewriter.rewrite_literal(*inner), + ExprImpl::FunctionCall(inner) => rewriter.rewrite_function_call(*inner), + ExprImpl::FunctionCallWithLambda(inner) => { + rewriter.rewrite_function_call_with_lambda(*inner) + } + ExprImpl::AggCall(inner) => rewriter.rewrite_agg_call(*inner), + ExprImpl::Subquery(inner) => rewriter.rewrite_subquery(*inner), + ExprImpl::CorrelatedInputRef(inner) => rewriter.rewrite_correlated_input_ref(*inner), + ExprImpl::TableFunction(inner) => rewriter.rewrite_table_function(*inner), + ExprImpl::WindowFunction(inner) => rewriter.rewrite_window_function(*inner), + ExprImpl::UserDefinedFunction(inner) => rewriter.rewrite_user_defined_function(*inner), + ExprImpl::Parameter(inner) => rewriter.rewrite_parameter(*inner), + ExprImpl::Now(inner) => rewriter.rewrite_now(*inner), + } + }) +} /// By default, `ExprRewriter` simply traverses the expression tree and leaves nodes unchanged. /// Implementations can override a subset of methods and perform transformation on some particular /// types of expression. pub trait ExprRewriter { fn rewrite_expr(&mut self, expr: ExprImpl) -> ExprImpl { - match expr { - ExprImpl::InputRef(inner) => self.rewrite_input_ref(*inner), - ExprImpl::Literal(inner) => self.rewrite_literal(*inner), - ExprImpl::FunctionCall(inner) => self.rewrite_function_call(*inner), - ExprImpl::FunctionCallWithLambda(inner) => { - self.rewrite_function_call_with_lambda(*inner) - } - ExprImpl::AggCall(inner) => self.rewrite_agg_call(*inner), - ExprImpl::Subquery(inner) => self.rewrite_subquery(*inner), - ExprImpl::CorrelatedInputRef(inner) => self.rewrite_correlated_input_ref(*inner), - ExprImpl::TableFunction(inner) => self.rewrite_table_function(*inner), - ExprImpl::WindowFunction(inner) => self.rewrite_window_function(*inner), - ExprImpl::UserDefinedFunction(inner) => self.rewrite_user_defined_function(*inner), - ExprImpl::Parameter(inner) => self.rewrite_parameter(*inner), - ExprImpl::Now(inner) => self.rewrite_now(*inner), - } + default_rewrite_expr(self, expr) } fn rewrite_function_call(&mut self, func_call: FunctionCall) -> ExprImpl { diff --git a/src/frontend/src/expr/expr_visitor.rs b/src/frontend/src/expr/expr_visitor.rs index 4e0484397ab9e..64b5c61b565dd 100644 --- a/src/frontend/src/expr/expr_visitor.rs +++ b/src/frontend/src/expr/expr_visitor.rs @@ -12,10 +12,48 @@ // See the License for the specific language governing permissions and // limitations under the License. +use risingwave_common::util::recursive::{tracker, Recurse}; + use super::{ AggCall, CorrelatedInputRef, ExprImpl, FunctionCall, FunctionCallWithLambda, InputRef, Literal, Now, Parameter, Subquery, TableFunction, UserDefinedFunction, WindowFunction, + EXPR_DEPTH_THRESHOLD, EXPR_TOO_DEEP_NOTICE, }; +use crate::session::current::notice_to_user; + +/// The default implementation of [`ExprVisitor::visit_expr`] that simply dispatches to other +/// methods based on the type of the expression. +/// +/// You can use this function as a helper to reduce boilerplate code when implementing the trait. +// TODO: This is essentially a mimic of `super` pattern from OO languages. Ideally, we should +// adopt the style proposed in https://github.com/risingwavelabs/risingwave/issues/13477. +pub fn default_visit_expr(visitor: &mut V, expr: &ExprImpl) { + // TODO: Implementors may choose to not use this function at all, in which case we will fail + // to track the recursion and grow the stack as necessary. The current approach is only a + // best-effort attempt to prevent stack overflow. + tracker!().recurse(|t| { + if t.depth_reaches(EXPR_DEPTH_THRESHOLD) { + notice_to_user(EXPR_TOO_DEEP_NOTICE); + } + + match expr { + ExprImpl::InputRef(inner) => visitor.visit_input_ref(inner), + ExprImpl::Literal(inner) => visitor.visit_literal(inner), + ExprImpl::FunctionCall(inner) => visitor.visit_function_call(inner), + ExprImpl::FunctionCallWithLambda(inner) => { + visitor.visit_function_call_with_lambda(inner) + } + ExprImpl::AggCall(inner) => visitor.visit_agg_call(inner), + ExprImpl::Subquery(inner) => visitor.visit_subquery(inner), + ExprImpl::CorrelatedInputRef(inner) => visitor.visit_correlated_input_ref(inner), + ExprImpl::TableFunction(inner) => visitor.visit_table_function(inner), + ExprImpl::WindowFunction(inner) => visitor.visit_window_function(inner), + ExprImpl::UserDefinedFunction(inner) => visitor.visit_user_defined_function(inner), + ExprImpl::Parameter(inner) => visitor.visit_parameter(inner), + ExprImpl::Now(inner) => visitor.visit_now(inner), + } + }) +} /// Traverse an expression tree. /// @@ -27,20 +65,7 @@ use super::{ /// subqueries are not traversed. pub trait ExprVisitor { fn visit_expr(&mut self, expr: &ExprImpl) { - match expr { - ExprImpl::InputRef(inner) => self.visit_input_ref(inner), - ExprImpl::Literal(inner) => self.visit_literal(inner), - ExprImpl::FunctionCall(inner) => self.visit_function_call(inner), - ExprImpl::FunctionCallWithLambda(inner) => self.visit_function_call_with_lambda(inner), - ExprImpl::AggCall(inner) => self.visit_agg_call(inner), - ExprImpl::Subquery(inner) => self.visit_subquery(inner), - ExprImpl::CorrelatedInputRef(inner) => self.visit_correlated_input_ref(inner), - ExprImpl::TableFunction(inner) => self.visit_table_function(inner), - ExprImpl::WindowFunction(inner) => self.visit_window_function(inner), - ExprImpl::UserDefinedFunction(inner) => self.visit_user_defined_function(inner), - ExprImpl::Parameter(inner) => self.visit_parameter(inner), - ExprImpl::Now(inner) => self.visit_now(inner), - } + default_visit_expr(self, expr) } fn visit_function_call(&mut self, func_call: &FunctionCall) { func_call diff --git a/src/frontend/src/expr/mod.rs b/src/frontend/src/expr/mod.rs index 03be40f955d79..d14d99766bcc4 100644 --- a/src/frontend/src/expr/mod.rs +++ b/src/frontend/src/expr/mod.rs @@ -53,8 +53,8 @@ mod utils; pub use agg_call::AggCall; pub use correlated_input_ref::{CorrelatedId, CorrelatedInputRef, Depth}; pub use expr_mutator::ExprMutator; -pub use expr_rewriter::ExprRewriter; -pub use expr_visitor::ExprVisitor; +pub use expr_rewriter::{default_rewrite_expr, ExprRewriter}; +pub use expr_visitor::{default_visit_expr, ExprVisitor}; pub use function_call::{is_row_function, FunctionCall, FunctionCallDisplay}; pub use function_call_with_lambda::FunctionCallWithLambda; pub use input_ref::{input_ref_to_column_indices, InputRef, InputRefDisplay}; @@ -74,6 +74,10 @@ pub use user_defined_function::UserDefinedFunction; pub use utils::*; pub use window_function::WindowFunction; +const EXPR_DEPTH_THRESHOLD: usize = 30; +const EXPR_TOO_DEEP_NOTICE: &str = "Some expression is too complicated. \ +Consider simplifying or splitting the query if you encounter any issues."; + /// the trait of bound expressions pub trait Expr: Into { /// Get the return type of the expr diff --git a/src/frontend/src/optimizer/plan_expr_rewriter/const_eval_rewriter.rs b/src/frontend/src/optimizer/plan_expr_rewriter/const_eval_rewriter.rs index 43b76891f3566..0844bdb33a85b 100644 --- a/src/frontend/src/optimizer/plan_expr_rewriter/const_eval_rewriter.rs +++ b/src/frontend/src/optimizer/plan_expr_rewriter/const_eval_rewriter.rs @@ -13,7 +13,7 @@ // limitations under the License. use crate::error::RwError; -use crate::expr::{Expr, ExprImpl, ExprRewriter, Literal}; +use crate::expr::{default_rewrite_expr, Expr, ExprImpl, ExprRewriter, Literal}; pub(crate) struct ConstEvalRewriter { pub(crate) error: Option, @@ -31,21 +31,10 @@ impl ExprRewriter for ConstEvalRewriter { expr } } + } else if let ExprImpl::Parameter(_) = expr { + unreachable!("Parameter should not appear here. It will be replaced by a literal before this step.") } else { - match expr { - ExprImpl::InputRef(inner) => self.rewrite_input_ref(*inner), - ExprImpl::Literal(inner) => self.rewrite_literal(*inner), - ExprImpl::FunctionCall(inner) => self.rewrite_function_call(*inner), - ExprImpl::FunctionCallWithLambda(inner) => self.rewrite_function_call_with_lambda(*inner), - ExprImpl::AggCall(inner) => self.rewrite_agg_call(*inner), - ExprImpl::Subquery(inner) => self.rewrite_subquery(*inner), - ExprImpl::CorrelatedInputRef(inner) => self.rewrite_correlated_input_ref(*inner), - ExprImpl::TableFunction(inner) => self.rewrite_table_function(*inner), - ExprImpl::WindowFunction(inner) => self.rewrite_window_function(*inner), - ExprImpl::UserDefinedFunction(inner) => self.rewrite_user_defined_function(*inner), - ExprImpl::Parameter(_) => unreachable!("Parameter should not appear here. It will be replaced by a literal before this step."), - ExprImpl::Now(inner) => self.rewrite_now(*inner), - } + default_rewrite_expr(self, expr) } } } diff --git a/src/frontend/src/optimizer/plan_expr_visitor/expr_counter.rs b/src/frontend/src/optimizer/plan_expr_visitor/expr_counter.rs index 68ce5b93b0441..b636218338c2d 100644 --- a/src/frontend/src/optimizer/plan_expr_visitor/expr_counter.rs +++ b/src/frontend/src/optimizer/plan_expr_visitor/expr_counter.rs @@ -14,7 +14,7 @@ use std::collections::HashMap; -use crate::expr::{ExprImpl, ExprType, ExprVisitor, FunctionCall}; +use crate::expr::{default_visit_expr, ExprImpl, ExprType, ExprVisitor, FunctionCall}; /// `ExprCounter` is used by `CseRewriter`. #[derive(Default)] @@ -35,20 +35,7 @@ impl ExprVisitor for CseExprCounter { return; } - match expr { - ExprImpl::InputRef(inner) => self.visit_input_ref(inner), - ExprImpl::Literal(inner) => self.visit_literal(inner), - ExprImpl::FunctionCall(inner) => self.visit_function_call(inner), - ExprImpl::FunctionCallWithLambda(inner) => self.visit_function_call_with_lambda(inner), - ExprImpl::AggCall(inner) => self.visit_agg_call(inner), - ExprImpl::Subquery(inner) => self.visit_subquery(inner), - ExprImpl::CorrelatedInputRef(inner) => self.visit_correlated_input_ref(inner), - ExprImpl::TableFunction(inner) => self.visit_table_function(inner), - ExprImpl::WindowFunction(inner) => self.visit_window_function(inner), - ExprImpl::UserDefinedFunction(inner) => self.visit_user_defined_function(inner), - ExprImpl::Parameter(inner) => self.visit_parameter(inner), - ExprImpl::Now(inner) => self.visit_now(inner), - } + default_visit_expr(self, expr); } fn visit_function_call(&mut self, func_call: &FunctionCall) { diff --git a/src/frontend/src/optimizer/plan_node/mod.rs b/src/frontend/src/optimizer/plan_node/mod.rs index d7187357e3fad..2567cbf01c6e8 100644 --- a/src/frontend/src/optimizer/plan_node/mod.rs +++ b/src/frontend/src/optimizer/plan_node/mod.rs @@ -696,8 +696,10 @@ impl dyn PlanNode { } } -const PLAN_DEPTH_THRESHOLD: usize = 30; -const PLAN_TOO_DEEP_NOTICE: &str = "The plan is too deep. \ +/// Recursion depth threshold for plan node visitor to send notice to user. +pub const PLAN_DEPTH_THRESHOLD: usize = 30; +/// Notice message for plan node visitor to send to user when the depth threshold is reached. +pub const PLAN_TOO_DEEP_NOTICE: &str = "The plan is too deep. \ Consider simplifying or splitting the query if you encounter any issues."; impl dyn PlanNode { diff --git a/src/frontend/src/optimizer/plan_rewriter/mod.rs b/src/frontend/src/optimizer/plan_rewriter/mod.rs index 81c0809bae86d..360c61d3121b0 100644 --- a/src/frontend/src/optimizer/plan_rewriter/mod.rs +++ b/src/frontend/src/optimizer/plan_rewriter/mod.rs @@ -56,11 +56,20 @@ macro_rules! def_rewriter { pub trait PlanRewriter { paste! { fn rewrite(&mut self, plan: PlanRef) -> PlanRef{ - match plan.node_type() { - $( - PlanNodeType::[<$convention $name>] => self.[](plan.downcast_ref::<[<$convention $name>]>().unwrap()), - )* - } + use risingwave_common::util::recursive::{tracker, Recurse}; + use crate::session::current::notice_to_user; + + tracker!().recurse(|t| { + if t.depth_reaches(PLAN_DEPTH_THRESHOLD) { + notice_to_user(PLAN_TOO_DEEP_NOTICE); + } + + match plan.node_type() { + $( + PlanNodeType::[<$convention $name>] => self.[](plan.downcast_ref::<[<$convention $name>]>().unwrap()), + )* + } + }) } $( diff --git a/src/frontend/src/optimizer/plan_visitor/mod.rs b/src/frontend/src/optimizer/plan_visitor/mod.rs index 6156454fd3e80..63a0484cfdfd5 100644 --- a/src/frontend/src/optimizer/plan_visitor/mod.rs +++ b/src/frontend/src/optimizer/plan_visitor/mod.rs @@ -93,11 +93,20 @@ macro_rules! def_visitor { paste! { fn visit(&mut self, plan: PlanRef) -> Self::Result { - match plan.node_type() { - $( - PlanNodeType::[<$convention $name>] => self.[](plan.downcast_ref::<[<$convention $name>]>().unwrap()), - )* - } + use risingwave_common::util::recursive::{tracker, Recurse}; + use crate::session::current::notice_to_user; + + tracker!().recurse(|t| { + if t.depth_reaches(PLAN_DEPTH_THRESHOLD) { + notice_to_user(PLAN_TOO_DEEP_NOTICE); + } + + match plan.node_type() { + $( + PlanNodeType::[<$convention $name>] => self.[](plan.downcast_ref::<[<$convention $name>]>().unwrap()), + )* + } + }) } $(