Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move executeUpdate to DriverExecutor #31452

Merged
merged 2 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import lombok.Getter;
import org.apache.shardingsphere.driver.executor.callback.ExecuteQueryCallback;
import org.apache.shardingsphere.driver.executor.callback.ExecuteUpdateCallback;
import org.apache.shardingsphere.driver.executor.callback.impl.PreparedStatementExecuteQueryCallback;
import org.apache.shardingsphere.driver.executor.callback.impl.StatementExecuteQueryCallback;
import org.apache.shardingsphere.driver.jdbc.core.connection.ShardingSphereConnection;
Expand All @@ -29,20 +30,26 @@
import org.apache.shardingsphere.infra.binder.context.statement.dml.SelectStatementContext;
import org.apache.shardingsphere.infra.config.props.ConfigurationPropertyKey;
import org.apache.shardingsphere.infra.connection.kernel.KernelProcessor;
import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
import org.apache.shardingsphere.infra.database.core.type.DatabaseTypeRegistry;
import org.apache.shardingsphere.infra.exception.dialect.SQLExceptionTransformEngine;
import org.apache.shardingsphere.infra.executor.audit.SQLAuditEngine;
import org.apache.shardingsphere.infra.executor.kernel.ExecutorEngine;
import org.apache.shardingsphere.infra.executor.kernel.model.ExecutionGroup;
import org.apache.shardingsphere.infra.executor.kernel.model.ExecutionGroupContext;
import org.apache.shardingsphere.infra.executor.kernel.model.ExecutionGroupReportContext;
import org.apache.shardingsphere.infra.executor.sql.context.ExecutionContext;
import org.apache.shardingsphere.infra.executor.sql.execute.engine.ConnectionMode;
import org.apache.shardingsphere.infra.executor.sql.execute.engine.SQLExecutorExceptionHandler;
import org.apache.shardingsphere.infra.executor.sql.execute.engine.driver.jdbc.JDBCExecutionUnit;
import org.apache.shardingsphere.infra.executor.sql.execute.engine.driver.jdbc.JDBCExecutor;
import org.apache.shardingsphere.infra.executor.sql.execute.engine.driver.jdbc.JDBCExecutorCallback;
import org.apache.shardingsphere.infra.executor.sql.execute.engine.raw.RawExecutor;
import org.apache.shardingsphere.infra.executor.sql.execute.engine.raw.RawSQLExecutionUnit;
import org.apache.shardingsphere.infra.executor.sql.execute.engine.raw.callback.RawSQLExecutorCallback;
import org.apache.shardingsphere.infra.executor.sql.execute.result.ExecuteResult;
import org.apache.shardingsphere.infra.executor.sql.execute.result.query.QueryResult;
import org.apache.shardingsphere.infra.executor.sql.execute.result.update.UpdateResult;
import org.apache.shardingsphere.infra.executor.sql.prepare.driver.DriverExecutionPrepareEngine;
import org.apache.shardingsphere.infra.executor.sql.prepare.driver.jdbc.JDBCDriverType;
import org.apache.shardingsphere.infra.executor.sql.prepare.raw.RawExecutionPrepareEngine;
Expand All @@ -55,11 +62,13 @@
import org.apache.shardingsphere.infra.rule.attribute.raw.RawExecutionRuleAttribute;
import org.apache.shardingsphere.infra.session.query.QueryContext;
import org.apache.shardingsphere.mode.metadata.MetaDataContexts;
import org.apache.shardingsphere.sql.parser.sql.common.statement.SQLStatement;
import org.apache.shardingsphere.sqlfederation.engine.SQLFederationEngine;
import org.apache.shardingsphere.sqlfederation.executor.context.SQLFederationContext;
import org.apache.shardingsphere.traffic.executor.TrafficExecutor;
import org.apache.shardingsphere.traffic.executor.TrafficExecutorCallback;
import org.apache.shardingsphere.traffic.rule.TrafficRule;
import org.apache.shardingsphere.transaction.implicit.ImplicitTransactionCallback;

