diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/tree/AddDecodeNodeForDictStringRule.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/tree/AddDecodeNodeForDictStringRule.java index 30405be9ad71d..d4fa8df2855c9 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/tree/AddDecodeNodeForDictStringRule.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/tree/AddDecodeNodeForDictStringRule.java @@ -636,7 +636,9 @@ private PhysicalHashAggregateOperator rewriteAggOperator(PhysicalHashAggregateOp if (canApplyDictDecodeOpt) { CallOperator oldCall = kv.getValue(); int columnId = kv.getValue().getUsedColumns().getFirstId(); - if (context.needRewriteMultiCountDistinctColumns.contains(columnId)) { + final String fnName = kv.getValue().getFnName(); + if (context.needRewriteMultiCountDistinctColumns.contains(columnId) + && fnName.equals(FunctionSet.MULTI_DISTINCT_COUNT)) { // we only need rewrite TFunction Type[] newTypes = new Type[] {ID_TYPE}; AggregateFunction newFunction = @@ -653,7 +655,6 @@ private PhysicalHashAggregateOperator rewriteAggOperator(PhysicalHashAggregateOp List newArguments = Collections.singletonList(dictColumn); Type[] newTypes = newArguments.stream().map(ScalarOperator::getType).toArray(Type[]::new); - String fnName = kv.getValue().getFnName(); AggregateFunction newFunction = (AggregateFunction) Expr.getBuiltinFunction(kv.getValue().getFnName(), newTypes, Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF); diff --git a/fe/fe-core/src/test/java/com/starrocks/sql/plan/LowCardinalityTest.java b/fe/fe-core/src/test/java/com/starrocks/sql/plan/LowCardinalityTest.java index c50926370e7dd..1e3aaf4a64d79 100644 --- a/fe/fe-core/src/test/java/com/starrocks/sql/plan/LowCardinalityTest.java +++ b/fe/fe-core/src/test/java/com/starrocks/sql/plan/LowCardinalityTest.java @@ -384,6 +384,11 @@ public void testDecodeNodeRewriteMultiCountDistinct() throws Exception { plan = getFragmentPlan(sql); Assert.assertTrue(plan.contains(" multi_distinct_count(11: S_ADDRESS), " + "multi_distinct_count(12: S_COMMENT)")); + + sql = "select max(a) from (select count(distinct S_ADDRESS) a from supplier)t"; + plan = getFragmentPlan(sql); + assertContains(plan, "multi_distinct_count(9: count)"); + connectContext.getSessionVariable().setNewPlanerAggStage(3); sql = "select max(S_ADDRESS), count(distinct S_ADDRESS) from supplier group by S_ADDRESS;"; plan = getFragmentPlan(sql);