Skip to content

Commit

Permalink
Refactor DriverExecutor
Browse files Browse the repository at this point in the history
  • Loading branch information
terrymanu committed May 29, 2024
1 parent acdda07 commit 52c0a2d
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.shardingsphere.driver.jdbc.core.connection.ShardingSphereConnection;
import org.apache.shardingsphere.driver.jdbc.core.resultset.ShardingSphereResultSet;
import org.apache.shardingsphere.driver.jdbc.core.resultset.ShardingSphereResultSetUtils;
import org.apache.shardingsphere.driver.jdbc.core.statement.StatementReplayCallback;
import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.binder.context.statement.dml.SelectStatementContext;
import org.apache.shardingsphere.infra.config.props.ConfigurationPropertyKey;
Expand Down Expand Up @@ -95,10 +96,10 @@ public final class DriverExecutor implements AutoCloseable {
private final KernelProcessor kernelProcessor;

@Getter
private final Collection<Statement> statements = new LinkedList<>();
private final List<Statement> statements = new ArrayList<>();

@Getter
private final Collection<List<Object>> parameterSets = new LinkedList<>();
private final List<List<Object>> parameterSets = new ArrayList<>();

public DriverExecutor(final ShardingSphereConnection connection) {
this.connection = connection;
Expand All @@ -122,12 +123,13 @@ public DriverExecutor(final ShardingSphereConnection connection) {
* @param prepareEngine prepare engine
* @param statement statement
* @param columnLabelAndIndexMap column label and index map
* @param statementReplayCallback statement replay callback
* @return result set
* @throws SQLException SQL exception
*/
public ResultSet executeAdvanceQuery(final ShardingSphereMetaData metaData, final ShardingSphereDatabase database, final QueryContext queryContext,
final DriverExecutionPrepareEngine<JDBCExecutionUnit, Connection> prepareEngine, final Statement statement,
final Map<String, Integer> columnLabelAndIndexMap) throws SQLException {
final DriverExecutionPrepareEngine<JDBCExecutionUnit, Connection> prepareEngine, final Statement statement,
final Map<String, Integer> columnLabelAndIndexMap, final StatementReplayCallback statementReplayCallback) throws SQLException {
Optional<String> trafficInstanceId = connection.getTrafficInstanceId(metaData.getGlobalRuleMetaData().getSingleRule(TrafficRule.class), queryContext);
if (trafficInstanceId.isPresent()) {
return trafficExecutor.execute(connection.getProcessId(), database.getName(), trafficInstanceId.get(), queryContext, prepareEngine, getTrafficExecutorCallback(prepareEngine));
Expand All @@ -136,7 +138,7 @@ public ResultSet executeAdvanceQuery(final ShardingSphereMetaData metaData, fina
return sqlFederationEngine.executeQuery(
prepareEngine, getExecuteQueryCallback(database, queryContext, prepareEngine.getType()), new SQLFederationContext(false, queryContext, metaData, connection.getProcessId()));
}
return doExecuteQuery(metaData, database, queryContext, prepareEngine, statement, columnLabelAndIndexMap);
return doExecuteQuery(metaData, database, queryContext, prepareEngine, statement, columnLabelAndIndexMap, statementReplayCallback);
}

private TrafficExecutorCallback<ResultSet> getTrafficExecutorCallback(final DriverExecutionPrepareEngine<JDBCExecutionUnit, Connection> prepareEngine) {
Expand All @@ -153,8 +155,8 @@ private ExecuteQueryCallback getExecuteQueryCallback(final ShardingSphereDatabas

private ShardingSphereResultSet doExecuteQuery(final ShardingSphereMetaData metaData, final ShardingSphereDatabase database, final QueryContext queryContext,
final DriverExecutionPrepareEngine<JDBCExecutionUnit, Connection> prepareEngine, final Statement statement,
final Map<String, Integer> columnLabelAndIndexMap) throws SQLException {
List<QueryResult> queryResults = executeQuery0(metaData, database, queryContext, prepareEngine);
final Map<String, Integer> columnLabelAndIndexMap, final StatementReplayCallback statementReplayCallback) throws SQLException {
List<QueryResult> queryResults = executeQuery0(metaData, database, queryContext, prepareEngine, statementReplayCallback);
MergedResult mergedResult = mergeQuery(metaData, database, queryResults, queryContext.getSqlStatementContext());
boolean selectContainsEnhancedTable = queryContext.getSqlStatementContext() instanceof SelectStatementContext
&& ((SelectStatementContext) queryContext.getSqlStatementContext()).isContainsEnhancedTable();
Expand All @@ -166,7 +168,8 @@ private ShardingSphereResultSet doExecuteQuery(final ShardingSphereMetaData meta
}

private List<QueryResult> executeQuery0(final ShardingSphereMetaData metaData, final ShardingSphereDatabase database, final QueryContext queryContext,
final DriverExecutionPrepareEngine<JDBCExecutionUnit, Connection> prepareEngine) throws SQLException {
final DriverExecutionPrepareEngine<JDBCExecutionUnit, Connection> prepareEngine,
final StatementReplayCallback statementReplayCallback) throws SQLException {
ExecutionContext executionContext = createExecutionContext(metaData, database, queryContext);
if (hasRawExecutionRule(database)) {
return rawExecutor.execute(createRawExecutionGroupContext(metaData, database, executionContext),
Expand All @@ -180,6 +183,7 @@ private List<QueryResult> executeQuery0(final ShardingSphereMetaData metaData, f
parameterSets.addAll(getParameterSets(each));
}
}
statementReplayCallback.replay(statements, parameterSets);
return regularExecutor.executeQuery(executionGroupContext, queryContext, getExecuteQueryCallback(database, queryContext, prepareEngine.getType()));
}

Expand Down Expand Up @@ -305,6 +309,14 @@ public Optional<ResultSet> getAdvancedResultSet() {
}
}

/**
* Clear.
*/
public void clear() {
statements.clear();
parameterSets.clear();
}

@Override
public void close() throws SQLException {
sqlFederationEngine.close();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,15 +223,15 @@ public ResultSet executeQuery() throws SQLException {
handleAutoCommit(queryContext.getSqlStatementContext().getSqlStatement());
ShardingSphereDatabase database = metaDataContexts.getMetaData().getDatabase(databaseName);
findGeneratedKey().ifPresent(optional -> generatedValues.addAll(optional.getGeneratedValues()));
currentResultSet = executor.executeAdvanceQuery(metaDataContexts.getMetaData(), database, queryContext, createDriverExecutionPrepareEngine(database), this, columnLabelAndIndexMap);
currentResultSet = executor.executeAdvanceQuery(metaDataContexts.getMetaData(), database, queryContext, createDriverExecutionPrepareEngine(database), this, columnLabelAndIndexMap,
(StatementReplayCallback<PreparedStatement>) this::replay);
if (currentResultSet instanceof ShardingSphereResultSet) {
columnLabelAndIndexMap = ((ShardingSphereResultSet) currentResultSet).getColumnLabelAndIndexMap();
}
for (Statement each : executor.getStatements()) {
statements.add((PreparedStatement) each);
}
parameterSets.addAll(executor.getParameterSets());
replay();
return currentResultSet;
// CHECKSTYLE:OFF
} catch (final RuntimeException ex) {
Expand All @@ -257,7 +257,7 @@ private void handleAutoCommit(final SQLStatement sqlStatement) throws SQLExcepti
private void resetParameters() throws SQLException {
parameterSets.clear();
parameterSets.add(getParameters());
replaySetParameter();
replaySetParameter(statements, parameterSets);
}

private DriverExecutionPrepareEngine<JDBCExecutionUnit, Connection> createDriverExecutionPrepareEngine(final ShardingSphereDatabase database) {
Expand All @@ -277,7 +277,6 @@ public int executeUpdate() throws SQLException {
QueryContext queryContext = createQueryContext();
handleAutoCommit(queryContext.getSqlStatementContext().getSqlStatement());
ShardingSphereDatabase database = metaDataContexts.getMetaData().getDatabase(databaseName);
findGeneratedKey().ifPresent(optional -> generatedValues.addAll(optional.getGeneratedValues()));
Optional<Integer> updatedCount = executor.executeAdvanceUpdate(metaDataContexts.getMetaData(), database, queryContext, createDriverExecutionPrepareEngine(database),
(statement, sql) -> ((PreparedStatement) statement).executeUpdate());
if (updatedCount.isPresent()) {
Expand Down Expand Up @@ -343,7 +342,6 @@ public boolean execute() throws SQLException {
QueryContext queryContext = createQueryContext();
handleAutoCommit(queryContext.getSqlStatementContext().getSqlStatement());
ShardingSphereDatabase database = metaDataContexts.getMetaData().getDatabase(databaseName);
findGeneratedKey().ifPresent(optional -> generatedValues.addAll(optional.getGeneratedValues()));
Optional<Boolean> advancedResult = executor.executeAdvance(
metaDataContexts.getMetaData(), database, queryContext, createDriverExecutionPrepareEngine(database), (statement, sql) -> ((PreparedStatement) statement).execute());
if (advancedResult.isPresent()) {
Expand Down Expand Up @@ -456,8 +454,10 @@ private ExecutionContext createExecutionContext(final QueryContext queryContext)
RuleMetaData globalRuleMetaData = metaDataContexts.getMetaData().getGlobalRuleMetaData();
ShardingSphereDatabase currentDatabase = metaDataContexts.getMetaData().getDatabase(databaseName);
SQLAuditEngine.audit(queryContext.getSqlStatementContext(), queryContext.getParameters(), globalRuleMetaData, currentDatabase, null, queryContext.getHintValueContext());
return kernelProcessor.generateExecutionContext(
ExecutionContext result = kernelProcessor.generateExecutionContext(
queryContext, currentDatabase, globalRuleMetaData, metaDataContexts.getMetaData().getProps(), connection.getDatabaseConnectionManager().getConnectionContext());
findGeneratedKey().ifPresent(optional -> generatedValues.addAll(optional.getGeneratedValues()));
return result;
}

private ExecutionContext createExecutionContext(final QueryContext queryContext, final String trafficInstanceId) {
Expand Down Expand Up @@ -492,17 +492,17 @@ private void cacheStatements(final Collection<ExecutionGroup<JDBCExecutionUnit>>
parameterSets.add(eachInput.getExecutionUnit().getSqlUnit().getParameters());
});
}
replay();
replay(statements, parameterSets);
}

private void replay() throws SQLException {
replaySetParameter();
private void replay(final List<PreparedStatement> statements, final List<List<Object>> parameterSets) throws SQLException {
replaySetParameter(statements, parameterSets);
for (Statement each : statements) {
getMethodInvocationRecorder().replay(each);
}
}

private void replaySetParameter() throws SQLException {
private void replaySetParameter(final List<PreparedStatement> statements, final List<List<Object>> parameterSets) throws SQLException {
for (int i = 0; i < statements.size(); i++) {
replaySetParameter(statements.get(i), parameterSets.get(i));
}
Expand All @@ -513,6 +513,7 @@ private void clearPrevious() {
statements.clear();
parameterSets.clear();
generatedValues.clear();
executor.clear();
}

private Optional<GeneratedKeyContext> findGeneratedKey() {
Expand Down Expand Up @@ -549,9 +550,6 @@ public void addBatch() {
Optional<String> trafficInstanceId = connection.getTrafficInstanceId(trafficRule, queryContext);
executionContext = trafficInstanceId
.map(optional -> createExecutionContext(queryContext, optional)).orElseGet(() -> createExecutionContext(queryContext));
if (!trafficInstanceId.isPresent()) {
findGeneratedKey().ifPresent(optional -> generatedValues.addAll(optional.getGeneratedValues()));
}
batchPreparedStatementExecutor.addBatchForExecutionUnits(executionContext.getExecutionUnits());
} finally {
currentResultSet = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,9 @@ public ResultSet executeQuery(final String sql) throws SQLException {
connection.getDatabaseConnectionManager().getConnectionContext().setCurrentDatabase(databaseName);
ShardingSphereDatabase database = metaDataContexts.getMetaData().getDatabase(databaseName);
sqlStatementContext = queryContext.getSqlStatementContext();
currentResultSet = executor.executeAdvanceQuery(metaDataContexts.getMetaData(), database, queryContext, createDriverExecutionPrepareEngine(database), this, null);
currentResultSet = executor.executeAdvanceQuery(metaDataContexts.getMetaData(), database, queryContext, createDriverExecutionPrepareEngine(database), this, null,
(StatementReplayCallback<Statement>) (statements, parameterSets) -> replay(statements));
statements.addAll(executor.getStatements());
replay();
return currentResultSet;
// CHECKSTYLE:OFF
} catch (final RuntimeException ex) {
Expand Down Expand Up @@ -441,7 +441,7 @@ private void cacheStatements(final Collection<ExecutionGroup<JDBCExecutionUnit>>
for (ExecutionGroup<JDBCExecutionUnit> each : executionGroups) {
statements.addAll(each.getInputs().stream().map(JDBCExecutionUnit::getStorageResource).collect(Collectors.toList()));
}
replay();
replay(statements);
}

private JDBCExecutorCallback<Boolean> createExecuteCallback(final ExecuteCallback executeCallback, final SQLStatement sqlStatement) {
Expand All @@ -461,7 +461,7 @@ protected Optional<Boolean> getSaneResult(final SQLStatement sqlStatement1, fina
};
}

private void replay() throws SQLException {
private void replay(final List<Statement> statements) throws SQLException {
for (Statement each : statements) {
getMethodInvocationRecorder().replay(each);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.shardingsphere.driver.jdbc.core.statement;

import java.sql.SQLException;
import java.sql.Statement;
import java.util.List;

/**
* Statement replay callback.
*
* @param <T> type of statement
*/
public interface StatementReplayCallback<T extends Statement> {

/**
* Replay statement.
*
* @param statements statements
* @param parameterSets parameter sets
* @throws SQLException SQL exception
*/
void replay(List<T> statements, List<List<Object>> parameterSets) throws SQLException;
}

0 comments on commit 52c0a2d

Please sign in to comment.