diff --git a/jdbc/src/main/java/org/apache/shardingsphere/driver/jdbc/core/statement/ShardingSpherePreparedStatement.java b/jdbc/src/main/java/org/apache/shardingsphere/driver/jdbc/core/statement/ShardingSpherePreparedStatement.java index 227cfc2fbe959..b7dc4c334e54e 100644 --- a/jdbc/src/main/java/org/apache/shardingsphere/driver/jdbc/core/statement/ShardingSpherePreparedStatement.java +++ b/jdbc/src/main/java/org/apache/shardingsphere/driver/jdbc/core/statement/ShardingSpherePreparedStatement.java @@ -41,7 +41,6 @@ import org.apache.shardingsphere.infra.connection.kernel.KernelProcessor; import org.apache.shardingsphere.infra.database.core.type.DatabaseType; import org.apache.shardingsphere.infra.database.mysql.type.MySQLDatabaseType; -import org.apache.shardingsphere.infra.exception.core.ShardingSpherePreconditions; import org.apache.shardingsphere.infra.exception.dialect.SQLExceptionTransformEngine; import org.apache.shardingsphere.infra.exception.kernel.syntax.EmptySQLException; import org.apache.shardingsphere.infra.executor.audit.SQLAuditEngine; @@ -227,14 +226,15 @@ public ResultSet executeQuery() throws SQLException { } clearPrevious(); QueryContext queryContext = createQueryContext(); - handleAutoCommit(queryContext); + handleAutoCommit(queryContext.getSqlStatementContext().getSqlStatement()); + ShardingSphereDatabase database = metaDataContexts.getMetaData().getDatabase(databaseName); String trafficInstanceId = getInstanceIdAndSet(queryContext).orElse(null); if (null != trafficInstanceId) { - JDBCExecutionUnit executionUnit = createTrafficExecutionUnit(trafficInstanceId, queryContext); - currentResultSet = executor.getTrafficExecutor().execute(executionUnit, (statement, sql) -> ((PreparedStatement) statement).executeQuery()); + currentResultSet = executor.getTrafficExecutor().execute(connection.getProcessId(), databaseName, + trafficInstanceId, queryContext, createDriverExecutionPrepareEngine(database), (statement, sql) -> ((PreparedStatement) statement).executeQuery()); return currentResultSet; } - if (decide(queryContext, metaDataContexts.getMetaData().getDatabase(databaseName), metaDataContexts.getMetaData().getGlobalRuleMetaData())) { + if (decide(queryContext, database, metaDataContexts.getMetaData().getGlobalRuleMetaData())) { currentResultSet = executeFederationQuery(queryContext); return currentResultSet; } @@ -266,19 +266,19 @@ private boolean decide(final QueryContext queryContext, final ShardingSphereData return executor.getSqlFederationEngine().decide(queryContext.getSqlStatementContext(), queryContext.getParameters(), database, globalRuleMetaData); } - private void handleAutoCommit(final QueryContext queryContext) throws SQLException { - if (AutoCommitUtils.needOpenTransaction(queryContext.getSqlStatementContext().getSqlStatement())) { + private void handleAutoCommit(final SQLStatement sqlStatement) throws SQLException { + if (AutoCommitUtils.needOpenTransaction(sqlStatement)) { connection.handleAutoCommit(); } } private JDBCExecutionUnit createTrafficExecutionUnit(final String trafficInstanceId, final QueryContext queryContext) throws SQLException { - DriverExecutionPrepareEngine prepareEngine = createDriverExecutionPrepareEngine(); + ShardingSphereDatabase database = metaDataContexts.getMetaData().getDatabase(databaseName); + DriverExecutionPrepareEngine prepareEngine = createDriverExecutionPrepareEngine(database); ExecutionUnit executionUnit = new ExecutionUnit(trafficInstanceId, new SQLUnit(queryContext.getSql(), queryContext.getParameters())); ExecutionGroupContext context = prepareEngine.prepare(new RouteContext(), Collections.singleton(executionUnit), new ExecutionGroupReportContext(connection.getProcessId(), databaseName, new Grantee("", ""))); - ShardingSpherePreconditions.checkState(!context.getInputGroups().isEmpty() && !context.getInputGroups().iterator().next().getInputs().isEmpty(), EmptyTrafficExecutionUnitException::new); - return context.getInputGroups().iterator().next().getInputs().iterator().next(); + return context.getInputGroups().stream().flatMap(each -> each.getInputs().stream()).findFirst().orElseThrow(EmptyTrafficExecutionUnitException::new); } private Optional getInstanceIdAndSet(final QueryContext queryContext) { @@ -319,17 +319,17 @@ private List executeQuery0(final ExecutionContext executionContext) } private ResultSet executeFederationQuery(final QueryContext queryContext) { - PreparedStatementExecuteQueryCallback callback = new PreparedStatementExecuteQueryCallback(metaDataContexts.getMetaData().getDatabase(databaseName).getProtocolType(), - metaDataContexts.getMetaData().getDatabase(databaseName).getResourceMetaData(), sqlStatement, SQLExecutorExceptionHandler.isExceptionThrown()); + ShardingSphereDatabase database = metaDataContexts.getMetaData().getDatabase(databaseName); + PreparedStatementExecuteQueryCallback callback = new PreparedStatementExecuteQueryCallback(database.getProtocolType(), + database.getResourceMetaData(), sqlStatement, SQLExecutorExceptionHandler.isExceptionThrown()); SQLFederationContext context = new SQLFederationContext(false, queryContext, metaDataContexts.getMetaData(), connection.getProcessId()); - return executor.getSqlFederationEngine().executeQuery(createDriverExecutionPrepareEngine(), callback, context); + return executor.getSqlFederationEngine().executeQuery(createDriverExecutionPrepareEngine(database), callback, context); } - private DriverExecutionPrepareEngine createDriverExecutionPrepareEngine() { + private DriverExecutionPrepareEngine createDriverExecutionPrepareEngine(final ShardingSphereDatabase database) { int maxConnectionsSizePerQuery = metaDataContexts.getMetaData().getProps().getValue(ConfigurationPropertyKey.MAX_CONNECTIONS_SIZE_PER_QUERY); - return new DriverExecutionPrepareEngine<>(JDBCDriverType.PREPARED_STATEMENT, maxConnectionsSizePerQuery, connection.getDatabaseConnectionManager(), statementManager, - statementOption, metaDataContexts.getMetaData().getDatabase(databaseName).getRuleMetaData().getRules(), - metaDataContexts.getMetaData().getDatabase(databaseName).getResourceMetaData().getStorageUnits()); + return new DriverExecutionPrepareEngine<>(JDBCDriverType.PREPARED_STATEMENT, maxConnectionsSizePerQuery, connection.getDatabaseConnectionManager(), statementManager, statementOption, + database.getRuleMetaData().getRules(), database.getResourceMetaData().getStorageUnits()); } @Override @@ -341,11 +341,12 @@ public int executeUpdate() throws SQLException { } clearPrevious(); QueryContext queryContext = createQueryContext(); - handleAutoCommit(queryContext); + handleAutoCommit(queryContext.getSqlStatementContext().getSqlStatement()); String trafficInstanceId = getInstanceIdAndSet(queryContext).orElse(null); if (null != trafficInstanceId) { - JDBCExecutionUnit executionUnit = createTrafficExecutionUnit(trafficInstanceId, queryContext); - return executor.getTrafficExecutor().execute(executionUnit, (statement, sql) -> ((PreparedStatement) statement).executeUpdate()); + ShardingSphereDatabase database = metaDataContexts.getMetaData().getDatabase(databaseName); + return executor.getTrafficExecutor().execute(connection.getProcessId(), databaseName, + trafficInstanceId, queryContext, createDriverExecutionPrepareEngine(database), (statement, sql) -> ((PreparedStatement) statement).executeUpdate()); } executionContext = createExecutionContext(queryContext); if (hasRawExecutionRule()) { @@ -405,11 +406,12 @@ public boolean execute() throws SQLException { } clearPrevious(); QueryContext queryContext = createQueryContext(); - handleAutoCommit(queryContext); + handleAutoCommit(queryContext.getSqlStatementContext().getSqlStatement()); String trafficInstanceId = getInstanceIdAndSet(queryContext).orElse(null); if (null != trafficInstanceId) { - JDBCExecutionUnit executionUnit = createTrafficExecutionUnit(trafficInstanceId, queryContext); - boolean result = executor.getTrafficExecutor().execute(executionUnit, (statement, sql) -> ((PreparedStatement) statement).execute()); + ShardingSphereDatabase database = metaDataContexts.getMetaData().getDatabase(databaseName); + boolean result = executor.getTrafficExecutor().execute(connection.getProcessId(), databaseName, + trafficInstanceId, queryContext, createDriverExecutionPrepareEngine(database), (statement, sql) -> ((PreparedStatement) statement).execute()); currentResultSet = executor.getTrafficExecutor().getResultSet(); return result; } @@ -484,7 +486,8 @@ protected Optional getSaneResult(final SQLStatement sqlStatement, final } private ExecutionGroupContext createExecutionGroupContext(final ExecutionContext executionContext) throws SQLException { - DriverExecutionPrepareEngine prepareEngine = createDriverExecutionPrepareEngine(); + ShardingSphereDatabase database = metaDataContexts.getMetaData().getDatabase(databaseName); + DriverExecutionPrepareEngine prepareEngine = createDriverExecutionPrepareEngine(database); return prepareEngine.prepare(executionContext.getRouteContext(), executionContext.getExecutionUnits(), new ExecutionGroupReportContext(connection.getProcessId(), databaseName, new Grantee("", ""))); } diff --git a/jdbc/src/main/java/org/apache/shardingsphere/driver/jdbc/core/statement/ShardingSphereStatement.java b/jdbc/src/main/java/org/apache/shardingsphere/driver/jdbc/core/statement/ShardingSphereStatement.java index abe0ed5e0a559..d2b2e12dfe2c0 100644 --- a/jdbc/src/main/java/org/apache/shardingsphere/driver/jdbc/core/statement/ShardingSphereStatement.java +++ b/jdbc/src/main/java/org/apache/shardingsphere/driver/jdbc/core/statement/ShardingSphereStatement.java @@ -46,8 +46,6 @@ import org.apache.shardingsphere.infra.executor.kernel.model.ExecutionGroupContext; import org.apache.shardingsphere.infra.executor.kernel.model.ExecutionGroupReportContext; import org.apache.shardingsphere.infra.executor.sql.context.ExecutionContext; -import org.apache.shardingsphere.infra.executor.sql.context.ExecutionUnit; -import org.apache.shardingsphere.infra.executor.sql.context.SQLUnit; import org.apache.shardingsphere.infra.executor.sql.execute.engine.ConnectionMode; import org.apache.shardingsphere.infra.executor.sql.execute.engine.SQLExecutorExceptionHandler; import org.apache.shardingsphere.infra.executor.sql.execute.engine.driver.jdbc.JDBCExecutionUnit; @@ -70,7 +68,6 @@ import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase; import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData; import org.apache.shardingsphere.infra.metadata.user.Grantee; -import org.apache.shardingsphere.infra.route.context.RouteContext; import org.apache.shardingsphere.infra.rule.attribute.datanode.DataNodeRuleAttribute; import org.apache.shardingsphere.infra.rule.attribute.raw.RawExecutionRuleAttribute; import org.apache.shardingsphere.infra.session.query.QueryContext; @@ -80,7 +77,6 @@ import org.apache.shardingsphere.sql.parser.sql.common.statement.dal.DALStatement; import org.apache.shardingsphere.sqlfederation.executor.context.SQLFederationContext; import org.apache.shardingsphere.traffic.engine.TrafficEngine; -import org.apache.shardingsphere.traffic.exception.EmptyTrafficExecutionUnitException; import org.apache.shardingsphere.traffic.executor.TrafficExecutorCallback; import org.apache.shardingsphere.traffic.rule.TrafficRule; import org.apache.shardingsphere.transaction.util.AutoCommitUtils; @@ -158,15 +154,17 @@ public ResultSet executeQuery(final String sql) throws SQLException { ShardingSpherePreconditions.checkNotEmpty(sql, () -> new EmptySQLException().toSQLException()); try { QueryContext queryContext = createQueryContext(sql); - handleAutoCommit(queryContext); + handleAutoCommit(queryContext.getSqlStatementContext().getSqlStatement()); databaseName = queryContext.getDatabaseNameFromSQLStatement().orElse(connection.getDatabaseName()); connection.getDatabaseConnectionManager().getConnectionContext().setCurrentDatabase(databaseName); + ShardingSphereDatabase database = metaDataContexts.getMetaData().getDatabase(databaseName); String trafficInstanceId = getInstanceIdAndSet(queryContext).orElse(null); if (null != trafficInstanceId) { - currentResultSet = executor.getTrafficExecutor().execute(createTrafficExecutionUnit(trafficInstanceId, queryContext), Statement::executeQuery); + currentResultSet = executor.getTrafficExecutor().execute( + connection.getProcessId(), databaseName, trafficInstanceId, queryContext, createDriverExecutionPrepareEngine(database), Statement::executeQuery); return currentResultSet; } - if (decide(queryContext, metaDataContexts.getMetaData().getDatabase(databaseName), metaDataContexts.getMetaData().getGlobalRuleMetaData())) { + if (decide(queryContext, database, metaDataContexts.getMetaData().getGlobalRuleMetaData())) { currentResultSet = executeFederationQuery(queryContext); return currentResultSet; } @@ -226,18 +224,17 @@ private List executeQuery0(final ExecutionContext executionContext) } private ResultSet executeFederationQuery(final QueryContext queryContext) { - StatementExecuteQueryCallback callback = new StatementExecuteQueryCallback(metaDataContexts.getMetaData().getDatabase(databaseName).getProtocolType(), - metaDataContexts.getMetaData().getDatabase(databaseName).getResourceMetaData(), queryContext.getSqlStatementContext().getSqlStatement(), - SQLExecutorExceptionHandler.isExceptionThrown()); + ShardingSphereDatabase database = metaDataContexts.getMetaData().getDatabase(databaseName); + StatementExecuteQueryCallback callback = new StatementExecuteQueryCallback(database.getProtocolType(), + database.getResourceMetaData(), queryContext.getSqlStatementContext().getSqlStatement(), SQLExecutorExceptionHandler.isExceptionThrown()); SQLFederationContext context = new SQLFederationContext(false, queryContext, metaDataContexts.getMetaData(), connection.getProcessId()); - return executor.getSqlFederationEngine().executeQuery(createDriverExecutionPrepareEngine(), callback, context); + return executor.getSqlFederationEngine().executeQuery(createDriverExecutionPrepareEngine(database), callback, context); } - private DriverExecutionPrepareEngine createDriverExecutionPrepareEngine() { + private DriverExecutionPrepareEngine createDriverExecutionPrepareEngine(final ShardingSphereDatabase database) { int maxConnectionsSizePerQuery = metaDataContexts.getMetaData().getProps().getValue(ConfigurationPropertyKey.MAX_CONNECTIONS_SIZE_PER_QUERY); return new DriverExecutionPrepareEngine<>(JDBCDriverType.STATEMENT, maxConnectionsSizePerQuery, connection.getDatabaseConnectionManager(), statementManager, statementOption, - metaDataContexts.getMetaData().getDatabase(databaseName).getRuleMetaData().getRules(), - metaDataContexts.getMetaData().getDatabase(databaseName).getResourceMetaData().getStorageUnits()); + database.getRuleMetaData().getRules(), database.getResourceMetaData().getStorageUnits()); } @Override @@ -311,13 +308,14 @@ private int executeUpdate(final ExecuteUpdateCallback updateCallback, final SQLS private int executeUpdate0(final String sql, final ExecuteUpdateCallback updateCallback, final TrafficExecutorCallback trafficCallback) throws SQLException { QueryContext queryContext = createQueryContext(sql); - handleAutoCommit(queryContext); + handleAutoCommit(queryContext.getSqlStatementContext().getSqlStatement()); databaseName = queryContext.getDatabaseNameFromSQLStatement().orElse(connection.getDatabaseName()); connection.getDatabaseConnectionManager().getConnectionContext().setCurrentDatabase(databaseName); String trafficInstanceId = getInstanceIdAndSet(queryContext).orElse(null); if (null != trafficInstanceId) { - JDBCExecutionUnit executionUnit = createTrafficExecutionUnit(trafficInstanceId, queryContext); - return executor.getTrafficExecutor().execute(executionUnit, trafficCallback); + ShardingSphereDatabase database = metaDataContexts.getMetaData().getDatabase(databaseName); + return executor.getTrafficExecutor().execute( + connection.getProcessId(), databaseName, trafficInstanceId, queryContext, createDriverExecutionPrepareEngine(database), trafficCallback); } executionContext = createExecutionContext(queryContext); if (!metaDataContexts.getMetaData().getDatabase(databaseName).getRuleMetaData().getAttributes(RawExecutionRuleAttribute.class).isEmpty()) { @@ -416,13 +414,14 @@ public boolean execute(final String sql, final String[] columnNames) throws SQLE private boolean execute0(final String sql, final ExecuteCallback executeCallback, final TrafficExecutorCallback trafficCallback) throws SQLException { QueryContext queryContext = createQueryContext(sql); - handleAutoCommit(queryContext); + handleAutoCommit(queryContext.getSqlStatementContext().getSqlStatement()); databaseName = queryContext.getDatabaseNameFromSQLStatement().orElse(connection.getDatabaseName()); connection.getDatabaseConnectionManager().getConnectionContext().setCurrentDatabase(databaseName); String trafficInstanceId = getInstanceIdAndSet(queryContext).orElse(null); if (null != trafficInstanceId) { - JDBCExecutionUnit executionUnit = createTrafficExecutionUnit(trafficInstanceId, queryContext); - boolean result = executor.getTrafficExecutor().execute(executionUnit, trafficCallback); + ShardingSphereDatabase database = metaDataContexts.getMetaData().getDatabase(databaseName); + boolean result = executor.getTrafficExecutor().execute( + connection.getProcessId(), databaseName, trafficInstanceId, queryContext, createDriverExecutionPrepareEngine(database), trafficCallback); currentResultSet = executor.getTrafficExecutor().getResultSet(); return result; } @@ -440,20 +439,12 @@ private boolean execute0(final String sql, final ExecuteCallback executeCallback return executeWithExecutionContext(executeCallback, executionContext); } - private void handleAutoCommit(final QueryContext queryContext) throws SQLException { - if (AutoCommitUtils.needOpenTransaction(queryContext.getSqlStatementContext().getSqlStatement())) { + private void handleAutoCommit(final SQLStatement sqlStatement) throws SQLException { + if (AutoCommitUtils.needOpenTransaction(sqlStatement)) { connection.handleAutoCommit(); } } - private JDBCExecutionUnit createTrafficExecutionUnit(final String trafficInstanceId, final QueryContext queryContext) throws SQLException { - DriverExecutionPrepareEngine prepareEngine = createDriverExecutionPrepareEngine(); - ExecutionUnit executionUnit = new ExecutionUnit(trafficInstanceId, new SQLUnit(queryContext.getSql(), queryContext.getParameters())); - ExecutionGroupContext context = - prepareEngine.prepare(new RouteContext(), Collections.singletonList(executionUnit), new ExecutionGroupReportContext(connection.getProcessId(), databaseName, new Grantee("", ""))); - return context.getInputGroups().stream().flatMap(each -> each.getInputs().stream()).findFirst().orElseThrow(EmptyTrafficExecutionUnitException::new); - } - private void clearStatements() throws SQLException { for (Statement each : statements) { each.close(); @@ -495,7 +486,8 @@ private ExecutionContext createExecutionContext(final QueryContext queryContext) } private ExecutionGroupContext createExecutionGroupContext(final ExecutionContext executionContext) throws SQLException { - DriverExecutionPrepareEngine prepareEngine = createDriverExecutionPrepareEngine(); + ShardingSphereDatabase database = metaDataContexts.getMetaData().getDatabase(databaseName); + DriverExecutionPrepareEngine prepareEngine = createDriverExecutionPrepareEngine(database); return prepareEngine.prepare(executionContext.getRouteContext(), executionContext.getExecutionUnits(), new ExecutionGroupReportContext(connection.getProcessId(), databaseName, new Grantee("", ""))); } diff --git a/kernel/traffic/core/src/main/java/org/apache/shardingsphere/traffic/executor/TrafficExecutor.java b/kernel/traffic/core/src/main/java/org/apache/shardingsphere/traffic/executor/TrafficExecutor.java index b349a280cb14f..fddefd4740f5d 100644 --- a/kernel/traffic/core/src/main/java/org/apache/shardingsphere/traffic/executor/TrafficExecutor.java +++ b/kernel/traffic/core/src/main/java/org/apache/shardingsphere/traffic/executor/TrafficExecutor.java @@ -18,13 +18,23 @@ package org.apache.shardingsphere.traffic.executor; import lombok.Getter; +import org.apache.shardingsphere.infra.executor.kernel.model.ExecutionGroupContext; +import org.apache.shardingsphere.infra.executor.kernel.model.ExecutionGroupReportContext; +import org.apache.shardingsphere.infra.executor.sql.context.ExecutionUnit; import org.apache.shardingsphere.infra.executor.sql.context.SQLUnit; import org.apache.shardingsphere.infra.executor.sql.execute.engine.driver.jdbc.JDBCExecutionUnit; +import org.apache.shardingsphere.infra.executor.sql.prepare.driver.DriverExecutionPrepareEngine; +import org.apache.shardingsphere.infra.metadata.user.Grantee; +import org.apache.shardingsphere.infra.route.context.RouteContext; +import org.apache.shardingsphere.infra.session.query.QueryContext; +import org.apache.shardingsphere.traffic.exception.EmptyTrafficExecutionUnitException; +import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; import java.sql.Statement; +import java.util.Collections; import java.util.List; /** @@ -54,6 +64,37 @@ public T execute(final JDBCExecutionUnit executionUnit, final TrafficExecuto return result; } + /** + * Execute. + * + * @param processId process ID + * @param databaseName database name + * @param trafficInstanceId traffic instance ID + * @param queryContext query context + * @param prepareEngine prepare engine + * @param callback callback + * @param return type + * @return execute result + * @throws SQLException SQL exception + */ + public T execute(final String processId, final String databaseName, final String trafficInstanceId, final QueryContext queryContext, + final DriverExecutionPrepareEngine prepareEngine, final TrafficExecutorCallback callback) throws SQLException { + JDBCExecutionUnit executionUnit = createTrafficExecutionUnit(processId, databaseName, trafficInstanceId, queryContext, prepareEngine); + SQLUnit sqlUnit = executionUnit.getExecutionUnit().getSqlUnit(); + cacheStatement(sqlUnit.getParameters(), executionUnit.getStorageResource()); + T result = callback.execute(statement, sqlUnit.getSql()); + resultSet = statement.getResultSet(); + return result; + } + + private JDBCExecutionUnit createTrafficExecutionUnit(final String processId, final String databaseName, final String trafficInstanceId, final QueryContext queryContext, + final DriverExecutionPrepareEngine prepareEngine) throws SQLException { + ExecutionUnit executionUnit = new ExecutionUnit(trafficInstanceId, new SQLUnit(queryContext.getSql(), queryContext.getParameters())); + ExecutionGroupContext context = + prepareEngine.prepare(new RouteContext(), Collections.singleton(executionUnit), new ExecutionGroupReportContext(processId, databaseName, new Grantee("", ""))); + return context.getInputGroups().stream().flatMap(each -> each.getInputs().stream()).findFirst().orElseThrow(EmptyTrafficExecutionUnitException::new); + } + private void cacheStatement(final List params, final Statement statement) throws SQLException { this.statement = statement; setParameters(statement, params); diff --git a/kernel/traffic/core/src/test/java/org/apache/shardingsphere/traffic/executor/TrafficExecutorTest.java b/kernel/traffic/core/src/test/java/org/apache/shardingsphere/traffic/executor/TrafficExecutorTest.java deleted file mode 100644 index 7074ad28b8588..0000000000000 --- a/kernel/traffic/core/src/test/java/org/apache/shardingsphere/traffic/executor/TrafficExecutorTest.java +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.shardingsphere.traffic.executor; - -import org.apache.shardingsphere.infra.executor.sql.context.ExecutionUnit; -import org.apache.shardingsphere.infra.executor.sql.context.SQLUnit; -import org.apache.shardingsphere.infra.executor.sql.execute.engine.driver.jdbc.JDBCExecutionUnit; -import org.junit.jupiter.api.Test; - -import java.sql.SQLException; -import java.sql.Statement; -import java.util.Collections; - -import static org.mockito.Mockito.RETURNS_DEEP_STUBS; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -class TrafficExecutorTest { - - @Test - void assertClose() throws SQLException { - Statement statement = mock(Statement.class, RETURNS_DEEP_STUBS); - try (TrafficExecutor trafficExecutor = new TrafficExecutor()) { - JDBCExecutionUnit executionUnit = mock(JDBCExecutionUnit.class); - when(executionUnit.getExecutionUnit()).thenReturn(new ExecutionUnit("oltp_proxy_instance_id", new SQLUnit("SELECT 1", Collections.emptyList()))); - when(executionUnit.getStorageResource()).thenReturn(statement); - trafficExecutor.execute(executionUnit, Statement::executeQuery); - } - verify(statement).close(); - verify(statement, times(0)).getConnection(); - } -}