diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/tree/AddDecodeNodeForDictStringRule.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/tree/AddDecodeNodeForDictStringRule.java index 9be3c0d99100e..c0951382eb6ec 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/tree/AddDecodeNodeForDictStringRule.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/tree/AddDecodeNodeForDictStringRule.java @@ -665,7 +665,9 @@ private PhysicalHashAggregateOperator rewriteAggOperator(PhysicalHashAggregateOp if (canApplyDictDecodeOpt) { CallOperator oldCall = kv.getValue(); int columnId = kv.getValue().getUsedColumns().getFirstId(); - if (context.needRewriteMultiCountDistinctColumns.contains(columnId)) { + final String fnName = kv.getValue().getFnName(); + if (context.needRewriteMultiCountDistinctColumns.contains(columnId) + && fnName.equals(FunctionSet.MULTI_DISTINCT_COUNT)) { // we only need rewrite TFunction Type[] newTypes = new Type[] {ID_TYPE}; AggregateFunction newFunction = @@ -682,7 +684,6 @@ private PhysicalHashAggregateOperator rewriteAggOperator(PhysicalHashAggregateOp List newArguments = Collections.singletonList(dictColumn); Type[] newTypes = newArguments.stream().map(ScalarOperator::getType).toArray(Type[]::new); - String fnName = kv.getValue().getFnName(); AggregateFunction newFunction = (AggregateFunction) Expr.getBuiltinFunction(kv.getValue().getFnName(), newTypes, Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF); diff --git a/fe/fe-core/src/test/java/com/starrocks/sql/plan/LowCardinalityTest.java b/fe/fe-core/src/test/java/com/starrocks/sql/plan/LowCardinalityTest.java index 2f6afaf791ab8..afe6f77e169f9 100644 --- a/fe/fe-core/src/test/java/com/starrocks/sql/plan/LowCardinalityTest.java +++ b/fe/fe-core/src/test/java/com/starrocks/sql/plan/LowCardinalityTest.java @@ -395,6 +395,11 @@ public void testDecodeNodeRewriteMultiCountDistinct() throws Exception { plan = getFragmentPlan(sql); Assert.assertTrue(plan.contains(" multi_distinct_count(11: S_ADDRESS), " + "multi_distinct_count(12: S_COMMENT)")); + + sql = "select max(a) from (select count(distinct S_ADDRESS) a from supplier)t"; + plan = getFragmentPlan(sql); + assertContains(plan, "multi_distinct_count(9: count)"); + connectContext.getSessionVariable().setNewPlanerAggStage(3); sql = "select max(S_ADDRESS), count(distinct S_ADDRESS) from supplier group by S_ADDRESS;"; plan = getFragmentPlan(sql);