Skip to content

Commit

Permalink
Move execute to DriverExecutor
Browse files Browse the repository at this point in the history
  • Loading branch information
terrymanu committed May 31, 2024
1 parent 1445d2f commit 86e8c59
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 161 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.shardingsphere.driver.executor;

import lombok.Getter;
import org.apache.shardingsphere.driver.executor.callback.ExecuteCallback;
import org.apache.shardingsphere.driver.executor.callback.ExecuteQueryCallback;
import org.apache.shardingsphere.driver.executor.callback.ExecuteUpdateCallback;
import org.apache.shardingsphere.driver.executor.callback.impl.PreparedStatementExecuteQueryCallback;
Expand Down Expand Up @@ -361,24 +362,76 @@ private int accumulate(final Collection<ExecuteResult> results) {
* @param queryContext query context
* @param prepareEngine prepare engine
* @param trafficCallback traffic callback
* @param isNeedImplicitCommitTransaction is need implicit commit transaction
* @param executeCallback execute callback
* @param statementReplayCallback statement replay callback
* @return execute result
* @throws SQLException SQL exception
*/
public Optional<Boolean> executeAdvance(final ShardingSphereMetaData metaData, final ShardingSphereDatabase database,
@SuppressWarnings("rawtypes")
public boolean executeAdvance(final ShardingSphereMetaData metaData, final ShardingSphereDatabase database,
final QueryContext queryContext, final DriverExecutionPrepareEngine<JDBCExecutionUnit, Connection> prepareEngine,
final TrafficExecutorCallback<Boolean> trafficCallback) throws SQLException {
final TrafficExecutorCallback<Boolean> trafficCallback, final boolean isNeedImplicitCommitTransaction,
final ExecuteCallback executeCallback, final StatementReplayCallback statementReplayCallback) throws SQLException {
Optional<String> trafficInstanceId = connection.getTrafficInstanceId(metaData.getGlobalRuleMetaData().getSingleRule(TrafficRule.class), queryContext);
if (trafficInstanceId.isPresent()) {
executeType = ExecuteType.TRAFFIC;
return Optional.of(trafficExecutor.execute(connection.getProcessId(), database.getName(), trafficInstanceId.get(), queryContext, prepareEngine, trafficCallback));
return trafficExecutor.execute(connection.getProcessId(), database.getName(), trafficInstanceId.get(), queryContext, prepareEngine, trafficCallback);
}
if (sqlFederationEngine.decide(queryContext.getSqlStatementContext(), queryContext.getParameters(), database, metaData.getGlobalRuleMetaData())) {
executeType = ExecuteType.FEDERATION;
ResultSet resultSet = sqlFederationEngine.executeQuery(
prepareEngine, getExecuteQueryCallback(database, queryContext, prepareEngine.getType()), new SQLFederationContext(false, queryContext, metaData, connection.getProcessId()));
return Optional.of(null != resultSet);
return null != resultSet;
}
ExecutionContext executionContext = createExecutionContext(metaData, database, queryContext);
if (!database.getRuleMetaData().getAttributes(RawExecutionRuleAttribute.class).isEmpty()) {
Collection<ExecuteResult> results = rawExecutor.execute(createRawExecutionGroupContext(metaData, database, executionContext), queryContext, new RawSQLExecutorCallback());
return results.iterator().next() instanceof QueryResult;
}
return executeWithExecutionContext(database, executeCallback, executionContext, prepareEngine, isNeedImplicitCommitTransaction, statementReplayCallback);
}

@SuppressWarnings("rawtypes")
private boolean executeWithExecutionContext(final ShardingSphereDatabase database, final ExecuteCallback executeCallback, final ExecutionContext executionContext,
final DriverExecutionPrepareEngine<JDBCExecutionUnit, Connection> prepareEngine,
final boolean isNeedImplicitCommitTransaction, final StatementReplayCallback statementReplayCallback) throws SQLException {
return isNeedImplicitCommitTransaction
? executeWithImplicitCommitTransaction(() -> useDriverToExecute(database, executeCallback, executionContext, prepareEngine, statementReplayCallback), connection,
database.getProtocolType())
: useDriverToExecute(database, executeCallback, executionContext, prepareEngine, statementReplayCallback);
}

@SuppressWarnings({"rawtypes", "unchecked"})
private boolean useDriverToExecute(final ShardingSphereDatabase database, final ExecuteCallback callback, final ExecutionContext executionContext,
final DriverExecutionPrepareEngine<JDBCExecutionUnit, Connection> prepareEngine, final StatementReplayCallback statementReplayCallback) throws SQLException {
ExecutionGroupContext<JDBCExecutionUnit> executionGroupContext = createExecutionGroupContext(database, executionContext, prepareEngine);
for (ExecutionGroup<JDBCExecutionUnit> each : executionGroupContext.getInputGroups()) {
statements.addAll(getStatements(each));
if (JDBCDriverType.PREPARED_STATEMENT.equals(prepareEngine.getType())) {
parameterSets.addAll(getParameterSets(each));
}
}
return Optional.empty();
statementReplayCallback.replay(statements, parameterSets);
JDBCExecutorCallback<Boolean> jdbcExecutorCallback = createExecuteCallback(database, callback, executionContext.getSqlStatementContext().getSqlStatement(), prepareEngine.getType());
return regularExecutor.execute(executionGroupContext, executionContext.getQueryContext(), executionContext.getRouteContext().getRouteUnits(), jdbcExecutorCallback);
}

private JDBCExecutorCallback<Boolean> createExecuteCallback(final ShardingSphereDatabase database, final ExecuteCallback executeCallback,
final SQLStatement sqlStatement, final String jdbcDriverType) {
boolean isExceptionThrown = SQLExecutorExceptionHandler.isExceptionThrown();
return new JDBCExecutorCallback<Boolean>(database.getProtocolType(), database.getResourceMetaData(), sqlStatement, isExceptionThrown) {

@Override
protected Boolean executeSQL(final String sql, final Statement statement, final ConnectionMode connectionMode, final DatabaseType storageType) throws SQLException {
return JDBCDriverType.STATEMENT.equals(jdbcDriverType) ? executeCallback.execute(sql, statement) : ((PreparedStatement) statement).execute();
}

@Override
protected Optional<Boolean> getSaneResult(final SQLStatement sqlStatement1, final SQLException ex) {
return Optional.empty();
}
};
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,31 +38,21 @@
import org.apache.shardingsphere.infra.binder.engine.SQLBindEngine;
import org.apache.shardingsphere.infra.config.props.ConfigurationPropertyKey;
import org.apache.shardingsphere.infra.connection.kernel.KernelProcessor;
import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
import org.apache.shardingsphere.infra.database.mysql.type.MySQLDatabaseType;
import org.apache.shardingsphere.infra.exception.dialect.SQLExceptionTransformEngine;
import org.apache.shardingsphere.infra.exception.kernel.syntax.EmptySQLException;
import org.apache.shardingsphere.infra.executor.audit.SQLAuditEngine;
import org.apache.shardingsphere.infra.executor.kernel.model.ExecutionGroup;
import org.apache.shardingsphere.infra.executor.kernel.model.ExecutionGroupContext;
import org.apache.shardingsphere.infra.executor.kernel.model.ExecutionGroupReportContext;
import org.apache.shardingsphere.infra.executor.sql.context.ExecutionContext;
import org.apache.shardingsphere.infra.executor.sql.context.ExecutionUnit;
import org.apache.shardingsphere.infra.executor.sql.context.SQLUnit;
import org.apache.shardingsphere.infra.executor.sql.execute.engine.ConnectionMode;
import org.apache.shardingsphere.infra.executor.sql.execute.engine.SQLExecutorExceptionHandler;
import org.apache.shardingsphere.infra.executor.sql.execute.engine.driver.jdbc.JDBCExecutionUnit;
import org.apache.shardingsphere.infra.executor.sql.execute.engine.driver.jdbc.JDBCExecutor;
import org.apache.shardingsphere.infra.executor.sql.execute.engine.driver.jdbc.JDBCExecutorCallback;
import org.apache.shardingsphere.infra.executor.sql.execute.engine.raw.RawSQLExecutionUnit;
import org.apache.shardingsphere.infra.executor.sql.execute.engine.raw.callback.RawSQLExecutorCallback;
import org.apache.shardingsphere.infra.executor.sql.execute.result.ExecuteResult;
import org.apache.shardingsphere.infra.executor.sql.execute.result.query.QueryResult;
import org.apache.shardingsphere.infra.executor.sql.execute.result.query.impl.driver.jdbc.type.stream.JDBCStreamQueryResult;
import org.apache.shardingsphere.infra.executor.sql.prepare.driver.DriverExecutionPrepareEngine;
import org.apache.shardingsphere.infra.executor.sql.prepare.driver.jdbc.JDBCDriverType;
import org.apache.shardingsphere.infra.executor.sql.prepare.driver.jdbc.StatementOption;
import org.apache.shardingsphere.infra.executor.sql.prepare.raw.RawExecutionPrepareEngine;
import org.apache.shardingsphere.infra.hint.HintManager;
import org.apache.shardingsphere.infra.hint.HintValueContext;
import org.apache.shardingsphere.infra.hint.SQLHintUtils;
Expand All @@ -74,7 +64,6 @@
import org.apache.shardingsphere.infra.parser.SQLParserEngine;
import org.apache.shardingsphere.infra.route.context.RouteContext;
import org.apache.shardingsphere.infra.rule.attribute.datanode.DataNodeRuleAttribute;
import org.apache.shardingsphere.infra.rule.attribute.raw.RawExecutionRuleAttribute;
import org.apache.shardingsphere.infra.rule.attribute.resoure.StorageConnectorReusableRuleAttribute;
import org.apache.shardingsphere.infra.session.query.QueryContext;
import org.apache.shardingsphere.mode.metadata.MetaDataContexts;
Expand Down Expand Up @@ -115,8 +104,6 @@ public final class ShardingSpherePreparedStatement extends AbstractPreparedState

private final List<List<Object>> parameterSets;

private final SQLStatement sqlStatement;

private final SQLStatementContext sqlStatementContext;

private final String databaseName;
Expand Down Expand Up @@ -190,7 +177,7 @@ private ShardingSpherePreparedStatement(final ShardingSphereConnection connectio
parameterSets = new ArrayList<>();
SQLParserRule sqlParserRule = metaDataContexts.getMetaData().getGlobalRuleMetaData().getSingleRule(SQLParserRule.class);
SQLParserEngine sqlParserEngine = sqlParserRule.getSQLParserEngine(metaDataContexts.getMetaData().getDatabase(connection.getDatabaseName()).getProtocolType());
sqlStatement = sqlParserEngine.parse(this.sql, true);
SQLStatement sqlStatement = sqlParserEngine.parse(this.sql, true);
sqlStatementContext = new SQLBindEngine(metaDataContexts.getMetaData(), connection.getDatabaseName(), hintValueContext).bind(sqlStatement, Collections.emptyList());
databaseName = sqlStatementContext.getTablesContext().getDatabaseName().orElse(connection.getDatabaseName());
connection.getDatabaseConnectionManager().getConnectionContext().setCurrentDatabase(databaseName);
Expand Down Expand Up @@ -243,10 +230,6 @@ public ResultSet executeQuery() throws SQLException {
}
}

private boolean hasRawExecutionRule() {
return !metaDataContexts.getMetaData().getDatabase(databaseName).getRuleMetaData().getAttributes(RawExecutionRuleAttribute.class).isEmpty();
}

private void handleAutoCommit(final SQLStatement sqlStatement) throws SQLException {
if (AutoCommitUtils.needOpenTransaction(sqlStatement)) {
connection.handleAutoCommit();
Expand Down Expand Up @@ -307,18 +290,16 @@ public boolean execute() throws SQLException {
QueryContext queryContext = createQueryContext();
handleAutoCommit(queryContext.getSqlStatementContext().getSqlStatement());
ShardingSphereDatabase database = metaDataContexts.getMetaData().getDatabase(databaseName);
Optional<Boolean> advancedResult = executor.executeAdvance(
metaDataContexts.getMetaData(), database, queryContext, createDriverExecutionPrepareEngine(database), (statement, sql) -> ((PreparedStatement) statement).execute());
if (advancedResult.isPresent()) {
return advancedResult.get();
}
ExecutionContext executionContext = createExecutionContext(queryContext);
if (hasRawExecutionRule()) {
Collection<ExecuteResult> results =
executor.getRawExecutor().execute(createRawExecutionGroupContext(executionContext), executionContext.getQueryContext(), new RawSQLExecutorCallback());
return results.iterator().next() instanceof QueryResult;
boolean isNeedImplicitCommitTransaction = isNeedImplicitCommitTransaction(connection, sqlStatementContext.getSqlStatement(), executionContext.getExecutionUnits().size() > 1);
boolean result = executor.executeAdvance(
metaDataContexts.getMetaData(), database, queryContext, createDriverExecutionPrepareEngine(database), (statement, sql) -> ((PreparedStatement) statement).execute(),
isNeedImplicitCommitTransaction, null, (StatementReplayCallback<PreparedStatement>) this::replay);
for (Statement each : executor.getStatements()) {
statements.add((PreparedStatement) each);
}
return executeWithExecutionContext(executionContext);
parameterSets.addAll(executor.getParameterSets());
return result;
// CHECKSTYLE:OFF
} catch (final RuntimeException ex) {
// CHECKSTYLE:ON
Expand All @@ -329,43 +310,6 @@ public boolean execute() throws SQLException {
}
}

private boolean executeWithExecutionContext(final ExecutionContext executionContext) throws SQLException {
return isNeedImplicitCommitTransaction(connection, sqlStatementContext.getSqlStatement(), executionContext.getExecutionUnits().size() > 1)
? executeWithImplicitCommitTransaction(() -> useDriverToExecute(executionContext), connection, metaDataContexts.getMetaData().getDatabase(databaseName).getProtocolType())
: useDriverToExecute(executionContext);
}

private boolean useDriverToExecute(final ExecutionContext executionContext) throws SQLException {
ExecutionGroupContext<JDBCExecutionUnit> executionGroupContext = createExecutionGroupContext(executionContext);
cacheStatements(executionGroupContext.getInputGroups());
return executor.getRegularExecutor().execute(executionGroupContext,
executionContext.getQueryContext(), executionContext.getRouteContext().getRouteUnits(), createExecuteCallback());
}

private JDBCExecutorCallback<Boolean> createExecuteCallback() {
boolean isExceptionThrown = SQLExecutorExceptionHandler.isExceptionThrown();
return new JDBCExecutorCallback<Boolean>(metaDataContexts.getMetaData().getDatabase(databaseName).getProtocolType(),
metaDataContexts.getMetaData().getDatabase(databaseName).getResourceMetaData(), sqlStatement, isExceptionThrown) {

@Override
protected Boolean executeSQL(final String sql, final Statement statement, final ConnectionMode connectionMode, final DatabaseType storageType) throws SQLException {
return ((PreparedStatement) statement).execute();
}

@Override
protected Optional<Boolean> getSaneResult(final SQLStatement sqlStatement, final SQLException ex) {
return Optional.empty();
}
};
}

private ExecutionGroupContext<JDBCExecutionUnit> createExecutionGroupContext(final ExecutionContext executionContext) throws SQLException {
ShardingSphereDatabase database = metaDataContexts.getMetaData().getDatabase(databaseName);
DriverExecutionPrepareEngine<JDBCExecutionUnit, Connection> prepareEngine = createDriverExecutionPrepareEngine(database);
return prepareEngine.prepare(executionContext.getRouteContext(), executionContext.getExecutionUnits(),
new ExecutionGroupReportContext(connection.getProcessId(), databaseName, new Grantee("", "")));
}

@Override
public ResultSet getResultSet() throws SQLException {
if (null != currentResultSet) {
Expand Down Expand Up @@ -424,12 +368,6 @@ private ExecutionContext createExecutionContext(final QueryContext queryContext,
return new ExecutionContext(queryContext, Collections.singletonList(executionUnit), new RouteContext());
}

private ExecutionGroupContext<RawSQLExecutionUnit> createRawExecutionGroupContext(final ExecutionContext executionContext) throws SQLException {
int maxConnectionsSizePerQuery = metaDataContexts.getMetaData().getProps().<Integer>getValue(ConfigurationPropertyKey.MAX_CONNECTIONS_SIZE_PER_QUERY);
return new RawExecutionPrepareEngine(maxConnectionsSizePerQuery, metaDataContexts.getMetaData().getDatabase(databaseName).getRuleMetaData().getRules())
.prepare(executionContext.getRouteContext(), executionContext.getExecutionUnits(), new ExecutionGroupReportContext(connection.getProcessId(), databaseName, new Grantee("", "")));
}

private QueryContext createQueryContext() {
List<Object> params = new ArrayList<>(getParameters());
if (sqlStatementContext instanceof ParameterAware) {
Expand All @@ -444,16 +382,6 @@ private MergedResult mergeQuery(final List<QueryResult> queryResults, final SQLS
return mergeEngine.merge(queryResults, sqlStatementContext);
}

private void cacheStatements(final Collection<ExecutionGroup<JDBCExecutionUnit>> executionGroups) throws SQLException {
for (ExecutionGroup<JDBCExecutionUnit> each : executionGroups) {
each.getInputs().forEach(eachInput -> {
statements.add((PreparedStatement) eachInput.getStorageResource());
parameterSets.add(eachInput.getExecutionUnit().getSqlUnit().getParameters());
});
}
replay(statements, parameterSets);
}

private void replay(final List<PreparedStatement> statements, final List<List<Object>> parameterSets) throws SQLException {
replaySetParameter(statements, parameterSets);
for (Statement each : statements) {
Expand Down
Loading

0 comments on commit 86e8c59

Please sign in to comment.