Skip to content

Commit

Permalink
feat(binder): bind RCTE (#15522)
Browse files Browse the repository at this point in the history
Signed-off-by: TennyZhuang <[email protected]>
Co-authored-by: xiangjinwu <[email protected]>
Co-authored-by: Michael Xu <[email protected]>
  • Loading branch information
3 people authored Apr 9, 2024
1 parent 9c5310b commit 0c8371a
Show file tree
Hide file tree
Showing 12 changed files with 570 additions and 109 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions src/frontend/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ workspace-hack = { path = "../workspace-hack" }

[dev-dependencies]
assert_matches = "1"
expect-test = "1"
risingwave_expr_impl = { workspace = true }
tempfile = "3"

Expand Down
84 changes: 79 additions & 5 deletions src/frontend/src/binder/bind_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,23 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::cell::RefCell;
use std::collections::hash_map::Entry;
use std::collections::{BTreeMap, HashMap, HashSet};
use std::rc::Rc;

use either::Either;
use parse_display::Display;
use risingwave_common::catalog::Field;
use risingwave_common::catalog::{Field, Schema};
use risingwave_common::types::DataType;
use risingwave_sqlparser::ast::TableAlias;

use crate::error::{ErrorCode, Result};

type LiteResult<T> = std::result::Result<T, ErrorCode>;

use super::statement::RewriteExprsRecursive;
use super::BoundSetExpr;
use crate::binder::{BoundQuery, ShareId, COLUMN_GROUP_PREFIX};

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -67,6 +71,68 @@ pub struct LateralBindContext {
pub context: BindContext,
}

/// For recursive CTE, we may need to store it in `cte_to_relation` first,
/// and then bind it *step by step*.
///
/// note: the below sql example is to illustrate when we get the
/// corresponding binding state when handling a recursive CTE like this.
///
/// ```sql
/// WITH RECURSIVE t(n) AS (
/// # -------------^ => Init
/// VALUES (1)
/// # ----------^ => BaseResolved (after binding the base term)
/// UNION ALL
/// SELECT n+1 FROM t WHERE n < 100
/// # ------------------^ => Bound (we know exactly what the entire cte looks like)
/// )
/// SELECT sum(n) FROM t;
/// ```
#[derive(Default, Debug, Clone)]
pub enum BindingCteState {
/// We know nothing about the CTE before resolving the body.
#[default]
Init,
/// We know the schema form after the base term resolved.
BaseResolved { schema: Schema },
/// We get the whole bound result of the (recursive) CTE.
Bound {
query: Either<BoundQuery, RecursiveUnion>,
},
}

/// the entire `RecursiveUnion` represents a *bound* recursive cte.
/// reference: <https://github.com/risingwavelabs/risingwave/pull/15522/files#r1524367781>
#[derive(Debug, Clone)]
pub struct RecursiveUnion {
/// currently this *must* be true,
/// otherwise binding will fail.
pub all: bool,
/// lhs part of the `UNION ALL` operator
pub base: Box<BoundSetExpr>,
/// rhs part of the `UNION ALL` operator
pub recursive: Box<BoundSetExpr>,
/// the aligned schema for this union
/// will be the *same* schema as recursive's
/// this is just for a better readability
pub schema: Schema,
}

impl RewriteExprsRecursive for RecursiveUnion {
fn rewrite_exprs_recursive(&mut self, rewriter: &mut impl crate::expr::ExprRewriter) {
// rewrite `base` and `recursive` separately
self.base.rewrite_exprs_recursive(rewriter);
self.recursive.rewrite_exprs_recursive(rewriter);
}
}

#[derive(Clone, Debug)]
pub struct BindingCte {
pub share_id: ShareId,
pub state: BindingCteState,
pub alias: TableAlias,
}

#[derive(Default, Debug, Clone)]
pub struct BindContext {
// Columns of all tables.
Expand All @@ -79,9 +145,9 @@ pub struct BindContext {
pub clause: Option<Clause>,
// The `BindContext`'s data on its column groups
pub column_group_context: ColumnGroupContext,
/// Map the cte's name to its `Relation::Subquery`.
/// The `ShareId` of the value is used to help the planner identify the share plan.
pub cte_to_relation: HashMap<String, Rc<(ShareId, BoundQuery, TableAlias)>>,
/// Map the cte's name to its binding state.
/// The `ShareId` in `BindingCte` of the value is used to help the planner identify the share plan.
pub cte_to_relation: HashMap<String, Rc<RefCell<BindingCte>>>,
/// Current lambda functions's arguments
pub lambda_args: Option<HashMap<String, (usize, DataType)>>,
}
Expand Down Expand Up @@ -305,8 +371,16 @@ impl BindContext {
entry.extend(v.into_iter().map(|x| x + begin));
}
for (k, (x, y)) in other.range_of {
match self.range_of.entry(k) {
match self.range_of.entry(k.clone()) {
Entry::Occupied(e) => {
// check if this is a merge with recursive cte
if let Some(r) = self.cte_to_relation.get(&k) {
if let BindingCteState::Bound { .. } = r.borrow().state {
// no-op
continue;
}
}
// otherwise this merge in invalid
return Err(ErrorCode::InternalError(format!(
"Duplicated table name while merging adjacent contexts: {}",
e.key()
Expand Down
199 changes: 199 additions & 0 deletions src/frontend/src/binder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -535,3 +535,202 @@ pub mod test_utils {
Binder::new_with_param_types(&SessionImpl::mock(), param_types)
}
}

#[cfg(test)]
mod tests {
use expect_test::expect;

use super::test_utils::*;

#[tokio::test]
async fn test_rcte() {
let stmt = risingwave_sqlparser::parser::Parser::parse_sql(
"WITH RECURSIVE t1 AS (SELECT 1 AS a UNION ALL SELECT a + 1 FROM t1 WHERE a < 10) SELECT * FROM t1",
).unwrap().into_iter().next().unwrap();
let mut binder = mock_binder();
let bound = binder.bind(stmt).unwrap();

let expected = expect![[r#"
Query(
BoundQuery {
body: Select(
BoundSelect {
distinct: All,
select_items: [
InputRef(
InputRef {
index: 0,
data_type: Int32,
},
),
InputRef(
InputRef {
index: 1,
data_type: Int32,
},
),
],
aliases: [
Some(
"a",
),
Some(
"a",
),
],
from: Some(
Share(
BoundShare {
share_id: 0,
input: Right(
RecursiveUnion {
all: true,
base: Select(
BoundSelect {
distinct: All,
select_items: [
Literal(
Literal {
data: Some(
Int32(
1,
),
),
data_type: Some(
Int32,
),
},
),
],
aliases: [
Some(
"a",
),
],
from: None,
where_clause: None,
group_by: GroupKey(
[],
),
having: None,
schema: Schema {
fields: [
a:Int32,
],
},
},
),
recursive: Select(
BoundSelect {
distinct: All,
select_items: [
FunctionCall(
FunctionCall {
func_type: Add,
return_type: Int32,
inputs: [
InputRef(
InputRef {
index: 0,
data_type: Int32,
},
),
Literal(
Literal {
data: Some(
Int32(
1,
),
),
data_type: Some(
Int32,
),
},
),
],
},
),
],
aliases: [
None,
],
from: Some(
BackCteRef(
BoundBackCteRef {
share_id: 0,
},
),
),
where_clause: Some(
FunctionCall(
FunctionCall {
func_type: LessThan,
return_type: Boolean,
inputs: [
InputRef(
InputRef {
index: 0,
data_type: Int32,
},
),
Literal(
Literal {
data: Some(
Int32(
10,
),
),
data_type: Some(
Int32,
),
},
),
],
},
),
),
group_by: GroupKey(
[],
),
having: None,
schema: Schema {
fields: [
?column?:Int32,
],
},
},
),
schema: Schema {
fields: [
a:Int32,
],
},
},
),
},
),
),
where_clause: None,
group_by: GroupKey(
[],
),
having: None,
schema: Schema {
fields: [
a:Int32,
a:Int32,
],
},
},
),
order: [],
limit: None,
offset: None,
with_ties: false,
extra_order_exprs: [],
},
)"#]];

expected.assert_eq(&format!("{:#?}", bound));
}
}
Loading

0 comments on commit 0c8371a

Please sign in to comment.