diff --git a/jdbc/src/main/java/org/apache/shardingsphere/driver/executor/DriverExecutor.java b/jdbc/src/main/java/org/apache/shardingsphere/driver/executor/DriverExecutor.java index 8c19be5880d20..04a7a4e0108a1 100644 --- a/jdbc/src/main/java/org/apache/shardingsphere/driver/executor/DriverExecutor.java +++ b/jdbc/src/main/java/org/apache/shardingsphere/driver/executor/DriverExecutor.java @@ -24,6 +24,7 @@ import org.apache.shardingsphere.driver.jdbc.core.connection.ShardingSphereConnection; import org.apache.shardingsphere.driver.jdbc.core.resultset.ShardingSphereResultSet; import org.apache.shardingsphere.driver.jdbc.core.resultset.ShardingSphereResultSetUtils; +import org.apache.shardingsphere.driver.jdbc.core.statement.StatementReplayCallback; import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext; import org.apache.shardingsphere.infra.binder.context.statement.dml.SelectStatementContext; import org.apache.shardingsphere.infra.config.props.ConfigurationPropertyKey; @@ -95,10 +96,10 @@ public final class DriverExecutor implements AutoCloseable { private final KernelProcessor kernelProcessor; @Getter - private final Collection statements = new LinkedList<>(); + private final List statements = new ArrayList<>(); @Getter - private final Collection> parameterSets = new LinkedList<>(); + private final List> parameterSets = new ArrayList<>(); public DriverExecutor(final ShardingSphereConnection connection) { this.connection = connection; @@ -122,12 +123,13 @@ public DriverExecutor(final ShardingSphereConnection connection) { * @param prepareEngine prepare engine * @param statement statement * @param columnLabelAndIndexMap column label and index map + * @param statementReplayCallback statement replay callback * @return result set * @throws SQLException SQL exception */ public ResultSet executeAdvanceQuery(final ShardingSphereMetaData metaData, final ShardingSphereDatabase database, final QueryContext queryContext, - final DriverExecutionPrepareEngine prepareEngine, final Statement statement, - final Map columnLabelAndIndexMap) throws SQLException { + final DriverExecutionPrepareEngine prepareEngine, final Statement statement, + final Map columnLabelAndIndexMap, final StatementReplayCallback statementReplayCallback) throws SQLException { Optional trafficInstanceId = connection.getTrafficInstanceId(metaData.getGlobalRuleMetaData().getSingleRule(TrafficRule.class), queryContext); if (trafficInstanceId.isPresent()) { return trafficExecutor.execute(connection.getProcessId(), database.getName(), trafficInstanceId.get(), queryContext, prepareEngine, getTrafficExecutorCallback(prepareEngine)); @@ -136,7 +138,7 @@ public ResultSet executeAdvanceQuery(final ShardingSphereMetaData metaData, fina return sqlFederationEngine.executeQuery( prepareEngine, getExecuteQueryCallback(database, queryContext, prepareEngine.getType()), new SQLFederationContext(false, queryContext, metaData, connection.getProcessId())); } - return doExecuteQuery(metaData, database, queryContext, prepareEngine, statement, columnLabelAndIndexMap); + return doExecuteQuery(metaData, database, queryContext, prepareEngine, statement, columnLabelAndIndexMap, statementReplayCallback); } private TrafficExecutorCallback getTrafficExecutorCallback(final DriverExecutionPrepareEngine prepareEngine) { @@ -153,8 +155,8 @@ private ExecuteQueryCallback getExecuteQueryCallback(final ShardingSphereDatabas private ShardingSphereResultSet doExecuteQuery(final ShardingSphereMetaData metaData, final ShardingSphereDatabase database, final QueryContext queryContext, final DriverExecutionPrepareEngine prepareEngine, final Statement statement, - final Map columnLabelAndIndexMap) throws SQLException { - List queryResults = executeQuery0(metaData, database, queryContext, prepareEngine); + final Map columnLabelAndIndexMap, final StatementReplayCallback statementReplayCallback) throws SQLException { + List queryResults = executeQuery0(metaData, database, queryContext, prepareEngine, statementReplayCallback); MergedResult mergedResult = mergeQuery(metaData, database, queryResults, queryContext.getSqlStatementContext()); boolean selectContainsEnhancedTable = queryContext.getSqlStatementContext() instanceof SelectStatementContext && ((SelectStatementContext) queryContext.getSqlStatementContext()).isContainsEnhancedTable(); @@ -166,7 +168,8 @@ private ShardingSphereResultSet doExecuteQuery(final ShardingSphereMetaData meta } private List executeQuery0(final ShardingSphereMetaData metaData, final ShardingSphereDatabase database, final QueryContext queryContext, - final DriverExecutionPrepareEngine prepareEngine) throws SQLException { + final DriverExecutionPrepareEngine prepareEngine, + final StatementReplayCallback statementReplayCallback) throws SQLException { ExecutionContext executionContext = createExecutionContext(metaData, database, queryContext); if (hasRawExecutionRule(database)) { return rawExecutor.execute(createRawExecutionGroupContext(metaData, database, executionContext), @@ -180,6 +183,7 @@ private List executeQuery0(final ShardingSphereMetaData metaData, f parameterSets.addAll(getParameterSets(each)); } } + statementReplayCallback.replay(statements, parameterSets); return regularExecutor.executeQuery(executionGroupContext, queryContext, getExecuteQueryCallback(database, queryContext, prepareEngine.getType())); } @@ -305,6 +309,14 @@ public Optional getAdvancedResultSet() { } } + /** + * Clear. + */ + public void clear() { + statements.clear(); + parameterSets.clear(); + } + @Override public void close() throws SQLException { sqlFederationEngine.close(); diff --git a/jdbc/src/main/java/org/apache/shardingsphere/driver/jdbc/core/statement/ShardingSpherePreparedStatement.java b/jdbc/src/main/java/org/apache/shardingsphere/driver/jdbc/core/statement/ShardingSpherePreparedStatement.java index 0ff141692a86e..5f5d9d8443fa0 100644 --- a/jdbc/src/main/java/org/apache/shardingsphere/driver/jdbc/core/statement/ShardingSpherePreparedStatement.java +++ b/jdbc/src/main/java/org/apache/shardingsphere/driver/jdbc/core/statement/ShardingSpherePreparedStatement.java @@ -223,7 +223,8 @@ public ResultSet executeQuery() throws SQLException { handleAutoCommit(queryContext.getSqlStatementContext().getSqlStatement()); ShardingSphereDatabase database = metaDataContexts.getMetaData().getDatabase(databaseName); findGeneratedKey().ifPresent(optional -> generatedValues.addAll(optional.getGeneratedValues())); - currentResultSet = executor.executeAdvanceQuery(metaDataContexts.getMetaData(), database, queryContext, createDriverExecutionPrepareEngine(database), this, columnLabelAndIndexMap); + currentResultSet = executor.executeAdvanceQuery(metaDataContexts.getMetaData(), database, queryContext, createDriverExecutionPrepareEngine(database), this, columnLabelAndIndexMap, + (StatementReplayCallback) this::replay); if (currentResultSet instanceof ShardingSphereResultSet) { columnLabelAndIndexMap = ((ShardingSphereResultSet) currentResultSet).getColumnLabelAndIndexMap(); } @@ -231,7 +232,6 @@ public ResultSet executeQuery() throws SQLException { statements.add((PreparedStatement) each); } parameterSets.addAll(executor.getParameterSets()); - replay(); return currentResultSet; // CHECKSTYLE:OFF } catch (final RuntimeException ex) { @@ -257,7 +257,7 @@ private void handleAutoCommit(final SQLStatement sqlStatement) throws SQLExcepti private void resetParameters() throws SQLException { parameterSets.clear(); parameterSets.add(getParameters()); - replaySetParameter(); + replaySetParameter(statements, parameterSets); } private DriverExecutionPrepareEngine createDriverExecutionPrepareEngine(final ShardingSphereDatabase database) { @@ -277,7 +277,6 @@ public int executeUpdate() throws SQLException { QueryContext queryContext = createQueryContext(); handleAutoCommit(queryContext.getSqlStatementContext().getSqlStatement()); ShardingSphereDatabase database = metaDataContexts.getMetaData().getDatabase(databaseName); - findGeneratedKey().ifPresent(optional -> generatedValues.addAll(optional.getGeneratedValues())); Optional updatedCount = executor.executeAdvanceUpdate(metaDataContexts.getMetaData(), database, queryContext, createDriverExecutionPrepareEngine(database), (statement, sql) -> ((PreparedStatement) statement).executeUpdate()); if (updatedCount.isPresent()) { @@ -343,7 +342,6 @@ public boolean execute() throws SQLException { QueryContext queryContext = createQueryContext(); handleAutoCommit(queryContext.getSqlStatementContext().getSqlStatement()); ShardingSphereDatabase database = metaDataContexts.getMetaData().getDatabase(databaseName); - findGeneratedKey().ifPresent(optional -> generatedValues.addAll(optional.getGeneratedValues())); Optional advancedResult = executor.executeAdvance( metaDataContexts.getMetaData(), database, queryContext, createDriverExecutionPrepareEngine(database), (statement, sql) -> ((PreparedStatement) statement).execute()); if (advancedResult.isPresent()) { @@ -456,8 +454,10 @@ private ExecutionContext createExecutionContext(final QueryContext queryContext) RuleMetaData globalRuleMetaData = metaDataContexts.getMetaData().getGlobalRuleMetaData(); ShardingSphereDatabase currentDatabase = metaDataContexts.getMetaData().getDatabase(databaseName); SQLAuditEngine.audit(queryContext.getSqlStatementContext(), queryContext.getParameters(), globalRuleMetaData, currentDatabase, null, queryContext.getHintValueContext()); - return kernelProcessor.generateExecutionContext( + ExecutionContext result = kernelProcessor.generateExecutionContext( queryContext, currentDatabase, globalRuleMetaData, metaDataContexts.getMetaData().getProps(), connection.getDatabaseConnectionManager().getConnectionContext()); + findGeneratedKey().ifPresent(optional -> generatedValues.addAll(optional.getGeneratedValues())); + return result; } private ExecutionContext createExecutionContext(final QueryContext queryContext, final String trafficInstanceId) { @@ -492,17 +492,17 @@ private void cacheStatements(final Collection> parameterSets.add(eachInput.getExecutionUnit().getSqlUnit().getParameters()); }); } - replay(); + replay(statements, parameterSets); } - private void replay() throws SQLException { - replaySetParameter(); + private void replay(final List statements, final List> parameterSets) throws SQLException { + replaySetParameter(statements, parameterSets); for (Statement each : statements) { getMethodInvocationRecorder().replay(each); } } - private void replaySetParameter() throws SQLException { + private void replaySetParameter(final List statements, final List> parameterSets) throws SQLException { for (int i = 0; i < statements.size(); i++) { replaySetParameter(statements.get(i), parameterSets.get(i)); } @@ -513,6 +513,7 @@ private void clearPrevious() { statements.clear(); parameterSets.clear(); generatedValues.clear(); + executor.clear(); } private Optional findGeneratedKey() { @@ -549,9 +550,6 @@ public void addBatch() { Optional trafficInstanceId = connection.getTrafficInstanceId(trafficRule, queryContext); executionContext = trafficInstanceId .map(optional -> createExecutionContext(queryContext, optional)).orElseGet(() -> createExecutionContext(queryContext)); - if (!trafficInstanceId.isPresent()) { - findGeneratedKey().ifPresent(optional -> generatedValues.addAll(optional.getGeneratedValues())); - } batchPreparedStatementExecutor.addBatchForExecutionUnits(executionContext.getExecutionUnits()); } finally { currentResultSet = null; diff --git a/jdbc/src/main/java/org/apache/shardingsphere/driver/jdbc/core/statement/ShardingSphereStatement.java b/jdbc/src/main/java/org/apache/shardingsphere/driver/jdbc/core/statement/ShardingSphereStatement.java index 5565ddee90caf..bda16fec1d48f 100644 --- a/jdbc/src/main/java/org/apache/shardingsphere/driver/jdbc/core/statement/ShardingSphereStatement.java +++ b/jdbc/src/main/java/org/apache/shardingsphere/driver/jdbc/core/statement/ShardingSphereStatement.java @@ -151,9 +151,9 @@ public ResultSet executeQuery(final String sql) throws SQLException { connection.getDatabaseConnectionManager().getConnectionContext().setCurrentDatabase(databaseName); ShardingSphereDatabase database = metaDataContexts.getMetaData().getDatabase(databaseName); sqlStatementContext = queryContext.getSqlStatementContext(); - currentResultSet = executor.executeAdvanceQuery(metaDataContexts.getMetaData(), database, queryContext, createDriverExecutionPrepareEngine(database), this, null); + currentResultSet = executor.executeAdvanceQuery(metaDataContexts.getMetaData(), database, queryContext, createDriverExecutionPrepareEngine(database), this, null, + (StatementReplayCallback) (statements, parameterSets) -> replay(statements)); statements.addAll(executor.getStatements()); - replay(); return currentResultSet; // CHECKSTYLE:OFF } catch (final RuntimeException ex) { @@ -441,7 +441,7 @@ private void cacheStatements(final Collection> for (ExecutionGroup each : executionGroups) { statements.addAll(each.getInputs().stream().map(JDBCExecutionUnit::getStorageResource).collect(Collectors.toList())); } - replay(); + replay(statements); } private JDBCExecutorCallback createExecuteCallback(final ExecuteCallback executeCallback, final SQLStatement sqlStatement) { @@ -461,7 +461,7 @@ protected Optional getSaneResult(final SQLStatement sqlStatement1, fina }; } - private void replay() throws SQLException { + private void replay(final List statements) throws SQLException { for (Statement each : statements) { getMethodInvocationRecorder().replay(each); } diff --git a/jdbc/src/main/java/org/apache/shardingsphere/driver/jdbc/core/statement/StatementReplayCallback.java b/jdbc/src/main/java/org/apache/shardingsphere/driver/jdbc/core/statement/StatementReplayCallback.java new file mode 100644 index 0000000000000..0ad5b4456c05b --- /dev/null +++ b/jdbc/src/main/java/org/apache/shardingsphere/driver/jdbc/core/statement/StatementReplayCallback.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.shardingsphere.driver.jdbc.core.statement; + +import java.sql.SQLException; +import java.sql.Statement; +import java.util.List; + +/** + * Statement replay callback. + * + * @param type of statement + */ +public interface StatementReplayCallback { + + /** + * Replay statement. + * + * @param statements statements + * @param parameterSets parameter sets + * @throws SQLException SQL exception + */ + void replay(List statements, List> parameterSets) throws SQLException; +}