diff --git a/src/frontend/src/optimizer/plan_node/generic/agg.rs b/src/frontend/src/optimizer/plan_node/generic/agg.rs index 81f35e3579471..48f36e2c1dd36 100644 --- a/src/frontend/src/optimizer/plan_node/generic/agg.rs +++ b/src/frontend/src/optimizer/plan_node/generic/agg.rs @@ -554,12 +554,13 @@ impl<PlanRef: stream::StreamPlanRef> Agg<PlanRef> { .collect() } - pub fn decompose(self) -> (Vec<PlanAggCall>, IndexSet, Vec<IndexSet>, PlanRef) { + pub fn decompose(self) -> (Vec<PlanAggCall>, IndexSet, Vec<IndexSet>, PlanRef, bool) { ( self.agg_calls, self.group_key, self.grouping_sets, self.input, + self.enable_two_phase, ) } diff --git a/src/frontend/src/optimizer/plan_node/logical_agg.rs b/src/frontend/src/optimizer/plan_node/logical_agg.rs index 1ca6784e0e037..a2099b7d33f81 100644 --- a/src/frontend/src/optimizer/plan_node/logical_agg.rs +++ b/src/frontend/src/optimizer/plan_node/logical_agg.rs @@ -827,7 +827,7 @@ impl LogicalAgg { &self.core.grouping_sets } - pub fn decompose(self) -> (Vec<PlanAggCall>, IndexSet, Vec<IndexSet>, PlanRef) { + pub fn decompose(self) -> (Vec<PlanAggCall>, IndexSet, Vec<IndexSet>, PlanRef, bool) { self.core.decompose() } diff --git a/src/frontend/src/optimizer/rule/agg_project_merge_rule.rs b/src/frontend/src/optimizer/rule/agg_project_merge_rule.rs index d295b7f7bd5b9..3f58b1af7c6d5 100644 --- a/src/frontend/src/optimizer/rule/agg_project_merge_rule.rs +++ b/src/frontend/src/optimizer/rule/agg_project_merge_rule.rs @@ -16,7 +16,6 @@ use itertools::Itertools; use super::super::plan_node::*; use super::{BoxedRule, Rule}; -use crate::optimizer::plan_node::generic::Agg; use crate::utils::IndexSet; /// Merge [`LogicalAgg`] <- [`LogicalProject`] to [`LogicalAgg`]. diff --git a/src/frontend/src/optimizer/rule/apply_agg_transpose_rule.rs b/src/frontend/src/optimizer/rule/apply_agg_transpose_rule.rs index 8781ca58b5ae8..78aa5affd509c 100644 --- a/src/frontend/src/optimizer/rule/apply_agg_transpose_rule.rs +++ b/src/frontend/src/optimizer/rule/apply_agg_transpose_rule.rs @@ -52,7 +52,8 @@ impl Rule for ApplyAggTransposeRule { apply.clone().decompose(); assert_eq!(join_type, JoinType::Inner); let agg: &LogicalAgg = right.as_logical_agg()?; - let (mut agg_calls, agg_group_key, grouping_sets, agg_input) = agg.clone().decompose(); + let (mut agg_calls, agg_group_key, grouping_sets, agg_input, enable_two_phase) = + agg.clone().decompose(); assert!(grouping_sets.is_empty()); let is_scalar_agg = agg_group_key.is_empty(); let apply_left_len = left.schema().len(); @@ -147,7 +148,9 @@ impl Rule for ApplyAggTransposeRule { } let mut group_keys: IndexSet = (0..apply_left_len).collect(); group_keys.extend(agg_group_key.indices().map(|key| key + apply_left_len)); - Agg::new(agg_calls, group_keys, node).into() + Agg::new(agg_calls, group_keys, node) + .with_enable_two_phase(enable_two_phase) + .into() }; let filter = LogicalFilter::create(group_agg, on); diff --git a/src/frontend/src/optimizer/rule/distinct_agg_rule.rs b/src/frontend/src/optimizer/rule/distinct_agg_rule.rs index 7c31d8938ca89..7c772a11967bc 100644 --- a/src/frontend/src/optimizer/rule/distinct_agg_rule.rs +++ b/src/frontend/src/optimizer/rule/distinct_agg_rule.rs @@ -35,7 +35,8 @@ pub struct DistinctAggRule { impl Rule for DistinctAggRule { fn apply(&self, plan: PlanRef) -> Option<PlanRef> { let agg: &LogicalAgg = plan.as_logical_agg()?; - let (mut agg_calls, mut agg_group_keys, grouping_sets, input) = agg.clone().decompose(); + let (mut agg_calls, mut agg_group_keys, grouping_sets, input, enable_two_phase) = + agg.clone().decompose(); assert!(grouping_sets.is_empty()); if agg_calls.iter().all(|c| !c.distinct) { @@ -84,6 +85,7 @@ impl Rule for DistinctAggRule { agg_calls, flag_values, has_expand, + enable_two_phase, )) } } @@ -246,6 +248,7 @@ impl DistinctAggRule { mut agg_calls: Vec<PlanAggCall>, flag_values: Vec<usize>, has_expand: bool, + enable_two_phase: bool, ) -> PlanRef { // the index of `flag` in schema of the middle `LogicalAgg`, if has `Expand`. let pos_of_flag = mid_agg.group_key.len() - 1; @@ -322,6 +325,8 @@ impl DistinctAggRule { } }); - Agg::new(agg_calls, final_agg_group_keys, mid_agg.into()).into() + Agg::new(agg_calls, final_agg_group_keys, mid_agg.into()) + .with_enable_two_phase(enable_two_phase) + .into() } } diff --git a/src/frontend/src/optimizer/rule/grouping_sets_to_expand_rule.rs b/src/frontend/src/optimizer/rule/grouping_sets_to_expand_rule.rs index 2073743c90c17..a15ccc19ffb71 100644 --- a/src/frontend/src/optimizer/rule/grouping_sets_to_expand_rule.rs +++ b/src/frontend/src/optimizer/rule/grouping_sets_to_expand_rule.rs @@ -75,7 +75,7 @@ impl Rule for GroupingSetsToExpandRule { return None; } let agg = Self::prune_column_for_agg(agg); - let (agg_calls, mut group_keys, grouping_sets, input) = agg.decompose(); + let (agg_calls, mut group_keys, grouping_sets, input, enable_two_phase) = agg.decompose(); let flag_col_idx = group_keys.len(); let input_schema_len = input.schema().len(); @@ -159,7 +159,8 @@ impl Rule for GroupingSetsToExpandRule { } } - let new_agg = Agg::new(new_agg_calls, group_keys, expand); + let new_agg = + Agg::new(new_agg_calls, group_keys, expand).with_enable_two_phase(enable_two_phase); let project_exprs = (0..flag_col_idx) .map(|i| { ExprImpl::InputRef(