diff --git a/src/frontend/src/binder/bind_context.rs b/src/frontend/src/binder/bind_context.rs index 3229f4d08be02..fbe2789179c79 100644 --- a/src/frontend/src/binder/bind_context.rs +++ b/src/frontend/src/binder/bind_context.rs @@ -17,6 +17,7 @@ use std::collections::hash_map::Entry; use std::collections::{BTreeMap, HashMap, HashSet}; use std::rc::Rc; +use either::Either; use parse_display::Display; use risingwave_common::catalog::{Field, Schema}; use risingwave_common::types::DataType; @@ -26,6 +27,7 @@ use crate::error::{ErrorCode, Result}; type LiteResult = std::result::Result; +use super::BoundSetExpr; use crate::binder::{BoundQuery, ShareId, COLUMN_GROUP_PREFIX}; #[derive(Debug, Clone)] @@ -78,12 +80,12 @@ pub struct LateralBindContext { /// WITH RECURSIVE t(n) AS ( /// # -------------^ => Init /// VALUES (1) +/// # ----------^ => BaseResolved (after binding the base term) /// UNION ALL /// SELECT n+1 FROM t WHERE n < 100 -/// # ------------------^ => BaseResolved +/// # ------------------^ => Bound (we know exactly what the entire cte looks like) /// ) /// SELECT sum(n) FROM t; -/// # -----------------^ => Bound /// ``` #[derive(Default, Debug, Clone)] pub enum BindingCteState { @@ -93,7 +95,26 @@ pub enum BindingCteState { /// We know the schema form after the base term resolved. BaseResolved { schema: Schema }, /// We get the whole bound result of the (recursive) CTE. - Bound { query: BoundQuery }, + Bound { + query: Either, + }, +} + +/// the entire `RecursiveUnion` represents a *bound* recursive cte. +/// reference: +#[derive(Debug, Clone)] +pub struct RecursiveUnion { + /// currently this *must* be true, + /// otherwise binding will fail. + pub all: bool, + /// lhs part of the `UNION ALL` operator + pub base: Box, + /// rhs part of the `UNION ALL` operator + pub recursive: Box, + /// the aligned schema for this union + /// will be the *same* schema as recursive's + /// this is just for a better readability + pub schema: Schema, } #[derive(Clone, Debug)] @@ -116,7 +137,7 @@ pub struct BindContext { // The `BindContext`'s data on its column groups pub column_group_context: ColumnGroupContext, /// Map the cte's name to its binding state. - /// The `ShareId` of the value is used to help the planner identify the share plan. + /// The `ShareId` in `BindingCte` of the value is used to help the planner identify the share plan. pub cte_to_relation: HashMap>>, /// Current lambda functions's arguments pub lambda_args: Option>, @@ -341,13 +362,19 @@ impl BindContext { entry.extend(v.into_iter().map(|x| x + begin)); } for (k, (x, y)) in other.range_of { - match self.range_of.entry(k) { + match self.range_of.entry(k.clone()) { Entry::Occupied(e) => { - return Err(ErrorCode::InternalError(format!( - "Duplicated table name while merging adjacent contexts: {}", - e.key() - )) - .into()); + if let BindingCteState::Bound { .. } = + self.cte_to_relation.get(&k).unwrap().borrow().state.clone() + { + // do nothing + } else { + return Err(ErrorCode::InternalError(format!( + "Duplicated table name while merging adjacent contexts: {}", + e.key() + )) + .into()); + } } Entry::Vacant(entry) => { entry.insert((begin + x, begin + y)); diff --git a/src/frontend/src/binder/query.rs b/src/frontend/src/binder/query.rs index ef957225c7b47..37a97e8bf365e 100644 --- a/src/frontend/src/binder/query.rs +++ b/src/frontend/src/binder/query.rs @@ -27,7 +27,7 @@ use thiserror_ext::AsReport; use super::bind_context::BindingCteState; use super::statement::RewriteExprsRecursive; use super::BoundValues; -use crate::binder::bind_context::BindingCte; +use crate::binder::bind_context::{BindingCte, RecursiveUnion}; use crate::binder::{Binder, BoundSetExpr}; use crate::error::{ErrorCode, Result}; use crate::expr::{CorrelatedId, Depth, ExprImpl, ExprRewriter}; @@ -287,6 +287,7 @@ impl Binder { let share_id = self.next_share_id(); let Cte { alias, query, .. } = cte_table; let table_name = alias.name.real_value(); + if with.recursive { let Query { with, @@ -296,6 +297,8 @@ impl Binder { offset, fetch, } = query; + + /// the input clause should not be supported. fn should_be_empty(v: Option, clause: &str) -> Result<()> { if v.is_some() { return Err(ErrorCode::BindError(format!( @@ -305,6 +308,7 @@ impl Binder { } Ok(()) } + should_be_empty(order_by.first(), "ORDER BY")?; should_be_empty(limit, "LIMIT")?; should_be_empty(offset, "OFFSET")?; @@ -315,7 +319,7 @@ impl Binder { all, left, right, - } = body + } = body.clone() else { return Err(ErrorCode::BindError( "`UNION` is required in recursive CTE".to_string(), @@ -346,37 +350,52 @@ impl Binder { self.bind_with(with)?; } - // We assume `left` is base term, otherwise the implementation may be very hard. - // The behavior is same as PostgreSQL. - // https://www.postgresql.org/docs/16/sql-select.html#:~:text=the%20recursive%20self%2Dreference%20must%20appear%20on%20the%20right%2Dhand%20side%20of%20the%20UNION - let bound_base = self.bind_set_expr(*left)?; + // We assume `left` is the base term, otherwise the implementation may be very hard. + // The behavior is the same as PostgreSQL's. + // reference: + let mut base = self.bind_set_expr(*left)?; entry.borrow_mut().state = BindingCteState::BaseResolved { - schema: bound_base.schema().clone(), + schema: base.schema().clone(), }; - let bound_recursive = self.bind_set_expr(*right)?; - - let bound_query = BoundQuery { - body: BoundSetExpr::RecursiveUnion { - base: Box::new(bound_base), - recursive: Box::new(bound_recursive), - }, - order: vec![], - limit: None, - offset: None, - with_ties: false, - extra_order_exprs: vec![], + // bind the rest of the recursive cte + let mut recursive = self.bind_set_expr(*right)?; + + // todo: add validate check here for *bound* `base` and `recursive` + Self::align_schema(&mut base, &mut recursive, SetOperator::Union)?; + + // please note that even after aligning, the schema of `left` + // may not be the same as `right`; this is because there may + // be case(s) where the `base` term is just a value, and the + // `recursive` term is a select expression / statement. + let schema = recursive.schema().clone(); + // yet another sanity check + assert_eq!( + schema, + recursive.schema().clone(), + "expect `schema` to be the same as recursive's" + ); + + let recursive_union = RecursiveUnion { + all, + base: Box::new(base), + recursive: Box::new(recursive), + schema, }; - entry.borrow_mut().state = BindingCteState::Bound { query: bound_query }; + entry.borrow_mut().state = BindingCteState::Bound { + query: either::Either::Right(recursive_union), + }; } else { let bound_query = self.bind_query(query)?; self.context.cte_to_relation.insert( table_name, Rc::new(RefCell::new(BindingCte { share_id, - state: BindingCteState::Bound { query: bound_query }, + state: BindingCteState::Bound { + query: either::Either::Left(bound_query), + }, alias, })), ); diff --git a/src/frontend/src/binder/relation/mod.rs b/src/frontend/src/binder/relation/mod.rs index d8d6ad7df08bc..17ea96b49cce9 100644 --- a/src/frontend/src/binder/relation/mod.rs +++ b/src/frontend/src/binder/relation/mod.rs @@ -15,6 +15,7 @@ use std::collections::hash_map::Entry; use std::ops::Deref; +use either::Either::{Left, Right}; use itertools::{EitherOrBoth, Itertools}; use risingwave_common::bail; use risingwave_common::catalog::{Field, TableId, DEFAULT_SCHEMA_NAME}; @@ -26,6 +27,7 @@ use thiserror::Error; use thiserror_ext::AsReport; use self::cte_ref::BoundBackCteRef; +use self::recursive_union::BoundRecursiveUnion; use super::bind_context::ColumnBinding; use super::statement::RewriteExprsRecursive; use crate::binder::bind_context::{BindingCte, BindingCteState}; @@ -35,6 +37,7 @@ use crate::expr::{ExprImpl, InputRef}; mod cte_ref; mod join; +mod recursive_union; mod share; mod subquery; mod table_function; @@ -70,6 +73,7 @@ pub enum Relation { Watermark(Box), Share(Box), BackCteRef(Box), + RecursiveUnion(Box), } impl RewriteExprsRecursive for Relation { @@ -85,6 +89,7 @@ impl RewriteExprsRecursive for Relation { *inner = rewriter.rewrite_expr(inner.take()) } Relation::BackCteRef(inner) => inner.rewrite_exprs_recursive(rewriter), + Relation::RecursiveUnion(inner) => inner.rewrite_exprs_recursive(rewriter), _ => {} } } @@ -342,7 +347,9 @@ impl Binder { as_of: Option, ) -> Result { let (schema_name, table_name) = Self::resolve_schema_qualified_name(&self.db_name, name)?; + if schema_name.is_none() + // the `table_name` here is the name of the currently binding cte. && let Some(item) = self.context.cte_to_relation.get(&table_name) { // Handles CTE @@ -352,7 +359,9 @@ impl Binder { state: cte_state, alias: mut original_alias, } = item.deref().borrow().clone(); - debug_assert_eq!(original_alias.name.real_value(), table_name); // The original CTE alias ought to be its table name. + + // The original CTE alias ought to be its table name. + debug_assert_eq!(original_alias.name.real_value(), table_name); if let Some(from_alias) = alias { original_alias.name = from_alias.name; @@ -366,7 +375,7 @@ impl Binder { match cte_state { BindingCteState::Init => { - Err(ErrorCode::BindError("Base term of recursive CTE not found, consider write it to left side of the `UNION` operator".to_string()).into()) + Err(ErrorCode::BindError("Base term of recursive CTE not found, consider writing it to left side of the `UNION ALL` operator".to_string()).into()) } BindingCteState::BaseResolved { schema } => { self.bind_table_to_context( @@ -377,23 +386,40 @@ impl Binder { Ok(Relation::BackCteRef(Box::new(BoundBackCteRef { share_id }))) } BindingCteState::Bound { query } => { + let schema = match query.clone() { + Left(normal) => normal.body.schema().clone(), + Right(recursive) => recursive.schema.clone(), + }; self.bind_table_to_context( - query - .body - .schema() + schema .fields .iter() .map(|f| (false, f.clone())), table_name.clone(), Some(original_alias), )?; - // Share the CTE. - let input_relation = Relation::Subquery(Box::new(BoundSubquery { - query, - lateral: false, - })); + // todo: to be further reviewed + let input_relation = match query { + // normal cte with union + Left(query) => { + Relation::Subquery(Box::new(BoundSubquery { + query, + lateral: false, + })) + } + // recursive cte + Right(recursive) => { + Relation::RecursiveUnion(Box::new(BoundRecursiveUnion { + base: *recursive.base, + recursive: *recursive.recursive, + })) + } + }; + // we could always share the cte, + // no matter it's recursive or not. let share_relation = Relation::Share(Box::new(BoundShare { share_id, + // should either be a *bound* `subquery` or `recursive union` input: input_relation, })); Ok(share_relation) diff --git a/src/frontend/src/binder/relation/recursive_union.rs b/src/frontend/src/binder/relation/recursive_union.rs new file mode 100644 index 0000000000000..2181403f49dbf --- /dev/null +++ b/src/frontend/src/binder/relation/recursive_union.rs @@ -0,0 +1,35 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use crate::binder::statement::RewriteExprsRecursive; +use crate::binder::BoundSetExpr; + +/// a *bound* recursive union representation. +#[allow(dead_code)] +#[derive(Debug, Clone)] +pub struct BoundRecursiveUnion { + /// the *bound* base case + pub(crate) base: BoundSetExpr, + /// the *bound* recursive case + pub(crate) recursive: BoundSetExpr, +} + +impl RewriteExprsRecursive for BoundRecursiveUnion { + fn rewrite_exprs_recursive(&mut self, rewriter: &mut impl crate::expr::ExprRewriter) { + // rewrite base case + self.base.rewrite_exprs_recursive(rewriter); + // rewrite recursive case + self.recursive.rewrite_exprs_recursive(rewriter); + } +} diff --git a/src/frontend/src/binder/set_expr.rs b/src/frontend/src/binder/set_expr.rs index 64c38657b3b4b..de6e3f28468f1 100644 --- a/src/frontend/src/binder/set_expr.rs +++ b/src/frontend/src/binder/set_expr.rs @@ -36,11 +36,6 @@ pub enum BoundSetExpr { left: Box, right: Box, }, - /// UNION in recursive CTE definition - RecursiveUnion { - base: Box, - recursive: Box, - }, } impl RewriteExprsRecursive for BoundSetExpr { @@ -53,10 +48,6 @@ impl RewriteExprsRecursive for BoundSetExpr { left.rewrite_exprs_recursive(rewriter); right.rewrite_exprs_recursive(rewriter); } - BoundSetExpr::RecursiveUnion { base, recursive } => { - base.rewrite_exprs_recursive(rewriter); - recursive.rewrite_exprs_recursive(rewriter); - } } } } @@ -87,7 +78,6 @@ impl BoundSetExpr { BoundSetExpr::Values(v) => v.schema(), BoundSetExpr::Query(q) => q.schema(), BoundSetExpr::SetOperation { left, .. } => left.schema(), - BoundSetExpr::RecursiveUnion { base, .. } => base.schema(), } } @@ -99,9 +89,6 @@ impl BoundSetExpr { BoundSetExpr::SetOperation { left, right, .. } => { left.is_correlated(depth) || right.is_correlated(depth) } - BoundSetExpr::RecursiveUnion { base, recursive } => { - base.is_correlated(depth) || recursive.is_correlated(depth) - } } } @@ -130,22 +117,83 @@ impl BoundSetExpr { ); correlated_indices } - BoundSetExpr::RecursiveUnion { base, recursive } => { - let mut correlated_indices = vec![]; - correlated_indices.extend( - base.collect_correlated_indices_by_depth_and_assign_id(depth, correlated_id), - ); - correlated_indices.extend( - recursive - .collect_correlated_indices_by_depth_and_assign_id(depth, correlated_id), - ); - correlated_indices - } } } } impl Binder { + /// note: align_schema only works when the `left` and `right` + /// are both select expression(s). + pub(crate) fn align_schema( + mut left: &mut BoundSetExpr, + mut right: &mut BoundSetExpr, + op: SetOperator, + ) -> Result<()> { + if left.schema().fields.len() != right.schema().fields.len() { + return Err(ErrorCode::InvalidInputSyntax(format!( + "each {} query must have the same number of columns", + op + )) + .into()); + } + + // handle type alignment for select union select + // e.g., select 1 UNION ALL select NULL + if let (BoundSetExpr::Select(l_select), BoundSetExpr::Select(r_select)) = + (&mut left, &mut right) + { + for (i, (l, r)) in l_select + .select_items + .iter_mut() + .zip_eq_fast(r_select.select_items.iter_mut()) + .enumerate() + { + let Ok(column_type) = align_types(vec![l, r].into_iter()) else { + return Err(ErrorCode::InvalidInputSyntax(format!( + "{} types {} and {} cannot be matched. Columns' name are `{}` and `{}`.", + op, + l_select.schema.fields[i].data_type, + r_select.schema.fields[i].data_type, + l_select.schema.fields[i].name, + r_select.schema.fields[i].name, + )) + .into()); + }; + l_select.schema.fields[i].data_type = column_type.clone(); + r_select.schema.fields[i].data_type = column_type; + } + } + + Self::validate(left, right, op) + } + + /// validate the schema, should be called after aligning. + pub(crate) fn validate( + left: &BoundSetExpr, + right: &BoundSetExpr, + op: SetOperator, + ) -> Result<()> { + for (a, b) in left + .schema() + .fields + .iter() + .zip_eq_fast(right.schema().fields.iter()) + { + if a.data_type != b.data_type { + return Err(ErrorCode::InvalidInputSyntax(format!( + "{} types {} and {} cannot be matched. Columns' name are {} and {}.", + op, + a.data_type.prost_type_name().as_str_name(), + b.data_type.prost_type_name().as_str_name(), + a.name, + b.name, + )) + .into()); + } + } + Ok(()) + } + pub(super) fn bind_set_expr(&mut self, set_expr: SetExpr) -> Result { match set_expr { SetExpr::Select(s) => Ok(BoundSetExpr::Select(Box::new(self.bind_select(*s)?))), @@ -157,7 +205,7 @@ impl Binder { left, right, } => { - match op { + match op.clone() { SetOperator::Union | SetOperator::Intersect | SetOperator::Except => { let mut left = self.bind_set_expr(*left)?; // Reset context for right side, but keep `cte_to_relation`. @@ -175,51 +223,7 @@ impl Binder { .into()); } - // Handle type alignment for select union select - // E.g. Select 1 UNION ALL Select NULL - if let (BoundSetExpr::Select(l_select), BoundSetExpr::Select(r_select)) = - (&mut left, &mut right) - { - for (i, (l, r)) in l_select - .select_items - .iter_mut() - .zip_eq_fast(r_select.select_items.iter_mut()) - .enumerate() - { - let Ok(column_type) = align_types(vec![l, r].into_iter()) else { - return Err(ErrorCode::InvalidInputSyntax(format!( - "{} types {} and {} cannot be matched. Columns' name are `{}` and `{}`.", - op, - l_select.schema.fields[i].data_type, - r_select.schema.fields[i].data_type, - l_select.schema.fields[i].name, - r_select.schema.fields[i].name, - )) - .into()); - }; - l_select.schema.fields[i].data_type = column_type.clone(); - r_select.schema.fields[i].data_type = column_type; - } - } - - for (a, b) in left - .schema() - .fields - .iter() - .zip_eq_fast(right.schema().fields.iter()) - { - if a.data_type != b.data_type { - return Err(ErrorCode::InvalidInputSyntax(format!( - "{} types {} and {} cannot be matched. Columns' name are {} and {}.", - op, - a.data_type.prost_type_name().as_str_name(), - b.data_type.prost_type_name().as_str_name(), - a.name, - b.name, - )) - .into()); - } - } + Self::align_schema(&mut left, &mut right, op.clone())?; if all { match op { diff --git a/src/frontend/src/expr/mod.rs b/src/frontend/src/expr/mod.rs index e6b2c0d382b90..893ae425b8513 100644 --- a/src/frontend/src/expr/mod.rs +++ b/src/frontend/src/expr/mod.rs @@ -483,10 +483,6 @@ impl ExprImpl { self.visit_bound_set_expr(left); self.visit_bound_set_expr(right); } - BoundSetExpr::RecursiveUnion { base, recursive } => { - self.visit_bound_set_expr(base); - self.visit_bound_set_expr(recursive); - } }; } } @@ -528,10 +524,6 @@ impl ExprImpl { self.visit_bound_set_expr(left); self.visit_bound_set_expr(right); } - BoundSetExpr::RecursiveUnion { base, recursive } => { - self.visit_bound_set_expr(base); - self.visit_bound_set_expr(recursive); - } } } } @@ -601,10 +593,6 @@ impl ExprImpl { self.visit_bound_set_expr(&mut *left); self.visit_bound_set_expr(&mut *right); } - BoundSetExpr::RecursiveUnion { base, recursive } => { - self.visit_bound_set_expr(&mut *base); - self.visit_bound_set_expr(&mut *recursive); - } } } } diff --git a/src/frontend/src/planner/relation.rs b/src/frontend/src/planner/relation.rs index 4d83bee2fba2e..5393341f13f4f 100644 --- a/src/frontend/src/planner/relation.rs +++ b/src/frontend/src/planner/relation.rs @@ -58,6 +58,11 @@ impl Planner { Relation::BackCteRef(..) => { bail_not_implemented!(issue = 15135, "recursive CTE is not supported") } + // todo: ensure this will always be wrapped in a `Relation::Share` + // so that it will not be explicitly planned here + Relation::RecursiveUnion(..) => { + bail_not_implemented!(issue = 15135, "recursive CTE is not supported") + } } } diff --git a/src/frontend/src/planner/set_expr.rs b/src/frontend/src/planner/set_expr.rs index eeb789e9d9ded..e2ff43a2c211b 100644 --- a/src/frontend/src/planner/set_expr.rs +++ b/src/frontend/src/planner/set_expr.rs @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use risingwave_common::bail_not_implemented; use risingwave_common::util::sort_util::ColumnOrder; use crate::binder::BoundSetExpr; @@ -38,9 +37,6 @@ impl Planner { left, right, } => self.plan_set_operation(op, all, *left, *right), - BoundSetExpr::RecursiveUnion { .. } => { - bail_not_implemented!(issue = 15135, "recursive CTE is not supported") - } } } }