Skip to content

Commit

Permalink
Move execute to DriverExecutor
Browse files Browse the repository at this point in the history
  • Loading branch information
terrymanu committed May 31, 2024
1 parent f82db4e commit 4f7dd61
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 81 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,16 @@
import org.apache.shardingsphere.infra.session.query.QueryContext;
import org.apache.shardingsphere.mode.metadata.MetaDataContexts;
import org.apache.shardingsphere.sql.parser.sql.common.statement.SQLStatement;
import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.DMLStatement;
import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.SelectStatement;
import org.apache.shardingsphere.sqlfederation.engine.SQLFederationEngine;
import org.apache.shardingsphere.sqlfederation.executor.context.SQLFederationContext;
import org.apache.shardingsphere.traffic.executor.TrafficExecutor;
import org.apache.shardingsphere.traffic.executor.TrafficExecutorCallback;
import org.apache.shardingsphere.traffic.rule.TrafficRule;
import org.apache.shardingsphere.transaction.api.TransactionType;
import org.apache.shardingsphere.transaction.implicit.ImplicitTransactionCallback;
import org.apache.shardingsphere.transaction.rule.TransactionRule;

import java.sql.Connection;
import java.sql.PreparedStatement;
Expand Down Expand Up @@ -260,21 +264,21 @@ private List<ResultSet> getResultSets() throws SQLException {
* @param prepareEngine prepare engine
* @param trafficCallback traffic callback
* @param updateCallback update callback
* @param isNeedImplicitCommitTransaction is need implicit commit transaction
* @param statementReplayCallback statement replay callback
* @param executionContext execution context
* @return updated row count
* @throws SQLException SQL exception
*/
@SuppressWarnings("rawtypes")
public int executeUpdate(final ShardingSphereMetaData metaData, final ShardingSphereDatabase database, final QueryContext queryContext,
final DriverExecutionPrepareEngine<JDBCExecutionUnit, Connection> prepareEngine,
final TrafficExecutorCallback<Integer> trafficCallback, final ExecuteUpdateCallback updateCallback, final StatementReplayCallback statementReplayCallback,
final boolean isNeedImplicitCommitTransaction, final ExecutionContext executionContext) throws SQLException {
final DriverExecutionPrepareEngine<JDBCExecutionUnit, Connection> prepareEngine, final TrafficExecutorCallback<Integer> trafficCallback,
final ExecuteUpdateCallback updateCallback, final StatementReplayCallback statementReplayCallback) throws SQLException {
ExecutionContext executionContext = createExecutionContext(metaData, database, queryContext);
Optional<String> trafficInstanceId = connection.getTrafficInstanceId(metaData.getGlobalRuleMetaData().getSingleRule(TrafficRule.class), queryContext);
if (trafficInstanceId.isPresent()) {
return trafficExecutor.execute(connection.getProcessId(), database.getName(), trafficInstanceId.get(), queryContext, prepareEngine, trafficCallback);
}
boolean isNeedImplicitCommitTransaction = isNeedImplicitCommitTransaction(
connection, queryContext.getSqlStatementContext().getSqlStatement(), executionContext.getExecutionUnits().size() > 1);
return database.getRuleMetaData().getAttributes(RawExecutionRuleAttribute.class).isEmpty()
? executeUpdate(database, updateCallback, queryContext.getSqlStatementContext(), executionContext, prepareEngine, isNeedImplicitCommitTransaction, statementReplayCallback)
: accumulate(rawExecutor.execute(createRawExecutionGroupContext(metaData, database, executionContext), queryContext, new RawSQLExecutorCallback()));
Expand Down Expand Up @@ -362,18 +366,15 @@ private int accumulate(final Collection<ExecuteResult> results) {
* @param queryContext query context
* @param prepareEngine prepare engine
* @param trafficCallback traffic callback
* @param isNeedImplicitCommitTransaction is need implicit commit transaction
* @param executeCallback execute callback
* @param statementReplayCallback statement replay callback
* @param executionContext execution context
* @return execute result
* @throws SQLException SQL exception
*/
@SuppressWarnings("rawtypes")
public boolean executeAdvance(final ShardingSphereMetaData metaData, final ShardingSphereDatabase database,
final QueryContext queryContext, final DriverExecutionPrepareEngine<JDBCExecutionUnit, Connection> prepareEngine,
final TrafficExecutorCallback<Boolean> trafficCallback, final boolean isNeedImplicitCommitTransaction,
final ExecuteCallback executeCallback, final StatementReplayCallback statementReplayCallback, final ExecutionContext executionContext) throws SQLException {
public boolean executeAdvance(final ShardingSphereMetaData metaData, final ShardingSphereDatabase database, final QueryContext queryContext,
final DriverExecutionPrepareEngine<JDBCExecutionUnit, Connection> prepareEngine, final TrafficExecutorCallback<Boolean> trafficCallback,
final ExecuteCallback executeCallback, final StatementReplayCallback statementReplayCallback) throws SQLException {
Optional<String> trafficInstanceId = connection.getTrafficInstanceId(metaData.getGlobalRuleMetaData().getSingleRule(TrafficRule.class), queryContext);
if (trafficInstanceId.isPresent()) {
executeType = ExecuteType.TRAFFIC;
Expand All @@ -385,10 +386,13 @@ public boolean executeAdvance(final ShardingSphereMetaData metaData, final Shard
prepareEngine, getExecuteQueryCallback(database, queryContext, prepareEngine.getType()), new SQLFederationContext(false, queryContext, metaData, connection.getProcessId()));
return null != resultSet;
}
ExecutionContext executionContext = createExecutionContext(metaData, database, queryContext);
if (!database.getRuleMetaData().getAttributes(RawExecutionRuleAttribute.class).isEmpty()) {
Collection<ExecuteResult> results = rawExecutor.execute(createRawExecutionGroupContext(metaData, database, executionContext), queryContext, new RawSQLExecutorCallback());
return results.iterator().next() instanceof QueryResult;
}
boolean isNeedImplicitCommitTransaction = isNeedImplicitCommitTransaction(
connection, queryContext.getSqlStatementContext().getSqlStatement(), executionContext.getExecutionUnits().size() > 1);
return executeWithExecutionContext(database, executeCallback, executionContext, prepareEngine, isNeedImplicitCommitTransaction, statementReplayCallback);
}

Expand Down Expand Up @@ -434,6 +438,22 @@ protected Optional<Boolean> getSaneResult(final SQLStatement sqlStatement1, fina
};
}

private boolean isNeedImplicitCommitTransaction(final ShardingSphereConnection connection, final SQLStatement sqlStatement, final boolean multiExecutionUnits) {
if (!connection.getAutoCommit()) {
return false;
}
TransactionType transactionType = connection.getContextManager().getMetaDataContexts().getMetaData().getGlobalRuleMetaData().getSingleRule(TransactionRule.class).getDefaultType();
boolean isInTransaction = connection.getDatabaseConnectionManager().getConnectionTransaction().isInTransaction();
if (!TransactionType.isDistributedTransaction(transactionType) || isInTransaction) {
return false;
}
return isWriteDMLStatement(sqlStatement) && multiExecutionUnits;
}

private boolean isWriteDMLStatement(final SQLStatement sqlStatement) {
return sqlStatement instanceof DMLStatement && !(sqlStatement instanceof SelectStatement);
}

/**
* Get advanced result set.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,8 @@
import org.apache.shardingsphere.infra.database.core.metadata.database.DialectDatabaseMetaData;
import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
import org.apache.shardingsphere.infra.database.core.type.DatabaseTypeRegistry;
import org.apache.shardingsphere.infra.exception.dialect.SQLExceptionTransformEngine;
import org.apache.shardingsphere.mode.metadata.MetaDataContexts;
import org.apache.shardingsphere.sql.parser.sql.common.statement.SQLStatement;
import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.DMLStatement;
import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.SelectStatement;
import org.apache.shardingsphere.transaction.api.TransactionType;
import org.apache.shardingsphere.transaction.implicit.ImplicitTransactionCallback;
import org.apache.shardingsphere.transaction.rule.TransactionRule;

import java.sql.Connection;
import java.sql.SQLException;
import java.sql.SQLFeatureNotSupportedException;
import java.sql.SQLWarning;
Expand All @@ -61,39 +53,6 @@ public abstract class AbstractStatementAdapter extends WrapperAdapter implements

private boolean closeOnCompletion;

protected final boolean isNeedImplicitCommitTransaction(final ShardingSphereConnection connection, final SQLStatement sqlStatement, final boolean multiExecutionUnits) {
if (!connection.getAutoCommit()) {
return false;
}
TransactionType transactionType = connection.getContextManager().getMetaDataContexts().getMetaData().getGlobalRuleMetaData().getSingleRule(TransactionRule.class).getDefaultType();
boolean isInTransaction = connection.getDatabaseConnectionManager().getConnectionTransaction().isInTransaction();
if (!TransactionType.isDistributedTransaction(transactionType) || isInTransaction) {
return false;
}
return isWriteDMLStatement(sqlStatement) && multiExecutionUnits;
}

protected final <T> T executeWithImplicitCommitTransaction(final ImplicitTransactionCallback<T> callback, final Connection connection, final DatabaseType databaseType) throws SQLException {
T result;
try {
connection.setAutoCommit(false);
result = callback.execute();
connection.commit();
// CHECKSTYLE:OFF
} catch (final Exception ex) {
// CHECKSTYLE:ON
connection.rollback();
throw SQLExceptionTransformEngine.toSQLException(ex, databaseType);
} finally {
connection.setAutoCommit(true);
}
return result;
}

private boolean isWriteDMLStatement(final SQLStatement sqlStatement) {
return sqlStatement instanceof DMLStatement && !(sqlStatement instanceof SelectStatement);
}

protected final void handleExceptionInTransaction(final ShardingSphereConnection connection, final MetaDataContexts metaDataContexts) {
if (connection.getDatabaseConnectionManager().getConnectionTransaction().isInTransaction()) {
DatabaseType databaseType = metaDataContexts.getMetaData().getDatabase(connection.getDatabaseName()).getProtocolType();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,11 +259,9 @@ public int executeUpdate() throws SQLException {
QueryContext queryContext = createQueryContext();
handleAutoCommit(queryContext.getSqlStatementContext().getSqlStatement());
ShardingSphereDatabase database = metaDataContexts.getMetaData().getDatabase(databaseName);
ExecutionContext executionContext = createExecutionContext(queryContext);
boolean isNeedImplicitCommitTransaction = isNeedImplicitCommitTransaction(connection, sqlStatementContext.getSqlStatement(), executionContext.getExecutionUnits().size() > 1);
findGeneratedKey().ifPresent(optional -> generatedValues.addAll(optional.getGeneratedValues()));
int result = executor.executeUpdate(metaDataContexts.getMetaData(), database, queryContext, createDriverExecutionPrepareEngine(database),
(statement, sql) -> ((PreparedStatement) statement).executeUpdate(), null, (StatementReplayCallback<PreparedStatement>) this::replay,
isNeedImplicitCommitTransaction, executionContext);
(statement, sql) -> ((PreparedStatement) statement).executeUpdate(), null, (StatementReplayCallback<PreparedStatement>) this::replay);
for (Statement each : executor.getStatements()) {
statements.add((PreparedStatement) each);
}
Expand All @@ -290,11 +288,9 @@ public boolean execute() throws SQLException {
QueryContext queryContext = createQueryContext();
handleAutoCommit(queryContext.getSqlStatementContext().getSqlStatement());
ShardingSphereDatabase database = metaDataContexts.getMetaData().getDatabase(databaseName);
ExecutionContext executionContext = createExecutionContext(queryContext);
boolean isNeedImplicitCommitTransaction = isNeedImplicitCommitTransaction(connection, sqlStatementContext.getSqlStatement(), executionContext.getExecutionUnits().size() > 1);
boolean result = executor.executeAdvance(
metaDataContexts.getMetaData(), database, queryContext, createDriverExecutionPrepareEngine(database), (statement, sql) -> ((PreparedStatement) statement).execute(),
isNeedImplicitCommitTransaction, null, (StatementReplayCallback<PreparedStatement>) this::replay, executionContext);
null, (StatementReplayCallback<PreparedStatement>) this::replay);
for (Statement each : executor.getStatements()) {
statements.add((PreparedStatement) each);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,10 @@
import org.apache.shardingsphere.infra.binder.context.statement.dml.SelectStatementContext;
import org.apache.shardingsphere.infra.binder.engine.SQLBindEngine;
import org.apache.shardingsphere.infra.config.props.ConfigurationPropertyKey;
import org.apache.shardingsphere.infra.connection.kernel.KernelProcessor;
import org.apache.shardingsphere.infra.database.mysql.type.MySQLDatabaseType;
import org.apache.shardingsphere.infra.exception.core.ShardingSpherePreconditions;
import org.apache.shardingsphere.infra.exception.dialect.SQLExceptionTransformEngine;
import org.apache.shardingsphere.infra.exception.kernel.syntax.EmptySQLException;
import org.apache.shardingsphere.infra.executor.audit.SQLAuditEngine;
import org.apache.shardingsphere.infra.executor.sql.context.ExecutionContext;
import org.apache.shardingsphere.infra.executor.sql.execute.engine.driver.jdbc.JDBCExecutionUnit;
import org.apache.shardingsphere.infra.executor.sql.execute.result.query.QueryResult;
import org.apache.shardingsphere.infra.executor.sql.execute.result.query.impl.driver.jdbc.type.stream.JDBCStreamQueryResult;
Expand All @@ -52,7 +49,6 @@
import org.apache.shardingsphere.infra.merge.MergeEngine;
import org.apache.shardingsphere.infra.merge.result.MergedResult;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData;
import org.apache.shardingsphere.infra.rule.attribute.datanode.DataNodeRuleAttribute;
import org.apache.shardingsphere.infra.session.query.QueryContext;
import org.apache.shardingsphere.mode.metadata.MetaDataContexts;
Expand Down Expand Up @@ -91,8 +87,6 @@ public final class ShardingSphereStatement extends AbstractStatementAdapter {
@Getter(AccessLevel.PROTECTED)
private final DriverExecutor executor;

private final KernelProcessor kernelProcessor;

@Getter(AccessLevel.PROTECTED)
private final StatementManager statementManager;

Expand Down Expand Up @@ -120,7 +114,6 @@ public ShardingSphereStatement(final ShardingSphereConnection connection, final
statements = new LinkedList<>();
statementOption = new StatementOption(resultSetType, resultSetConcurrency, resultSetHoldability);
executor = new DriverExecutor(connection);
kernelProcessor = new KernelProcessor();
statementManager = new StatementManager();
batchStatementExecutor = new BatchStatementExecutor(this);
databaseName = connection.getDatabaseName();
Expand Down Expand Up @@ -224,11 +217,10 @@ private int executeUpdate(final String sql, final ExecuteUpdateCallback updateCa
connection.getDatabaseConnectionManager().getConnectionContext().setCurrentDatabase(databaseName);
ShardingSphereDatabase database = metaDataContexts.getMetaData().getDatabase(databaseName);
sqlStatementContext = queryContext.getSqlStatementContext();
ExecutionContext executionContext = createExecutionContext(queryContext);
boolean isNeedImplicitCommitTransaction = isNeedImplicitCommitTransaction(connection, sqlStatementContext.getSqlStatement(), executionContext.getExecutionUnits().size() > 1);
clearStatements();
int result = executor.executeUpdate(
metaDataContexts.getMetaData(), database, queryContext, createDriverExecutionPrepareEngine(database), trafficCallback, updateCallback,
(StatementReplayCallback<Statement>) (statements, parameterSets) -> replay(statements), isNeedImplicitCommitTransaction, executionContext);
(StatementReplayCallback<Statement>) (statements, parameterSets) -> replay(statements));
statements.addAll(executor.getStatements());
replay(statements);
return result;
Expand Down Expand Up @@ -295,10 +287,8 @@ private boolean execute0(final String sql, final ExecuteCallback executeCallback
connection.getDatabaseConnectionManager().getConnectionContext().setCurrentDatabase(databaseName);
ShardingSphereDatabase database = metaDataContexts.getMetaData().getDatabase(databaseName);
sqlStatementContext = queryContext.getSqlStatementContext();
ExecutionContext executionContext = createExecutionContext(queryContext);
boolean isNeedImplicitCommitTransaction = isNeedImplicitCommitTransaction(connection, sqlStatementContext.getSqlStatement(), executionContext.getExecutionUnits().size() > 1);
boolean result = executor.executeAdvance(metaDataContexts.getMetaData(), database, queryContext, createDriverExecutionPrepareEngine(database), trafficCallback,
isNeedImplicitCommitTransaction, executeCallback, (StatementReplayCallback<Statement>) (statements, parameterSets) -> replay(statements), executionContext);
executeCallback, (StatementReplayCallback<Statement>) (statements, parameterSets) -> replay(statements));
statements.addAll(executor.getStatements());
return result;
}
Expand Down Expand Up @@ -341,15 +331,6 @@ private QueryContext createQueryContext(final String originSQL) {
return new QueryContext(sqlStatementContext, sql, Collections.emptyList(), hintValueContext);
}

private ExecutionContext createExecutionContext(final QueryContext queryContext) throws SQLException {
clearStatements();
RuleMetaData globalRuleMetaData = metaDataContexts.getMetaData().getGlobalRuleMetaData();
ShardingSphereDatabase currentDatabase = metaDataContexts.getMetaData().getDatabase(databaseName);
SQLAuditEngine.audit(queryContext.getSqlStatementContext(), queryContext.getParameters(), globalRuleMetaData, currentDatabase, null, queryContext.getHintValueContext());
return kernelProcessor.generateExecutionContext(queryContext, currentDatabase, globalRuleMetaData, metaDataContexts.getMetaData().getProps(),
connection.getDatabaseConnectionManager().getConnectionContext());
}

private void replay(final List<Statement> statements) throws SQLException {
for (Statement each : statements) {
getMethodInvocationRecorder().replay(each);
Expand Down

0 comments on commit 4f7dd61

Please sign in to comment.