Skip to content

Commit

Permalink
feat(binder): correctly bind rcte in bind_with & `bind_relation_by_…
Browse files Browse the repository at this point in the history
…name` (#16023)
  • Loading branch information
xzhseh authored Apr 3, 2024
1 parent e2292a3 commit 78a2422
Show file tree
Hide file tree
Showing 8 changed files with 227 additions and 127 deletions.
47 changes: 37 additions & 10 deletions src/frontend/src/binder/bind_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use std::collections::hash_map::Entry;
use std::collections::{BTreeMap, HashMap, HashSet};
use std::rc::Rc;

use either::Either;
use parse_display::Display;
use risingwave_common::catalog::{Field, Schema};
use risingwave_common::types::DataType;
Expand All @@ -26,6 +27,7 @@ use crate::error::{ErrorCode, Result};

type LiteResult<T> = std::result::Result<T, ErrorCode>;

use super::BoundSetExpr;
use crate::binder::{BoundQuery, ShareId, COLUMN_GROUP_PREFIX};

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -78,12 +80,12 @@ pub struct LateralBindContext {
/// WITH RECURSIVE t(n) AS (
/// # -------------^ => Init
/// VALUES (1)
/// # ----------^ => BaseResolved (after binding the base term)
/// UNION ALL
/// SELECT n+1 FROM t WHERE n < 100
/// # ------------------^ => BaseResolved
/// # ------------------^ => Bound (we know exactly what the entire cte looks like)
/// )
/// SELECT sum(n) FROM t;
/// # -----------------^ => Bound
/// ```
#[derive(Default, Debug, Clone)]
pub enum BindingCteState {
Expand All @@ -93,7 +95,26 @@ pub enum BindingCteState {
/// We know the schema form after the base term resolved.
BaseResolved { schema: Schema },
/// We get the whole bound result of the (recursive) CTE.
Bound { query: BoundQuery },
Bound {
query: Either<BoundQuery, RecursiveUnion>,
},
}

/// the entire `RecursiveUnion` represents a *bound* recursive cte.
/// reference: <https://github.com/risingwavelabs/risingwave/pull/15522/files#r1524367781>
#[derive(Debug, Clone)]
pub struct RecursiveUnion {
/// currently this *must* be true,
/// otherwise binding will fail.
pub all: bool,
/// lhs part of the `UNION ALL` operator
pub base: Box<BoundSetExpr>,
/// rhs part of the `UNION ALL` operator
pub recursive: Box<BoundSetExpr>,
/// the aligned schema for this union
/// will be the *same* schema as recursive's
/// this is just for a better readability
pub schema: Schema,
}

#[derive(Clone, Debug)]
Expand All @@ -116,7 +137,7 @@ pub struct BindContext {
// The `BindContext`'s data on its column groups
pub column_group_context: ColumnGroupContext,
/// Map the cte's name to its binding state.
/// The `ShareId` of the value is used to help the planner identify the share plan.
/// The `ShareId` in `BindingCte` of the value is used to help the planner identify the share plan.
pub cte_to_relation: HashMap<String, Rc<RefCell<BindingCte>>>,
/// Current lambda functions's arguments
pub lambda_args: Option<HashMap<String, (usize, DataType)>>,
Expand Down Expand Up @@ -341,13 +362,19 @@ impl BindContext {
entry.extend(v.into_iter().map(|x| x + begin));
}
for (k, (x, y)) in other.range_of {
match self.range_of.entry(k) {
match self.range_of.entry(k.clone()) {
Entry::Occupied(e) => {
return Err(ErrorCode::InternalError(format!(
"Duplicated table name while merging adjacent contexts: {}",
e.key()
))
.into());
if let BindingCteState::Bound { .. } =
self.cte_to_relation.get(&k).unwrap().borrow().state.clone()
{
// do nothing
} else {
return Err(ErrorCode::InternalError(format!(
"Duplicated table name while merging adjacent contexts: {}",
e.key()
))
.into());
}
}
Entry::Vacant(entry) => {
entry.insert((begin + x, begin + y));
Expand Down
61 changes: 40 additions & 21 deletions src/frontend/src/binder/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use thiserror_ext::AsReport;
use super::bind_context::BindingCteState;
use super::statement::RewriteExprsRecursive;
use super::BoundValues;
use crate::binder::bind_context::BindingCte;
use crate::binder::bind_context::{BindingCte, RecursiveUnion};
use crate::binder::{Binder, BoundSetExpr};
use crate::error::{ErrorCode, Result};
use crate::expr::{CorrelatedId, Depth, ExprImpl, ExprRewriter};
Expand Down Expand Up @@ -287,6 +287,7 @@ impl Binder {
let share_id = self.next_share_id();
let Cte { alias, query, .. } = cte_table;
let table_name = alias.name.real_value();

if with.recursive {
let Query {
with,
Expand All @@ -296,6 +297,8 @@ impl Binder {
offset,
fetch,
} = query;

/// the input clause should not be supported.
fn should_be_empty<T>(v: Option<T>, clause: &str) -> Result<()> {
if v.is_some() {
return Err(ErrorCode::BindError(format!(
Expand All @@ -305,6 +308,7 @@ impl Binder {
}
Ok(())
}

should_be_empty(order_by.first(), "ORDER BY")?;
should_be_empty(limit, "LIMIT")?;
should_be_empty(offset, "OFFSET")?;
Expand All @@ -315,7 +319,7 @@ impl Binder {
all,
left,
right,
} = body
} = body.clone()
else {
return Err(ErrorCode::BindError(
"`UNION` is required in recursive CTE".to_string(),
Expand Down Expand Up @@ -346,37 +350,52 @@ impl Binder {
self.bind_with(with)?;
}

// We assume `left` is base term, otherwise the implementation may be very hard.
// The behavior is same as PostgreSQL.
// https://www.postgresql.org/docs/16/sql-select.html#:~:text=the%20recursive%20self%2Dreference%20must%20appear%20on%20the%20right%2Dhand%20side%20of%20the%20UNION
let bound_base = self.bind_set_expr(*left)?;
// We assume `left` is the base term, otherwise the implementation may be very hard.
// The behavior is the same as PostgreSQL's.
// reference: <https://www.postgresql.org/docs/16/sql-select.html#:~:text=the%20recursive%20self%2Dreference%20must%20appear%20on%20the%20right%2Dhand%20side%20of%20the%20UNION>
let mut base = self.bind_set_expr(*left)?;

entry.borrow_mut().state = BindingCteState::BaseResolved {
schema: bound_base.schema().clone(),
schema: base.schema().clone(),
};

let bound_recursive = self.bind_set_expr(*right)?;

let bound_query = BoundQuery {
body: BoundSetExpr::RecursiveUnion {
base: Box::new(bound_base),
recursive: Box::new(bound_recursive),
},
order: vec![],
limit: None,
offset: None,
with_ties: false,
extra_order_exprs: vec![],
// bind the rest of the recursive cte
let mut recursive = self.bind_set_expr(*right)?;

// todo: add validate check here for *bound* `base` and `recursive`
Self::align_schema(&mut base, &mut recursive, SetOperator::Union)?;

// please note that even after aligning, the schema of `left`
// may not be the same as `right`; this is because there may
// be case(s) where the `base` term is just a value, and the
// `recursive` term is a select expression / statement.
let schema = recursive.schema().clone();
// yet another sanity check
assert_eq!(
schema,
recursive.schema().clone(),
"expect `schema` to be the same as recursive's"
);

let recursive_union = RecursiveUnion {
all,
base: Box::new(base),
recursive: Box::new(recursive),
schema,
};

entry.borrow_mut().state = BindingCteState::Bound { query: bound_query };
entry.borrow_mut().state = BindingCteState::Bound {
query: either::Either::Right(recursive_union),
};
} else {
let bound_query = self.bind_query(query)?;
self.context.cte_to_relation.insert(
table_name,
Rc::new(RefCell::new(BindingCte {
share_id,
state: BindingCteState::Bound { query: bound_query },
state: BindingCteState::Bound {
query: either::Either::Left(bound_query),
},
alias,
})),
);
Expand Down
46 changes: 36 additions & 10 deletions src/frontend/src/binder/relation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
use std::collections::hash_map::Entry;
use std::ops::Deref;

use either::Either::{Left, Right};
use itertools::{EitherOrBoth, Itertools};
use risingwave_common::bail;
use risingwave_common::catalog::{Field, TableId, DEFAULT_SCHEMA_NAME};
Expand All @@ -26,6 +27,7 @@ use thiserror::Error;
use thiserror_ext::AsReport;

use self::cte_ref::BoundBackCteRef;
use self::recursive_union::BoundRecursiveUnion;
use super::bind_context::ColumnBinding;
use super::statement::RewriteExprsRecursive;
use crate::binder::bind_context::{BindingCte, BindingCteState};
Expand All @@ -35,6 +37,7 @@ use crate::expr::{ExprImpl, InputRef};

mod cte_ref;
mod join;
mod recursive_union;
mod share;
mod subquery;
mod table_function;
Expand Down Expand Up @@ -70,6 +73,7 @@ pub enum Relation {
Watermark(Box<BoundWatermark>),
Share(Box<BoundShare>),
BackCteRef(Box<BoundBackCteRef>),
RecursiveUnion(Box<BoundRecursiveUnion>),
}

impl RewriteExprsRecursive for Relation {
Expand All @@ -85,6 +89,7 @@ impl RewriteExprsRecursive for Relation {
*inner = rewriter.rewrite_expr(inner.take())
}
Relation::BackCteRef(inner) => inner.rewrite_exprs_recursive(rewriter),
Relation::RecursiveUnion(inner) => inner.rewrite_exprs_recursive(rewriter),
_ => {}
}
}
Expand Down Expand Up @@ -342,7 +347,9 @@ impl Binder {
as_of: Option<AsOf>,
) -> Result<Relation> {
let (schema_name, table_name) = Self::resolve_schema_qualified_name(&self.db_name, name)?;

if schema_name.is_none()
// the `table_name` here is the name of the currently binding cte.
&& let Some(item) = self.context.cte_to_relation.get(&table_name)
{
// Handles CTE
Expand All @@ -352,7 +359,9 @@ impl Binder {
state: cte_state,
alias: mut original_alias,
} = item.deref().borrow().clone();
debug_assert_eq!(original_alias.name.real_value(), table_name); // The original CTE alias ought to be its table name.

// The original CTE alias ought to be its table name.
debug_assert_eq!(original_alias.name.real_value(), table_name);

if let Some(from_alias) = alias {
original_alias.name = from_alias.name;
Expand All @@ -366,7 +375,7 @@ impl Binder {

match cte_state {
BindingCteState::Init => {
Err(ErrorCode::BindError("Base term of recursive CTE not found, consider write it to left side of the `UNION` operator".to_string()).into())
Err(ErrorCode::BindError("Base term of recursive CTE not found, consider writing it to left side of the `UNION ALL` operator".to_string()).into())
}
BindingCteState::BaseResolved { schema } => {
self.bind_table_to_context(
Expand All @@ -377,23 +386,40 @@ impl Binder {
Ok(Relation::BackCteRef(Box::new(BoundBackCteRef { share_id })))
}
BindingCteState::Bound { query } => {
let schema = match query.clone() {
Left(normal) => normal.body.schema().clone(),
Right(recursive) => recursive.schema.clone(),
};
self.bind_table_to_context(
query
.body
.schema()
schema
.fields
.iter()
.map(|f| (false, f.clone())),
table_name.clone(),
Some(original_alias),
)?;
// Share the CTE.
let input_relation = Relation::Subquery(Box::new(BoundSubquery {
query,
lateral: false,
}));
// todo: to be further reviewed
let input_relation = match query {
// normal cte with union
Left(query) => {
Relation::Subquery(Box::new(BoundSubquery {
query,
lateral: false,
}))
}
// recursive cte
Right(recursive) => {
Relation::RecursiveUnion(Box::new(BoundRecursiveUnion {
base: *recursive.base,
recursive: *recursive.recursive,
}))
}
};
// we could always share the cte,
// no matter it's recursive or not.
let share_relation = Relation::Share(Box::new(BoundShare {
share_id,
// should either be a *bound* `subquery` or `recursive union`
input: input_relation,
}));
Ok(share_relation)
Expand Down
35 changes: 35 additions & 0 deletions src/frontend/src/binder/relation/recursive_union.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright 2024 RisingWave Labs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use crate::binder::statement::RewriteExprsRecursive;
use crate::binder::BoundSetExpr;

/// a *bound* recursive union representation.
#[allow(dead_code)]
#[derive(Debug, Clone)]
pub struct BoundRecursiveUnion {
/// the *bound* base case
pub(crate) base: BoundSetExpr,
/// the *bound* recursive case
pub(crate) recursive: BoundSetExpr,
}

impl RewriteExprsRecursive for BoundRecursiveUnion {
fn rewrite_exprs_recursive(&mut self, rewriter: &mut impl crate::expr::ExprRewriter) {
// rewrite base case
self.base.rewrite_exprs_recursive(rewriter);
// rewrite recursive case
self.recursive.rewrite_exprs_recursive(rewriter);
}
}
Loading

0 comments on commit 78a2422

Please sign in to comment.