From 27b5439636eb732b3566306add89685fa6d81129 Mon Sep 17 00:00:00 2001 From: TennyZhuang Date: Thu, 7 Mar 2024 20:13:56 +0800 Subject: [PATCH] feat(binder): bind RCTE Signed-off-by: TennyZhuang --- src/frontend/src/binder/bind_context.rs | 29 ++++- src/frontend/src/binder/query.rs | 113 +++++++++++++++++--- src/frontend/src/binder/relation/cte_ref.rs | 28 +++++ src/frontend/src/binder/relation/mod.rs | 63 +++++++---- src/frontend/src/binder/set_expr.rs | 24 +++++ src/frontend/src/expr/mod.rs | 14 +++ src/frontend/src/lib.rs | 1 + src/frontend/src/planner/relation.rs | 3 + src/frontend/src/planner/set_expr.rs | 4 + 9 files changed, 242 insertions(+), 37 deletions(-) create mode 100644 src/frontend/src/binder/relation/cte_ref.rs diff --git a/src/frontend/src/binder/bind_context.rs b/src/frontend/src/binder/bind_context.rs index 0dc03464fbe62..e787d4ae24970 100644 --- a/src/frontend/src/binder/bind_context.rs +++ b/src/frontend/src/binder/bind_context.rs @@ -12,12 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::cell::RefCell; use std::collections::hash_map::Entry; use std::collections::{BTreeMap, HashMap, HashSet}; use std::rc::Rc; use parse_display::Display; -use risingwave_common::catalog::Field; +use risingwave_common::catalog::{Field, Schema}; use risingwave_common::types::DataType; use risingwave_sqlparser::ast::TableAlias; @@ -66,6 +67,30 @@ pub struct LateralBindContext { pub context: BindContext, } +/// If the CTE a recursive one, we may need store it in `cte_to_relation` first, and bind it step by step. +/// +/// ```sql +/// WITH RECURSIVE t(n) AS ( +/// ---------------^ Init +/// VALUES (1) +/// UNION ALL +/// SELECT n+1 FROM t WHERE n < 100 +/// # ------------------^BaseResolved +/// ) +/// SELECT sum(n) FROM t; +/// # -----------------^Bound +/// ``` +#[derive(Default, Debug, Clone)] +pub enum BindingCteState { + /// We know nothing about the CTE before resolve the body. + #[default] + Init, + /// We know the schema from after the base term resolved. + BaseResolved { schema: Schema }, + /// We get the whole bound result. + Bound { query: BoundQuery }, +} + #[derive(Default, Debug, Clone)] pub struct BindContext { // Columns of all tables. @@ -80,7 +105,7 @@ pub struct BindContext { pub column_group_context: ColumnGroupContext, /// Map the cte's name to its Relation::Subquery. /// The `ShareId` of the value is used to help the planner identify the share plan. - pub cte_to_relation: HashMap>, + pub cte_to_relation: HashMap>>, /// Current lambda functions's arguments pub lambda_args: Option>, } diff --git a/src/frontend/src/binder/query.rs b/src/frontend/src/binder/query.rs index fe2008f50f3eb..8a45550b056f5 100644 --- a/src/frontend/src/binder/query.rs +++ b/src/frontend/src/binder/query.rs @@ -12,16 +12,19 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::cell::RefCell; use std::collections::HashMap; use std::rc::Rc; -use risingwave_common::bail_not_implemented; use risingwave_common::catalog::Schema; use risingwave_common::types::DataType; use risingwave_common::util::sort_util::{ColumnOrder, OrderType}; -use risingwave_sqlparser::ast::{Cte, Expr, Fetch, OrderByExpr, Query, Value, With}; +use risingwave_sqlparser::ast::{ + Cte, Expr, Fetch, OrderByExpr, Query, SetExpr, SetOperator, Value, With, +}; use thiserror_ext::AsReport; +use super::bind_context::BindingCteState; use super::statement::RewriteExprsRecursive; use super::BoundValues; use crate::binder::{Binder, BoundSetExpr}; @@ -279,20 +282,104 @@ impl Binder { } fn bind_with(&mut self, with: With) -> Result<()> { - if with.recursive { - bail_not_implemented!("recursive cte"); - } else { - for cte_table in with.cte_tables { - let Cte { alias, query, .. } = cte_table; - let table_name = alias.name.real_value(); - let bound_query = self.bind_query(query)?; - let share_id = self.next_share_id(); - self.context + for cte_table in with.cte_tables { + let share_id = self.next_share_id(); + let Cte { alias, query, .. } = cte_table; + let table_name = alias.name.real_value(); + if with.recursive { + let Query { + with, + body, + order_by, + limit, + offset, + fetch, + } = query; + fn should_be_empty(v: Option, clause: &str) -> Result<()> { + if !v.is_none() { + return Err(ErrorCode::BindError(format!( + "`{clause}` is not supported in recursive CTE" + )) + .into()); + } + Ok(()) + } + should_be_empty(order_by.first(), "ORDER BY")?; + should_be_empty(limit, "LIMIT")?; + should_be_empty(offset, "OFFSET")?; + should_be_empty(fetch, "FETCH")?; + + let SetExpr::SetOperation { + op: SetOperator::Union, + all, + left, + right, + } = body + else { + return Err(ErrorCode::BindError(format!( + "`UNION` is required in recursive CTE" + )) + .into()); + }; + + if !all { + return Err(ErrorCode::BindError(format!( + "only `UNION ALL` is supported in recursive CTE now" + )) + .into()); + } + + let entry = self + .context .cte_to_relation - .insert(table_name, Rc::new((share_id, bound_query, alias))); + .entry(table_name) + .insert_entry(Rc::new(RefCell::new(( + share_id, + BindingCteState::Init, + alias, + )))) + .get() + .clone(); + + if let Some(with) = with { + self.bind_with(with)?; + } + + // We assume `left` is base term, otherwise the implementation may be very hard. + let bound_base = self.bind_set_expr(*left)?; + + entry.borrow_mut().1 = BindingCteState::BaseResolved { + schema: bound_base.schema().clone(), + }; + + let bound_recursive = self.bind_set_expr(*right)?; + + let bound_query = BoundQuery { + body: BoundSetExpr::RecursiveUnion { + base: Box::new(bound_base), + recursive: Box::new(bound_recursive), + }, + order: vec![], + limit: None, + offset: None, + with_ties: false, + extra_order_exprs: vec![], + }; + + entry.borrow_mut().1 = BindingCteState::Bound { query: bound_query }; + } else { + let bound_query = self.bind_query(query)?; + self.context.cte_to_relation.insert( + table_name, + Rc::new(RefCell::new(( + share_id, + BindingCteState::Bound { query: bound_query }, + alias, + ))), + ); } - Ok(()) } + Ok(()) } } diff --git a/src/frontend/src/binder/relation/cte_ref.rs b/src/frontend/src/binder/relation/cte_ref.rs new file mode 100644 index 0000000000000..86c40cc096dd1 --- /dev/null +++ b/src/frontend/src/binder/relation/cte_ref.rs @@ -0,0 +1,28 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use crate::binder::statement::RewriteExprsRecursive; +use crate::binder::ShareId; + +/// A CTE reference, currently only used in the back reference of recursive CTE. +/// For the non-recursive one, see [`BoundShare`](super::BoundShare). +#[derive(Debug, Clone)] +pub struct BoundBackCteRef { + #[expect(dead_code)] + pub(crate) share_id: ShareId, +} + +impl RewriteExprsRecursive for BoundBackCteRef { + fn rewrite_exprs_recursive(&mut self, _rewriter: &mut impl crate::expr::ExprRewriter) {} +} diff --git a/src/frontend/src/binder/relation/mod.rs b/src/frontend/src/binder/relation/mod.rs index 69eb6787d47a0..aabf5ba719dd7 100644 --- a/src/frontend/src/binder/relation/mod.rs +++ b/src/frontend/src/binder/relation/mod.rs @@ -24,12 +24,15 @@ use risingwave_sqlparser::ast::{ use thiserror::Error; use thiserror_ext::AsReport; +use self::cte_ref::BoundBackCteRef; use super::bind_context::ColumnBinding; use super::statement::RewriteExprsRecursive; +use crate::binder::bind_context::BindingCteState; use crate::binder::Binder; use crate::error::{ErrorCode, Result, RwError}; use crate::expr::{ExprImpl, InputRef}; +mod cte_ref; mod join; mod share; mod subquery; @@ -65,6 +68,7 @@ pub enum Relation { }, Watermark(Box), Share(Box), + BackCteRef(Box), } impl RewriteExprsRecursive for Relation { @@ -79,6 +83,7 @@ impl RewriteExprsRecursive for Relation { Relation::TableFunction { expr: inner, .. } => { *inner = rewriter.rewrite_expr(inner.take()) } + Relation::BackCteRef(inner) => inner.rewrite_exprs_recursive(rewriter), _ => {} } } @@ -336,7 +341,7 @@ impl Binder { { // Handles CTE - let (share_id, query, mut original_alias) = item.deref().clone(); + let (share_id, cte_state, mut original_alias) = item.deref().borrow().clone(); debug_assert_eq!(original_alias.name.real_value(), table_name); // The original CTE alias ought to be its table name. if let Some(from_alias) = alias { @@ -349,27 +354,41 @@ impl Binder { .collect(); } - self.bind_table_to_context( - query - .body - .schema() - .fields - .iter() - .map(|f| (false, f.clone())), - table_name.clone(), - Some(original_alias), - )?; - - // Share the CTE. - let input_relation = Relation::Subquery(Box::new(BoundSubquery { - query, - lateral: false, - })); - let share_relation = Relation::Share(Box::new(BoundShare { - share_id, - input: input_relation, - })); - Ok(share_relation) + match cte_state { + BindingCteState::Init => { + Err(ErrorCode::BindError(format!("Base term of recursive CTE not found, consider write it to left side of the `UNION` operator")).into()) + } + BindingCteState::BaseResolved { schema } => { + self.bind_table_to_context( + schema.fields.iter().map(|f| (false, f.clone())), + table_name.clone(), + Some(original_alias), + )?; + Ok(Relation::BackCteRef(Box::new(BoundBackCteRef { share_id }))) + } + BindingCteState::Bound { query } => { + self.bind_table_to_context( + query + .body + .schema() + .fields + .iter() + .map(|f| (false, f.clone())), + table_name.clone(), + Some(original_alias), + )?; + // Share the CTE. + let input_relation = Relation::Subquery(Box::new(BoundSubquery { + query, + lateral: false, + })); + let share_relation = Relation::Share(Box::new(BoundShare { + share_id, + input: input_relation, + })); + Ok(share_relation) + } + } } else { self.bind_relation_by_name_inner( schema_name.as_deref(), diff --git a/src/frontend/src/binder/set_expr.rs b/src/frontend/src/binder/set_expr.rs index 99ec66ac0b725..289043cd10029 100644 --- a/src/frontend/src/binder/set_expr.rs +++ b/src/frontend/src/binder/set_expr.rs @@ -36,6 +36,11 @@ pub enum BoundSetExpr { left: Box, right: Box, }, + /// UNION in recursive CTE definition + RecursiveUnion { + base: Box, + recursive: Box, + }, } impl RewriteExprsRecursive for BoundSetExpr { @@ -48,6 +53,10 @@ impl RewriteExprsRecursive for BoundSetExpr { left.rewrite_exprs_recursive(rewriter); right.rewrite_exprs_recursive(rewriter); } + BoundSetExpr::RecursiveUnion { base, recursive } => { + base.rewrite_exprs_recursive(rewriter); + recursive.rewrite_exprs_recursive(rewriter); + } } } } @@ -78,6 +87,7 @@ impl BoundSetExpr { BoundSetExpr::Values(v) => v.schema(), BoundSetExpr::Query(q) => q.schema(), BoundSetExpr::SetOperation { left, .. } => left.schema(), + BoundSetExpr::RecursiveUnion { base, .. } => base.schema(), } } @@ -89,6 +99,9 @@ impl BoundSetExpr { BoundSetExpr::SetOperation { left, right, .. } => { left.is_correlated(depth) || right.is_correlated(depth) } + BoundSetExpr::RecursiveUnion { base, recursive } => { + base.is_correlated(depth) || recursive.is_correlated(depth) + } } } @@ -117,6 +130,17 @@ impl BoundSetExpr { ); correlated_indices } + BoundSetExpr::RecursiveUnion { base, recursive } => { + let mut correlated_indices = vec![]; + correlated_indices.extend( + base.collect_correlated_indices_by_depth_and_assign_id(depth, correlated_id), + ); + correlated_indices.extend( + recursive + .collect_correlated_indices_by_depth_and_assign_id(depth, correlated_id), + ); + correlated_indices + } } } } diff --git a/src/frontend/src/expr/mod.rs b/src/frontend/src/expr/mod.rs index 78ae2db726a39..db8a438aa1755 100644 --- a/src/frontend/src/expr/mod.rs +++ b/src/frontend/src/expr/mod.rs @@ -483,6 +483,12 @@ impl ExprImpl { self.visit_bound_set_expr(left); self.visit_bound_set_expr(right); } + BoundSetExpr::RecursiveUnion { + base, recursive, .. + } => { + self.visit_bound_set_expr(base); + self.visit_bound_set_expr(recursive); + } }; } } @@ -524,6 +530,10 @@ impl ExprImpl { self.visit_bound_set_expr(left); self.visit_bound_set_expr(right); } + BoundSetExpr::RecursiveUnion { base, recursive } => { + self.visit_bound_set_expr(base); + self.visit_bound_set_expr(recursive); + } } } } @@ -593,6 +603,10 @@ impl ExprImpl { self.visit_bound_set_expr(&mut *left); self.visit_bound_set_expr(&mut *right); } + BoundSetExpr::RecursiveUnion { base, recursive } => { + self.visit_bound_set_expr(&mut *base); + self.visit_bound_set_expr(&mut *recursive); + } } } } diff --git a/src/frontend/src/lib.rs b/src/frontend/src/lib.rs index 9dc64983671d3..59271dfd18d4d 100644 --- a/src/frontend/src/lib.rs +++ b/src/frontend/src/lib.rs @@ -35,6 +35,7 @@ #![feature(round_ties_even)] #![feature(iterator_try_collect)] #![feature(used_with_arg)] +#![feature(entry_insert)] #![recursion_limit = "256"] #[cfg(test)] diff --git a/src/frontend/src/planner/relation.rs b/src/frontend/src/planner/relation.rs index 3f64a8fde4405..5101e263b299b 100644 --- a/src/frontend/src/planner/relation.rs +++ b/src/frontend/src/planner/relation.rs @@ -54,6 +54,9 @@ impl Planner { } => self.plan_table_function(tf, with_ordinality), Relation::Watermark(tf) => self.plan_watermark(*tf), Relation::Share(share) => self.plan_share(*share), + Relation::BackCteRef(..) => { + bail_not_implemented!(issue = 15135, "recursive CTE is not supported") + } } } diff --git a/src/frontend/src/planner/set_expr.rs b/src/frontend/src/planner/set_expr.rs index e2ff43a2c211b..eeb789e9d9ded 100644 --- a/src/frontend/src/planner/set_expr.rs +++ b/src/frontend/src/planner/set_expr.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use risingwave_common::bail_not_implemented; use risingwave_common::util::sort_util::ColumnOrder; use crate::binder::BoundSetExpr; @@ -37,6 +38,9 @@ impl Planner { left, right, } => self.plan_set_operation(op, all, *left, *right), + BoundSetExpr::RecursiveUnion { .. } => { + bail_not_implemented!(issue = 15135, "recursive CTE is not supported") + } } } }