From 21745a283f60ce625247eb2d49621be4e9ef08c2 Mon Sep 17 00:00:00 2001 From: terrymanu Date: Sat, 1 Jun 2024 15:45:12 +0800 Subject: [PATCH] Refactor BatchPreparedStatementExecutor --- .../batch/BatchPreparedStatementExecutor.java | 18 ++++----- .../ShardingSpherePreparedStatement.java | 5 ++- .../BatchPreparedStatementExecutorTest.java | 38 ++++--------------- 3 files changed, 18 insertions(+), 43 deletions(-) diff --git a/jdbc/src/main/java/org/apache/shardingsphere/driver/executor/batch/BatchPreparedStatementExecutor.java b/jdbc/src/main/java/org/apache/shardingsphere/driver/executor/batch/BatchPreparedStatementExecutor.java index 4acb9ff047e4f..d60344af42bab 100644 --- a/jdbc/src/main/java/org/apache/shardingsphere/driver/executor/batch/BatchPreparedStatementExecutor.java +++ b/jdbc/src/main/java/org/apache/shardingsphere/driver/executor/batch/BatchPreparedStatementExecutor.java @@ -29,7 +29,7 @@ import org.apache.shardingsphere.infra.executor.sql.execute.engine.driver.jdbc.JDBCExecutionUnit; import org.apache.shardingsphere.infra.executor.sql.execute.engine.driver.jdbc.JDBCExecutor; import org.apache.shardingsphere.infra.executor.sql.execute.engine.driver.jdbc.JDBCExecutorCallback; -import org.apache.shardingsphere.infra.metadata.ShardingSphereMetaData; +import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase; import org.apache.shardingsphere.infra.metadata.user.Grantee; import org.apache.shardingsphere.infra.rule.attribute.datanode.DataNodeRuleAttribute; import org.apache.shardingsphere.sql.parser.sql.common.statement.SQLStatement; @@ -50,7 +50,7 @@ */ public final class BatchPreparedStatementExecutor { - private final ShardingSphereMetaData metaData; + private final ShardingSphereDatabase database; private final JDBCExecutor jdbcExecutor; @@ -61,13 +61,10 @@ public final class BatchPreparedStatementExecutor { private int batchCount; - private final String databaseName; - - public BatchPreparedStatementExecutor(final ShardingSphereMetaData metaData, final JDBCExecutor jdbcExecutor, final String databaseName, final String processId) { - this.databaseName = databaseName; - this.metaData = metaData; + public BatchPreparedStatementExecutor(final ShardingSphereDatabase database, final JDBCExecutor jdbcExecutor, final String processId) { + this.database = database; this.jdbcExecutor = jdbcExecutor; - executionGroupContext = new ExecutionGroupContext<>(new LinkedList<>(), new ExecutionGroupReportContext(processId, databaseName, new Grantee("", ""))); + executionGroupContext = new ExecutionGroupContext<>(new LinkedList<>(), new ExecutionGroupReportContext(processId, database.getName(), new Grantee("", ""))); batchExecutionUnits = new LinkedList<>(); } @@ -135,8 +132,7 @@ private void handleNewBatchExecutionUnits(final Collection n */ public int[] executeBatch(final SQLStatementContext sqlStatementContext) throws SQLException { boolean isExceptionThrown = SQLExecutorExceptionHandler.isExceptionThrown(); - JDBCExecutorCallback callback = new JDBCExecutorCallback(metaData.getDatabase(databaseName).getProtocolType(), - metaData.getDatabase(databaseName).getResourceMetaData(), sqlStatementContext.getSqlStatement(), isExceptionThrown) { + JDBCExecutorCallback callback = new JDBCExecutorCallback(database.getProtocolType(), database.getResourceMetaData(), sqlStatementContext.getSqlStatement(), isExceptionThrown) { @Override protected int[] executeSQL(final String sql, final Statement statement, final ConnectionMode connectionMode, final DatabaseType storageType) throws SQLException { @@ -157,7 +153,7 @@ protected Optional getSaneResult(final SQLStatement sqlStatement, final S } private boolean isNeedAccumulate(final SQLStatementContext sqlStatementContext) { - for (DataNodeRuleAttribute each : metaData.getDatabase(databaseName).getRuleMetaData().getAttributes(DataNodeRuleAttribute.class)) { + for (DataNodeRuleAttribute each : database.getRuleMetaData().getAttributes(DataNodeRuleAttribute.class)) { if (each.isNeedAccumulate(sqlStatementContext.getTablesContext().getTableNames())) { return true; } diff --git a/jdbc/src/main/java/org/apache/shardingsphere/driver/jdbc/core/statement/ShardingSpherePreparedStatement.java b/jdbc/src/main/java/org/apache/shardingsphere/driver/jdbc/core/statement/ShardingSpherePreparedStatement.java index 0ce53932acce2..a10c9e17b2a10 100644 --- a/jdbc/src/main/java/org/apache/shardingsphere/driver/jdbc/core/statement/ShardingSpherePreparedStatement.java +++ b/jdbc/src/main/java/org/apache/shardingsphere/driver/jdbc/core/statement/ShardingSpherePreparedStatement.java @@ -184,10 +184,11 @@ private ShardingSpherePreparedStatement(final ShardingSphereConnection connectio parameterMetaData = new ShardingSphereParameterMetaData(sqlStatement); statementOption = returnGeneratedKeys ? new StatementOption(true, columns) : new StatementOption(resultSetType, resultSetConcurrency, resultSetHoldability); executor = new DriverExecutor(connection); + ShardingSphereDatabase database = metaData.getDatabase(databaseName); JDBCExecutor jdbcExecutor = new JDBCExecutor(connection.getContextManager().getExecutorEngine(), connection.getDatabaseConnectionManager().getConnectionContext()); - batchPreparedStatementExecutor = new BatchPreparedStatementExecutor(metaData, jdbcExecutor, databaseName, connection.getProcessId()); + batchPreparedStatementExecutor = new BatchPreparedStatementExecutor(database, jdbcExecutor, connection.getProcessId()); kernelProcessor = new KernelProcessor(); - statementsCacheable = isStatementsCacheable(metaData.getDatabase(databaseName).getRuleMetaData()); + statementsCacheable = isStatementsCacheable(database.getRuleMetaData()); trafficRule = metaData.getGlobalRuleMetaData().getSingleRule(TrafficRule.class); selectContainsEnhancedTable = sqlStatementContext instanceof SelectStatementContext && ((SelectStatementContext) sqlStatementContext).isContainsEnhancedTable(); statementManager = new StatementManager(); diff --git a/jdbc/src/test/java/org/apache/shardingsphere/driver/executor/batch/BatchPreparedStatementExecutorTest.java b/jdbc/src/test/java/org/apache/shardingsphere/driver/executor/batch/BatchPreparedStatementExecutorTest.java index 25c086062014e..7174aac285632 100644 --- a/jdbc/src/test/java/org/apache/shardingsphere/driver/executor/batch/BatchPreparedStatementExecutorTest.java +++ b/jdbc/src/test/java/org/apache/shardingsphere/driver/executor/batch/BatchPreparedStatementExecutorTest.java @@ -18,8 +18,6 @@ package org.apache.shardingsphere.driver.executor.batch; import lombok.SneakyThrows; -import org.apache.shardingsphere.authority.rule.AuthorityRule; -import org.apache.shardingsphere.driver.jdbc.core.connection.ShardingSphereConnection; import org.apache.shardingsphere.infra.binder.context.segment.table.TablesContext; import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext; import org.apache.shardingsphere.infra.executor.kernel.ExecutorEngine; @@ -32,18 +30,13 @@ import org.apache.shardingsphere.infra.executor.sql.execute.engine.SQLExecutorExceptionHandler; import org.apache.shardingsphere.infra.executor.sql.execute.engine.driver.jdbc.JDBCExecutionUnit; import org.apache.shardingsphere.infra.executor.sql.execute.engine.driver.jdbc.JDBCExecutor; +import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase; import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData; import org.apache.shardingsphere.infra.metadata.user.Grantee; import org.apache.shardingsphere.infra.rule.attribute.RuleAttributes; import org.apache.shardingsphere.infra.rule.attribute.datanode.DataNodeRuleAttribute; -import org.apache.shardingsphere.mode.manager.ContextManager; -import org.apache.shardingsphere.mode.metadata.MetaDataContexts; +import org.apache.shardingsphere.infra.session.connection.ConnectionContext; import org.apache.shardingsphere.sharding.rule.ShardingRule; -import org.apache.shardingsphere.traffic.rule.TrafficRule; -import org.apache.shardingsphere.traffic.rule.builder.DefaultTrafficRuleConfigurationBuilder; -import org.apache.shardingsphere.transaction.api.TransactionType; -import org.apache.shardingsphere.transaction.config.TransactionRuleConfiguration; -import org.apache.shardingsphere.transaction.rule.TransactionRule; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -61,7 +54,6 @@ import java.util.Collections; import java.util.LinkedList; import java.util.List; -import java.util.Properties; import java.util.UUID; import java.util.concurrent.ThreadLocalRandom; @@ -90,33 +82,19 @@ class BatchPreparedStatementExecutorTest { @BeforeEach void setUp() { SQLExecutorExceptionHandler.setExceptionThrown(true); - ShardingSphereConnection connection = new ShardingSphereConnection("foo_db", mockContextManager()); String processId = new UUID(ThreadLocalRandom.current().nextLong(), ThreadLocalRandom.current().nextLong()).toString().replace("-", ""); - executor = new BatchPreparedStatementExecutor(connection.getContextManager().getMetaDataContexts().getMetaData(), - new JDBCExecutor(executorEngine, connection.getDatabaseConnectionManager().getConnectionContext()), "foo_db", processId); + executor = new BatchPreparedStatementExecutor(mockDatabase(), new JDBCExecutor(executorEngine, mock(ConnectionContext.class, RETURNS_DEEP_STUBS)), processId); when(sqlStatementContext.getTablesContext()).thenReturn(mock(TablesContext.class)); } - private ContextManager mockContextManager() { - ContextManager result = mock(ContextManager.class, RETURNS_DEEP_STUBS); - MetaDataContexts metaDataContexts = mockMetaDataContexts(); - when(result.getMetaDataContexts()).thenReturn(metaDataContexts); + private ShardingSphereDatabase mockDatabase() { + ShardingSphereDatabase result = mock(ShardingSphereDatabase.class, RETURNS_DEEP_STUBS); + when(result.getName()).thenReturn("foo_db"); + RuleMetaData ruleMetaData = new RuleMetaData(Collections.singleton(mockShardingRule())); + when(result.getRuleMetaData()).thenReturn(ruleMetaData); return result; } - private MetaDataContexts mockMetaDataContexts() { - MetaDataContexts result = mock(MetaDataContexts.class, RETURNS_DEEP_STUBS); - RuleMetaData globalRuleMetaData = new RuleMetaData(Arrays.asList(mockTransactionRule(), new TrafficRule(new DefaultTrafficRuleConfigurationBuilder().build()), mock(AuthorityRule.class))); - when(result.getMetaData().getGlobalRuleMetaData()).thenReturn(globalRuleMetaData); - RuleMetaData databaseRuleMetaData = new RuleMetaData(Collections.singleton(mockShardingRule())); - when(result.getMetaData().getDatabase("foo_db").getRuleMetaData()).thenReturn(databaseRuleMetaData); - return result; - } - - private TransactionRule mockTransactionRule() { - return new TransactionRule(new TransactionRuleConfiguration(TransactionType.LOCAL.name(), "", new Properties()), Collections.emptyMap()); - } - private ShardingRule mockShardingRule() { ShardingRule result = mock(ShardingRule.class, RETURNS_DEEP_STUBS); DataNodeRuleAttribute ruleAttribute = mock(DataNodeRuleAttribute.class);