diff --git a/src/frontend/planner_test/tests/testdata/input/subquery.yaml b/src/frontend/planner_test/tests/testdata/input/subquery.yaml index 47785f0234271..5cb1dec5f8af2 100644 --- a/src/frontend/planner_test/tests/testdata/input/subquery.yaml +++ b/src/frontend/planner_test/tests/testdata/input/subquery.yaml @@ -299,5 +299,10 @@ - name: While this one is allowed. sql: | SELECT generate_series(1, (select 1)); + expected_outputs: + - batch_plan +- name: array subquery + sql: | + select Array(select 1 union select 2); expected_outputs: - batch_plan \ No newline at end of file diff --git a/src/frontend/planner_test/tests/testdata/input/subquery_expr_correlated.yaml b/src/frontend/planner_test/tests/testdata/input/subquery_expr_correlated.yaml index 6f16abeb813a3..8692bedf26ffb 100644 --- a/src/frontend/planner_test/tests/testdata/input/subquery_expr_correlated.yaml +++ b/src/frontend/planner_test/tests/testdata/input/subquery_expr_correlated.yaml @@ -448,3 +448,27 @@ select (select 1 from t, di where t.a = dl.c1 and t.b = di.d1 limit 1) name, (select 1 from t, di where t.a = dl.c2 and t.c = di.d2 limit 1) name2 from dl; expected_outputs: - optimized_logical_plan_for_stream +- name: correlated array subquery + sql: | + create table t1 (a int, b int); + create table t2 (c int, d int); + select Array(select c from t2 where b = d) arr from t1; + expected_outputs: + - batch_plan + - stream_plan +- name: correlated array subquery \du + sql: | + SELECT r.rolname, r.rolsuper, r.rolinherit, + r.rolcreaterole, r.rolcreatedb, r.rolcanlogin, + r.rolconnlimit, r.rolvaliduntil, + ARRAY(SELECT b.rolname + FROM pg_catalog.pg_auth_members m + JOIN pg_catalog.pg_roles b ON (m.roleid = b.oid) + WHERE m.member = r.oid) as memberof + , r.rolreplication + , r.rolbypassrls + FROM pg_catalog.pg_roles r + WHERE r.rolname !~ '^pg_' + ORDER BY 1; + expected_outputs: + - batch_plan diff --git a/src/frontend/planner_test/tests/testdata/output/subquery.yaml b/src/frontend/planner_test/tests/testdata/output/subquery.yaml index d79ace454a704..fce9b37df5f20 100644 --- a/src/frontend/planner_test/tests/testdata/output/subquery.yaml +++ b/src/frontend/planner_test/tests/testdata/output/subquery.yaml @@ -875,3 +875,14 @@ └─BatchNestedLoopJoin { type: LeftOuter, predicate: true, output: all } ├─BatchValues { rows: [[]] } └─BatchValues { rows: [[1:Int32]] } +- name: array subquery + sql: | + select Array(select 1 union select 2); + batch_plan: |- + BatchNestedLoopJoin { type: LeftOuter, predicate: true, output: all } + ├─BatchValues { rows: [[]] } + └─BatchSimpleAgg { aggs: [array_agg(1:Int32)] } + └─BatchExchange { order: [], dist: Single } + └─BatchHashAgg { group_key: [1:Int32], aggs: [] } + └─BatchExchange { order: [], dist: HashShard(1:Int32) } + └─BatchValues { rows: [[1:Int32], [2:Int32]] } diff --git a/src/frontend/planner_test/tests/testdata/output/subquery_expr_correlated.yaml b/src/frontend/planner_test/tests/testdata/output/subquery_expr_correlated.yaml index f21af89d1c6db..7bebd532f349a 100644 --- a/src/frontend/planner_test/tests/testdata/output/subquery_expr_correlated.yaml +++ b/src/frontend/planner_test/tests/testdata/output/subquery_expr_correlated.yaml @@ -971,3 +971,82 @@ └─LogicalJoin { type: Inner, on: ((t.b = di.d1) OR (t.c = di.d2)), output: all } ├─LogicalScan { table: t, columns: [t.a, t.b, t.c], predicate: IsNotNull(t.a) } └─LogicalScan { table: di, columns: [di.d1, di.d2] } +- name: correlated array subquery + sql: | + create table t1 (a int, b int); + create table t2 (c int, d int); + select Array(select c from t2 where b = d) arr from t1; + batch_plan: |- + BatchExchange { order: [], dist: Single } + └─BatchHashJoin { type: LeftOuter, predicate: t1.b IS NOT DISTINCT FROM t1.b, output: [array_agg(t2.c)] } + ├─BatchExchange { order: [], dist: HashShard(t1.b) } + │ └─BatchScan { table: t1, columns: [t1.b], distribution: SomeShard } + └─BatchHashAgg { group_key: [t1.b], aggs: [array_agg(t2.c)] } + └─BatchHashJoin { type: LeftOuter, predicate: t1.b IS NOT DISTINCT FROM t2.d, output: [t1.b, t2.c] } + ├─BatchHashAgg { group_key: [t1.b], aggs: [] } + │ └─BatchExchange { order: [], dist: HashShard(t1.b) } + │ └─BatchScan { table: t1, columns: [t1.b], distribution: SomeShard } + └─BatchExchange { order: [], dist: HashShard(t2.d) } + └─BatchProject { exprs: [t2.d, t2.c] } + └─BatchFilter { predicate: IsNotNull(t2.d) } + └─BatchScan { table: t2, columns: [t2.c, t2.d], distribution: SomeShard } + stream_plan: |- + StreamMaterialize { columns: [arr, t1._row_id(hidden), t1.b(hidden), t1.b#1(hidden)], stream_key: [t1._row_id, t1.b], pk_columns: [t1._row_id, t1.b], pk_conflict: NoCheck } + └─StreamExchange { dist: HashShard(t1._row_id, t1.b) } + └─StreamHashJoin { type: LeftOuter, predicate: t1.b IS NOT DISTINCT FROM t1.b, output: [array_agg(t2.c), t1._row_id, t1.b, t1.b] } + ├─StreamExchange { dist: HashShard(t1.b) } + │ └─StreamTableScan { table: t1, columns: [t1.b, t1._row_id], pk: [t1._row_id], dist: UpstreamHashShard(t1._row_id) } + └─StreamProject { exprs: [t1.b, array_agg(t2.c)] } + └─StreamHashAgg { group_key: [t1.b], aggs: [array_agg(t2.c), count] } + └─StreamHashJoin { type: LeftOuter, predicate: t1.b IS NOT DISTINCT FROM t2.d, output: [t1.b, t2.c, t2._row_id] } + ├─StreamProject { exprs: [t1.b] } + │ └─StreamHashAgg { group_key: [t1.b], aggs: [count] } + │ └─StreamExchange { dist: HashShard(t1.b) } + │ └─StreamTableScan { table: t1, columns: [t1.b, t1._row_id], pk: [t1._row_id], dist: UpstreamHashShard(t1._row_id) } + └─StreamExchange { dist: HashShard(t2.d) } + └─StreamProject { exprs: [t2.d, t2.c, t2._row_id] } + └─StreamFilter { predicate: IsNotNull(t2.d) } + └─StreamTableScan { table: t2, columns: [t2.c, t2.d, t2._row_id], pk: [t2._row_id], dist: UpstreamHashShard(t2._row_id) } +- name: correlated array subquery \du + sql: | + SELECT r.rolname, r.rolsuper, r.rolinherit, + r.rolcreaterole, r.rolcreatedb, r.rolcanlogin, + r.rolconnlimit, r.rolvaliduntil, + ARRAY(SELECT b.rolname + FROM pg_catalog.pg_auth_members m + JOIN pg_catalog.pg_roles b ON (m.roleid = b.oid) + WHERE m.member = r.oid) as memberof + , r.rolreplication + , r.rolbypassrls + FROM pg_catalog.pg_roles r + WHERE r.rolname !~ '^pg_' + ORDER BY 1; + batch_plan: |- + BatchExchange { order: [rw_users.name ASC], dist: Single } + └─BatchProject { exprs: [rw_users.name, rw_users.is_super, true:Boolean, rw_users.create_user, rw_users.create_db, rw_users.can_login, -1:Int32, null:Timestamptz, array_agg(rw_users.name), true:Boolean, true:Boolean] } + └─BatchSort { order: [rw_users.name ASC] } + └─BatchHashJoin { type: LeftOuter, predicate: rw_users.id IS NOT DISTINCT FROM rw_users.id, output: all } + ├─BatchExchange { order: [], dist: HashShard(rw_users.id) } + │ └─BatchFilter { predicate: Not(RegexpEq(rw_users.name, '^pg_':Varchar)) } + │ └─BatchScan { table: rw_users, columns: [rw_users.id, rw_users.name, rw_users.is_super, rw_users.create_db, rw_users.create_user, rw_users.can_login], distribution: Single } + └─BatchHashAgg { group_key: [rw_users.id], aggs: [array_agg(rw_users.name)] } + └─BatchHashJoin { type: LeftOuter, predicate: rw_users.id IS NOT DISTINCT FROM rw_users.id, output: [rw_users.id, rw_users.name] } + ├─BatchHashAgg { group_key: [rw_users.id], aggs: [] } + │ └─BatchExchange { order: [], dist: HashShard(rw_users.id) } + │ └─BatchProject { exprs: [rw_users.id] } + │ └─BatchFilter { predicate: Not(RegexpEq(rw_users.name, '^pg_':Varchar)) } + │ └─BatchScan { table: rw_users, columns: [rw_users.id, rw_users.name], distribution: Single } + └─BatchExchange { order: [], dist: HashShard(rw_users.id) } + └─BatchHashJoin { type: Inner, predicate: null:Int32 = rw_users.id, output: [rw_users.id, rw_users.name] } + ├─BatchExchange { order: [], dist: HashShard(null:Int32) } + │ └─BatchProject { exprs: [rw_users.id, null:Int32] } + │ └─BatchNestedLoopJoin { type: Inner, predicate: true, output: all } + │ ├─BatchExchange { order: [], dist: Single } + │ │ └─BatchHashAgg { group_key: [rw_users.id], aggs: [] } + │ │ └─BatchExchange { order: [], dist: HashShard(rw_users.id) } + │ │ └─BatchProject { exprs: [rw_users.id] } + │ │ └─BatchFilter { predicate: (null:Int32 = rw_users.id) AND Not(RegexpEq(rw_users.name, '^pg_':Varchar)) } + │ │ └─BatchScan { table: rw_users, columns: [rw_users.id, rw_users.name], distribution: Single } + │ └─BatchValues { rows: [] } + └─BatchExchange { order: [], dist: HashShard(rw_users.id) } + └─BatchScan { table: rw_users, columns: [rw_users.id, rw_users.name], distribution: Single } diff --git a/src/frontend/src/binder/expr/mod.rs b/src/frontend/src/binder/expr/mod.rs index c52c42dbc973d..93df1a8a14ea4 100644 --- a/src/frontend/src/binder/expr/mod.rs +++ b/src/frontend/src/binder/expr/mod.rs @@ -169,6 +169,7 @@ impl Binder { } => self.bind_overlay(*expr, *new_substring, *start, count), Expr::Parameter { index } => self.bind_parameter(index), Expr::Collate { expr, collation } => self.bind_collate(*expr, collation), + Expr::ArraySubquery(q) => self.bind_subquery_expr(*q, SubqueryKind::Array), _ => bail_not_implemented!(issue = 112, "unsupported expression {:?}", expr), } } diff --git a/src/frontend/src/expr/subquery.rs b/src/frontend/src/expr/subquery.rs index 84fbd7d55c979..9208775ace583 100644 --- a/src/frontend/src/expr/subquery.rs +++ b/src/frontend/src/expr/subquery.rs @@ -32,6 +32,8 @@ pub enum SubqueryKind { Some(ExprImpl, ExprType), /// Expression operator `ALL` subquery. All(ExprImpl, ExprType), + /// Expression operator `ARRAY` subquery. + Array, } /// Subquery expression. @@ -86,6 +88,11 @@ impl Expr for Subquery { assert_eq!(types.len(), 1, "Subquery with more than one column"); types[0].clone() } + SubqueryKind::Array => { + let types = self.query.data_types(); + assert_eq!(types.len(), 1, "Subquery with more than one column"); + DataType::List(types[0].clone().into()) + } _ => DataType::Boolean, } } diff --git a/src/frontend/src/planner/select.rs b/src/frontend/src/planner/select.rs index fa0e08d4f0217..40409617d1bd3 100644 --- a/src/frontend/src/planner/select.rs +++ b/src/frontend/src/planner/select.rs @@ -15,12 +15,13 @@ use std::collections::HashMap; use itertools::Itertools; -use risingwave_common::bail_not_implemented; use risingwave_common::catalog::Schema; use risingwave_common::error::{ErrorCode, Result}; use risingwave_common::types::DataType; use risingwave_common::util::iter_util::ZipEqFast; use risingwave_common::util::sort_util::ColumnOrder; +use risingwave_common::{bail, bail_not_implemented}; +use risingwave_expr::aggregate::AggKind; use risingwave_expr::ExprError; use risingwave_pb::plan_common::JoinType; @@ -29,7 +30,7 @@ use crate::expr::{ CorrelatedId, Expr, ExprImpl, ExprRewriter, ExprType, FunctionCall, InputRef, Subquery, SubqueryKind, }; -use crate::optimizer::plan_node::generic::{Agg, Project, ProjectBuilder}; +use crate::optimizer::plan_node::generic::{Agg, GenericPlanRef, Project, ProjectBuilder}; pub use crate::optimizer::plan_node::LogicalFilter; use crate::optimizer::plan_node::{ LogicalAgg, LogicalApply, LogicalDedup, LogicalOverWindow, LogicalProject, LogicalProjectSet, @@ -218,6 +219,30 @@ impl Planner { Ok(LogicalProject::create(count_star.into(), vec![ge.into()])) } + /// Helper to create an `ARRAY_AGG` operator with the given `input`. + /// It is represented by `ARRAY_AGG($0) -> input` + fn create_array_agg(&self, input: PlanRef) -> Result { + let fields = input.schema().fields(); + if fields.len() != 1 { + bail!("subquery must return only one column"); + } + let input_column_type = fields[0].data_type(); + Ok(Agg::new( + vec![PlanAggCall { + agg_kind: AggKind::ArrayAgg, + return_type: DataType::List(input.schema().fields()[0].data_type().into()), + inputs: vec![InputRef::new(0, input_column_type)], + distinct: false, + order_by: vec![], + filter: Condition::true_cond(), + direct_args: vec![], + }], + IndexSet::empty(), + input, + ) + .into()) + } + /// For `(NOT) EXISTS subquery` or `(NOT) IN subquery`, we can plan it as /// `LeftSemi/LeftAnti` [`LogicalApply`] /// For other subqueries, we plan it as `LeftOuter` [`LogicalApply`] using @@ -373,6 +398,9 @@ impl Planner { SubqueryKind::Existential => { right = self.create_exists(right)?; } + SubqueryKind::Array => { + right = self.create_array_agg(right)?; + } _ => bail_not_implemented!(issue = 1343, "{:?}", subquery.kind), }