import java.sql.Connection;
import java.sql.PreparedStatement;
Expand Down Expand Up @@ -249,17 +258,99 @@ private List<ResultSet> getResultSets() throws SQLException {
* @param queryContext query context
* @param prepareEngine prepare engine
* @param trafficCallback traffic callback
* @param updateCallback update callback
* @param isNeedImplicitCommitTransaction is need implicit commit transaction
* @param statementReplayCallback statement replay callback
* @param executionContext execution context
* @return updated row count
* @throws SQLException SQL exception
*/
public Optional<Integer> executeAdvanceUpdate(final ShardingSphereMetaData metaData, final ShardingSphereDatabase database, final QueryContext queryContext,
final DriverExecutionPrepareEngine<JDBCExecutionUnit, Connection> prepareEngine,
final TrafficExecutorCallback<Integer> trafficCallback) throws SQLException {
@SuppressWarnings("rawtypes")
public int executeAdvanceUpdate(final ShardingSphereMetaData metaData, final ShardingSphereDatabase database, final QueryContext queryContext,
final DriverExecutionPrepareEngine<JDBCExecutionUnit, Connection> prepareEngine, final TrafficExecutorCallback<Integer> trafficCallback,
final ExecuteUpdateCallback updateCallback, final boolean isNeedImplicitCommitTransaction,
final StatementReplayCallback statementReplayCallback, final ExecutionContext executionContext) throws SQLException {
Optional<String> trafficInstanceId = connection.getTrafficInstanceId(metaData.getGlobalRuleMetaData().getSingleRule(TrafficRule.class), queryContext);
if (trafficInstanceId.isPresent()) {
return Optional.of(trafficExecutor.execute(connection.getProcessId(), database.getName(), trafficInstanceId.get(), queryContext, prepareEngine, trafficCallback));
return trafficExecutor.execute(connection.getProcessId(), database.getName(), trafficInstanceId.get(), queryContext, prepareEngine, trafficCallback);
}
return Optional.empty();
return database.getRuleMetaData().getAttributes(RawExecutionRuleAttribute.class).isEmpty()
? executeUpdate(database, updateCallback, queryContext.getSqlStatementContext(), executionContext, prepareEngine, isNeedImplicitCommitTransaction, statementReplayCallback)
: accumulate(rawExecutor.execute(createRawExecutionGroupContext(metaData, database, executionContext), queryContext, new RawSQLExecutorCallback()));
}

@SuppressWarnings("rawtypes")
private int executeUpdate(final ShardingSphereDatabase database, final ExecuteUpdateCallback updateCallback, final SQLStatementContext sqlStatementContext, final ExecutionContext executionContext,
final DriverExecutionPrepareEngine<JDBCExecutionUnit, Connection> prepareEngine, final boolean isNeedImplicitCommitTransaction,
final StatementReplayCallback statementReplayCallback) throws SQLException {
return isNeedImplicitCommitTransaction
? executeWithImplicitCommitTransaction(() -> useDriverToExecuteUpdate(
database, updateCallback, sqlStatementContext, executionContext, prepareEngine, statementReplayCallback), connection, database.getProtocolType())
: useDriverToExecuteUpdate(database, updateCallback, sqlStatementContext, executionContext, prepareEngine, statementReplayCallback);
}

private <T> T executeWithImplicitCommitTransaction(final ImplicitTransactionCallback<T> callback, final Connection connection, final DatabaseType databaseType) throws SQLException {
T result;
try {
connection.setAutoCommit(false);
result = callback.execute();
connection.commit();
// CHECKSTYLE:OFF
} catch (final Exception ex) {
// CHECKSTYLE:ON
connection.rollback();
throw SQLExceptionTransformEngine.toSQLException(ex, databaseType);
} finally {
connection.setAutoCommit(true);
}
return result;
}

@SuppressWarnings({"rawtypes", "unchecked"})
private int useDriverToExecuteUpdate(final ShardingSphereDatabase database, final ExecuteUpdateCallback updateCallback, final SQLStatementContext sqlStatementContext,
final ExecutionContext executionContext, final DriverExecutionPrepareEngine<JDBCExecutionUnit, Connection> prepareEngine,
final StatementReplayCallback statementReplayCallback) throws SQLException {
ExecutionGroupContext<JDBCExecutionUnit> executionGroupContext = createExecutionGroupContext(database, executionContext, prepareEngine);
for (ExecutionGroup<JDBCExecutionUnit> each : executionGroupContext.getInputGroups()) {
statements.addAll(getStatements(each));
if (JDBCDriverType.PREPARED_STATEMENT.equals(prepareEngine.getType())) {
parameterSets.addAll(getParameterSets(each));
}
}
statementReplayCallback.replay(statements, parameterSets);
JDBCExecutorCallback<Integer> callback = createExecuteUpdateCallback(database, updateCallback, sqlStatementContext, prepareEngine.getType());
return regularExecutor.executeUpdate(executionGroupContext, executionContext.getQueryContext(), executionContext.getRouteContext().getRouteUnits(), callback);
}

private ExecutionGroupContext<JDBCExecutionUnit> createExecutionGroupContext(final ShardingSphereDatabase database, final ExecutionContext executionContext,
final DriverExecutionPrepareEngine<JDBCExecutionUnit, Connection> prepareEngine) throws SQLException {
return prepareEngine.prepare(executionContext.getRouteContext(), executionContext.getExecutionUnits(),
new ExecutionGroupReportContext(connection.getProcessId(), database.getName(), new Grantee("", "")));
}

private JDBCExecutorCallback<Integer> createExecuteUpdateCallback(final ShardingSphereDatabase database,
final ExecuteUpdateCallback updateCallback, final SQLStatementContext sqlStatementContext, final String jdbcDriverType) {
boolean isExceptionThrown = SQLExecutorExceptionHandler.isExceptionThrown();
return new JDBCExecutorCallback<Integer>(database.getProtocolType(), database.getResourceMetaData(), sqlStatementContext.getSqlStatement(), isExceptionThrown) {

@Override
protected Integer executeSQL(final String sql, final Statement statement, final ConnectionMode connectionMode, final DatabaseType storageType) throws SQLException {
return JDBCDriverType.STATEMENT.equals(jdbcDriverType) ? updateCallback.executeUpdate(sql, statement) : ((PreparedStatement) statement).executeUpdate();
}

@Override
protected Optional<Integer> getSaneResult(final SQLStatement sqlStatement, final SQLException ex) {
return Optional.empty();
}
};
}

private int accumulate(final Collection<ExecuteResult> results) {
int result = 0;
for (ExecuteResult each : results) {
result += ((UpdateResult) each).getUpdateCount();
}
return result;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@
import org.apache.shardingsphere.infra.executor.sql.execute.result.ExecuteResult;
import org.apache.shardingsphere.infra.executor.sql.execute.result.query.QueryResult;
import org.apache.shardingsphere.infra.executor.sql.execute.result.query.impl.driver.jdbc.type.stream.JDBCStreamQueryResult;
import org.apache.shardingsphere.infra.executor.sql.execute.result.update.UpdateResult;
import org.apache.shardingsphere.infra.executor.sql.prepare.driver.DriverExecutionPrepareEngine;
import org.apache.shardingsphere.infra.executor.sql.prepare.driver.jdbc.JDBCDriverType;
import org.apache.shardingsphere.infra.executor.sql.prepare.driver.jdbc.StatementOption;
Expand Down Expand Up @@ -277,18 +276,16 @@ public int executeUpdate() throws SQLException {
QueryContext queryContext = createQueryContext();
handleAutoCommit(queryContext.getSqlStatementContext().getSqlStatement());
ShardingSphereDatabase database = metaDataContexts.getMetaData().getDatabase(databaseName);
Optional<Integer> updatedCount = executor.executeAdvanceUpdate(metaDataContexts.getMetaData(), database, queryContext, createDriverExecutionPrepareEngine(database),
(statement, sql) -> ((PreparedStatement) statement).executeUpdate());
if (updatedCount.isPresent()) {
return updatedCount.get();
}
ExecutionContext executionContext = createExecutionContext(queryContext);
if (hasRawExecutionRule()) {
Collection<ExecuteResult> results =
executor.getRawExecutor().execute(createRawExecutionGroupContext(executionContext), executionContext.getQueryContext(), new RawSQLExecutorCallback());
return accumulate(results);
boolean isNeedImplicitCommitTransaction = isNeedImplicitCommitTransaction(connection, sqlStatementContext.getSqlStatement(), executionContext.getExecutionUnits().size() > 1);
final int result = executor.executeAdvanceUpdate(metaDataContexts.getMetaData(), database, queryContext, createDriverExecutionPrepareEngine(database),
(statement, sql) -> ((PreparedStatement) statement).executeUpdate(), null, isNeedImplicitCommitTransaction, (StatementReplayCallback<PreparedStatement>) this::replay,
executionContext);
for (Statement each : executor.getStatements()) {
statements.add((PreparedStatement) each);
}
return executeUpdateWithExecutionContext(executionContext);
parameterSets.addAll(executor.getParameterSets());
return result;
// CHECKSTYLE:OFF
} catch (final RuntimeException ex) {
// CHECKSTYLE:ON
Expand All @@ -299,38 +296,6 @@ public int executeUpdate() throws SQLException {
}
}

private int useDriverToExecuteUpdate(final ExecutionContext executionContext) throws SQLException {
ExecutionGroupContext<JDBCExecutionUnit> executionGroupContext = createExecutionGroupContext(executionContext);
cacheStatements(executionGroupContext.getInputGroups());
return executor.getRegularExecutor().executeUpdate(executionGroupContext,
executionContext.getQueryContext(), executionContext.getRouteContext().getRouteUnits(), createExecuteUpdateCallback());
}

private int accumulate(final Collection<ExecuteResult> results) {
int result = 0;
for (ExecuteResult each : results) {
result += ((UpdateResult) each).getUpdateCount();
}
return result;
}

private JDBCExecutorCallback<Integer> createExecuteUpdateCallback() {
boolean isExceptionThrown = SQLExecutorExceptionHandler.isExceptionThrown();
return new JDBCExecutorCallback<Integer>(metaDataContexts.getMetaData().getDatabase(databaseName).getProtocolType(),
metaDataContexts.getMetaData().getDatabase(databaseName).getResourceMetaData(), sqlStatement, isExceptionThrown) {

@Override
protected Integer executeSQL(final String sql, final Statement statement, final ConnectionMode connectionMode, final DatabaseType storageType) throws SQLException {
return ((PreparedStatement) statement).executeUpdate();
}

@Override
protected Optional<Integer> getSaneResult(final SQLStatement sqlStatement, final SQLException ex) {
return Optional.empty();
}
};
}

@Override
public boolean execute() throws SQLException {
try {
Expand Down Expand Up @@ -370,12 +335,6 @@ private boolean executeWithExecutionContext(final ExecutionContext executionCont
: useDriverToExecute(executionContext);
}

private int executeUpdateWithExecutionContext(final ExecutionContext executionContext) throws SQLException {
return isNeedImplicitCommitTransaction(connection, sqlStatementContext.getSqlStatement(), executionContext.getExecutionUnits().size() > 1)
? executeWithImplicitCommitTransaction(() -> useDriverToExecuteUpdate(executionContext), connection, metaDataContexts.getMetaData().getDatabase(databaseName).getProtocolType())
: useDriverToExecuteUpdate(executionContext);
}

private boolean useDriverToExecute(final ExecutionContext executionContext) throws SQLException {
ExecutionGroupContext<JDBCExecutionUnit> executionGroupContext = createExecutionGroupContext(executionContext);
cacheStatements(executionGroupContext.getInputGroups());
Expand Down
Loading