Skip to content

Commit

Permalink
feat(binder): bind RCTE
Browse files Browse the repository at this point in the history
Signed-off-by: TennyZhuang <[email protected]>
  • Loading branch information
TennyZhuang committed Mar 7, 2024
1 parent ecf95c5 commit 27b5439
Show file tree
Hide file tree
Showing 9 changed files with 242 additions and 37 deletions.
29 changes: 27 additions & 2 deletions src/frontend/src/binder/bind_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::cell::RefCell;
use std::collections::hash_map::Entry;
use std::collections::{BTreeMap, HashMap, HashSet};
use std::rc::Rc;

use parse_display::Display;
use risingwave_common::catalog::Field;
use risingwave_common::catalog::{Field, Schema};
use risingwave_common::types::DataType;
use risingwave_sqlparser::ast::TableAlias;

Expand Down Expand Up @@ -66,6 +67,30 @@ pub struct LateralBindContext {
pub context: BindContext,
}

/// If the CTE a recursive one, we may need store it in `cte_to_relation` first, and bind it step by step.
///
/// ```sql
/// WITH RECURSIVE t(n) AS (
/// ---------------^ Init
/// VALUES (1)
/// UNION ALL
/// SELECT n+1 FROM t WHERE n < 100
/// # ------------------^BaseResolved
/// )
/// SELECT sum(n) FROM t;
/// # -----------------^Bound
/// ```
#[derive(Default, Debug, Clone)]
pub enum BindingCteState {
/// We know nothing about the CTE before resolve the body.
#[default]
Init,
/// We know the schema from after the base term resolved.
BaseResolved { schema: Schema },
/// We get the whole bound result.
Bound { query: BoundQuery },
}

