Skip to content

Commit

Permalink
Move executeUpdate to DriverExecutor (#31452)
Browse files Browse the repository at this point in the history
* Move executeUpdate to DriverExecutor

* Move executeUpdate to DriverExecutor
  • Loading branch information
terrymanu authored May 31, 2024
1 parent aa5ea90 commit 1118bcd
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 102 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import lombok.Getter;
import org.apache.shardingsphere.driver.executor.callback.ExecuteQueryCallback;
import org.apache.shardingsphere.driver.executor.callback.ExecuteUpdateCallback;
import org.apache.shardingsphere.driver.executor.callback.impl.PreparedStatementExecuteQueryCallback;
import org.apache.shardingsphere.driver.executor.callback.impl.StatementExecuteQueryCallback;
import org.apache.shardingsphere.driver.jdbc.core.connection.ShardingSphereConnection;
Expand All @@ -29,20 +30,26 @@
import org.apache.shardingsphere.infra.binder.context.statement.dml.SelectStatementContext;
import org.apache.shardingsphere.infra.config.props.ConfigurationPropertyKey;
import org.apache.shardingsphere.infra.connection.kernel.KernelProcessor;
import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
import org.apache.shardingsphere.infra.database.core.type.DatabaseTypeRegistry;
import org.apache.shardingsphere.infra.exception.dialect.SQLExceptionTransformEngine;
import org.apache.shardingsphere.infra.executor.audit.SQLAuditEngine;
import org.apache.shardingsphere.infra.executor.kernel.ExecutorEngine;
import org.apache.shardingsphere.infra.executor.kernel.model.ExecutionGroup;
import org.apache.shardingsphere.infra.executor.kernel.model.ExecutionGroupContext;
import org.apache.shardingsphere.infra.executor.kernel.model.ExecutionGroupReportContext;
import org.apache.shardingsphere.infra.executor.sql.context.ExecutionContext;
import org.apache.shardingsphere.infra.executor.sql.execute.engine.ConnectionMode;
import org.apache.shardingsphere.infra.executor.sql.execute.engine.SQLExecutorExceptionHandler;
import org.apache.shardingsphere.infra.executor.sql.execute.engine.driver.jdbc.JDBCExecutionUnit;
import org.apache.shardingsphere.infra.executor.sql.execute.engine.driver.jdbc.JDBCExecutor;
import org.apache.shardingsphere.infra.executor.sql.execute.engine.driver.jdbc.JDBCExecutorCallback;
import org.apache.shardingsphere.infra.executor.sql.execute.engine.raw.RawExecutor;
import org.apache.shardingsphere.infra.executor.sql.execute.engine.raw.RawSQLExecutionUnit;
import org.apache.shardingsphere.infra.executor.sql.execute.engine.raw.callback.RawSQLExecutorCallback;
import org.apache.shardingsphere.infra.executor.sql.execute.result.ExecuteResult;
import org.apache.shardingsphere.infra.executor.sql.execute.result.query.QueryResult;
import org.apache.shardingsphere.infra.executor.sql.execute.result.update.UpdateResult;
import org.apache.shardingsphere.infra.executor.sql.prepare.driver.DriverExecutionPrepareEngine;
import org.apache.shardingsphere.infra.executor.sql.prepare.driver.jdbc.JDBCDriverType;
import org.apache.shardingsphere.infra.executor.sql.prepare.raw.RawExecutionPrepareEngine;
Expand All @@ -55,11 +62,13 @@
import org.apache.shardingsphere.infra.rule.attribute.raw.RawExecutionRuleAttribute;
import org.apache.shardingsphere.infra.session.query.QueryContext;
import org.apache.shardingsphere.mode.metadata.MetaDataContexts;
import org.apache.shardingsphere.sql.parser.sql.common.statement.SQLStatement;
import org.apache.shardingsphere.sqlfederation.engine.SQLFederationEngine;
import org.apache.shardingsphere.sqlfederation.executor.context.SQLFederationContext;
import org.apache.shardingsphere.traffic.executor.TrafficExecutor;
import org.apache.shardingsphere.traffic.executor.TrafficExecutorCallback;
import org.apache.shardingsphere.traffic.rule.TrafficRule;
import org.apache.shardingsphere.transaction.implicit.ImplicitTransactionCallback;

import java.sql.Connection;
import java.sql.PreparedStatement;
Expand Down Expand Up @@ -249,17 +258,99 @@ private List<ResultSet> getResultSets() throws SQLException {
* @param queryContext query context
* @param prepareEngine prepare engine
* @param trafficCallback traffic callback
* @param updateCallback update callback
* @param isNeedImplicitCommitTransaction is need implicit commit transaction
* @param statementReplayCallback statement replay callback
* @param executionContext execution context
* @return updated row count
* @throws SQLException SQL exception
*/
public Optional<Integer> executeAdvanceUpdate(final ShardingSphereMetaData metaData, final ShardingSphereDatabase database, final QueryContext queryContext,
final DriverExecutionPrepareEngine<JDBCExecutionUnit, Connection> prepareEngine,
final TrafficExecutorCallback<Integer> trafficCallback) throws SQLException {
@SuppressWarnings("rawtypes")
public int executeAdvanceUpdate(final ShardingSphereMetaData metaData, final ShardingSphereDatabase database, final QueryContext queryContext,
final DriverExecutionPrepareEngine<JDBCExecutionUnit, Connection> prepareEngine, final TrafficExecutorCallback<Integer> trafficCallback,
final ExecuteUpdateCallback updateCallback, final boolean isNeedImplicitCommitTransaction,
final StatementReplayCallback statementReplayCallback, final ExecutionContext executionContext) throws SQLException {
Optional<String> trafficInstanceId = connection.getTrafficInstanceId(metaData.getGlobalRuleMetaData().getSingleRule(TrafficRule.class), queryContext);
if (trafficInstanceId.isPresent()) {
return Optional.of(trafficExecutor.execute(connection.getProcessId(), database.getName(), trafficInstanceId.get(), queryContext, prepareEngine, trafficCallback));
return trafficExecutor.execute(connection.getProcessId(), database.getName(), trafficInstanceId.get(), queryContext, prepareEngine, trafficCallback);
}
return Optional.empty();
return database.getRuleMetaData().getAttributes(RawExecutionRuleAttribute.class).isEmpty()
? executeUpdate(database, updateCallback, queryContext.getSqlStatementContext(), executionContext, prepareEngine, isNeedImplicitCommitTransaction, statementReplayCallback)
: accumulate(rawExecutor.execute(createRawExecutionGroupContext(metaData, database, executionContext), queryContext, new RawSQLExecutorCallback()));
}

@SuppressWarnings("rawtypes")
private int executeUpdate(final ShardingSphereDatabase database, final ExecuteUpdateCallback updateCallback, final SQLStatementContext sqlStatementContext, final ExecutionContext executionContext,
final DriverExecutionPrepareEngine<JDBCExecutionUnit, Connection> prepareEngine, final boolean isNeedImplicitCommitTransaction,
final StatementReplayCallback statementReplayCallback) throws SQLException {
return isNeedImplicitCommitTransaction
? executeWithImplicitCommitTransaction(() -> useDriverToExecuteUpdate(
database, updateCallback, sqlStatementContext, executionContext, prepareEngine, statementReplayCallback), connection, database.getProtocolType())
: useDriverToExecuteUpdate(database, updateCallback, sqlStatementContext, executionContext, prepareEngine, statementReplayCallback);
}

private <T> T executeWithImplicitCommitTransaction(final ImplicitTransactionCallback<T> callback, final Connection connection, final DatabaseType databaseType) throws SQLException {
T result;
try {
connection.setAutoCommit(false);
result = callback.execute();
connection.commit();
// CHECKSTYLE:OFF
} catch (final Exception ex) {
// CHECKSTYLE:ON
connection.rollback();
throw SQLExceptionTransformEngine.toSQLException(ex, databaseType);
} finally {
connection.setAutoCommit(true);
}
return result;
}

@SuppressWarnings({"rawtypes", "unchecked"})
private int useDriverToExecuteUpdate(final ShardingSphereDatabase database, final ExecuteUpdateCallback updateCallback, final SQLStatementContext sqlStatementContext,
final ExecutionContext executionContext, final DriverExecutionPrepareEngine<JDBCExecutionUnit, Connection> prepareEngine,
final StatementReplayCallback statementReplayCallback) throws SQLException {
ExecutionGroupContext<JDBCExecutionUnit> executionGroupContext = createExecutionGroupContext(database, executionContext, prepareEngine);
for (ExecutionGroup<JDBCExecutionUnit> each : executionGroupContext.getInputGroups()) {
statements.addAll(getStatements(each));
if (JDBCDriverType.PREPARED_STATEMENT.equals(prepareEngine.getType())) {
parameterSets.addAll(getParameterSets(each));
}
}
statementReplayCallback.replay(statements, parameterSets);
JDBCExecutorCallback<Integer> callback = createExecuteUpdateCallback(database, updateCallback, sqlStatementContext, prepareEngine.getType());
return regularExecutor.executeUpdate(executionGroupContext, executionContext.getQueryContext(), executionContext.getRouteContext().getRouteUnits(), callback);
}

private ExecutionGroupContext<JDBCExecutionUnit> createExecutionGroupContext(final ShardingSphereDatabase database, final ExecutionContext executionContext,
final DriverExecutionPrepareEngine<JDBCExecutionUnit, Connection> prepareEngine) throws SQLException {
return prepareEngine.prepare(executionContext.getRouteContext(), executionContext.getExecutionUnits(),
new ExecutionGroupReportContext(connection.getProcessId(), database.getName(), new Grantee("", "")));
}

private JDBCExecutorCallback<Integer> createExecuteUpdateCallback(final ShardingSphereDatabase database,
final ExecuteUpdateCallback updateCallback, final SQLStatementContext sqlStatementContext, final String jdbcDriverType) {
boolean isExceptionThrown = SQLExecutorExceptionHandler.isExceptionThrown();
return new JDBCExecutorCallback<Integer>(database.getProtocolType(), database.getResourceMetaData(), sqlStatementContext.getSqlStatement(), isExceptionThrown) {

@Override
protected Integer executeSQL(final String sql, final Statement statement, final ConnectionMode connectionMode, final DatabaseType storageType) throws SQLException {
return JDBCDriverType.STATEMENT.equals(jdbcDriverType) ? updateCallback.executeUpdate(sql, statement) : ((PreparedStatement) statement).executeUpdate();
}

@Override
protected Optional<Integer> getSaneResult(final SQLStatement sqlStatement, final SQLException ex) {
return Optional.empty();
}
};
}

private int accumulate(final Collection<ExecuteResult> results) {
int result = 0;
for (ExecuteResult each : results) {
result += ((UpdateResult) each).getUpdateCount();
}
return result;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@
import org.apache.shardingsphere.infra.executor.sql.execute.result.ExecuteResult;
import org.apache.shardingsphere.infra.executor.sql.execute.result.query.QueryResult;
import org.apache.shardingsphere.infra.executor.sql.execute.result.query.impl.driver.jdbc.type.stream.JDBCStreamQueryResult;
import org.apache.shardingsphere.infra.executor.sql.execute.result.update.UpdateResult;
import org.apache.shardingsphere.infra.executor.sql.prepare.driver.DriverExecutionPrepareEngine;
import org.apache.shardingsphere.infra.executor.sql.prepare.driver.jdbc.JDBCDriverType;
import org.apache.shardingsphere.infra.executor.sql.prepare.driver.jdbc.StatementOption;
Expand Down Expand Up @@ -277,18 +276,16 @@ public int executeUpdate() throws SQLException {
QueryContext queryContext = createQueryContext();
handleAutoCommit(queryContext.getSqlStatementContext().getSqlStatement());
ShardingSphereDatabase database = metaDataContexts.getMetaData().getDatabase(databaseName);
Optional<Integer> updatedCount = executor.executeAdvanceUpdate(metaDataContexts.getMetaData(), database, queryContext, createDriverExecutionPrepareEngine(database),
(statement, sql) -> ((PreparedStatement) statement).executeUpdate());
if (updatedCount.isPresent()) {
return updatedCount.get();
}
ExecutionContext executionContext = createExecutionContext(queryContext);
if (hasRawExecutionRule()) {
Collection<ExecuteResult> results =
executor.getRawExecutor().execute(createRawExecutionGroupContext(executionContext), executionContext.getQueryContext(), new RawSQLExecutorCallback());
return accumulate(results);
boolean isNeedImplicitCommitTransaction = isNeedImplicitCommitTransaction(connection, sqlStatementContext.getSqlStatement(), executionContext.getExecutionUnits().size() > 1);
final int result = executor.executeAdvanceUpdate(metaDataContexts.getMetaData(), database, queryContext, createDriverExecutionPrepareEngine(database),
(statement, sql) -> ((PreparedStatement) statement).executeUpdate(), null, isNeedImplicitCommitTransaction, (StatementReplayCallback<PreparedStatement>) this::replay,
executionContext);
for (Statement each : executor.getStatements()) {
statements.add((PreparedStatement) each);
}
return executeUpdateWithExecutionContext(executionContext);
parameterSets.addAll(executor.getParameterSets());
return result;
// CHECKSTYLE:OFF
} catch (final RuntimeException ex) {
// CHECKSTYLE:ON
Expand All @@ -299,38 +296,6 @@ public int executeUpdate() throws SQLException {
}
}

private int useDriverToExecuteUpdate(final ExecutionContext executionContext) throws SQLException {
ExecutionGroupContext<JDBCExecutionUnit> executionGroupContext = createExecutionGroupContext(executionContext);
cacheStatements(executionGroupContext.getInputGroups());
return executor.getRegularExecutor().executeUpdate(executionGroupContext,
executionContext.getQueryContext(), executionContext.getRouteContext().getRouteUnits(), createExecuteUpdateCallback());
}

private int accumulate(final Collection<ExecuteResult> results) {
int result = 0;
for (ExecuteResult each : results) {
result += ((UpdateResult) each).getUpdateCount();
}
return result;
}

private JDBCExecutorCallback<Integer> createExecuteUpdateCallback() {
boolean isExceptionThrown = SQLExecutorExceptionHandler.isExceptionThrown();
return new JDBCExecutorCallback<Integer>(metaDataContexts.getMetaData().getDatabase(databaseName).getProtocolType(),
metaDataContexts.getMetaData().getDatabase(databaseName).getResourceMetaData(), sqlStatement, isExceptionThrown) {

@Override
protected Integer executeSQL(final String sql, final Statement statement, final ConnectionMode connectionMode, final DatabaseType storageType) throws SQLException {
return ((PreparedStatement) statement).executeUpdate();
}

@Override
protected Optional<Integer> getSaneResult(final SQLStatement sqlStatement, final SQLException ex) {
return Optional.empty();
}
};
}

@Override
public boolean execute() throws SQLException {
try {
Expand Down Expand Up @@ -370,12 +335,6 @@ private boolean executeWithExecutionContext(final ExecutionContext executionCont
: useDriverToExecute(executionContext);
}

private int executeUpdateWithExecutionContext(final ExecutionContext executionContext) throws SQLException {
return isNeedImplicitCommitTransaction(connection, sqlStatementContext.getSqlStatement(), executionContext.getExecutionUnits().size() > 1)
? executeWithImplicitCommitTransaction(() -> useDriverToExecuteUpdate(executionContext), connection, metaDataContexts.getMetaData().getDatabase(databaseName).getProtocolType())
: useDriverToExecuteUpdate(executionContext);
}

private boolean useDriverToExecute(final ExecutionContext executionContext) throws SQLException {
ExecutionGroupContext<JDBCExecutionUnit> executionGroupContext = createExecutionGroupContext(executionContext);
cacheStatements(executionGroupContext.getInputGroups());
Expand Down
Loading

0 comments on commit 1118bcd

Please sign in to comment.