Skip to content

Commit

Permalink
Support mysql cross join statement for sql federation (#28679)
Browse files Browse the repository at this point in the history
  • Loading branch information
strongduanmu authored Oct 8, 2023
1 parent ea57363 commit f7b4e22
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,17 @@ public Optional<SqlNode> convert(final JoinTableSegment segment) {
SqlNode right = new TableConverter().convert(segment.getRight()).orElseThrow(IllegalStateException::new);
Optional<SqlNode> condition = convertJoinCondition(segment);
SqlLiteral conditionType = convertConditionType(segment);
SqlLiteral joinType = JoinType.valueOf(segment.getJoinType()).symbol(SqlParserPos.ZERO);
SqlLiteral joinType = convertJoinType(segment);
return Optional.of(new SqlJoin(SqlParserPos.ZERO, left, SqlLiteral.createBoolean(segment.isNatural(), SqlParserPos.ZERO), joinType, right, conditionType, condition.orElse(null)));
}

private static SqlLiteral convertJoinType(final JoinTableSegment segment) {
if (JoinType.INNER.name().equals(segment.getJoinType()) && !segment.isNatural() && null == segment.getCondition() && segment.getUsing().isEmpty()) {
return JoinType.COMMA.symbol(SqlParserPos.ZERO);
}
return JoinType.valueOf(segment.getJoinType()).symbol(SqlParserPos.ZERO);
}

private SqlLiteral convertConditionType(final JoinTableSegment segment) {
if (!segment.getUsing().isEmpty()) {
return JoinConditionType.USING.symbol(SqlParserPos.ZERO);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import org.apache.shardingsphere.infra.database.core.metadata.database.system.SystemDatabase;
import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
import org.apache.shardingsphere.infra.database.opengauss.type.OpenGaussDatabaseType;
import org.apache.shardingsphere.infra.exception.core.external.sql.type.wrapper.SQLWrapperException;
import org.apache.shardingsphere.infra.executor.kernel.model.ExecutionGroup;
import org.apache.shardingsphere.infra.executor.kernel.model.ExecutionGroupContext;
import org.apache.shardingsphere.infra.executor.kernel.model.ExecutionGroupReportContext;
Expand All @@ -40,7 +39,6 @@
import org.apache.shardingsphere.infra.executor.sql.execute.engine.driver.jdbc.JDBCExecutorCallback;
import org.apache.shardingsphere.infra.executor.sql.execute.result.ExecuteResult;
import org.apache.shardingsphere.infra.executor.sql.execute.result.query.QueryResult;
import org.apache.shardingsphere.infra.executor.sql.execute.result.query.QueryResultMetaData;
import org.apache.shardingsphere.infra.executor.sql.prepare.driver.DriverExecutionPrepareEngine;
import org.apache.shardingsphere.infra.executor.sql.process.ProcessEngine;
import org.apache.shardingsphere.infra.hint.HintValueContext;
Expand Down Expand Up @@ -143,28 +141,33 @@ public Enumerable<Object> execute(final ShardingSphereTable table, final Enumera
federationContext.getExecutionUnits().addAll(context.getExecutionUnits());
return createEmptyEnumerable();
}
return execute(queryContext, database, context);
}

private AbstractEnumerable<Object> execute(final QueryContext queryContext, final ShardingSphereDatabase database, final ExecutionContext context) {
try {
computeConnectionOffsets(context);
ExecutionGroupContext<JDBCExecutionUnit> executionGroupContext =
prepareEngine.prepare(context.getRouteContext(), executorContext.getConnectionOffsets(), context.getExecutionUnits(), new ExecutionGroupReportContext(database.getName()));
setParameters(executionGroupContext.getInputGroups());
processEngine.executeSQL(executionGroupContext, context.getQueryContext());
List<QueryResult> queryResults = jdbcExecutor.execute(executionGroupContext, callback).stream().map(QueryResult.class::cast).collect(Collectors.toList());
MergeEngine mergeEngine = new MergeEngine(database, executorContext.getProps(), new ConnectionContext());
MergedResult mergedResult = mergeEngine.merge(queryResults, queryContext.getSqlStatementContext());
Collection<Statement> statements = getStatements(executionGroupContext.getInputGroups());
return createEnumerable(mergedResult, queryResults.get(0).getMetaData(), statements);
} catch (final SQLException ex) {
throw new SQLWrapperException(ex);
return createEnumerable(queryContext, database, context);
} finally {
processEngine.completeSQLExecution();
}
}

private AbstractEnumerable<Object> createEnumerable(final QueryContext queryContext, final ShardingSphereDatabase database, final ExecutionContext context) {
return new AbstractEnumerable<Object>() {

@SneakyThrows
@Override
public Enumerator<Object> enumerator() {
computeConnectionOffsets(context);
ExecutionGroupContext<JDBCExecutionUnit> executionGroupContext =
prepareEngine.prepare(context.getRouteContext(), executorContext.getConnectionOffsets(), context.getExecutionUnits(), new ExecutionGroupReportContext(database.getName()));
setParameters(executionGroupContext.getInputGroups());
processEngine.executeSQL(executionGroupContext, context.getQueryContext());
List<QueryResult> queryResults = jdbcExecutor.execute(executionGroupContext, callback).stream().map(QueryResult.class::cast).collect(Collectors.toList());
MergeEngine mergeEngine = new MergeEngine(database, executorContext.getProps(), new ConnectionContext());
MergedResult mergedResult = mergeEngine.merge(queryResults, queryContext.getSqlStatementContext());
Collection<Statement> statements = getStatements(executionGroupContext.getInputGroups());
return new SQLFederationRowEnumerator(mergedResult, queryResults.get(0).getMetaData(), statements);
}
};
}

private void computeConnectionOffsets(final ExecutionContext context) {
for (ExecutionUnit each : context.getExecutionUnits()) {
if (executorContext.getConnectionOffsets().containsKey(each.getDataSourceName())) {
Expand Down Expand Up @@ -267,16 +270,6 @@ private void setParameters(final PreparedStatement preparedStatement, final List
}
}

private AbstractEnumerable<Object> createEnumerable(final MergedResult mergedResult, final QueryResultMetaData metaData, final Collection<Statement> statements) {
return new AbstractEnumerable<Object>() {

@Override
public Enumerator<Object> enumerator() {
return new SQLFederationRowEnumerator(mergedResult, metaData, statements);
}
};
}

private QueryContext createQueryContext(final ShardingSphereMetaData metaData, final EnumerableScanExecutorContext sqlString, final DatabaseType databaseType, final boolean useCache) {
String sql = sqlString.getSql().replace("\n", " ");
SQLStatement sqlStatement = new SQLStatementParserEngine(databaseType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@
<assertion parameters="10:int" expected-data-source-name="read_dataset" />
</test-case>

<!-- FIXME: In MySQL, JOIN, CROSS JOIN, and INNER JOIN are syntactic equivalents -->
<test-case sql="SELECT * FROM t_order o CROSS JOIN t_order_item i WHERE o.user_id = ? ORDER BY o.order_id, 7 LIMIT 10, 10" db-types="openGauss" scenario-types="db">
<test-case sql="SELECT * FROM t_order o CROSS JOIN t_order_item i WHERE o.user_id = ? ORDER BY o.order_id, 7 LIMIT 10, 10" db-types="MySQL,openGauss" scenario-types="db">
<assertion parameters="10:int" expected-data-source-name="read_dataset" />
</test-case>

Expand All @@ -46,7 +45,7 @@
<assertion parameters="10:int" expected-data-source-name="read_dataset" />
</test-case>

<test-case sql="SELECT * FROM t_order o CROSS JOIN t_merchant m WHERE o.user_id = ? ORDER BY o.order_id, 7 LIMIT 10, 10" db-types="openGauss" scenario-types="db">
<test-case sql="SELECT * FROM t_order o CROSS JOIN t_merchant m WHERE o.user_id = ? ORDER BY o.order_id, 7 LIMIT 10, 10" db-types="MySQL,openGauss" scenario-types="db">
<assertion parameters="10:int" expected-data-source-name="read_dataset" />
</test-case>

Expand All @@ -62,7 +61,7 @@
<assertion parameters="10:int" expected-data-source-name="read_dataset" />
</test-case>

<test-case sql="SELECT * FROM t_product p CROSS JOIN t_product_detail d WHERE p.product_id = ? ORDER BY d.product_id, 7 LIMIT 10, 10" db-types="openGauss" scenario-types="db">
<test-case sql="SELECT * FROM t_product p CROSS JOIN t_product_detail d WHERE p.product_id = ? ORDER BY d.product_id, 7 LIMIT 10, 10" db-types="MySQL,openGauss" scenario-types="db">
<assertion parameters="10:int" expected-data-source-name="read_dataset" />
</test-case>

Expand Down

0 comments on commit f7b4e22

Please sign in to comment.