#[derive(Default, Debug, Clone)]
pub struct BindContext {
// Columns of all tables.
Expand All @@ -80,7 +105,7 @@ pub struct BindContext {
pub column_group_context: ColumnGroupContext,
/// Map the cte's name to its Relation::Subquery.
/// The `ShareId` of the value is used to help the planner identify the share plan.
pub cte_to_relation: HashMap<String, Rc<(ShareId, BoundQuery, TableAlias)>>,
pub cte_to_relation: HashMap<String, Rc<RefCell<(ShareId, BindingCteState, TableAlias)>>>,
/// Current lambda functions's arguments
pub lambda_args: Option<HashMap<String, (usize, DataType)>>,
}
Expand Down
113 changes: 100 additions & 13 deletions src/frontend/src/binder/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,19 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::cell::RefCell;
use std::collections::HashMap;
use std::rc::Rc;

use risingwave_common::bail_not_implemented;
use risingwave_common::catalog::Schema;
use risingwave_common::types::DataType;
use risingwave_common::util::sort_util::{ColumnOrder, OrderType};
use risingwave_sqlparser::ast::{Cte, Expr, Fetch, OrderByExpr, Query, Value, With};
use risingwave_sqlparser::ast::{
Cte, Expr, Fetch, OrderByExpr, Query, SetExpr, SetOperator, Value, With,
};
use thiserror_ext::AsReport;

use super::bind_context::BindingCteState;
use super::statement::RewriteExprsRecursive;
use super::BoundValues;
use crate::binder::{Binder, BoundSetExpr};
Expand Down Expand Up @@ -279,20 +282,104 @@ impl Binder {
}

fn bind_with(&mut self, with: With) -> Result<()> {
if with.recursive {
bail_not_implemented!("recursive cte");
} else {
for cte_table in with.cte_tables {
let Cte { alias, query, .. } = cte_table;
let table_name = alias.name.real_value();
let bound_query = self.bind_query(query)?;
let share_id = self.next_share_id();
self.context
for cte_table in with.cte_tables {
let share_id = self.next_share_id();
let Cte { alias, query, .. } = cte_table;
let table_name = alias.name.real_value();
if with.recursive {
let Query {
with,
body,
order_by,
limit,
offset,
fetch,
} = query;
fn should_be_empty<T>(v: Option<T>, clause: &str) -> Result<()> {
if !v.is_none() {
return Err(ErrorCode::BindError(format!(
"`{clause}` is not supported in recursive CTE"
))
.into());
}
Ok(())
}
should_be_empty(order_by.first(), "ORDER BY")?;
should_be_empty(limit, "LIMIT")?;
should_be_empty(offset, "OFFSET")?;
should_be_empty(fetch, "FETCH")?;

let SetExpr::SetOperation {
op: SetOperator::Union,
all,
left,
right,
} = body
else {
return Err(ErrorCode::BindError(format!(
"`UNION` is required in recursive CTE"
))
.into());
};

if !all {
return Err(ErrorCode::BindError(format!(
"only `UNION ALL` is supported in recursive CTE now"
))
.into());
}

let entry = self
.context
.cte_to_relation
.insert(table_name, Rc::new((share_id, bound_query, alias)));
.entry(table_name)
.insert_entry(Rc::new(RefCell::new((
share_id,
BindingCteState::Init,
alias,
))))
.get()
.clone();

if let Some(with) = with {
self.bind_with(with)?;
}

// We assume `left` is base term, otherwise the implementation may be very hard.
let bound_base = self.bind_set_expr(*left)?;

entry.borrow_mut().1 = BindingCteState::BaseResolved {
schema: bound_base.schema().clone(),
};

let bound_recursive = self.bind_set_expr(*right)?;

let bound_query = BoundQuery {
body: BoundSetExpr::RecursiveUnion {
base: Box::new(bound_base),
recursive: Box::new(bound_recursive),
},
order: vec![],
limit: None,
offset: None,
with_ties: false,
extra_order_exprs: vec![],
};

entry.borrow_mut().1 = BindingCteState::Bound { query: bound_query };
} else {
let bound_query = self.bind_query(query)?;
self.context.cte_to_relation.insert(
table_name,
Rc::new(RefCell::new((
share_id,
BindingCteState::Bound { query: bound_query },
alias,
))),
);
}
Ok(())
}
Ok(())
}
}

Expand Down
28 changes: 28 additions & 0 deletions src/frontend/src/binder/relation/cte_ref.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Copyright 2024 RisingWave Labs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use crate::binder::statement::RewriteExprsRecursive;
use crate::binder::ShareId;

/// A CTE reference, currently only used in the back reference of recursive CTE.
/// For the non-recursive one, see [`BoundShare`](super::BoundShare).
#[derive(Debug, Clone)]
pub struct BoundBackCteRef {
#[expect(dead_code)]
pub(crate) share_id: ShareId,
}

impl RewriteExprsRecursive for BoundBackCteRef {
fn rewrite_exprs_recursive(&mut self, _rewriter: &mut impl crate::expr::ExprRewriter) {}
}
63 changes: 41 additions & 22 deletions src/frontend/src/binder/relation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,15 @@ use risingwave_sqlparser::ast::{
use thiserror::Error;
use thiserror_ext::AsReport;

use self::cte_ref::BoundBackCteRef;
use super::bind_context::ColumnBinding;
use super::statement::RewriteExprsRecursive;
use crate::binder::bind_context::BindingCteState;
use crate::binder::Binder;
use crate::error::{ErrorCode, Result, RwError};
use crate::expr::{ExprImpl, InputRef};

mod cte_ref;
mod join;
mod share;
mod subquery;
Expand Down Expand Up @@ -65,6 +68,7 @@ pub enum Relation {
},
Watermark(Box<BoundWatermark>),
Share(Box<BoundShare>),
BackCteRef(Box<BoundBackCteRef>),
}

impl RewriteExprsRecursive for Relation {
Expand All @@ -79,6 +83,7 @@ impl RewriteExprsRecursive for Relation {
Relation::TableFunction { expr: inner, .. } => {
*inner = rewriter.rewrite_expr(inner.take())
}
Relation::BackCteRef(inner) => inner.rewrite_exprs_recursive(rewriter),
_ => {}
}
}
Expand Down Expand Up @@ -336,7 +341,7 @@ impl Binder {
{
// Handles CTE

let (share_id, query, mut original_alias) = item.deref().clone();
let (share_id, cte_state, mut original_alias) = item.deref().borrow().clone();
debug_assert_eq!(original_alias.name.real_value(), table_name); // The original CTE alias ought to be its table name.

if let Some(from_alias) = alias {
Expand All @@ -349,27 +354,41 @@ impl Binder {
.collect();
}

self.bind_table_to_context(
query
.body
.schema()
.fields
.iter()
.map(|f| (false, f.clone())),
table_name.clone(),
Some(original_alias),
)?;

// Share the CTE.
let input_relation = Relation::Subquery(Box::new(BoundSubquery {
query,
lateral: false,
}));
let share_relation = Relation::Share(Box::new(BoundShare {
share_id,
input: input_relation,
}));
Ok(share_relation)
match cte_state {
BindingCteState::Init => {
Err(ErrorCode::BindError(format!("Base term of recursive CTE not found, consider write it to left side of the `UNION` operator")).into())
}
BindingCteState::BaseResolved { schema } => {
self.bind_table_to_context(
schema.fields.iter().map(|f| (false, f.clone())),
table_name.clone(),
Some(original_alias),
)?;
Ok(Relation::BackCteRef(Box::new(BoundBackCteRef { share_id })))
}
BindingCteState::Bound { query } => {
self.bind_table_to_context(
query
.body
.schema()
.fields
.iter()
.map(|f| (false, f.clone())),
table_name.clone(),
Some(original_alias),
)?;
// Share the CTE.
let input_relation = Relation::Subquery(Box::new(BoundSubquery {
query,
lateral: false,
}));
let share_relation = Relation::Share(Box::new(BoundShare {
share_id,
input: input_relation,
}));
Ok(share_relation)
}
}
} else {
self.bind_relation_by_name_inner(
schema_name.as_deref(),
Expand Down
24 changes: 24 additions & 0 deletions src/frontend/src/binder/set_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ pub enum BoundSetExpr {
left: Box<BoundSetExpr>,
right: Box<BoundSetExpr>,
},
/// UNION in recursive CTE definition
RecursiveUnion {
base: Box<BoundSetExpr>,
recursive: Box<BoundSetExpr>,
},
}

impl RewriteExprsRecursive for BoundSetExpr {
Expand All @@ -48,6 +53,10 @@ impl RewriteExprsRecursive for BoundSetExpr {
left.rewrite_exprs_recursive(rewriter);
right.rewrite_exprs_recursive(rewriter);
}
BoundSetExpr::RecursiveUnion { base, recursive } => {
base.rewrite_exprs_recursive(rewriter);
recursive.rewrite_exprs_recursive(rewriter);
}
}
}
}
Expand Down Expand Up @@ -78,6 +87,7 @@ impl BoundSetExpr {
BoundSetExpr::Values(v) => v.schema(),
BoundSetExpr::Query(q) => q.schema(),
BoundSetExpr::SetOperation { left, .. } => left.schema(),
BoundSetExpr::RecursiveUnion { base, .. } => base.schema(),
}
}

Expand All @@ -89,6 +99,9 @@ impl BoundSetExpr {
BoundSetExpr::SetOperation { left, right, .. } => {
left.is_correlated(depth) || right.is_correlated(depth)
}
BoundSetExpr::RecursiveUnion { base, recursive } => {
base.is_correlated(depth) || recursive.is_correlated(depth)
}
}
}

Expand Down Expand Up @@ -117,6 +130,17 @@ impl BoundSetExpr {
);
correlated_indices
}
BoundSetExpr::RecursiveUnion { base, recursive } => {
let mut correlated_indices = vec![];
correlated_indices.extend(
base.collect_correlated_indices_by_depth_and_assign_id(depth, correlated_id),
);
correlated_indices.extend(
recursive
.collect_correlated_indices_by_depth_and_assign_id(depth, correlated_id),
);
correlated_indices
}
}
}
}
Expand Down
Loading

0 comments on commit 27b5439

Please sign in to comment.