Skip to content

Commit

Permalink
feat(frontend): fearless recursion on deep plans (#16279)
Browse files Browse the repository at this point in the history
Signed-off-by: Bugen Zhao <[email protected]>
  • Loading branch information
BugenZhao authored Apr 17, 2024
1 parent 4ba80c9 commit 32df9cb
Show file tree
Hide file tree
Showing 6 changed files with 448 additions and 216 deletions.
14 changes: 14 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions src/common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ serde_json = "1"
serde_with = "3"
smallbitset = "0.7.1"
speedate = "0.14.0"
stacker = "0.1"
static_assertions = "1"
strum = "0.26"
strum_macros = "0.26"
Expand Down
1 change: 1 addition & 0 deletions src/common/src/util/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ pub mod pretty_bytes;
pub mod prost;
pub mod query_log;
pub use rw_resource_util as resource_util;
pub mod recursive;
pub mod row_id;
pub mod row_serde;
pub mod runtime;
Expand Down
190 changes: 190 additions & 0 deletions src/common/src/util/recursive.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
// Copyright 2024 RisingWave Labs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

//! Track the recursion and grow the stack when necessary to enable fearless recursion.
use std::cell::RefCell;

// See documentation of `stacker` for the meaning of these constants.
// TODO: determine good values or make them configurable
const RED_ZONE: usize = 128 * 1024; // 128KiB
const STACK_SIZE: usize = 16 * RED_ZONE; // 2MiB

/// Recursion depth.
struct Depth {
/// The current depth.
current: usize,
/// The max depth reached so far, not considering the current depth.
last_max: usize,
}

impl Depth {
const fn new() -> Self {
Self {
current: 0,
last_max: 0,
}
}

fn reset(&mut self) {
*self = Self::new();
}
}

/// The tracker for a recursive function.
pub struct Tracker {
depth: RefCell<Depth>,
}

impl Tracker {
/// Create a new tracker.
pub const fn new() -> Self {
Self {
depth: RefCell::new(Depth::new()),
}
}

/// Retrieve the current depth of the recursion. Starts from 1 once the
/// recursive function is called.
pub fn depth(&self) -> usize {
self.depth.borrow().current
}

/// Check if the current depth reaches the given depth **for the first time**.
///
/// This is useful for logging without any duplication.
pub fn depth_reaches(&self, depth: usize) -> bool {
let d = self.depth.borrow();
d.current == depth && d.current > d.last_max
}

/// Run a recursive function. Grow the stack if necessary.
fn recurse<T>(&self, f: impl FnOnce() -> T) -> T {
struct DepthGuard<'a> {
depth: &'a RefCell<Depth>,
}

impl<'a> DepthGuard<'a> {
fn new(depth: &'a RefCell<Depth>) -> Self {
depth.borrow_mut().current += 1;
Self { depth }
}
}

impl<'a> Drop for DepthGuard<'a> {
fn drop(&mut self) {
let mut d = self.depth.borrow_mut();
d.last_max = d.last_max.max(d.current); // update the last max depth
d.current -= 1; // restore the current depth
if d.current == 0 {
d.reset(); // reset state if the recursion is finished
}
}
}

let _guard = DepthGuard::new(&self.depth);

if cfg!(madsim) {
f() // madsim does not support stack growth
} else {
stacker::maybe_grow(RED_ZONE, STACK_SIZE, f)
}
}
}

/// The extension trait for a thread-local tracker to run a recursive function.
#[easy_ext::ext(Recurse)]
impl std::thread::LocalKey<Tracker> {
/// Run the given recursive function. Grow the stack if necessary.
///
/// # Fearless Recursion
///
/// This enables fearless recursion in most cases as long as a single frame
/// does not exceed the [`RED_ZONE`] size. That is, the caller can recurse
/// as much as it wants without worrying about stack overflow.
///
/// # Tracker
///
/// The caller can retrieve the [`Tracker`] of the current recursion from
/// the closure argument. This can be useful for checking the depth of the
/// recursion, logging or throwing an error gracefully if it's too deep.
///
/// Note that different trackers defined in different functions are
/// independent of each other. If there's a cross-function recursion, the
/// tracker retrieved from the closure argument only represents the current
/// function's state.
///
/// # Example
///
/// Define the tracker with [`tracker!`] and call this method on it to run
/// a recursive function.
///
/// ```ignore
/// #[inline(never)]
/// fn sum(x: u64) -> u64 {
/// tracker!().recurse(|t| {
/// if t.depth() % 100000 == 0 {
/// eprintln!("too deep!");
/// }
/// if x == 0 {
/// return 0;
/// }
/// x + sum(x - 1)
/// })
/// }
/// ```
pub fn recurse<T>(&'static self, f: impl FnOnce(&Tracker) -> T) -> T {
self.with(|t| t.recurse(|| f(t)))
}
}

/// Define the tracker for recursion and return it.
///
/// Call [`Recurse::recurse`] on it to run a recursive function. See
/// documentation there for usage.
#[macro_export]
macro_rules! __recursive_tracker {
() => {{
use $crate::util::recursive::Tracker;
std::thread_local! {
static __TRACKER: Tracker = const { Tracker::new() };
}
__TRACKER
}};
}
pub use __recursive_tracker as tracker;

