Skip to content

Commit

Permalink
Support GroupConcat sql for aggregating multiple shards in opengauss …
Browse files Browse the repository at this point in the history
…and doris database(#33992) (#33991)

* Support GroupConcat sql for aggregating multiple shards(#33797)

* Check Style fix(#33797)

* Check Style fix(#33797)

* spotless fix (#33797)

* unit test fix  (#33797)

* spotless fix (#33797)

* group_concat distinct compatible  (#33797)

* group_concat distinct compatible  (#33797)

* unit test fix for distinct group_concat  (#33797)

* e2e test for group_concat  (#33797)

* e2e test for group_concat  (#33797)

* code format  (#33797)

* e2e test  (#33797)

* e2e test  (#33797)

* e2e test  (#33797)

* remove useless code(#33797)

* code optimization (#33797)

* sql parse unit test (#33797)

* RELEASE-NOTES.md updated(#33797)

* Code Optimization (#33797)

* Support GroupConcat sql for aggregating multiple shards in opengauss and doris database(#33797)

* doris parse unit test fix (#33797)

* spotless fix (#33797)

* Update RELEASE-NOTES.md

---------

Co-authored-by: yaofly <[email protected]>
Co-authored-by: Zhengqiang Duan <[email protected]>
  • Loading branch information
3 people authored Dec 10, 2024
1 parent 7eabecf commit fac2bb6
Show file tree
Hide file tree
Showing 11 changed files with 72 additions and 23 deletions.
2 changes: 1 addition & 1 deletion RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
1. Proxy Native: Change the Base Docker Image of ShardingSphere Proxy Native - [#33263](https://github.com/apache/shardingsphere/issues/33263)
1. Proxy Native: Support connecting to HiveServer2 with ZooKeeper Service Discovery enabled in GraalVM Native Image - [#33768](https://github.com/apache/shardingsphere/pull/33768)
1. Proxy Native: Support local transactions of ClickHouse under GraalVM Native Image - [#33801](https://github.com/apache/shardingsphere/pull/33801)
1. Sharding: Support MYSQL GroupConcat function for aggregating multiple shards - [#33808](https://github.com/apache/shardingsphere/pull/33808)
1. Sharding: Support GroupConcat function for aggregating multiple shards in MySQL, OpenGauss, Doris - [#33808](https://github.com/apache/shardingsphere/pull/33808)
1. Proxy Native: Support Seata AT integration under Proxy Native in GraalVM Native Image - [#33889](https://github.com/apache/shardingsphere/pull/33889)
1. Agent: Simplify the use of Agent's Docker Image - [#33356](https://github.com/apache/shardingsphere/pull/33356)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -965,8 +965,16 @@ udfFunction
: functionName LP_ (expr? | expr (COMMA_ expr)*) RP_
;

separatorName
: SEPARATOR string_
;

aggregationExpression
: expr (COMMA_ expr)* | ASTERISK_
;

aggregationFunction
: aggregationFunctionName LP_ distinct? (expr (COMMA_ expr)* | ASTERISK_)? collateClause? RP_ overClause?
: aggregationFunctionName LP_ distinct? aggregationExpression? collateClause? separatorName? RP_ overClause?
;

// DORIS ADDED BEGIN
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.antlr.v4.runtime.ParserRuleContext;
import org.antlr.v4.runtime.Token;
import org.antlr.v4.runtime.misc.Interval;
import org.antlr.v4.runtime.tree.ParseTree;
import org.antlr.v4.runtime.tree.TerminalNode;
import org.apache.shardingsphere.sql.parser.api.ASTNode;
import org.apache.shardingsphere.sql.parser.autogen.DorisStatementBaseVisitor;
Expand Down Expand Up @@ -958,14 +959,18 @@ public final ASTNode visitJsonTableFunction(final JsonTableFunctionContext ctx)

private ASTNode createAggregationSegment(final AggregationFunctionContext ctx, final String aggregationType) {
AggregationType type = AggregationType.valueOf(aggregationType.toUpperCase());
String separator = null;
if (null != ctx.separatorName()) {
separator = new StringLiteralValue(ctx.separatorName().string_().getText()).getValue();
}
if (null != ctx.distinct()) {
AggregationDistinctProjectionSegment result =
new AggregationDistinctProjectionSegment(ctx.getStart().getStartIndex(), ctx.getStop().getStopIndex(), type, getOriginalText(ctx), getDistinctExpression(ctx));
result.getParameters().addAll(getExpressions(ctx.expr()));
new AggregationDistinctProjectionSegment(ctx.getStart().getStartIndex(), ctx.getStop().getStopIndex(), type, getOriginalText(ctx), getDistinctExpression(ctx), separator);
result.getParameters().addAll(getExpressions(ctx.aggregationExpression().expr()));
return result;
}
AggregationProjectionSegment result = new AggregationProjectionSegment(ctx.getStart().getStartIndex(), ctx.getStop().getStopIndex(), type, getOriginalText(ctx));
result.getParameters().addAll(getExpressions(ctx.expr()));
AggregationProjectionSegment result = new AggregationProjectionSegment(ctx.getStart().getStartIndex(), ctx.getStop().getStopIndex(), type, getOriginalText(ctx), separator);
result.getParameters().addAll(getExpressions(ctx.aggregationExpression().expr()));
return result;
}

Expand All @@ -981,11 +986,7 @@ protected Collection<ExpressionSegment> getExpressions(final List<ExprContext> e
}

private String getDistinctExpression(final AggregationFunctionContext ctx) {
StringBuilder result = new StringBuilder();
for (int i = 3; i < ctx.getChildCount() - 1; i++) {
result.append(ctx.getChild(i).getText());
}
return result.toString();
return ctx.aggregationExpression().getText();
}

@Override
Expand Down Expand Up @@ -1046,12 +1047,25 @@ public final ASTNode visitSpecialFunction(final SpecialFunctionContext ctx) {
public final ASTNode visitGroupConcatFunction(final GroupConcatFunctionContext ctx) {
calculateParameterCount(ctx.expr());
FunctionSegment result = new FunctionSegment(ctx.getStart().getStartIndex(), ctx.getStop().getStopIndex(), ctx.GROUP_CONCAT().getText(), getOriginalText(ctx));
for (ExprContext each : ctx.expr()) {
for (ExprContext each : getTargetRuleContextFromParseTree(ctx, ExprContext.class)) {
result.getParameters().add((ExpressionSegment) visit(each));
}
return result;
}

private <T extends ParseTree> Collection<T> getTargetRuleContextFromParseTree(final ParseTree parseTree, final Class<? extends T> clazz) {
Collection<T> result = new LinkedList<>();
for (int index = 0; index < parseTree.getChildCount(); index++) {
ParseTree child = parseTree.getChild(index);
if (clazz.isInstance(child)) {
result.add(clazz.cast(child));
} else {
result.addAll(getTargetRuleContextFromParseTree(child, clazz));
}
}
return result;
}

// DORIS ADDED BEGIN
@Override
public final ASTNode visitBitwiseFunction(final BitwiseFunctionContext ctx) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.antlr.v4.runtime.ParserRuleContext;
import org.antlr.v4.runtime.Token;
import org.antlr.v4.runtime.misc.Interval;
import org.antlr.v4.runtime.tree.ParseTree;
import org.antlr.v4.runtime.tree.TerminalNode;
import org.apache.shardingsphere.sql.parser.api.ASTNode;
import org.apache.shardingsphere.sql.parser.autogen.MySQLStatementBaseVisitor;
Expand Down Expand Up @@ -1041,12 +1042,25 @@ public final ASTNode visitSpecialFunction(final SpecialFunctionContext ctx) {
public final ASTNode visitGroupConcatFunction(final GroupConcatFunctionContext ctx) {
calculateParameterCount(ctx.expr());
FunctionSegment result = new FunctionSegment(ctx.getStart().getStartIndex(), ctx.getStop().getStopIndex(), ctx.GROUP_CONCAT().getText(), getOriginalText(ctx));
for (ExprContext each : ctx.expr()) {
for (ExprContext each : getTargetRuleContextFromParseTree(ctx, ExprContext.class)) {
result.getParameters().add((ExpressionSegment) visit(each));
}
return result;
}

private <T extends ParseTree> Collection<T> getTargetRuleContextFromParseTree(final ParseTree parseTree, final Class<? extends T> clazz) {
Collection<T> result = new LinkedList<>();
for (int index = 0; index < parseTree.getChildCount(); index++) {
ParseTree child = parseTree.getChild(index);
if (clazz.isInstance(child)) {
result.add(clazz.cast(child));
} else {
result.addAll(getTargetRuleContextFromParseTree(child, clazz));
}
}
return result;
}

@Override
public final ASTNode visitCastFunction(final CastFunctionContext ctx) {
FunctionSegment result = new FunctionSegment(ctx.getStart().getStartIndex(), ctx.getStop().getStopIndex(), ctx.CAST().getText(), getOriginalText(ctx));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -954,13 +954,16 @@ typeList
: typeName (COMMA_ typeName)*
;

separatorName
: SEPARATOR STRING_
;

funcApplication
: funcName LP_ RP_
| funcName LP_ funcArgList sortClause? RP_
| funcName LP_ DISTINCT? funcArgList sortClause? separatorName? RP_
| funcName LP_ VARIADIC funcArgExpr sortClause? RP_
| funcName LP_ funcArgList COMMA_ VARIADIC funcArgExpr sortClause? RP_
| funcName LP_ ALL funcArgList sortClause? RP_
| funcName LP_ DISTINCT funcArgList sortClause? RP_
| funcName LP_ ASTERISK_ RP_
;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -664,3 +664,6 @@ MAXVALUE
: M A X V A L U E
;

SEPARATOR
: S E P A R A T O R
;
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ public ASTNode visitFuncExpr(final FuncExprContext ctx) {
Collection<ExpressionSegment> expressionSegments = getExpressionSegments(getTargetRuleContextFromParseTree(ctx, AExprContext.class));
// TODO replace aggregation segment
String aggregationType = ctx.funcApplication().funcName().getText();
if (AggregationType.isAggregationType(aggregationType)) {
if (AggregationType.isAggregationType(aggregationType) && null == ctx.funcApplication().sortClause()) {
return createAggregationSegment(ctx.funcApplication(), aggregationType, expressionSegments);
}
FunctionSegment result = new FunctionSegment(ctx.getStart().getStartIndex(), ctx.getStop().getStopIndex(), ctx.funcApplication().funcName().getText(), getOriginalText(ctx));
Expand Down Expand Up @@ -594,13 +594,17 @@ public ASTNode visitBExpr(final BExprContext ctx) {

private ProjectionSegment createAggregationSegment(final FuncApplicationContext ctx, final String aggregationType, final Collection<ExpressionSegment> expressionSegments) {
AggregationType type = AggregationType.valueOf(aggregationType.toUpperCase());
String separator = null;
if (null != ctx.separatorName()) {
separator = new StringLiteralValue(ctx.separatorName().STRING_().getText()).getValue();
}
if (null == ctx.DISTINCT()) {
AggregationProjectionSegment result = new AggregationProjectionSegment(ctx.getStart().getStartIndex(), ctx.getStop().getStopIndex(), type, getOriginalText(ctx));
AggregationProjectionSegment result = new AggregationProjectionSegment(ctx.getStart().getStartIndex(), ctx.getStop().getStopIndex(), type, getOriginalText(ctx), separator);
result.getParameters().addAll(expressionSegments);
return result;
}
AggregationDistinctProjectionSegment result =
new AggregationDistinctProjectionSegment(ctx.getStart().getStartIndex(), ctx.getStop().getStopIndex(), type, getOriginalText(ctx), getDistinctExpression(ctx));
new AggregationDistinctProjectionSegment(ctx.getStart().getStartIndex(), ctx.getStop().getStopIndex(), type, getOriginalText(ctx), getDistinctExpression(ctx), separator);
result.getParameters().addAll(expressionSegments);
return result;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,11 @@
<assertion parameters="abc:String" expected-data-source-name="read_dataset" />
</test-case>

<test-case sql="SELECT GROUP_CONCAT(o.remark) as order_id_group_concat FROM t_order o where o.order_id > 1 - 1" db-types="MySQL" scenario-types="db,tbl,dbtbl_with_readwrite_splitting,readwrite_splitting,db_tbl_sql_federation">
<test-case sql="SELECT GROUP_CONCAT(o.remark) as order_id_group_concat FROM t_order o where o.order_id > 1 - 1" db-types="MySQL,openGauss" scenario-types="db,tbl,dbtbl_with_readwrite_splitting,readwrite_splitting,db_tbl_sql_federation">
<assertion expected-data-source-name="read_dataset" />
</test-case>

<test-case sql="SELECT GROUP_CONCAT(distinct o.remark SEPARATOR ' ') as order_id_group_concat FROM t_order o where o.order_id > 1 - 1" db-types="MySQL" scenario-types="db,tbl,dbtbl_with_readwrite_splitting,readwrite_splitting">
<test-case sql="SELECT GROUP_CONCAT(distinct o.remark SEPARATOR ' ') as order_id_group_concat FROM t_order o where o.order_id > 1 - 1" db-types="MySQL,openGauss" scenario-types="db,tbl,dbtbl_with_readwrite_splitting,readwrite_splitting">
<assertion expected-data-source-name="read_dataset" />
</test-case>
</e2e-test-cases>
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
<parameter>
<column name="status" start-index="20" stop-index="25" />
</parameter>
<parameter>
<column name="status" start-index="36" stop-index="41" />
</parameter>
</function>
</expr>
</expression-projection>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,6 @@
<sql-case id="select_approx_count"
value="select owner, approx_count(*) , approx_rank(partition by owner order by approx_count(*) desc) from t group by owner having approx_rank(partition by owner order by approx_count(*) desc) &lt;= 1 order by 1"
db-types="Oracle"/>
<sql-case id="select_group_concat" value="SELECT GROUP_CONCAT(user_id) AS user_id_group_concat FROM t_order" db-types="MySQL"/>
<sql-case id="select_group_concat_with_distinct_with_separator" value="SELECT GROUP_CONCAT(distinct user_id SEPARATOR ' ') AS user_id_group_concat FROM t_order" db-types="MySQL"/>
<sql-case id="select_group_concat" value="SELECT GROUP_CONCAT(user_id) AS user_id_group_concat FROM t_order" db-types="MySQL,Doris,openGauss"/>
<sql-case id="select_group_concat_with_distinct_with_separator" value="SELECT GROUP_CONCAT(distinct user_id SEPARATOR ' ') AS user_id_group_concat FROM t_order" db-types="MySQL,Doris,openGauss"/>
</sql-cases>
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
-->

<sql-cases>
<sql-case id="select_group_concat_with_order_by" value="SELECT GROUP_CONCAT(status ORDER BY status) FROM t_order" db-types="MySQL" />
<sql-case id="select_group_concat_with_order_by" value="SELECT GROUP_CONCAT(status ORDER BY status) FROM t_order" db-types="MySQL,Doris,openGauss" />
<sql-case id="select_window_function" value="SELECT order_id, ROW_NUMBER() OVER() FROM t_order" db-types="MySQL" />
<sql-case id="select_cast_function" value="SELECT CAST('1' AS UNSIGNED)" db-types="MySQL" />
<sql-case id="select_cast" value="SELECT CAST(c AT TIME ZONE 'UTC' AS DATETIME)" db-types="MySQL" />
Expand Down

0 comments on commit fac2bb6

Please sign in to comment.