diff --git a/src/frontend/planner_test/tests/testdata/input/nexmark_source.yaml b/src/frontend/planner_test/tests/testdata/input/nexmark_source.yaml index 71c683c8ba14..a88a5d138a61 100644 --- a/src/frontend/planner_test/tests/testdata/input/nexmark_source.yaml +++ b/src/frontend/planner_test/tests/testdata/input/nexmark_source.yaml @@ -396,6 +396,32 @@ - stream_dist_plan with_config_map: RW_FORCE_SPLIT_DISTINCT_AGG: 'true' +- id: nexmark_q15_split_distinct_agg_and_force_two_phase + before: + - create_sources + sql: | + SELECT + TO_CHAR(date_time, 'yyyy-MM-dd') as day, + count(*) AS total_bids, + count(*) filter (where price < 10000) AS rank1_bids, + count(*) filter (where price >= 10000 and price < 1000000) AS rank2_bids, + count(*) filter (where price >= 1000000) AS rank3_bids, + count(distinct bidder) AS total_bidders, + count(distinct bidder) filter (where price < 10000) AS rank1_bidders, + count(distinct bidder) filter (where price >= 10000 and price < 1000000) AS rank2_bidders, + count(distinct bidder) filter (where price >= 1000000) AS rank3_bidders, + count(distinct auction) AS total_auctions, + count(distinct auction) filter (where price < 10000) AS rank1_auctions, + count(distinct auction) filter (where price >= 10000 and price < 1000000) AS rank2_auctions, + count(distinct auction) filter (where price >= 1000000) AS rank3_auctions + FROM bid + GROUP BY to_char(date_time, 'yyyy-MM-dd'); + expected_outputs: + - stream_plan + - stream_dist_plan + with_config_map: + RW_FORCE_SPLIT_DISTINCT_AGG: 'true' + RW_FORCE_TWO_PHASE_AGG: 'true' - id: nexmark_q16 before: - create_sources diff --git a/src/frontend/planner_test/tests/testdata/output/nexmark_source.yaml b/src/frontend/planner_test/tests/testdata/output/nexmark_source.yaml index f95405ca3d12..8a96d59de67b 100644 --- a/src/frontend/planner_test/tests/testdata/output/nexmark_source.yaml +++ b/src/frontend/planner_test/tests/testdata/output/nexmark_source.yaml @@ -1147,6 +1147,85 @@ with_config_map: RW_FORCE_SPLIT_DISTINCT_AGG: 'true' +- id: nexmark_q15_split_distinct_agg_and_force_two_phase + before: + - create_sources + sql: | + SELECT + TO_CHAR(date_time, 'yyyy-MM-dd') as day, + count(*) AS total_bids, + count(*) filter (where price < 10000) AS rank1_bids, + count(*) filter (where price >= 10000 and price < 1000000) AS rank2_bids, + count(*) filter (where price >= 1000000) AS rank3_bids, + count(distinct bidder) AS total_bidders, + count(distinct bidder) filter (where price < 10000) AS rank1_bidders, + count(distinct bidder) filter (where price >= 10000 and price < 1000000) AS rank2_bidders, + count(distinct bidder) filter (where price >= 1000000) AS rank3_bidders, + count(distinct auction) AS total_auctions, + count(distinct auction) filter (where price < 10000) AS rank1_auctions, + count(distinct auction) filter (where price >= 10000 and price < 1000000) AS rank2_auctions, + count(distinct auction) filter (where price >= 1000000) AS rank3_auctions + FROM bid + GROUP BY to_char(date_time, 'yyyy-MM-dd'); + stream_plan: |- + StreamMaterialize { columns: [day, total_bids, rank1_bids, rank2_bids, rank3_bids, total_bidders, rank1_bidders, rank2_bidders, rank3_bidders, total_auctions, rank1_auctions, rank2_auctions, rank3_auctions], stream_key: [day], pk_columns: [day], pk_conflict: NoCheck } + └─StreamProject { exprs: [$expr1, sum0(sum0(count) filter((flag = 0:Int64))), sum0(sum0(count filter((price < 10000:Int32))) filter((flag = 0:Int64))), sum0(sum0(count filter((price >= 10000:Int32) AND (price < 1000000:Int32))) filter((flag = 0:Int64))), sum0(sum0(count filter((price >= 1000000:Int32))) filter((flag = 0:Int64))), sum0(count(bidder) filter((flag = 1:Int64))), sum0(count(bidder) filter((count filter((price < 10000:Int32)) > 0:Int64) AND (flag = 1:Int64))), sum0(count(bidder) filter((count filter((price >= 10000:Int32) AND (price < 1000000:Int32)) > 0:Int64) AND (flag = 1:Int64))), sum0(count(bidder) filter((count filter((price >= 1000000:Int32)) > 0:Int64) AND (flag = 1:Int64))), sum0(count(auction) filter((flag = 2:Int64))), sum0(count(auction) filter((count filter((price < 10000:Int32)) > 0:Int64) AND (flag = 2:Int64))), sum0(count(auction) filter((count filter((price >= 10000:Int32) AND (price < 1000000:Int32)) > 0:Int64) AND (flag = 2:Int64))), sum0(count(auction) filter((count filter((price >= 1000000:Int32)) > 0:Int64) AND (flag = 2:Int64)))] } + └─StreamHashAgg { group_key: [$expr1], aggs: [sum0(sum0(count) filter((flag = 0:Int64))), sum0(sum0(count filter((price < 10000:Int32))) filter((flag = 0:Int64))), sum0(sum0(count filter((price >= 10000:Int32) AND (price < 1000000:Int32))) filter((flag = 0:Int64))), sum0(sum0(count filter((price >= 1000000:Int32))) filter((flag = 0:Int64))), sum0(count(bidder) filter((flag = 1:Int64))), sum0(count(bidder) filter((count filter((price < 10000:Int32)) > 0:Int64) AND (flag = 1:Int64))), sum0(count(bidder) filter((count filter((price >= 10000:Int32) AND (price < 1000000:Int32)) > 0:Int64) AND (flag = 1:Int64))), sum0(count(bidder) filter((count filter((price >= 1000000:Int32)) > 0:Int64) AND (flag = 1:Int64))), sum0(count(auction) filter((flag = 2:Int64))), sum0(count(auction) filter((count filter((price < 10000:Int32)) > 0:Int64) AND (flag = 2:Int64))), sum0(count(auction) filter((count filter((price >= 10000:Int32) AND (price < 1000000:Int32)) > 0:Int64) AND (flag = 2:Int64))), sum0(count(auction) filter((count filter((price >= 1000000:Int32)) > 0:Int64) AND (flag = 2:Int64))), count] } + └─StreamExchange { dist: HashShard($expr1) } + └─StreamHashAgg { group_key: [$expr1, $expr2], aggs: [sum0(count) filter((flag = 0:Int64)), sum0(count filter((price < 10000:Int32))) filter((flag = 0:Int64)), sum0(count filter((price >= 10000:Int32) AND (price < 1000000:Int32))) filter((flag = 0:Int64)), sum0(count filter((price >= 1000000:Int32))) filter((flag = 0:Int64)), count(bidder) filter((flag = 1:Int64)), count(bidder) filter((count filter((price < 10000:Int32)) > 0:Int64) AND (flag = 1:Int64)), count(bidder) filter((count filter((price >= 10000:Int32) AND (price < 1000000:Int32)) > 0:Int64) AND (flag = 1:Int64)), count(bidder) filter((count filter((price >= 1000000:Int32)) > 0:Int64) AND (flag = 1:Int64)), count(auction) filter((flag = 2:Int64)), count(auction) filter((count filter((price < 10000:Int32)) > 0:Int64) AND (flag = 2:Int64)), count(auction) filter((count filter((price >= 10000:Int32) AND (price < 1000000:Int32)) > 0:Int64) AND (flag = 2:Int64)), count(auction) filter((count filter((price >= 1000000:Int32)) > 0:Int64) AND (flag = 2:Int64)), count] } + └─StreamProject { exprs: [$expr1, bidder, auction, flag, count, count filter((price < 10000:Int32)), count filter((price >= 10000:Int32) AND (price < 1000000:Int32)), count filter((price >= 1000000:Int32)), count filter((price < 10000:Int32)), count filter((price >= 10000:Int32) AND (price < 1000000:Int32)), count filter((price >= 1000000:Int32)), count filter((price < 10000:Int32)), count filter((price >= 10000:Int32) AND (price < 1000000:Int32)), count filter((price >= 1000000:Int32)), Vnode($expr1, bidder, auction, flag) as $expr2] } + └─StreamHashAgg [append_only] { group_key: [$expr1, bidder, auction, flag], aggs: [count, count filter((price < 10000:Int32)), count filter((price >= 10000:Int32) AND (price < 1000000:Int32)), count filter((price >= 1000000:Int32)), count filter((price < 10000:Int32)), count filter((price >= 10000:Int32) AND (price < 1000000:Int32)), count filter((price >= 1000000:Int32)), count filter((price < 10000:Int32)), count filter((price >= 10000:Int32) AND (price < 1000000:Int32)), count filter((price >= 1000000:Int32))] } + └─StreamExchange { dist: HashShard($expr1, bidder, auction, flag) } + └─StreamExpand { column_subsets: [[$expr1], [$expr1, bidder], [$expr1, auction]] } + └─StreamProject { exprs: [ToChar(date_time, 'yyyy-MM-dd':Varchar) as $expr1, price, bidder, auction, _row_id] } + └─StreamRowIdGen { row_id_index: 7 } + └─StreamSource { source: bid, columns: [auction, bidder, price, channel, url, date_time, extra, _row_id] } + stream_dist_plan: |+ + Fragment 0 + StreamMaterialize { columns: [day, total_bids, rank1_bids, rank2_bids, rank3_bids, total_bidders, rank1_bidders, rank2_bidders, rank3_bidders, total_auctions, rank1_auctions, rank2_auctions, rank3_auctions], stream_key: [day], pk_columns: [day], pk_conflict: NoCheck } { materialized table: 4294967294 } + └── StreamProject { exprs: [$expr1, sum0(sum0(count) filter((flag = 0:Int64))), sum0(sum0(count filter((price < 10000:Int32))) filter((flag = 0:Int64))), sum0(sum0(count filter((price >= 10000:Int32) AND (price < 1000000:Int32))) filter((flag = 0:Int64))), sum0(sum0(count filter((price >= 1000000:Int32))) filter((flag = 0:Int64))), sum0(count(bidder) filter((flag = 1:Int64))), sum0(count(bidder) filter((count filter((price < 10000:Int32)) > 0:Int64) AND (flag = 1:Int64))), sum0(count(bidder) filter((count filter((price >= 10000:Int32) AND (price < 1000000:Int32)) > 0:Int64) AND (flag = 1:Int64))), sum0(count(bidder) filter((count filter((price >= 1000000:Int32)) > 0:Int64) AND (flag = 1:Int64))), sum0(count(auction) filter((flag = 2:Int64))), sum0(count(auction) filter((count filter((price < 10000:Int32)) > 0:Int64) AND (flag = 2:Int64))), sum0(count(auction) filter((count filter((price >= 10000:Int32) AND (price < 1000000:Int32)) > 0:Int64) AND (flag = 2:Int64))), sum0(count(auction) filter((count filter((price >= 1000000:Int32)) > 0:Int64) AND (flag = 2:Int64)))] } + └── StreamHashAgg { group_key: [$expr1], aggs: [sum0(sum0(count) filter((flag = 0:Int64))), sum0(sum0(count filter((price < 10000:Int32))) filter((flag = 0:Int64))), sum0(sum0(count filter((price >= 10000:Int32) AND (price < 1000000:Int32))) filter((flag = 0:Int64))), sum0(sum0(count filter((price >= 1000000:Int32))) filter((flag = 0:Int64))), sum0(count(bidder) filter((flag = 1:Int64))), sum0(count(bidder) filter((count filter((price < 10000:Int32)) > 0:Int64) AND (flag = 1:Int64))), sum0(count(bidder) filter((count filter((price >= 10000:Int32) AND (price < 1000000:Int32)) > 0:Int64) AND (flag = 1:Int64))), sum0(count(bidder) filter((count filter((price >= 1000000:Int32)) > 0:Int64) AND (flag = 1:Int64))), sum0(count(auction) filter((flag = 2:Int64))), sum0(count(auction) filter((count filter((price < 10000:Int32)) > 0:Int64) AND (flag = 2:Int64))), sum0(count(auction) filter((count filter((price >= 10000:Int32) AND (price < 1000000:Int32)) > 0:Int64) AND (flag = 2:Int64))), sum0(count(auction) filter((count filter((price >= 1000000:Int32)) > 0:Int64) AND (flag = 2:Int64))), count] } + ├── result table: 0 + ├── state tables: [] + ├── distinct tables: [] + └── StreamExchange Hash([0]) from 1 + + Fragment 1 + StreamHashAgg { group_key: [$expr1, $expr2], aggs: [sum0(count) filter((flag = 0:Int64)), sum0(count filter((price < 10000:Int32))) filter((flag = 0:Int64)), sum0(count filter((price >= 10000:Int32) AND (price < 1000000:Int32))) filter((flag = 0:Int64)), sum0(count filter((price >= 1000000:Int32))) filter((flag = 0:Int64)), count(bidder) filter((flag = 1:Int64)), count(bidder) filter((count filter((price < 10000:Int32)) > 0:Int64) AND (flag = 1:Int64)), count(bidder) filter((count filter((price >= 10000:Int32) AND (price < 1000000:Int32)) > 0:Int64) AND (flag = 1:Int64)), count(bidder) filter((count filter((price >= 1000000:Int32)) > 0:Int64) AND (flag = 1:Int64)), count(auction) filter((flag = 2:Int64)), count(auction) filter((count filter((price < 10000:Int32)) > 0:Int64) AND (flag = 2:Int64)), count(auction) filter((count filter((price >= 10000:Int32) AND (price < 1000000:Int32)) > 0:Int64) AND (flag = 2:Int64)), count(auction) filter((count filter((price >= 1000000:Int32)) > 0:Int64) AND (flag = 2:Int64)), count] } { result table: 1, state tables: [], distinct tables: [] } + └── StreamProject { exprs: [$expr1, bidder, auction, flag, count, count filter((price < 10000:Int32)), count filter((price >= 10000:Int32) AND (price < 1000000:Int32)), count filter((price >= 1000000:Int32)), count filter((price < 10000:Int32)), count filter((price >= 10000:Int32) AND (price < 1000000:Int32)), count filter((price >= 1000000:Int32)), count filter((price < 10000:Int32)), count filter((price >= 10000:Int32) AND (price < 1000000:Int32)), count filter((price >= 1000000:Int32)), Vnode($expr1, bidder, auction, flag) as $expr2] } + └── StreamHashAgg [append_only] { group_key: [$expr1, bidder, auction, flag], aggs: [count, count filter((price < 10000:Int32)), count filter((price >= 10000:Int32) AND (price < 1000000:Int32)), count filter((price >= 1000000:Int32)), count filter((price < 10000:Int32)), count filter((price >= 10000:Int32) AND (price < 1000000:Int32)), count filter((price >= 1000000:Int32)), count filter((price < 10000:Int32)), count filter((price >= 10000:Int32) AND (price < 1000000:Int32)), count filter((price >= 1000000:Int32))] } { result table: 2, state tables: [], distinct tables: [] } + └── StreamExchange Hash([0, 2, 3, 10]) from 2 + + Fragment 2 + StreamExpand { column_subsets: [[$expr1], [$expr1, bidder], [$expr1, auction]] } + └── StreamProject { exprs: [ToChar(date_time, 'yyyy-MM-dd':Varchar) as $expr1, price, bidder, auction, _row_id] } + └── StreamRowIdGen { row_id_index: 7 } + └── StreamSource { source: bid, columns: [auction, bidder, price, channel, url, date_time, extra, _row_id] } { source state table: 3 } + + Table 0 + ├── columns: [ $expr1, sum0(sum0(count) filter((flag = 0:Int64))), sum0(sum0(count filter((price < 10000:Int32))) filter((flag = 0:Int64))), sum0(sum0(count filter((price >= 10000:Int32) AND (price < 1000000:Int32))) filter((flag = 0:Int64))), sum0(sum0(count filter((price >= 1000000:Int32))) filter((flag = 0:Int64))), sum0(count(bidder) filter((flag = 1:Int64))), sum0(count(bidder) filter((count filter((price < 10000:Int32)) > 0:Int64) AND (flag = 1:Int64))), sum0(count(bidder) filter((count filter((price >= 10000:Int32) AND (price < 1000000:Int32)) > 0:Int64) AND (flag = 1:Int64))), sum0(count(bidder) filter((count filter((price >= 1000000:Int32)) > 0:Int64) AND (flag = 1:Int64))), sum0(count(auction) filter((flag = 2:Int64))), sum0(count(auction) filter((count filter((price < 10000:Int32)) > 0:Int64) AND (flag = 2:Int64))), sum0(count(auction) filter((count filter((price >= 10000:Int32) AND (price < 1000000:Int32)) > 0:Int64) AND (flag = 2:Int64))), sum0(count(auction) filter((count filter((price >= 1000000:Int32)) > 0:Int64) AND (flag = 2:Int64))), count ] + ├── primary key: [ $0 ASC ] + ├── value indices: [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13 ] + ├── distribution key: [ 0 ] + └── read pk prefix len hint: 1 + + Table 1 + ├── columns: [ $expr1, $expr2, sum0(count) filter((flag = 0:Int64)), sum0(count filter((price < 10000:Int32))) filter((flag = 0:Int64)), sum0(count filter((price >= 10000:Int32) AND (price < 1000000:Int32))) filter((flag = 0:Int64)), sum0(count filter((price >= 1000000:Int32))) filter((flag = 0:Int64)), count(bidder) filter((flag = 1:Int64)), count(bidder) filter((count filter((price < 10000:Int32)) > 0:Int64) AND (flag = 1:Int64)), count(bidder) filter((count filter((price >= 10000:Int32) AND (price < 1000000:Int32)) > 0:Int64) AND (flag = 1:Int64)), count(bidder) filter((count filter((price >= 1000000:Int32)) > 0:Int64) AND (flag = 1:Int64)), count(auction) filter((flag = 2:Int64)), count(auction) filter((count filter((price < 10000:Int32)) > 0:Int64) AND (flag = 2:Int64)), count(auction) filter((count filter((price >= 10000:Int32) AND (price < 1000000:Int32)) > 0:Int64) AND (flag = 2:Int64)), count(auction) filter((count filter((price >= 1000000:Int32)) > 0:Int64) AND (flag = 2:Int64)), count ] + ├── primary key: [ $0 ASC, $1 ASC ] + ├── value indices: [ 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14 ] + ├── distribution key: [] + ├── read pk prefix len hint: 2 + └── vnode column idx: 1 + + Table 2 { columns: [ $expr1, bidder, auction, flag, count, count filter((price < 10000:Int32)), count filter((price >= 10000:Int32) AND (price < 1000000:Int32)), count filter((price >= 1000000:Int32)), count filter((price < 10000:Int32))_0, count filter((price >= 10000:Int32) AND (price < 1000000:Int32))_0, count filter((price >= 1000000:Int32))_0, count filter((price < 10000:Int32))_1, count filter((price >= 10000:Int32) AND (price < 1000000:Int32))_1, count filter((price >= 1000000:Int32))_1 ], primary key: [ $0 ASC, $1 ASC, $2 ASC, $3 ASC ], value indices: [ 4, 5, 6, 7, 8, 9, 10, 11, 12, 13 ], distribution key: [ 0, 1, 2, 3 ], read pk prefix len hint: 4 } + + Table 3 { columns: [ partition_id, offset_info ], primary key: [ $0 ASC ], value indices: [ 0, 1 ], distribution key: [], read pk prefix len hint: 1 } + + Table 4294967294 { columns: [ day, total_bids, rank1_bids, rank2_bids, rank3_bids, total_bidders, rank1_bidders, rank2_bidders, rank3_bidders, total_auctions, rank1_auctions, rank2_auctions, rank3_auctions ], primary key: [ $0 ASC ], value indices: [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 ], distribution key: [ 0 ], read pk prefix len hint: 1 } + + with_config_map: + RW_FORCE_SPLIT_DISTINCT_AGG: 'true' + RW_FORCE_TWO_PHASE_AGG: 'true' - id: nexmark_q16 before: - create_sources diff --git a/src/frontend/src/optimizer/plan_node/generic/agg.rs b/src/frontend/src/optimizer/plan_node/generic/agg.rs index 63bde692af61..48f36e2c1dd3 100644 --- a/src/frontend/src/optimizer/plan_node/generic/agg.rs +++ b/src/frontend/src/optimizer/plan_node/generic/agg.rs @@ -52,6 +52,7 @@ pub struct Agg { pub group_key: IndexSet, pub grouping_sets: Vec, pub input: PlanRef, + pub enable_two_phase: bool, } impl Agg { @@ -89,7 +90,7 @@ impl Agg { } fn two_phase_agg_enabled(&self) -> bool { - self.ctx().session_ctx().config().get_enable_two_phase_agg() + self.enable_two_phase } pub(crate) fn can_two_phase_agg(&self) -> bool { @@ -136,26 +137,28 @@ impl Agg { } pub fn new(agg_calls: Vec, group_key: IndexSet, input: PlanRef) -> Self { + let enable_two_phase = input + .ctx() + .session_ctx() + .config() + .get_enable_two_phase_agg(); Self { agg_calls, group_key, input, grouping_sets: vec![], + enable_two_phase, } } - pub fn new_with_grouping_sets( - agg_calls: Vec, - group_key: IndexSet, - grouping_sets: Vec, - input: PlanRef, - ) -> Self { - Self { - agg_calls, - group_key, - grouping_sets, - input, - } + pub fn with_grouping_sets(mut self, grouping_sets: Vec) -> Self { + self.grouping_sets = grouping_sets; + self + } + + pub fn with_enable_two_phase(mut self, enable_two_phase: bool) -> Self { + self.enable_two_phase = enable_two_phase; + self } } @@ -551,12 +554,13 @@ impl Agg { .collect() } - pub fn decompose(self) -> (Vec, IndexSet, Vec, PlanRef) { + pub fn decompose(self) -> (Vec, IndexSet, Vec, PlanRef, bool) { ( self.agg_calls, self.group_key, self.grouping_sets, self.input, + self.enable_two_phase, ) } diff --git a/src/frontend/src/optimizer/plan_node/logical_agg.rs b/src/frontend/src/optimizer/plan_node/logical_agg.rs index 4387c8f4f89f..a2099b7d33f8 100644 --- a/src/frontend/src/optimizer/plan_node/logical_agg.rs +++ b/src/frontend/src/optimizer/plan_node/logical_agg.rs @@ -361,13 +361,9 @@ impl LogicalAggBuilder { let logical_project = LogicalProject::with_core(self.input_proj_builder.build(input)); // This LogicalAgg focuses on calculating the aggregates and grouping. - Agg::new_with_grouping_sets( - self.agg_calls, - self.group_key, - self.grouping_sets, - logical_project.into(), - ) - .into() + Agg::new(self.agg_calls, self.group_key, logical_project.into()) + .with_grouping_sets(self.grouping_sets) + .into() } fn rewrite_with_error(&mut self, expr: ExprImpl) -> Result { @@ -831,7 +827,7 @@ impl LogicalAgg { &self.core.grouping_sets } - pub fn decompose(self) -> (Vec, IndexSet, Vec, PlanRef) { + pub fn decompose(self) -> (Vec, IndexSet, Vec, PlanRef, bool) { self.core.decompose() } @@ -870,8 +866,9 @@ impl LogicalAgg { .map(|set| set.indices().map(|key| input_col_change.map(key)).collect()) .collect(); - let new_agg = - Agg::new_with_grouping_sets(agg_calls, group_key.clone(), grouping_sets, input); + let new_agg = Agg::new(agg_calls, group_key.clone(), input) + .with_grouping_sets(grouping_sets) + .with_enable_two_phase(self.core().enable_two_phase); // group_key remapping might cause an output column change, since group key actually is a // `FixedBitSet`. @@ -896,13 +893,10 @@ impl PlanTreeNodeUnary for LogicalAgg { } fn clone_with_input(&self, input: PlanRef) -> Self { - Agg::new_with_grouping_sets( - self.agg_calls().to_vec(), - self.group_key().clone(), - self.grouping_sets().clone(), - input, - ) - .into() + Agg::new(self.agg_calls().to_vec(), self.group_key().clone(), input) + .with_grouping_sets(self.grouping_sets().clone()) + .with_enable_two_phase(self.core().enable_two_phase) + .into() } #[must_use] diff --git a/src/frontend/src/optimizer/plan_node/logical_union.rs b/src/frontend/src/optimizer/plan_node/logical_union.rs index e21b39088315..38ef55405693 100644 --- a/src/frontend/src/optimizer/plan_node/logical_union.rs +++ b/src/frontend/src/optimizer/plan_node/logical_union.rs @@ -125,11 +125,10 @@ impl ToBatch for LogicalUnion { // Convert union to union all + agg if !self.all() { let batch_union = BatchUnion::new(new_logical).into(); - Ok(BatchHashAgg::new(generic::Agg::new( - vec![], - (0..self.base.schema.len()).collect(), - batch_union, - )) + Ok(BatchHashAgg::new( + generic::Agg::new(vec![], (0..self.base.schema.len()).collect(), batch_union) + .with_enable_two_phase(false), + ) .into()) } else { Ok(BatchUnion::new(new_logical).into()) diff --git a/src/frontend/src/optimizer/rule/agg_project_merge_rule.rs b/src/frontend/src/optimizer/rule/agg_project_merge_rule.rs index 3c6056e2db27..3f58b1af7c6d 100644 --- a/src/frontend/src/optimizer/rule/agg_project_merge_rule.rs +++ b/src/frontend/src/optimizer/rule/agg_project_merge_rule.rs @@ -16,7 +16,6 @@ use itertools::Itertools; use super::super::plan_node::*; use super::{BoxedRule, Rule}; -use crate::optimizer::plan_node::generic::Agg; use crate::utils::IndexSet; /// Merge [`LogicalAgg`] <- [`LogicalProject`] to [`LogicalAgg`]. @@ -24,31 +23,33 @@ pub struct AggProjectMergeRule {} impl Rule for AggProjectMergeRule { fn apply(&self, plan: PlanRef) -> Option { let agg = plan.as_logical_agg()?; - let (mut agg_calls, agg_group_keys, grouping_sets, input) = agg.clone().decompose(); - assert!(grouping_sets.is_empty()); - let proj = input.as_logical_project()?; - + let agg = agg.core().clone(); + assert!(agg.grouping_sets.is_empty()); + let old_input = agg.input.clone(); + let proj = old_input.as_logical_project()?; // only apply when the input proj is all input-ref if !proj.is_all_inputref() { return None; } - let proj_o2i = proj.o2i_col_mapping(); - let new_input = proj.input(); - - // modify agg calls according to projection - agg_calls - .iter_mut() - .for_each(|x| x.rewrite_input_index(proj_o2i.clone())); // modify group key according to projection - let new_agg_group_keys_in_vec = agg_group_keys + let new_agg_group_keys_in_vec = agg + .group_key .indices() .map(|x| proj_o2i.map(x)) .collect_vec(); - let new_agg_group_keys = IndexSet::from_iter(new_agg_group_keys_in_vec.clone()); + let mut agg = agg; + agg.input = proj.input(); + // modify agg calls according to projection + agg.agg_calls + .iter_mut() + .for_each(|x| x.rewrite_input_index(proj_o2i.clone())); + agg.group_key = new_agg_group_keys.clone(); + agg.input = proj.input(); + if new_agg_group_keys.to_vec() != new_agg_group_keys_in_vec { // Need a project let new_agg_group_keys_cardinality = new_agg_group_keys.len(); @@ -57,17 +58,11 @@ impl Rule for AggProjectMergeRule { .map(|x| new_agg_group_keys.indices().position(|y| y == x).unwrap()) .chain( new_agg_group_keys_cardinality - ..new_agg_group_keys_cardinality + agg_calls.len(), + ..new_agg_group_keys_cardinality + agg.agg_calls.len(), ); - Some( - LogicalProject::with_out_col_idx( - Agg::new(agg_calls, new_agg_group_keys.clone(), new_input).into(), - out_col_idx, - ) - .into(), - ) + Some(LogicalProject::with_out_col_idx(agg.into(), out_col_idx).into()) } else { - Some(Agg::new(agg_calls, new_agg_group_keys, new_input).into()) + Some(agg.into()) } } } diff --git a/src/frontend/src/optimizer/rule/apply_agg_transpose_rule.rs b/src/frontend/src/optimizer/rule/apply_agg_transpose_rule.rs index 8781ca58b5ae..78aa5affd509 100644 --- a/src/frontend/src/optimizer/rule/apply_agg_transpose_rule.rs +++ b/src/frontend/src/optimizer/rule/apply_agg_transpose_rule.rs @@ -52,7 +52,8 @@ impl Rule for ApplyAggTransposeRule { apply.clone().decompose(); assert_eq!(join_type, JoinType::Inner); let agg: &LogicalAgg = right.as_logical_agg()?; - let (mut agg_calls, agg_group_key, grouping_sets, agg_input) = agg.clone().decompose(); + let (mut agg_calls, agg_group_key, grouping_sets, agg_input, enable_two_phase) = + agg.clone().decompose(); assert!(grouping_sets.is_empty()); let is_scalar_agg = agg_group_key.is_empty(); let apply_left_len = left.schema().len(); @@ -147,7 +148,9 @@ impl Rule for ApplyAggTransposeRule { } let mut group_keys: IndexSet = (0..apply_left_len).collect(); group_keys.extend(agg_group_key.indices().map(|key| key + apply_left_len)); - Agg::new(agg_calls, group_keys, node).into() + Agg::new(agg_calls, group_keys, node) + .with_enable_two_phase(enable_two_phase) + .into() }; let filter = LogicalFilter::create(group_agg, on); diff --git a/src/frontend/src/optimizer/rule/distinct_agg_rule.rs b/src/frontend/src/optimizer/rule/distinct_agg_rule.rs index b7f9a5f90210..cc3273e726ee 100644 --- a/src/frontend/src/optimizer/rule/distinct_agg_rule.rs +++ b/src/frontend/src/optimizer/rule/distinct_agg_rule.rs @@ -35,7 +35,8 @@ pub struct DistinctAggRule { impl Rule for DistinctAggRule { fn apply(&self, plan: PlanRef) -> Option { let agg: &LogicalAgg = plan.as_logical_agg()?; - let (mut agg_calls, mut agg_group_keys, grouping_sets, input) = agg.clone().decompose(); + let (mut agg_calls, mut agg_group_keys, grouping_sets, input, enable_two_phase) = + agg.clone().decompose(); assert!(grouping_sets.is_empty()); if agg_calls.iter().all(|c| !c.distinct) { @@ -84,6 +85,7 @@ impl Rule for DistinctAggRule { agg_calls, flag_values, has_expand, + enable_two_phase, )) } } @@ -237,7 +239,7 @@ impl DistinctAggRule { // append `flag`. group_keys.insert(project.schema().len() - 1); } - Agg::new(agg_calls, group_keys, project) + Agg::new(agg_calls, group_keys, project).with_enable_two_phase(false) } fn build_final_agg( @@ -246,6 +248,7 @@ impl DistinctAggRule { mut agg_calls: Vec, flag_values: Vec, has_expand: bool, + enable_two_phase: bool, ) -> PlanRef { // the index of `flag` in schema of the middle `LogicalAgg`, if has `Expand`. let pos_of_flag = mid_agg.group_key.len() - 1; @@ -322,6 +325,8 @@ impl DistinctAggRule { } }); - Agg::new(agg_calls, final_agg_group_keys, mid_agg.into()).into() + Agg::new(agg_calls, final_agg_group_keys, mid_agg.into()) + .with_enable_two_phase(enable_two_phase) + .into() } } diff --git a/src/frontend/src/optimizer/rule/grouping_sets_to_expand_rule.rs b/src/frontend/src/optimizer/rule/grouping_sets_to_expand_rule.rs index 2073743c90c1..a15ccc19ffb7 100644 --- a/src/frontend/src/optimizer/rule/grouping_sets_to_expand_rule.rs +++ b/src/frontend/src/optimizer/rule/grouping_sets_to_expand_rule.rs @@ -75,7 +75,7 @@ impl Rule for GroupingSetsToExpandRule { return None; } let agg = Self::prune_column_for_agg(agg); - let (agg_calls, mut group_keys, grouping_sets, input) = agg.decompose(); + let (agg_calls, mut group_keys, grouping_sets, input, enable_two_phase) = agg.decompose(); let flag_col_idx = group_keys.len(); let input_schema_len = input.schema().len(); @@ -159,7 +159,8 @@ impl Rule for GroupingSetsToExpandRule { } } - let new_agg = Agg::new(new_agg_calls, group_keys, expand); + let new_agg = + Agg::new(new_agg_calls, group_keys, expand).with_enable_two_phase(enable_two_phase); let project_exprs = (0..flag_col_idx) .map(|i| { ExprImpl::InputRef( diff --git a/src/frontend/src/optimizer/rule/union_to_distinct_rule.rs b/src/frontend/src/optimizer/rule/union_to_distinct_rule.rs index bd4764fe04f1..2a12f6b712e0 100644 --- a/src/frontend/src/optimizer/rule/union_to_distinct_rule.rs +++ b/src/frontend/src/optimizer/rule/union_to_distinct_rule.rs @@ -24,7 +24,8 @@ impl Rule for UnionToDistinctRule { let union: &LogicalUnion = plan.as_logical_union()?; if !union.all() { let union_all = LogicalUnion::create(true, union.inputs().into_iter().collect()); - let distinct = Agg::new(vec![], (0..union.base.schema.len()).collect(), union_all); + let distinct = Agg::new(vec![], (0..union.base.schema.len()).collect(), union_all) + .with_enable_two_phase(false); Some(distinct.into()) } else { None