Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(optimizer): grow stack for complicated plan in more places #17224

Merged
merged 4 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 2 additions & 17 deletions src/frontend/src/binder/bind_param.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use risingwave_common::types::{Datum, ScalarImpl};
use super::statement::RewriteExprsRecursive;
use super::BoundStatement;
use crate::error::{ErrorCode, Result};
use crate::expr::{Expr, ExprImpl, ExprRewriter, Literal};
use crate::expr::{default_rewrite_expr, Expr, ExprImpl, ExprRewriter, Literal};

/// Rewrites parameter expressions to literals.
pub(crate) struct ParamRewriter {
Expand All @@ -47,22 +47,7 @@ impl ExprRewriter for ParamRewriter {
if self.error.is_some() {
return expr;
}
match expr {
ExprImpl::InputRef(inner) => self.rewrite_input_ref(*inner),
ExprImpl::Literal(inner) => self.rewrite_literal(*inner),
ExprImpl::FunctionCall(inner) => self.rewrite_function_call(*inner),
ExprImpl::FunctionCallWithLambda(inner) => {
self.rewrite_function_call_with_lambda(*inner)
}
ExprImpl::AggCall(inner) => self.rewrite_agg_call(*inner),
ExprImpl::Subquery(inner) => self.rewrite_subquery(*inner),
ExprImpl::CorrelatedInputRef(inner) => self.rewrite_correlated_input_ref(*inner),
ExprImpl::TableFunction(inner) => self.rewrite_table_function(*inner),
ExprImpl::WindowFunction(inner) => self.rewrite_window_function(*inner),
ExprImpl::UserDefinedFunction(inner) => self.rewrite_user_defined_function(*inner),
ExprImpl::Parameter(inner) => self.rewrite_parameter(*inner),
ExprImpl::Now(inner) => self.rewrite_now(*inner),
}
default_rewrite_expr(self, expr)
}

fn rewrite_subquery(&mut self, mut subquery: crate::expr::Subquery) -> ExprImpl {
Expand Down
60 changes: 43 additions & 17 deletions src/frontend/src/expr/expr_rewriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,33 +12,59 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use risingwave_common::util::recursive::{tracker, Recurse};

use super::{
AggCall, CorrelatedInputRef, ExprImpl, FunctionCall, FunctionCallWithLambda, InputRef, Literal,
Parameter, Subquery, TableFunction, UserDefinedFunction, WindowFunction,
Parameter, Subquery, TableFunction, UserDefinedFunction, WindowFunction, EXPR_DEPTH_THRESHOLD,
EXPR_TOO_DEEP_NOTICE,
};
use crate::expr::Now;
use crate::session::current::notice_to_user;

/// The default implementation of [`ExprRewriter::rewrite_expr`] that simply dispatches to other
/// methods based on the type of the expression.
///
/// You can use this function as a helper to reduce boilerplate code when implementing the trait.
// TODO: This is essentially a mimic of `super` pattern from OO languages. Ideally, we should
// adopt the style proposed in https://github.com/risingwavelabs/risingwave/issues/13477.
pub fn default_rewrite_expr<R: ExprRewriter + ?Sized>(
rewriter: &mut R,
expr: ExprImpl,
) -> ExprImpl {
// TODO: Implementors may choose to not use this function at all, in which case we will fail
// to track the recursion and grow the stack as necessary. The current approach is only a
// best-effort attempt to prevent stack overflow.
tracker!().recurse(|t| {
if t.depth_reaches(EXPR_DEPTH_THRESHOLD) {
notice_to_user(EXPR_TOO_DEEP_NOTICE);
}

match expr {
ExprImpl::InputRef(inner) => rewriter.rewrite_input_ref(*inner),
ExprImpl::Literal(inner) => rewriter.rewrite_literal(*inner),
ExprImpl::FunctionCall(inner) => rewriter.rewrite_function_call(*inner),
ExprImpl::FunctionCallWithLambda(inner) => {
rewriter.rewrite_function_call_with_lambda(*inner)
}
ExprImpl::AggCall(inner) => rewriter.rewrite_agg_call(*inner),
ExprImpl::Subquery(inner) => rewriter.rewrite_subquery(*inner),
ExprImpl::CorrelatedInputRef(inner) => rewriter.rewrite_correlated_input_ref(*inner),
ExprImpl::TableFunction(inner) => rewriter.rewrite_table_function(*inner),
ExprImpl::WindowFunction(inner) => rewriter.rewrite_window_function(*inner),
ExprImpl::UserDefinedFunction(inner) => rewriter.rewrite_user_defined_function(*inner),
ExprImpl::Parameter(inner) => rewriter.rewrite_parameter(*inner),
ExprImpl::Now(inner) => rewriter.rewrite_now(*inner),
}
})
}

/// By default, `ExprRewriter` simply traverses the expression tree and leaves nodes unchanged.
/// Implementations can override a subset of methods and perform transformation on some particular
/// types of expression.
pub trait ExprRewriter {
fn rewrite_expr(&mut self, expr: ExprImpl) -> ExprImpl {
match expr {
ExprImpl::InputRef(inner) => self.rewrite_input_ref(*inner),
ExprImpl::Literal(inner) => self.rewrite_literal(*inner),
ExprImpl::FunctionCall(inner) => self.rewrite_function_call(*inner),
ExprImpl::FunctionCallWithLambda(inner) => {
self.rewrite_function_call_with_lambda(*inner)
}
ExprImpl::AggCall(inner) => self.rewrite_agg_call(*inner),
ExprImpl::Subquery(inner) => self.rewrite_subquery(*inner),
ExprImpl::CorrelatedInputRef(inner) => self.rewrite_correlated_input_ref(*inner),
ExprImpl::TableFunction(inner) => self.rewrite_table_function(*inner),
ExprImpl::WindowFunction(inner) => self.rewrite_window_function(*inner),
ExprImpl::UserDefinedFunction(inner) => self.rewrite_user_defined_function(*inner),
ExprImpl::Parameter(inner) => self.rewrite_parameter(*inner),
ExprImpl::Now(inner) => self.rewrite_now(*inner),
}
default_rewrite_expr(self, expr)
}

fn rewrite_function_call(&mut self, func_call: FunctionCall) -> ExprImpl {
Expand Down
53 changes: 39 additions & 14 deletions src/frontend/src/expr/expr_visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,48 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use risingwave_common::util::recursive::{tracker, Recurse};

use super::{
AggCall, CorrelatedInputRef, ExprImpl, FunctionCall, FunctionCallWithLambda, InputRef, Literal,
Now, Parameter, Subquery, TableFunction, UserDefinedFunction, WindowFunction,
EXPR_DEPTH_THRESHOLD, EXPR_TOO_DEEP_NOTICE,
};
use crate::session::current::notice_to_user;

/// The default implementation of [`ExprVisitor::visit_expr`] that simply dispatches to other
/// methods based on the type of the expression.
///
/// You can use this function as a helper to reduce boilerplate code when implementing the trait.
// TODO: This is essentially a mimic of `super` pattern from OO languages. Ideally, we should
// adopt the style proposed in https://github.com/risingwavelabs/risingwave/issues/13477.
pub fn default_visit_expr<V: ExprVisitor + ?Sized>(visitor: &mut V, expr: &ExprImpl) {
// TODO: Implementors may choose to not use this function at all, in which case we will fail
// to track the recursion and grow the stack as necessary. The current approach is only a
// best-effort attempt to prevent stack overflow.
tracker!().recurse(|t| {
if t.depth_reaches(EXPR_DEPTH_THRESHOLD) {
notice_to_user(EXPR_TOO_DEEP_NOTICE);
}

match expr {
ExprImpl::InputRef(inner) => visitor.visit_input_ref(inner),
ExprImpl::Literal(inner) => visitor.visit_literal(inner),
ExprImpl::FunctionCall(inner) => visitor.visit_function_call(inner),
ExprImpl::FunctionCallWithLambda(inner) => {
visitor.visit_function_call_with_lambda(inner)
}
ExprImpl::AggCall(inner) => visitor.visit_agg_call(inner),
ExprImpl::Subquery(inner) => visitor.visit_subquery(inner),
ExprImpl::CorrelatedInputRef(inner) => visitor.visit_correlated_input_ref(inner),
ExprImpl::TableFunction(inner) => visitor.visit_table_function(inner),
ExprImpl::WindowFunction(inner) => visitor.visit_window_function(inner),
ExprImpl::UserDefinedFunction(inner) => visitor.visit_user_defined_function(inner),
ExprImpl::Parameter(inner) => visitor.visit_parameter(inner),
ExprImpl::Now(inner) => visitor.visit_now(inner),
}
})
}

/// Traverse an expression tree.
///
Expand All @@ -27,20 +65,7 @@ use super::{
/// subqueries are not traversed.
pub trait ExprVisitor {
fn visit_expr(&mut self, expr: &ExprImpl) {
match expr {
ExprImpl::InputRef(inner) => self.visit_input_ref(inner),
ExprImpl::Literal(inner) => self.visit_literal(inner),
ExprImpl::FunctionCall(inner) => self.visit_function_call(inner),
ExprImpl::FunctionCallWithLambda(inner) => self.visit_function_call_with_lambda(inner),
ExprImpl::AggCall(inner) => self.visit_agg_call(inner),
ExprImpl::Subquery(inner) => self.visit_subquery(inner),
ExprImpl::CorrelatedInputRef(inner) => self.visit_correlated_input_ref(inner),
ExprImpl::TableFunction(inner) => self.visit_table_function(inner),
ExprImpl::WindowFunction(inner) => self.visit_window_function(inner),
ExprImpl::UserDefinedFunction(inner) => self.visit_user_defined_function(inner),
ExprImpl::Parameter(inner) => self.visit_parameter(inner),
ExprImpl::Now(inner) => self.visit_now(inner),
}
default_visit_expr(self, expr)
}
fn visit_function_call(&mut self, func_call: &FunctionCall) {
func_call
Expand Down
8 changes: 6 additions & 2 deletions src/frontend/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ mod utils;
pub use agg_call::AggCall;
pub use correlated_input_ref::{CorrelatedId, CorrelatedInputRef, Depth};
pub use expr_mutator::ExprMutator;
pub use expr_rewriter::ExprRewriter;
pub use expr_visitor::ExprVisitor;
pub use expr_rewriter::{default_rewrite_expr, ExprRewriter};
pub use expr_visitor::{default_visit_expr, ExprVisitor};
pub use function_call::{is_row_function, FunctionCall, FunctionCallDisplay};
pub use function_call_with_lambda::FunctionCallWithLambda;
pub use input_ref::{input_ref_to_column_indices, InputRef, InputRefDisplay};
Expand All @@ -74,6 +74,10 @@ pub use user_defined_function::UserDefinedFunction;
pub use utils::*;
pub use window_function::WindowFunction;

const EXPR_DEPTH_THRESHOLD: usize = 30;
const EXPR_TOO_DEEP_NOTICE: &str = "Some expression is too complicated. \
Consider simplifying or splitting the query if you encounter any issues.";

/// the trait of bound expressions
pub trait Expr: Into<ExprImpl> {
/// Get the return type of the expr
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
// limitations under the License.

use crate::error::RwError;
use crate::expr::{Expr, ExprImpl, ExprRewriter, Literal};
use crate::expr::{default_rewrite_expr, Expr, ExprImpl, ExprRewriter, Literal};

pub(crate) struct ConstEvalRewriter {
pub(crate) error: Option<RwError>,
Expand All @@ -31,21 +31,10 @@ impl ExprRewriter for ConstEvalRewriter {
expr
}
}
} else if let ExprImpl::Parameter(_) = expr {
unreachable!("Parameter should not appear here. It will be replaced by a literal before this step.")
} else {
match expr {
ExprImpl::InputRef(inner) => self.rewrite_input_ref(*inner),
ExprImpl::Literal(inner) => self.rewrite_literal(*inner),
ExprImpl::FunctionCall(inner) => self.rewrite_function_call(*inner),
ExprImpl::FunctionCallWithLambda(inner) => self.rewrite_function_call_with_lambda(*inner),
ExprImpl::AggCall(inner) => self.rewrite_agg_call(*inner),
ExprImpl::Subquery(inner) => self.rewrite_subquery(*inner),
ExprImpl::CorrelatedInputRef(inner) => self.rewrite_correlated_input_ref(*inner),
ExprImpl::TableFunction(inner) => self.rewrite_table_function(*inner),
ExprImpl::WindowFunction(inner) => self.rewrite_window_function(*inner),
ExprImpl::UserDefinedFunction(inner) => self.rewrite_user_defined_function(*inner),
ExprImpl::Parameter(_) => unreachable!("Parameter should not appear here. It will be replaced by a literal before this step."),
ExprImpl::Now(inner) => self.rewrite_now(*inner),
}
default_rewrite_expr(self, expr)
}
}
}
17 changes: 2 additions & 15 deletions src/frontend/src/optimizer/plan_expr_visitor/expr_counter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

use std::collections::HashMap;

use crate::expr::{ExprImpl, ExprType, ExprVisitor, FunctionCall};
use crate::expr::{default_visit_expr, ExprImpl, ExprType, ExprVisitor, FunctionCall};

/// `ExprCounter` is used by `CseRewriter`.
#[derive(Default)]
Expand All @@ -35,20 +35,7 @@ impl ExprVisitor for CseExprCounter {
return;
}

match expr {
ExprImpl::InputRef(inner) => self.visit_input_ref(inner),
ExprImpl::Literal(inner) => self.visit_literal(inner),
ExprImpl::FunctionCall(inner) => self.visit_function_call(inner),
ExprImpl::FunctionCallWithLambda(inner) => self.visit_function_call_with_lambda(inner),
ExprImpl::AggCall(inner) => self.visit_agg_call(inner),
ExprImpl::Subquery(inner) => self.visit_subquery(inner),
ExprImpl::CorrelatedInputRef(inner) => self.visit_correlated_input_ref(inner),
ExprImpl::TableFunction(inner) => self.visit_table_function(inner),
ExprImpl::WindowFunction(inner) => self.visit_window_function(inner),
ExprImpl::UserDefinedFunction(inner) => self.visit_user_defined_function(inner),
ExprImpl::Parameter(inner) => self.visit_parameter(inner),
ExprImpl::Now(inner) => self.visit_now(inner),
}
default_visit_expr(self, expr);
}

fn visit_function_call(&mut self, func_call: &FunctionCall) {
Expand Down
6 changes: 4 additions & 2 deletions src/frontend/src/optimizer/plan_node/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -696,8 +696,10 @@ impl dyn PlanNode {
}
}

const PLAN_DEPTH_THRESHOLD: usize = 30;
const PLAN_TOO_DEEP_NOTICE: &str = "The plan is too deep. \
/// Recursion depth threshold for plan node visitor to send notice to user.
pub const PLAN_DEPTH_THRESHOLD: usize = 30;
/// Notice message for plan node visitor to send to user when the depth threshold is reached.
pub const PLAN_TOO_DEEP_NOTICE: &str = "The plan is too deep. \
Consider simplifying or splitting the query if you encounter any issues.";

impl dyn PlanNode {
Expand Down
19 changes: 14 additions & 5 deletions src/frontend/src/optimizer/plan_rewriter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,20 @@ macro_rules! def_rewriter {
pub trait PlanRewriter {
paste! {
fn rewrite(&mut self, plan: PlanRef) -> PlanRef{
match plan.node_type() {
$(
PlanNodeType::[<$convention $name>] => self.[<rewrite_ $convention:snake _ $name:snake>](plan.downcast_ref::<[<$convention $name>]>().unwrap()),
)*
}
use risingwave_common::util::recursive::{tracker, Recurse};
use crate::session::current::notice_to_user;

tracker!().recurse(|t| {
if t.depth_reaches(PLAN_DEPTH_THRESHOLD) {
notice_to_user(PLAN_TOO_DEEP_NOTICE);
}

match plan.node_type() {
$(
PlanNodeType::[<$convention $name>] => self.[<rewrite_ $convention:snake _ $name:snake>](plan.downcast_ref::<[<$convention $name>]>().unwrap()),
)*
}
})
}

$(
Expand Down
19 changes: 14 additions & 5 deletions src/frontend/src/optimizer/plan_visitor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,20 @@ macro_rules! def_visitor {

paste! {
fn visit(&mut self, plan: PlanRef) -> Self::Result {
match plan.node_type() {
$(
PlanNodeType::[<$convention $name>] => self.[<visit_ $convention:snake _ $name:snake>](plan.downcast_ref::<[<$convention $name>]>().unwrap()),
)*
}
use risingwave_common::util::recursive::{tracker, Recurse};
use crate::session::current::notice_to_user;

tracker!().recurse(|t| {
if t.depth_reaches(PLAN_DEPTH_THRESHOLD) {
notice_to_user(PLAN_TOO_DEEP_NOTICE);
}

match plan.node_type() {
$(
PlanNodeType::[<$convention $name>] => self.[<visit_ $convention:snake _ $name:snake>](plan.downcast_ref::<[<$convention $name>]>().unwrap()),
)*
}
})
}

$(
Expand Down
Loading