diff --git a/e2e_test/batch/basic/union.slt.part b/e2e_test/batch/basic/union.slt.part index 7dad5cf0ffe69..1db2f2ef4a502 100644 --- a/e2e_test/batch/basic/union.slt.part +++ b/e2e_test/batch/basic/union.slt.part @@ -2,33 +2,33 @@ statement ok SET RW_IMPLICIT_FLUSH TO true; statement ok -create table t1 (v1 int, v2 bigint); +create table t1 (v1 int, v2 bigint, v4 int); statement ok -create table t2 (v1 int, v3 int); +create table t2 (v1 int, v3 int, v4 int); statement ok -insert into t1 values(1, 2); +insert into t1 values(1, 2, 3); statement ok -insert into t2 values(1, 2); +insert into t2 values(1, 2, 3); -query II +query III select * from t1 union select * from t2 ---- -1 2 +1 2 3 -query II +query III select * from t1 union all select * from t2 ---- -1 2 -1 2 +1 2 3 +1 2 3 -query II +query III select * from t1 union all select * from t2 order by v1 ---- -1 2 -1 2 +1 2 3 +1 2 3 statement error select * from t1 union all select * from t2 order by v1 + 1 @@ -69,9 +69,55 @@ NULL statement error select null union all select null select union 1 +query II +select * from t1 union all corresponding select * from t2 order by v1 +---- +1 3 +1 3 + +query II +select * from t1 union corresponding select v4, v3 as v1 from t2 order by v1 +---- +1 3 +2 3 + +query II +select * from t1 union all corresponding by (v4, v1) select * from t2 +---- +3 1 +3 1 + +query II +select * from t1 union corresponding by (v4) select * from t2 +---- +3 + +statement error +select * from t1 union corresponding by (vxx) select * from t2 +---- +db error: ERROR: Failed to run the query + +Caused by: + Invalid input syntax: Column name `vxx` in CORRESPONDING BY is not found in a side of the UNION operation. It shall be included in both sides. + + +statement ok +create table txx (vxx int); + +statement error +select * from t1 union corresponding select * from txx +---- +db error: ERROR: Failed to run the query + +Caused by: + Invalid input syntax: When CORRESPONDING is specified, at least one column of the left side shall have a column name that is the column name of some column of the right side in a UNION operation. Left side query column list: ("v1", "v2", "v4"). Right side query column list: ("vxx"). + statement ok drop table t1; statement ok drop table t2; + +statement ok +drop table txx; diff --git a/e2e_test/streaming/union.slt b/e2e_test/streaming/union.slt index 2e99097a84bc2..8bb5245bee722 100644 --- a/e2e_test/streaming/union.slt +++ b/e2e_test/streaming/union.slt @@ -2,10 +2,10 @@ statement ok SET RW_IMPLICIT_FLUSH TO true; statement ok -create table t1 (v1 int, v2 int); +create table t1 (v1 int, v2 int, v4 int); statement ok -create table t2 (v1 int, v3 int); +create table t2 (v1 int, v3 int, v4 int); statement ok create materialized view v as select * from t1 union all select * from t2; @@ -13,6 +13,12 @@ create materialized view v as select * from t1 union all select * from t2; statement ok create materialized view v2 as select * from t1 union select * from t2; +statement ok +create materialized view v3 as select * from t1 union all corresponding select * from t2; + +statement ok +create materialized view v4 as select * from t1 union corresponding by (v4, v1) select * from t2; + query II select * from v; ---- @@ -22,57 +28,96 @@ select * from v2; ---- statement ok -insert into t1 values(1, 2); +insert into t1 values(1, 2, 3); -query II +query III select * from v; ---- -1 2 +1 2 3 -query II +query III select * from v2; ---- -1 2 +1 2 3 + +query II +select * from v3; +---- +1 3 + +query II +select * from v4; +---- +3 1 statement ok -insert into t2 values(1, 2); +insert into t2 values(1, 2, 3); -query II +query III select * from v; ---- -1 2 -1 2 +1 2 3 +1 2 3 -query II +query III select * from v2; ---- -1 2 +1 2 3 + +query II +select * from v3; +---- +1 3 +1 3 + +query II +select * from v4; +---- +3 1 statement ok delete from t1 where v1 = 1; -query II +query III select * from v; ---- -1 2 +1 2 3 -query II +query III select * from v2; ---- -1 2 +1 2 3 + +query II +select * from v3; +---- +1 3 + +query II +select * from v4; +---- +3 1 statement ok delete from t2 where v1 = 1; -query II +query III select * from v; ---- -query II +query III select * from v2; ---- +query II +select * from v3; +---- + +query II +select * from v4; +---- + statement ok drop materialized view v; @@ -80,6 +125,33 @@ drop materialized view v; statement ok drop materialized view v2; +statement ok +drop materialized view v3; + +statement ok +drop materialized view v4; + +statement error +create materialized view v5 as select * from t1 union corresponding by (vxx, v1) select * from t2 +---- +db error: ERROR: Failed to run the query + +Caused by: + Invalid input syntax: Column name `vxx` in CORRESPONDING BY is not found in a side of the UNION operation. It shall be included in both sides. + + +statement ok +create table txx (vxx int); + +statement error +create materialized view v5 as select * from t1 union corresponding select * from txx +---- +db error: ERROR: Failed to run the query + +Caused by: + Invalid input syntax: When CORRESPONDING is specified, at least one column of the left side shall have a column name that is the column name of some column of the right side in a UNION operation. Left side query column list: ("v1", "v2", "v4"). Right side query column list: ("vxx"). + + statement ok drop table t1; diff --git a/src/common/src/catalog/schema.rs b/src/common/src/catalog/schema.rs index 113d9f804b3d4..9eccecfc2fc00 100644 --- a/src/common/src/catalog/schema.rs +++ b/src/common/src/catalog/schema.rs @@ -197,6 +197,14 @@ impl Schema { true } } + + pub fn formatted_col_names(&self) -> String { + self.fields + .iter() + .map(|f| format!("\"{}\"", &f.name)) + .collect::>() + .join(", ") + } } impl Field { diff --git a/src/frontend/planner_test/tests/testdata/input/union.yaml b/src/frontend/planner_test/tests/testdata/input/union.yaml index 8775d4f9d36f2..93e6b00089066 100644 --- a/src/frontend/planner_test/tests/testdata/input/union.yaml +++ b/src/frontend/planner_test/tests/testdata/input/union.yaml @@ -95,3 +95,31 @@ select * from t1 union all select * from t2 union all select * from t3 union all select * from t4 union all select * from t5; expected_outputs: - stream_dist_plan + +- name: test corresponding union + sql: | + create table t1 (a int, b numeric, c bigint); + create table t2 (a int, b numeric, y bigint); + create table t3 (x int, b numeric, c bigint); + select * from t1 union corresponding select * from t2 union all corresponding by (b) select * from t3; + expected_outputs: + - batch_plan + - stream_plan + - stream_dist_plan + +- name: test corresponding union error - corresponding list + sql: | + create table t1 (a int, b numeric, c bigint); + create table t2 (a int, b numeric, y bigint); + create table t3 (x int, b numeric, c bigint); + select * from t1 union corresponding select * from t2 union all corresponding by (c) select * from t3; + expected_outputs: + - binder_error + +- name: test corresponding union error - duplicate names + sql: | + create table t1 (a int, b numeric, c bigint); + create table t2 (a int, b numeric, y bigint); + select a, b as a from t1 union corresponding select * from t2; + expected_outputs: + - binder_error \ No newline at end of file diff --git a/src/frontend/planner_test/tests/testdata/output/union.yaml b/src/frontend/planner_test/tests/testdata/output/union.yaml index ffd31dec73da5..57c65de3cd1f2 100644 --- a/src/frontend/planner_test/tests/testdata/output/union.yaml +++ b/src/frontend/planner_test/tests/testdata/output/union.yaml @@ -639,3 +639,106 @@ ├── distribution key: [ 0, 1, 3 ] └── read pk prefix len hint: 3 +- name: test corresponding union + sql: | + create table t1 (a int, b numeric, c bigint); + create table t2 (a int, b numeric, y bigint); + create table t3 (x int, b numeric, c bigint); + select * from t1 union corresponding select * from t2 union all corresponding by (b) select * from t3; + batch_plan: |- + BatchUnion { all: true } + ├─BatchExchange { order: [], dist: Single } + │ └─BatchProject { exprs: [t1.b] } + │ └─BatchHashAgg { group_key: [t1.a, t1.b], aggs: [] } + │ └─BatchExchange { order: [], dist: HashShard(t1.a, t1.b) } + │ └─BatchUnion { all: true } + │ ├─BatchExchange { order: [], dist: Single } + │ │ └─BatchScan { table: t1, columns: [t1.a, t1.b], distribution: SomeShard } + │ └─BatchExchange { order: [], dist: Single } + │ └─BatchScan { table: t2, columns: [t2.a, t2.b], distribution: SomeShard } + └─BatchExchange { order: [], dist: Single } + └─BatchScan { table: t3, columns: [t3.b], distribution: SomeShard } + stream_plan: |- + StreamMaterialize { columns: [b, t1.a(hidden), t1.b(hidden), null:Serial(hidden), $src(hidden)], stream_key: [t1.a, t1.b, null:Serial, $src], pk_columns: [t1.a, t1.b, null:Serial, $src], pk_conflict: NoCheck } + └─StreamUnion { all: true } + ├─StreamExchange { dist: HashShard(t1.a, t1.b, null:Serial, 0:Int32) } + │ └─StreamProject { exprs: [t1.b, t1.a, t1.b, null:Serial, 0:Int32], noop_update_hint: true } + │ └─StreamHashAgg { group_key: [t1.a, t1.b], aggs: [count] } + │ └─StreamExchange { dist: HashShard(t1.a, t1.b) } + │ └─StreamUnion { all: true } + │ ├─StreamExchange { dist: HashShard(t1._row_id, 0:Int32) } + │ │ └─StreamProject { exprs: [t1.a, t1.b, t1._row_id, 0:Int32] } + │ │ └─StreamTableScan { table: t1, columns: [t1.a, t1.b, t1._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t1._row_id], pk: [_row_id], dist: UpstreamHashShard(t1._row_id) } + │ └─StreamExchange { dist: HashShard(t2._row_id, 1:Int32) } + │ └─StreamProject { exprs: [t2.a, t2.b, t2._row_id, 1:Int32] } + │ └─StreamTableScan { table: t2, columns: [t2.a, t2.b, t2._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t2._row_id], pk: [_row_id], dist: UpstreamHashShard(t2._row_id) } + └─StreamExchange { dist: HashShard(null:Int32, null:Decimal, t3._row_id, 1:Int32) } + └─StreamProject { exprs: [t3.b, null:Int32, null:Decimal, t3._row_id, 1:Int32] } + └─StreamTableScan { table: t3, columns: [t3.b, t3._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t3._row_id], pk: [_row_id], dist: UpstreamHashShard(t3._row_id) } + stream_dist_plan: |+ + Fragment 0 + StreamMaterialize { columns: [b, t1.a(hidden), t1.b(hidden), null:Serial(hidden), $src(hidden)], stream_key: [t1.a, t1.b, null:Serial, $src], pk_columns: [t1.a, t1.b, null:Serial, $src], pk_conflict: NoCheck } + ├── tables: [ Materialize: 4294967294 ] + └── StreamUnion { all: true } + ├── StreamExchange Hash([1, 2, 3, 4]) from 1 + └── StreamExchange Hash([1, 2, 3, 4]) from 5 + + Fragment 1 + StreamProject { exprs: [t1.b, t1.a, t1.b, null:Serial, 0:Int32], noop_update_hint: true } + └── StreamHashAgg { group_key: [t1.a, t1.b], aggs: [count] } { tables: [ HashAggState: 0 ] } + └── StreamExchange Hash([0, 1]) from 2 + + Fragment 2 + StreamUnion { all: true } + ├── StreamExchange Hash([2, 3]) from 3 + └── StreamExchange Hash([2, 3]) from 4 + + Fragment 3 + StreamProject { exprs: [t1.a, t1.b, t1._row_id, 0:Int32] } + └── StreamTableScan { table: t1, columns: [t1.a, t1.b, t1._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t1._row_id], pk: [_row_id], dist: UpstreamHashShard(t1._row_id) } + ├── tables: [ StreamScan: 1 ] + ├── Upstream + └── BatchPlanNode + + Fragment 4 + StreamProject { exprs: [t2.a, t2.b, t2._row_id, 1:Int32] } + └── StreamTableScan { table: t2, columns: [t2.a, t2.b, t2._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t2._row_id], pk: [_row_id], dist: UpstreamHashShard(t2._row_id) } + ├── tables: [ StreamScan: 2 ] + ├── Upstream + └── BatchPlanNode + + Fragment 5 + StreamProject { exprs: [t3.b, null:Int32, null:Decimal, t3._row_id, 1:Int32] } + └── StreamTableScan { table: t3, columns: [t3.b, t3._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t3._row_id], pk: [_row_id], dist: UpstreamHashShard(t3._row_id) } + ├── tables: [ StreamScan: 3 ] + ├── Upstream + └── BatchPlanNode + + Table 0 { columns: [ t1_a, t1_b, count ], primary key: [ $0 ASC, $1 ASC ], value indices: [ 2 ], distribution key: [ 0, 1 ], read pk prefix len hint: 2 } + + Table 1 { columns: [ vnode, _row_id, backfill_finished, row_count ], primary key: [ $0 ASC ], value indices: [ 1, 2, 3 ], distribution key: [ 0 ], read pk prefix len hint: 1, vnode column idx: 0 } + + Table 2 { columns: [ vnode, _row_id, backfill_finished, row_count ], primary key: [ $0 ASC ], value indices: [ 1, 2, 3 ], distribution key: [ 0 ], read pk prefix len hint: 1, vnode column idx: 0 } + + Table 3 { columns: [ vnode, _row_id, backfill_finished, row_count ], primary key: [ $0 ASC ], value indices: [ 1, 2, 3 ], distribution key: [ 0 ], read pk prefix len hint: 1, vnode column idx: 0 } + + Table 4294967294 + ├── columns: [ b, t1.a, t1.b, null:Serial, $src ] + ├── primary key: [ $1 ASC, $2 ASC, $3 ASC, $4 ASC ] + ├── value indices: [ 0, 1, 2, 3, 4 ] + ├── distribution key: [ 1, 2, 3, 4 ] + └── read pk prefix len hint: 4 + +- name: test corresponding union error - corresponding list + sql: | + create table t1 (a int, b numeric, c bigint); + create table t2 (a int, b numeric, y bigint); + create table t3 (x int, b numeric, c bigint); + select * from t1 union corresponding select * from t2 union all corresponding by (c) select * from t3; + binder_error: 'Invalid input syntax: Column name `c` in CORRESPONDING BY is not found in a side of the UNION operation. It shall be included in both sides.' +- name: test corresponding union error - duplicate names + sql: | + create table t1 (a int, b numeric, c bigint); + create table t2 (a int, b numeric, y bigint); + select a, b as a from t1 union corresponding select * from t2; + binder_error: 'Invalid input syntax: Duplicated column name `a` in a column list of the query in a UNION operation. Column list of the query: ("a", "a").' diff --git a/src/frontend/src/binder/query.rs b/src/frontend/src/binder/query.rs index 459e1b7921e94..7ad2091e6fb87 100644 --- a/src/frontend/src/binder/query.rs +++ b/src/frontend/src/binder/query.rs @@ -46,7 +46,7 @@ pub struct BoundQuery { impl BoundQuery { /// The schema returned by this [`BoundQuery`]. - pub fn schema(&self) -> &Schema { + pub fn schema(&self) -> std::borrow::Cow<'_, Schema> { self.body.schema() } @@ -295,6 +295,7 @@ impl Binder { SetExpr::SetOperation { op: SetOperator::Union, all, + corresponding, left, right, }, @@ -307,6 +308,12 @@ impl Binder { .into()); }; + // validated in `validate_rcte` + assert!( + !corresponding.is_corresponding(), + "`CORRESPONDING` is not supported in recursive CTE" + ); + let entry = self .context .cte_to_relation @@ -396,6 +403,7 @@ impl Binder { let SetExpr::SetOperation { op: SetOperator::Union, all, + corresponding, left, right, } = body @@ -412,10 +420,18 @@ impl Binder { .into()); } + if corresponding.is_corresponding() { + return Err(ErrorCode::BindError( + "`CORRESPONDING` is not supported in recursive CTE".to_string(), + ) + .into()); + } + Ok(( SetExpr::SetOperation { op: SetOperator::Union, all, + corresponding, left, right, }, @@ -468,7 +484,7 @@ impl Binder { self.context.cte_to_relation = new_context.cte_to_relation; Self::align_schema(&mut base, &mut recursive, SetOperator::Union)?; - let schema = base.schema().clone(); + let schema = base.schema().into_owned(); let recursive_union = RecursiveUnion { all, diff --git a/src/frontend/src/binder/set_expr.rs b/src/frontend/src/binder/set_expr.rs index be4943d59defd..68af5845bf7a4 100644 --- a/src/frontend/src/binder/set_expr.rs +++ b/src/frontend/src/binder/set_expr.rs @@ -12,12 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::borrow::Cow; +use std::collections::HashMap; + use risingwave_common::bail_not_implemented; use risingwave_common::catalog::Schema; +use risingwave_common::util::column_index_mapping::ColIndexMapping; use risingwave_common::util::iter_util::ZipEqFast; -use risingwave_sqlparser::ast::{SetExpr, SetOperator}; +use risingwave_sqlparser::ast::{Corresponding, SetExpr, SetOperator}; use super::statement::RewriteExprsRecursive; +use super::UNNAMED_COLUMN; use crate::binder::{BindContext, Binder, BoundQuery, BoundSelect, BoundValues}; use crate::error::{ErrorCode, Result}; use crate::expr::{align_types, CorrelatedId, Depth}; @@ -33,6 +38,8 @@ pub enum BoundSetExpr { SetOperation { op: BoundSetOperation, all: bool, + // Corresponding columns of the left and right side. + corresponding_col_indices: Option<(ColIndexMapping, ColIndexMapping)>, left: Box, right: Box, }, @@ -72,12 +79,29 @@ impl From for BoundSetOperation { impl BoundSetExpr { /// The schema returned by this [`BoundSetExpr`]. - pub fn schema(&self) -> &Schema { + pub fn schema(&self) -> Cow<'_, Schema> { match self { - BoundSetExpr::Select(s) => s.schema(), - BoundSetExpr::Values(v) => v.schema(), + BoundSetExpr::Select(s) => Cow::Borrowed(s.schema()), + BoundSetExpr::Values(v) => Cow::Borrowed(v.schema()), BoundSetExpr::Query(q) => q.schema(), - BoundSetExpr::SetOperation { left, .. } => left.schema(), + BoundSetExpr::SetOperation { + left, + corresponding_col_indices, + .. + } => { + if let Some((mapping_l, _)) = corresponding_col_indices { + let mut schema = vec![None; mapping_l.target_size()]; + for (src, tar) in mapping_l.mapping_pairs() { + assert_eq!(schema[tar], None); + schema[tar] = Some(left.schema().fields[src].clone()); + } + Cow::Owned(Schema::new( + schema.into_iter().map(|x| x.unwrap()).collect(), + )) + } else { + left.schema() + } + } } } @@ -194,6 +218,92 @@ impl Binder { Ok(()) } + /// Check the corresponding specification of the set operation. + /// Returns the corresponding column index of the left and right side. + fn corresponding( + &self, + left: &BoundSetExpr, + right: &BoundSetExpr, + corresponding: Corresponding, + op: &SetOperator, + ) -> Result<(ColIndexMapping, ColIndexMapping)> { + let check_duplicate_name = |set_expr: &BoundSetExpr| { + let mut name2idx = HashMap::new(); + for (idx, field) in set_expr.schema().fields.iter().enumerate() { + if name2idx.insert(field.name.clone(), idx).is_some() { + return Err(ErrorCode::InvalidInputSyntax(format!( + "Duplicated column name `{}` in a column list of the query in a {} operation. Column list of the query: ({}).", + field.name, op, set_expr.schema().formatted_col_names(), + ))); + } + } + Ok(name2idx) + }; + + // Within the columns of both side, the same shall not + // be specified more than once. + let name2idx_l = check_duplicate_name(left)?; + let name2idx_r = check_duplicate_name(right)?; + + let mut corresponding_col_idx_l = vec![]; + let mut corresponding_col_idx_r = vec![]; + + if let Some(column_list) = corresponding.column_list() { + // The select list of the corresponding set operation should be in the order of + for column in column_list { + let col_name = column.real_value(); + if let Some(idx_l) = name2idx_l.get(&col_name) + && let Some(idx_r) = name2idx_l.get(&col_name) + { + corresponding_col_idx_l.push(*idx_l); + corresponding_col_idx_r.push(*idx_r); + } else { + return Err(ErrorCode::InvalidInputSyntax(format!( + "Column name `{}` in CORRESPONDING BY is not found in a side of the {} operation. \ + It shall be included in both sides.", + col_name, + op, + )).into()); + } + } + } else { + // The select list of the corresponding set operation should be + // in the order that appears in the s of the left side. + for field in &left.schema().fields { + let col_name = &field.name; + if col_name != UNNAMED_COLUMN + && let Some(idx_l) = name2idx_l.get(col_name) + && let Some(idx_r) = name2idx_r.get(col_name) + { + corresponding_col_idx_l.push(*idx_l); + corresponding_col_idx_r.push(*idx_r); + } + } + + if corresponding_col_idx_l.is_empty() { + return Err(ErrorCode::InvalidInputSyntax( + format!( + "When CORRESPONDING is specified, at least one column of the left side \ + shall have a column name that is the column name of some column of the right side in a {} operation. \ + Left side query column list: ({}). \ + Right side query column list: ({}).", + op, + left.schema().formatted_col_names(), + right.schema().formatted_col_names(), + ) + ) + .into()); + } + } + + let corresponding_mapping_l = + ColIndexMapping::with_remaining_columns(&corresponding_col_idx_l, left.schema().len()); + let corresponding_mapping_r = + ColIndexMapping::with_remaining_columns(&corresponding_col_idx_r, right.schema().len()); + + Ok((corresponding_mapping_l, corresponding_mapping_r)) + } + pub(super) fn bind_set_expr(&mut self, set_expr: SetExpr) -> Result { match set_expr { SetExpr::Select(s) => Ok(BoundSetExpr::Select(Box::new(self.bind_select(*s)?))), @@ -202,6 +312,7 @@ impl Binder { SetExpr::SetOperation { op, all, + corresponding, left, right, } => { @@ -215,15 +326,19 @@ impl Binder { .clone_from(&new_context.cte_to_relation); let mut right = self.bind_set_expr(*right)?; - if left.schema().fields.len() != right.schema().fields.len() { - return Err(ErrorCode::InvalidInputSyntax(format!( - "each {} query must have the same number of columns", - op - )) - .into()); - } - - Self::align_schema(&mut left, &mut right, op.clone())?; + let corresponding_col_indices = if corresponding.is_corresponding() { + Some(Self::corresponding( + self, + &left, + &right, + corresponding, + &op, + )?) + // TODO: Align schema + } else { + Self::align_schema(&mut left, &mut right, op.clone())?; + None + }; if all { match op { @@ -243,6 +358,7 @@ impl Binder { Ok(BoundSetExpr::SetOperation { op: op.into(), all, + corresponding_col_indices, left: Box::new(left), right: Box::new(right), }) diff --git a/src/frontend/src/planner/set_expr.rs b/src/frontend/src/planner/set_expr.rs index e2ff43a2c211b..8c8405b039b91 100644 --- a/src/frontend/src/planner/set_expr.rs +++ b/src/frontend/src/planner/set_expr.rs @@ -34,9 +34,10 @@ impl Planner { BoundSetExpr::SetOperation { op, all, + corresponding_col_indices, left, right, - } => self.plan_set_operation(op, all, *left, *right), + } => self.plan_set_operation(op, all, corresponding_col_indices, *left, *right), } } } diff --git a/src/frontend/src/planner/set_operation.rs b/src/frontend/src/planner/set_operation.rs index 1050c28bd11fb..8abdac8a7e2b1 100644 --- a/src/frontend/src/planner/set_operation.rs +++ b/src/frontend/src/planner/set_operation.rs @@ -1,3 +1,5 @@ +use risingwave_common::util::column_index_mapping::ColIndexMapping; + // Copyright 2024 RisingWave Labs // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,7 +15,7 @@ // limitations under the License. use crate::binder::{BoundSetExpr, BoundSetOperation}; use crate::error::Result; -use crate::optimizer::plan_node::{LogicalExcept, LogicalIntersect, LogicalUnion}; +use crate::optimizer::plan_node::{LogicalExcept, LogicalIntersect, LogicalProject, LogicalUnion}; use crate::planner::Planner; use crate::PlanRef; @@ -22,25 +24,27 @@ impl Planner { &mut self, op: BoundSetOperation, all: bool, + corresponding_col_indices: Option<(ColIndexMapping, ColIndexMapping)>, left: BoundSetExpr, right: BoundSetExpr, ) -> Result { + let left = self.plan_set_expr(left, vec![], &[])?; + let right = self.plan_set_expr(right, vec![], &[])?; + + // Map the corresponding columns + let (left, right) = if let Some((mapping_l, mapping_r)) = corresponding_col_indices { + ( + LogicalProject::with_mapping(left, mapping_l).into(), + LogicalProject::with_mapping(right, mapping_r).into(), + ) + } else { + (left, right) + }; + match op { - BoundSetOperation::Union => { - let left = self.plan_set_expr(left, vec![], &[])?; - let right = self.plan_set_expr(right, vec![], &[])?; - Ok(LogicalUnion::create(all, vec![left, right])) - } - BoundSetOperation::Intersect => { - let left = self.plan_set_expr(left, vec![], &[])?; - let right = self.plan_set_expr(right, vec![], &[])?; - Ok(LogicalIntersect::create(all, vec![left, right])) - } - BoundSetOperation::Except => { - let left = self.plan_set_expr(left, vec![], &[])?; - let right = self.plan_set_expr(right, vec![], &[])?; - Ok(LogicalExcept::create(all, vec![left, right])) - } + BoundSetOperation::Union => Ok(LogicalUnion::create(all, vec![left, right])), + BoundSetOperation::Intersect => Ok(LogicalIntersect::create(all, vec![left, right])), + BoundSetOperation::Except => Ok(LogicalExcept::create(all, vec![left, right])), } } } diff --git a/src/sqlparser/src/ast/mod.rs b/src/sqlparser/src/ast/mod.rs index dad2177dde0e4..d5cca61b6a186 100644 --- a/src/sqlparser/src/ast/mod.rs +++ b/src/sqlparser/src/ast/mod.rs @@ -46,9 +46,9 @@ pub use self::legacy_source::{ }; pub use self::operator::{BinaryOperator, QualifiedOperator, UnaryOperator}; pub use self::query::{ - Cte, CteInner, Distinct, Fetch, Join, JoinConstraint, JoinOperator, LateralView, OrderByExpr, - Query, Select, SelectItem, SetExpr, SetOperator, TableAlias, TableFactor, TableWithJoins, Top, - Values, With, + Corresponding, Cte, CteInner, Distinct, Fetch, Join, JoinConstraint, JoinOperator, LateralView, + OrderByExpr, Query, Select, SelectItem, SetExpr, SetOperator, TableAlias, TableFactor, + TableWithJoins, Top, Values, With, }; pub use self::statement::*; pub use self::value::{ diff --git a/src/sqlparser/src/ast/query.rs b/src/sqlparser/src/ast/query.rs index 83e84907a1091..b16a3075f90d9 100644 --- a/src/sqlparser/src/ast/query.rs +++ b/src/sqlparser/src/ast/query.rs @@ -97,6 +97,7 @@ pub enum SetExpr { SetOperation { op: SetOperator, all: bool, + corresponding: Corresponding, left: Box, right: Box, }, @@ -114,9 +115,10 @@ impl fmt::Display for SetExpr { right, op, all, + corresponding, } => { let all_str = if *all { " ALL" } else { "" }; - write!(f, "{} {}{} {}", left, op, all_str, right) + write!(f, "{} {}{}{} {}", left, op, all_str, corresponding, right) } } } @@ -140,6 +142,50 @@ impl fmt::Display for SetOperator { } } +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +/// `CORRESPONDING [ BY ]` +pub struct Corresponding { + pub corresponding: bool, + pub column_list: Option>, +} + +impl Corresponding { + pub fn with_column_list(column_list: Option>) -> Self { + Self { + corresponding: true, + column_list, + } + } + + pub fn none() -> Self { + Self { + corresponding: false, + column_list: None, + } + } + + pub fn is_corresponding(&self) -> bool { + self.corresponding + } + + pub fn column_list(&self) -> Option<&[Ident]> { + self.column_list.as_deref() + } +} + +impl fmt::Display for Corresponding { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if self.corresponding { + write!(f, " CORRESPONDING")?; + if let Some(column_list) = &self.column_list { + write!(f, " BY ({})", display_comma_separated(column_list))?; + } + } + Ok(()) + } +} + /// A restricted variant of `SELECT` (without CTEs/`ORDER BY`), which may /// appear either as the only body item of an `SQLQuery`, or as an operand /// to a set operation like `UNION`. diff --git a/src/sqlparser/src/parser.rs b/src/sqlparser/src/parser.rs index d683bf39f5d45..996fd9ebe8490 100644 --- a/src/sqlparser/src/parser.rs +++ b/src/sqlparser/src/parser.rs @@ -4114,10 +4114,15 @@ impl Parser<'_> { break; } self.next_token(); // skip past the set operator + + let all = self.parse_keyword(Keyword::ALL); + let corresponding = self.parse_corresponding()?; + expr = SetExpr::SetOperation { left: Box::new(expr), op: op.unwrap(), - all: self.parse_keyword(Keyword::ALL), + corresponding, + all, right: Box::new(self.parse_query_body(next_precedence)?), }; } @@ -4134,6 +4139,20 @@ impl Parser<'_> { } } + fn parse_corresponding(&mut self) -> PResult { + let corresponding = if self.parse_keyword(Keyword::CORRESPONDING) { + let column_list = if self.parse_keyword(Keyword::BY) { + Some(self.parse_parenthesized_column_list(IsOptional::Mandatory)?) + } else { + None + }; + Corresponding::with_column_list(column_list) + } else { + Corresponding::none() + }; + Ok(corresponding) + } + /// Parse a restricted `SELECT` statement (no CTEs / `UNION` / `ORDER BY`), /// assuming the initial `SELECT` was already consumed pub fn parse_select(&mut self) -> PResult