#[cfg(all(test, not(madsim)))]
mod tests {
use super::*;

#[test]
fn test_fearless_recursion() {
const X: u64 = 1919810;
const EXPECTED: u64 = 1842836177955;

#[inline(never)]
fn sum(x: u64) -> u64 {
tracker!().recurse(|t| {
if x == 0 {
assert_eq!(t.depth(), X as usize + 1);
return 0;
}
x + sum(x - 1)
})
}

assert_eq!(sum(X), EXPECTED);
}
}
112 changes: 65 additions & 47 deletions src/frontend/src/optimizer/plan_node/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ use itertools::Itertools;
use paste::paste;
use pretty_xmlish::{Pretty, PrettyConfig};
use risingwave_common::catalog::Schema;
use risingwave_common::util::recursive::{self, Recurse};
use risingwave_pb::batch_plan::PlanNode as BatchPlanPb;
use risingwave_pb::stream_plan::StreamNode as StreamPlanPb;
use serde::Serialize;
Expand All @@ -51,6 +52,7 @@ use self::utils::Distill;
use super::property::{Distribution, FunctionalDependencySet, Order};
use crate::error::{ErrorCode, Result};
use crate::optimizer::ExpressionSimplifyRewriter;
use crate::session::current::notice_to_user;

/// A marker trait for different conventions, used for enforcing type safety.
///
Expand Down Expand Up @@ -694,6 +696,10 @@ impl dyn PlanNode {
}
}

const PLAN_DEPTH_THRESHOLD: usize = 30;
const PLAN_TOO_DEEP_NOTICE: &str = "The plan is too deep. \
Consider simplifying or splitting the query if you encounter any issues.";

impl dyn PlanNode {
/// Serialize the plan node and its children to a stream plan proto.
///
Expand All @@ -703,41 +709,47 @@ impl dyn PlanNode {
&self,
state: &mut BuildFragmentGraphState,
) -> SchedulerResult<StreamPlanPb> {
use stream::prelude::*;
recursive::tracker!().recurse(|t| {
if t.depth_reaches(PLAN_DEPTH_THRESHOLD) {
notice_to_user(PLAN_TOO_DEEP_NOTICE);
}

if let Some(stream_table_scan) = self.as_stream_table_scan() {
return stream_table_scan.adhoc_to_stream_prost(state);
}
if let Some(stream_cdc_table_scan) = self.as_stream_cdc_table_scan() {
return stream_cdc_table_scan.adhoc_to_stream_prost(state);
}
if let Some(stream_source_scan) = self.as_stream_source_scan() {
return stream_source_scan.adhoc_to_stream_prost(state);
}
if let Some(stream_share) = self.as_stream_share() {
return stream_share.adhoc_to_stream_prost(state);
}
use stream::prelude::*;

let node = Some(self.try_to_stream_prost_body(state)?);
let input = self
.inputs()
.into_iter()
.map(|plan| plan.to_stream_prost(state))
.try_collect()?;
// TODO: support pk_indices and operator_id
Ok(StreamPlanPb {
input,
identity: self.explain_myself_to_string(),
node_body: node,
operator_id: self.id().0 as _,
stream_key: self
.stream_key()
.unwrap_or_default()
.iter()
.map(|x| *x as u32)
.collect(),
fields: self.schema().to_prost(),
append_only: self.plan_base().append_only(),
if let Some(stream_table_scan) = self.as_stream_table_scan() {
return stream_table_scan.adhoc_to_stream_prost(state);
}
if let Some(stream_cdc_table_scan) = self.as_stream_cdc_table_scan() {
return stream_cdc_table_scan.adhoc_to_stream_prost(state);
}
if let Some(stream_source_scan) = self.as_stream_source_scan() {
return stream_source_scan.adhoc_to_stream_prost(state);
}
if let Some(stream_share) = self.as_stream_share() {
return stream_share.adhoc_to_stream_prost(state);
}

let node = Some(self.try_to_stream_prost_body(state)?);
let input = self
.inputs()
.into_iter()
.map(|plan| plan.to_stream_prost(state))
.try_collect()?;
// TODO: support pk_indices and operator_id
Ok(StreamPlanPb {
input,
identity: self.explain_myself_to_string(),
node_body: node,
operator_id: self.id().0 as _,
stream_key: self
.stream_key()
.unwrap_or_default()
.iter()
.map(|x| *x as u32)
.collect(),
fields: self.schema().to_prost(),
append_only: self.plan_base().append_only(),
})
})
}

Expand All @@ -749,20 +761,26 @@ impl dyn PlanNode {
/// Serialize the plan node and its children to a batch plan proto without the identity field
/// (for testing).
pub fn to_batch_prost_identity(&self, identity: bool) -> SchedulerResult<BatchPlanPb> {
let node_body = Some(self.try_to_batch_prost_body()?);
let children = self
.inputs()
.into_iter()
.map(|plan| plan.to_batch_prost_identity(identity))
.try_collect()?;
Ok(BatchPlanPb {
children,
identity: if identity {
self.explain_myself_to_string()
} else {
"".into()
},
node_body,
recursive::tracker!().recurse(|t| {
if t.depth_reaches(PLAN_DEPTH_THRESHOLD) {
notice_to_user(PLAN_TOO_DEEP_NOTICE);
}

let node_body = Some(self.try_to_batch_prost_body()?);
let children = self
.inputs()
.into_iter()
.map(|plan| plan.to_batch_prost_identity(identity))
.try_collect()?;
Ok(BatchPlanPb {
children,
identity: if identity {
self.explain_myself_to_string()
} else {
"".into()
},
node_body,
})
})
}

Expand Down
Loading

0 comments on commit 32df9cb

Please sign in to comment.