Skip to content

Commit

Permalink
Refactor BatchPreparedStatementExecutor
Browse files Browse the repository at this point in the history
  • Loading branch information
terrymanu committed Jun 1, 2024
1 parent 51d173e commit 21745a2
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import org.apache.shardingsphere.infra.executor.sql.execute.engine.driver.jdbc.JDBCExecutionUnit;
import org.apache.shardingsphere.infra.executor.sql.execute.engine.driver.jdbc.JDBCExecutor;
import org.apache.shardingsphere.infra.executor.sql.execute.engine.driver.jdbc.JDBCExecutorCallback;
import org.apache.shardingsphere.infra.metadata.ShardingSphereMetaData;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.infra.metadata.user.Grantee;
import org.apache.shardingsphere.infra.rule.attribute.datanode.DataNodeRuleAttribute;
import org.apache.shardingsphere.sql.parser.sql.common.statement.SQLStatement;
Expand All @@ -50,7 +50,7 @@
*/
public final class BatchPreparedStatementExecutor {

private final ShardingSphereMetaData metaData;
private final ShardingSphereDatabase database;

private final JDBCExecutor jdbcExecutor;

Expand All @@ -61,13 +61,10 @@ public final class BatchPreparedStatementExecutor {

private int batchCount;

private final String databaseName;

public BatchPreparedStatementExecutor(final ShardingSphereMetaData metaData, final JDBCExecutor jdbcExecutor, final String databaseName, final String processId) {
this.databaseName = databaseName;
this.metaData = metaData;
public BatchPreparedStatementExecutor(final ShardingSphereDatabase database, final JDBCExecutor jdbcExecutor, final String processId) {
this.database = database;
this.jdbcExecutor = jdbcExecutor;
executionGroupContext = new ExecutionGroupContext<>(new LinkedList<>(), new ExecutionGroupReportContext(processId, databaseName, new Grantee("", "")));
executionGroupContext = new ExecutionGroupContext<>(new LinkedList<>(), new ExecutionGroupReportContext(processId, database.getName(), new Grantee("", "")));
batchExecutionUnits = new LinkedList<>();
}

Expand Down Expand Up @@ -135,8 +132,7 @@ private void handleNewBatchExecutionUnits(final Collection<BatchExecutionUnit> n
*/
public int[] executeBatch(final SQLStatementContext sqlStatementContext) throws SQLException {
boolean isExceptionThrown = SQLExecutorExceptionHandler.isExceptionThrown();
JDBCExecutorCallback<int[]> callback = new JDBCExecutorCallback<int[]>(metaData.getDatabase(databaseName).getProtocolType(),
metaData.getDatabase(databaseName).getResourceMetaData(), sqlStatementContext.getSqlStatement(), isExceptionThrown) {
JDBCExecutorCallback<int[]> callback = new JDBCExecutorCallback<int[]>(database.getProtocolType(), database.getResourceMetaData(), sqlStatementContext.getSqlStatement(), isExceptionThrown) {

@Override
protected int[] executeSQL(final String sql, final Statement statement, final ConnectionMode connectionMode, final DatabaseType storageType) throws SQLException {
Expand All @@ -157,7 +153,7 @@ protected Optional<int[]> getSaneResult(final SQLStatement sqlStatement, final S
}

private boolean isNeedAccumulate(final SQLStatementContext sqlStatementContext) {
for (DataNodeRuleAttribute each : metaData.getDatabase(databaseName).getRuleMetaData().getAttributes(DataNodeRuleAttribute.class)) {
for (DataNodeRuleAttribute each : database.getRuleMetaData().getAttributes(DataNodeRuleAttribute.class)) {
if (each.isNeedAccumulate(sqlStatementContext.getTablesContext().getTableNames())) {
return true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,11 @@ private ShardingSpherePreparedStatement(final ShardingSphereConnection connectio
parameterMetaData = new ShardingSphereParameterMetaData(sqlStatement);
statementOption = returnGeneratedKeys ? new StatementOption(true, columns) : new StatementOption(resultSetType, resultSetConcurrency, resultSetHoldability);
executor = new DriverExecutor(connection);
ShardingSphereDatabase database = metaData.getDatabase(databaseName);
JDBCExecutor jdbcExecutor = new JDBCExecutor(connection.getContextManager().getExecutorEngine(), connection.getDatabaseConnectionManager().getConnectionContext());
batchPreparedStatementExecutor = new BatchPreparedStatementExecutor(metaData, jdbcExecutor, databaseName, connection.getProcessId());
batchPreparedStatementExecutor = new BatchPreparedStatementExecutor(database, jdbcExecutor, connection.getProcessId());
kernelProcessor = new KernelProcessor();
statementsCacheable = isStatementsCacheable(metaData.getDatabase(databaseName).getRuleMetaData());
statementsCacheable = isStatementsCacheable(database.getRuleMetaData());
trafficRule = metaData.getGlobalRuleMetaData().getSingleRule(TrafficRule.class);
selectContainsEnhancedTable = sqlStatementContext instanceof SelectStatementContext && ((SelectStatementContext) sqlStatementContext).isContainsEnhancedTable();
statementManager = new StatementManager();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
package org.apache.shardingsphere.driver.executor.batch;

import lombok.SneakyThrows;
import org.apache.shardingsphere.authority.rule.AuthorityRule;
import org.apache.shardingsphere.driver.jdbc.core.connection.ShardingSphereConnection;
import org.apache.shardingsphere.infra.binder.context.segment.table.TablesContext;
import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.executor.kernel.ExecutorEngine;
Expand All @@ -32,18 +30,13 @@
import org.apache.shardingsphere.infra.executor.sql.execute.engine.SQLExecutorExceptionHandler;
import org.apache.shardingsphere.infra.executor.sql.execute.engine.driver.jdbc.JDBCExecutionUnit;
import org.apache.shardingsphere.infra.executor.sql.execute.engine.driver.jdbc.JDBCExecutor;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData;
import org.apache.shardingsphere.infra.metadata.user.Grantee;
import org.apache.shardingsphere.infra.rule.attribute.RuleAttributes;
import org.apache.shardingsphere.infra.rule.attribute.datanode.DataNodeRuleAttribute;
import org.apache.shardingsphere.mode.manager.ContextManager;
import org.apache.shardingsphere.mode.metadata.MetaDataContexts;
import org.apache.shardingsphere.infra.session.connection.ConnectionContext;
import org.apache.shardingsphere.sharding.rule.ShardingRule;
import org.apache.shardingsphere.traffic.rule.TrafficRule;
import org.apache.shardingsphere.traffic.rule.builder.DefaultTrafficRuleConfigurationBuilder;
import org.apache.shardingsphere.transaction.api.TransactionType;
import org.apache.shardingsphere.transaction.config.TransactionRuleConfiguration;
import org.apache.shardingsphere.transaction.rule.TransactionRule;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
Expand All @@ -61,7 +54,6 @@
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.Properties;
import java.util.UUID;
import java.util.concurrent.ThreadLocalRandom;

Expand Down Expand Up @@ -90,33 +82,19 @@ class BatchPreparedStatementExecutorTest {
@BeforeEach
void setUp() {
SQLExecutorExceptionHandler.setExceptionThrown(true);
ShardingSphereConnection connection = new ShardingSphereConnection("foo_db", mockContextManager());
String processId = new UUID(ThreadLocalRandom.current().nextLong(), ThreadLocalRandom.current().nextLong()).toString().replace("-", "");
executor = new BatchPreparedStatementExecutor(connection.getContextManager().getMetaDataContexts().getMetaData(),
new JDBCExecutor(executorEngine, connection.getDatabaseConnectionManager().getConnectionContext()), "foo_db", processId);
executor = new BatchPreparedStatementExecutor(mockDatabase(), new JDBCExecutor(executorEngine, mock(ConnectionContext.class, RETURNS_DEEP_STUBS)), processId);
when(sqlStatementContext.getTablesContext()).thenReturn(mock(TablesContext.class));
}

private ContextManager mockContextManager() {
ContextManager result = mock(ContextManager.class, RETURNS_DEEP_STUBS);
MetaDataContexts metaDataContexts = mockMetaDataContexts();
when(result.getMetaDataContexts()).thenReturn(metaDataContexts);
private ShardingSphereDatabase mockDatabase() {
ShardingSphereDatabase result = mock(ShardingSphereDatabase.class, RETURNS_DEEP_STUBS);
when(result.getName()).thenReturn("foo_db");
RuleMetaData ruleMetaData = new RuleMetaData(Collections.singleton(mockShardingRule()));
when(result.getRuleMetaData()).thenReturn(ruleMetaData);
return result;
}

private MetaDataContexts mockMetaDataContexts() {
MetaDataContexts result = mock(MetaDataContexts.class, RETURNS_DEEP_STUBS);
RuleMetaData globalRuleMetaData = new RuleMetaData(Arrays.asList(mockTransactionRule(), new TrafficRule(new DefaultTrafficRuleConfigurationBuilder().build()), mock(AuthorityRule.class)));
when(result.getMetaData().getGlobalRuleMetaData()).thenReturn(globalRuleMetaData);
RuleMetaData databaseRuleMetaData = new RuleMetaData(Collections.singleton(mockShardingRule()));
when(result.getMetaData().getDatabase("foo_db").getRuleMetaData()).thenReturn(databaseRuleMetaData);
return result;
}

private TransactionRule mockTransactionRule() {
return new TransactionRule(new TransactionRuleConfiguration(TransactionType.LOCAL.name(), "", new Properties()), Collections.emptyMap());
}

private ShardingRule mockShardingRule() {
ShardingRule result = mock(ShardingRule.class, RETURNS_DEEP_STUBS);
DataNodeRuleAttribute ruleAttribute = mock(DataNodeRuleAttribute.class);
Expand Down

0 comments on commit 21745a2

Please sign in to comment.