Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor TrafficExecutorCallback #31522

Merged
merged 3 commits into from
Jun 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ public ResultSet executeQuery(final ShardingSphereDatabase database, final Query
}

private TrafficExecutorCallback<ResultSet> getTrafficExecuteQueryCallback(final String jdbcDriverType) {
return JDBCDriverType.STATEMENT.equals(jdbcDriverType) ? Statement::executeQuery : ((statement, sql) -> ((PreparedStatement) statement).executeQuery());
return JDBCDriverType.STATEMENT.equals(jdbcDriverType) ? ((sql, statement) -> statement.executeQuery(sql)) : ((sql, statement) -> ((PreparedStatement) statement).executeQuery());
}

private ExecuteQueryCallback getExecuteQueryCallback(final ShardingSphereDatabase database, final QueryContext queryContext, final String jdbcDriverType) {
Expand Down Expand Up @@ -257,19 +257,18 @@ private List<ResultSet> getResultSets() throws SQLException {
* @param database database
* @param queryContext query context
* @param prepareEngine prepare engine
* @param trafficCallback traffic callback
* @param updateCallback update callback
* @param statementReplayCallback statement replay callback
* @return updated row count
* @throws SQLException SQL exception
*/
@SuppressWarnings("rawtypes")
public int executeUpdate(final ShardingSphereDatabase database, final QueryContext queryContext,
final DriverExecutionPrepareEngine<JDBCExecutionUnit, Connection> prepareEngine, final TrafficExecutorCallback<Integer> trafficCallback,
final DriverExecutionPrepareEngine<JDBCExecutionUnit, Connection> prepareEngine,
final ExecuteUpdateCallback updateCallback, final StatementReplayCallback statementReplayCallback) throws SQLException {
Optional<String> trafficInstanceId = connection.getTrafficInstanceId(metaData.getGlobalRuleMetaData().getSingleRule(TrafficRule.class), queryContext);
if (trafficInstanceId.isPresent()) {
return trafficExecutor.execute(connection.getProcessId(), database.getName(), trafficInstanceId.get(), queryContext, prepareEngine, trafficCallback);
return trafficExecutor.execute(connection.getProcessId(), database.getName(), trafficInstanceId.get(), queryContext, prepareEngine, updateCallback::executeUpdate);
}
ExecutionContext executionContext = createExecutionContext(database, queryContext);
return database.getRuleMetaData().getAttributes(RawExecutionRuleAttribute.class).isEmpty()
Expand Down Expand Up @@ -360,20 +359,19 @@ private int accumulate(final Collection<ExecuteResult> results) {
* @param database database
* @param queryContext query context
* @param prepareEngine prepare engine
* @param trafficCallback traffic callback
* @param executeCallback execute callback
* @param statementReplayCallback statement replay callback
* @return execute result
* @throws SQLException SQL exception
*/
@SuppressWarnings("rawtypes")
public boolean executeAdvance(final ShardingSphereDatabase database, final QueryContext queryContext,
final DriverExecutionPrepareEngine<JDBCExecutionUnit, Connection> prepareEngine, final TrafficExecutorCallback<Boolean> trafficCallback,
final DriverExecutionPrepareEngine<JDBCExecutionUnit, Connection> prepareEngine,
final ExecuteCallback executeCallback, final StatementReplayCallback statementReplayCallback) throws SQLException {
Optional<String> trafficInstanceId = connection.getTrafficInstanceId(metaData.getGlobalRuleMetaData().getSingleRule(TrafficRule.class), queryContext);
if (trafficInstanceId.isPresent()) {
executeType = ExecuteType.TRAFFIC;
return trafficExecutor.execute(connection.getProcessId(), database.getName(), trafficInstanceId.get(), queryContext, prepareEngine, trafficCallback);
return trafficExecutor.execute(connection.getProcessId(), database.getName(), trafficInstanceId.get(), queryContext, prepareEngine, executeCallback::execute);
}
if (sqlFederationEngine.decide(queryContext.getSqlStatementContext(), queryContext.getParameters(), database, metaData.getGlobalRuleMetaData())) {
executeType = ExecuteType.FEDERATION;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ public int executeUpdate() throws SQLException {
handleAutoCommit(queryContext.getSqlStatementContext().getSqlStatement());
ShardingSphereDatabase database = metaData.getDatabase(databaseName);
final int result = executor.executeUpdate(database, queryContext, createDriverExecutionPrepareEngine(database),
(statement, sql) -> ((PreparedStatement) statement).executeUpdate(), null, (StatementReplayCallback<PreparedStatement>) this::replay);
(sql, statement) -> ((PreparedStatement) statement).executeUpdate(), (StatementReplayCallback<PreparedStatement>) this::replay);
for (Statement each : executor.getStatements()) {
statements.add((PreparedStatement) each);
}
Expand Down Expand Up @@ -289,8 +289,8 @@ public boolean execute() throws SQLException {
QueryContext queryContext = createQueryContext();
handleAutoCommit(queryContext.getSqlStatementContext().getSqlStatement());
ShardingSphereDatabase database = metaData.getDatabase(databaseName);
final boolean result = executor.executeAdvance(database, queryContext, createDriverExecutionPrepareEngine(database), (statement, sql) -> ((PreparedStatement) statement).execute(),
null, (StatementReplayCallback<PreparedStatement>) this::replay);
final boolean result = executor.executeAdvance(database, queryContext, createDriverExecutionPrepareEngine(database), (sql, statement) -> ((PreparedStatement) statement).execute(),
(StatementReplayCallback<PreparedStatement>) this::replay);
for (Statement each : executor.getStatements()) {
statements.add((PreparedStatement) each);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@
import org.apache.shardingsphere.parser.rule.SQLParserRule;
import org.apache.shardingsphere.sql.parser.sql.common.statement.SQLStatement;
import org.apache.shardingsphere.sql.parser.sql.common.statement.dal.DALStatement;
import org.apache.shardingsphere.traffic.executor.TrafficExecutorCallback;
import org.apache.shardingsphere.transaction.util.AutoCommitUtils;

import java.sql.Connection;
Expand Down Expand Up @@ -151,7 +150,7 @@ private DriverExecutionPrepareEngine<JDBCExecutionUnit, Connection> createDriver
@Override
public int executeUpdate(final String sql) throws SQLException {
try {
return executeUpdate(sql, (actualSQL, statement) -> statement.executeUpdate(actualSQL), Statement::executeUpdate);
return executeUpdate(sql, (actualSQL, statement) -> statement.executeUpdate(actualSQL));
// CHECKSTYLE:OFF
} catch (final RuntimeException ex) {
// CHECKSTYLE:ON
Expand All @@ -168,8 +167,7 @@ public int executeUpdate(final String sql, final int autoGeneratedKeys) throws S
returnGeneratedKeys = true;
}
try {
return executeUpdate(sql, (actualSQL, statement) -> statement.executeUpdate(actualSQL, autoGeneratedKeys),
(statement, actualSQL) -> statement.executeUpdate(actualSQL, autoGeneratedKeys));
return executeUpdate(sql, (actualSQL, statement) -> statement.executeUpdate(actualSQL, autoGeneratedKeys));
// CHECKSTYLE:OFF
} catch (final RuntimeException ex) {
// CHECKSTYLE:ON
Expand All @@ -184,7 +182,7 @@ public int executeUpdate(final String sql, final int autoGeneratedKeys) throws S
public int executeUpdate(final String sql, final int[] columnIndexes) throws SQLException {
returnGeneratedKeys = true;
try {
return executeUpdate(sql, (actualSQL, statement) -> statement.executeUpdate(actualSQL, columnIndexes), (statement, actualSQL) -> statement.executeUpdate(actualSQL, columnIndexes));
return executeUpdate(sql, (actualSQL, statement) -> statement.executeUpdate(actualSQL, columnIndexes));
// CHECKSTYLE:OFF
} catch (final RuntimeException ex) {
// CHECKSTYLE:ON
Expand All @@ -199,7 +197,7 @@ public int executeUpdate(final String sql, final int[] columnIndexes) throws SQL
public int executeUpdate(final String sql, final String[] columnNames) throws SQLException {
returnGeneratedKeys = true;
try {
return executeUpdate(sql, (actualSQL, statement) -> statement.executeUpdate(actualSQL, columnNames), (statement, actualSQL) -> statement.executeUpdate(actualSQL, columnNames));
return executeUpdate(sql, (actualSQL, statement) -> statement.executeUpdate(actualSQL, columnNames));
// CHECKSTYLE:OFF
} catch (final RuntimeException ex) {
// CHECKSTYLE:ON
Expand All @@ -210,15 +208,15 @@ public int executeUpdate(final String sql, final String[] columnNames) throws SQ
}
}

private int executeUpdate(final String sql, final ExecuteUpdateCallback updateCallback, final TrafficExecutorCallback<Integer> trafficCallback) throws SQLException {
private int executeUpdate(final String sql, final ExecuteUpdateCallback updateCallback) throws SQLException {
QueryContext queryContext = createQueryContext(sql);
handleAutoCommit(queryContext.getSqlStatementContext().getSqlStatement());
databaseName = queryContext.getDatabaseNameFromSQLStatement().orElse(connection.getDatabaseName());
connection.getDatabaseConnectionManager().getConnectionContext().setCurrentDatabase(databaseName);
ShardingSphereDatabase database = metaData.getDatabase(databaseName);
sqlStatementContext = queryContext.getSqlStatementContext();
clearStatements();
int result = executor.executeUpdate(database, queryContext, createDriverExecutionPrepareEngine(database), trafficCallback, updateCallback,
int result = executor.executeUpdate(database, queryContext, createDriverExecutionPrepareEngine(database), updateCallback,
(StatementReplayCallback<Statement>) (statements, parameterSets) -> replay(statements));
statements.addAll(executor.getStatements());
replay(statements);
Expand All @@ -228,7 +226,7 @@ private int executeUpdate(final String sql, final ExecuteUpdateCallback updateCa
@Override
public boolean execute(final String sql) throws SQLException {
try {
return execute0(sql, (actualSQL, statement) -> statement.execute(actualSQL), Statement::execute);
return execute0(sql, (actualSQL, statement) -> statement.execute(actualSQL));
// CHECKSTYLE:OFF
} catch (final SQLException ex) {
// CHECKSTYLE:ON
Expand All @@ -243,7 +241,7 @@ public boolean execute(final String sql, final int autoGeneratedKeys) throws SQL
if (RETURN_GENERATED_KEYS == autoGeneratedKeys) {
returnGeneratedKeys = true;
}
return execute0(sql, (actualSQL, statement) -> statement.execute(actualSQL, autoGeneratedKeys), (statement, actualSQL) -> statement.execute(actualSQL, autoGeneratedKeys));
return execute0(sql, (actualSQL, statement) -> statement.execute(actualSQL, autoGeneratedKeys));
// CHECKSTYLE:OFF
} catch (final SQLException ex) {
// CHECKSTYLE:ON
Expand All @@ -256,7 +254,7 @@ public boolean execute(final String sql, final int autoGeneratedKeys) throws SQL
public boolean execute(final String sql, final int[] columnIndexes) throws SQLException {
try {
returnGeneratedKeys = true;
return execute0(sql, (actualSQL, statement) -> statement.execute(actualSQL, columnIndexes), (statement, actualSQL) -> statement.execute(actualSQL, columnIndexes));
return execute0(sql, (actualSQL, statement) -> statement.execute(actualSQL, columnIndexes));
// CHECKSTYLE:OFF
} catch (final SQLException ex) {
// CHECKSTYLE:ON
Expand All @@ -269,7 +267,7 @@ public boolean execute(final String sql, final int[] columnIndexes) throws SQLEx
public boolean execute(final String sql, final String[] columnNames) throws SQLException {
try {
returnGeneratedKeys = true;
return execute0(sql, (actualSQL, statement) -> statement.execute(actualSQL, columnNames), (statement, actualSQL) -> statement.execute(actualSQL, columnNames));
return execute0(sql, (actualSQL, statement) -> statement.execute(actualSQL, columnNames));
// CHECKSTYLE:OFF
} catch (final SQLException ex) {
// CHECKSTYLE:ON
Expand All @@ -278,7 +276,7 @@ public boolean execute(final String sql, final String[] columnNames) throws SQLE
}
}

private boolean execute0(final String sql, final ExecuteCallback executeCallback, final TrafficExecutorCallback<Boolean> trafficCallback) throws SQLException {
private boolean execute0(final String sql, final ExecuteCallback executeCallback) throws SQLException {
currentResultSet = null;
QueryContext queryContext = createQueryContext(sql);
handleAutoCommit(queryContext.getSqlStatementContext().getSqlStatement());
Expand All @@ -287,7 +285,7 @@ private boolean execute0(final String sql, final ExecuteCallback executeCallback
ShardingSphereDatabase database = metaData.getDatabase(databaseName);
sqlStatementContext = queryContext.getSqlStatementContext();
clearStatements();
boolean result = executor.executeAdvance(database, queryContext, createDriverExecutionPrepareEngine(database), trafficCallback,
boolean result = executor.executeAdvance(database, queryContext, createDriverExecutionPrepareEngine(database),
executeCallback, (StatementReplayCallback<Statement>) (statements, parameterSets) -> replay(statements));
statements.addAll(executor.getStatements());
return result;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ public <T> T execute(final String processId, final String databaseName, final St
JDBCExecutionUnit executionUnit = createTrafficExecutionUnit(processId, databaseName, trafficInstanceId, queryContext, prepareEngine);
SQLUnit sqlUnit = executionUnit.getExecutionUnit().getSqlUnit();
cacheStatement(sqlUnit.getParameters(), executionUnit.getStorageResource());
T result = callback.execute(statement, sqlUnit.getSql());
T result = callback.execute(sqlUnit.getSql(), statement);
resultSet = statement.getResultSet();
return result;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ public interface TrafficExecutorCallback<T> {
/**
* Execute.
*
* @param statement statement
* @param sql SQL
* @param statement statement
* @return execution result
* @throws SQLException SQL exception
*/
T execute(Statement statement, String sql) throws SQLException;
T execute(String sql, Statement statement) throws SQLException;
}