Skip to content

Commit

Permalink
Refactor ShardingSphereStatement.metaData
Browse files Browse the repository at this point in the history
  • Loading branch information
terrymanu committed Jun 1, 2024
1 parent 382d2f0 commit 08b9f4e
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@
import org.apache.shardingsphere.infra.executor.sql.execute.engine.driver.jdbc.JDBCExecutionUnit;
import org.apache.shardingsphere.infra.executor.sql.execute.engine.driver.jdbc.JDBCExecutor;
import org.apache.shardingsphere.infra.executor.sql.execute.engine.driver.jdbc.JDBCExecutorCallback;
import org.apache.shardingsphere.infra.metadata.ShardingSphereMetaData;
import org.apache.shardingsphere.infra.metadata.user.Grantee;
import org.apache.shardingsphere.infra.rule.attribute.datanode.DataNodeRuleAttribute;
import org.apache.shardingsphere.mode.metadata.MetaDataContexts;
import org.apache.shardingsphere.sql.parser.sql.common.statement.SQLStatement;

import java.sql.SQLException;
Expand All @@ -50,7 +50,7 @@
*/
public final class BatchPreparedStatementExecutor {

private final MetaDataContexts metaDataContexts;
private final ShardingSphereMetaData metaData;

private final JDBCExecutor jdbcExecutor;

Expand All @@ -63,9 +63,9 @@ public final class BatchPreparedStatementExecutor {

private final String databaseName;

public BatchPreparedStatementExecutor(final MetaDataContexts metaDataContexts, final JDBCExecutor jdbcExecutor, final String databaseName, final String processId) {
public BatchPreparedStatementExecutor(final ShardingSphereMetaData metaData, final JDBCExecutor jdbcExecutor, final String databaseName, final String processId) {
this.databaseName = databaseName;
this.metaDataContexts = metaDataContexts;
this.metaData = metaData;
this.jdbcExecutor = jdbcExecutor;
executionGroupContext = new ExecutionGroupContext<>(new LinkedList<>(), new ExecutionGroupReportContext(processId, databaseName, new Grantee("", "")));
batchExecutionUnits = new LinkedList<>();
Expand Down Expand Up @@ -135,8 +135,8 @@ private void handleNewBatchExecutionUnits(final Collection<BatchExecutionUnit> n
*/
public int[] executeBatch(final SQLStatementContext sqlStatementContext) throws SQLException {
boolean isExceptionThrown = SQLExecutorExceptionHandler.isExceptionThrown();
JDBCExecutorCallback<int[]> callback = new JDBCExecutorCallback<int[]>(metaDataContexts.getMetaData().getDatabase(databaseName).getProtocolType(),
metaDataContexts.getMetaData().getDatabase(databaseName).getResourceMetaData(), sqlStatementContext.getSqlStatement(), isExceptionThrown) {
JDBCExecutorCallback<int[]> callback = new JDBCExecutorCallback<int[]>(metaData.getDatabase(databaseName).getProtocolType(),
metaData.getDatabase(databaseName).getResourceMetaData(), sqlStatementContext.getSqlStatement(), isExceptionThrown) {

@Override
protected int[] executeSQL(final String sql, final Statement statement, final ConnectionMode connectionMode, final DatabaseType storageType) throws SQLException {
Expand All @@ -157,7 +157,7 @@ protected Optional<int[]> getSaneResult(final SQLStatement sqlStatement, final S
}

private boolean isNeedAccumulate(final SQLStatementContext sqlStatementContext) {
for (DataNodeRuleAttribute each : metaDataContexts.getMetaData().getDatabase(databaseName).getRuleMetaData().getAttributes(DataNodeRuleAttribute.class)) {
for (DataNodeRuleAttribute each : metaData.getDatabase(databaseName).getRuleMetaData().getAttributes(DataNodeRuleAttribute.class)) {
if (each.isNeedAccumulate(sqlStatementContext.getTablesContext().getTableNames())) {
return true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
import org.apache.shardingsphere.infra.hint.SQLHintUtils;
import org.apache.shardingsphere.infra.merge.MergeEngine;
import org.apache.shardingsphere.infra.merge.result.MergedResult;
import org.apache.shardingsphere.infra.metadata.ShardingSphereMetaData;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData;
import org.apache.shardingsphere.infra.metadata.user.Grantee;
Expand All @@ -66,7 +67,6 @@
import org.apache.shardingsphere.infra.rule.attribute.datanode.DataNodeRuleAttribute;
import org.apache.shardingsphere.infra.rule.attribute.resoure.StorageConnectorReusableRuleAttribute;
import org.apache.shardingsphere.infra.session.query.QueryContext;
import org.apache.shardingsphere.mode.metadata.MetaDataContexts;
import org.apache.shardingsphere.parser.rule.SQLParserRule;
import org.apache.shardingsphere.sql.parser.sql.common.statement.SQLStatement;
import org.apache.shardingsphere.sql.parser.sql.common.statement.dal.DALStatement;
Expand Down Expand Up @@ -96,7 +96,7 @@ public final class ShardingSpherePreparedStatement extends AbstractPreparedState
@Getter
private final ShardingSphereConnection connection;

private final MetaDataContexts metaDataContexts;
private final ShardingSphereMetaData metaData;

private final String sql;

Expand Down Expand Up @@ -170,25 +170,25 @@ private ShardingSpherePreparedStatement(final ShardingSphereConnection connectio
throw new EmptySQLException().toSQLException();
}
this.connection = connection;
metaDataContexts = connection.getContextManager().getMetaDataContexts();
metaData = connection.getContextManager().getMetaDataContexts().getMetaData();
hintValueContext = SQLHintUtils.extractHint(sql);
this.sql = SQLHintUtils.removeHint(sql);
statements = new ArrayList<>();
parameterSets = new ArrayList<>();
SQLParserRule sqlParserRule = metaDataContexts.getMetaData().getGlobalRuleMetaData().getSingleRule(SQLParserRule.class);
SQLParserEngine sqlParserEngine = sqlParserRule.getSQLParserEngine(metaDataContexts.getMetaData().getDatabase(connection.getDatabaseName()).getProtocolType());
SQLParserRule sqlParserRule = metaData.getGlobalRuleMetaData().getSingleRule(SQLParserRule.class);
SQLParserEngine sqlParserEngine = sqlParserRule.getSQLParserEngine(metaData.getDatabase(connection.getDatabaseName()).getProtocolType());
SQLStatement sqlStatement = sqlParserEngine.parse(this.sql, true);
sqlStatementContext = new SQLBindEngine(metaDataContexts.getMetaData(), connection.getDatabaseName(), hintValueContext).bind(sqlStatement, Collections.emptyList());
sqlStatementContext = new SQLBindEngine(metaData, connection.getDatabaseName(), hintValueContext).bind(sqlStatement, Collections.emptyList());
databaseName = sqlStatementContext.getTablesContext().getDatabaseName().orElse(connection.getDatabaseName());
connection.getDatabaseConnectionManager().getConnectionContext().setCurrentDatabase(databaseName);
parameterMetaData = new ShardingSphereParameterMetaData(sqlStatement);
statementOption = returnGeneratedKeys ? new StatementOption(true, columns) : new StatementOption(resultSetType, resultSetConcurrency, resultSetHoldability);
executor = new DriverExecutor(connection);
JDBCExecutor jdbcExecutor = new JDBCExecutor(connection.getContextManager().getExecutorEngine(), connection.getDatabaseConnectionManager().getConnectionContext());
batchPreparedStatementExecutor = new BatchPreparedStatementExecutor(metaDataContexts, jdbcExecutor, databaseName, connection.getProcessId());
batchPreparedStatementExecutor = new BatchPreparedStatementExecutor(metaData, jdbcExecutor, databaseName, connection.getProcessId());
kernelProcessor = new KernelProcessor();
statementsCacheable = isStatementsCacheable(metaDataContexts.getMetaData().getDatabase(databaseName).getRuleMetaData());
trafficRule = metaDataContexts.getMetaData().getGlobalRuleMetaData().getSingleRule(TrafficRule.class);
statementsCacheable = isStatementsCacheable(metaData.getDatabase(databaseName).getRuleMetaData());
trafficRule = metaData.getGlobalRuleMetaData().getSingleRule(TrafficRule.class);
selectContainsEnhancedTable = sqlStatementContext instanceof SelectStatementContext && ((SelectStatementContext) sqlStatementContext).isContainsEnhancedTable();
statementManager = new StatementManager();
}
Expand All @@ -207,9 +207,9 @@ public ResultSet executeQuery() throws SQLException {
clearPrevious();
QueryContext queryContext = createQueryContext();
handleAutoCommit(queryContext.getSqlStatementContext().getSqlStatement());
ShardingSphereDatabase database = metaDataContexts.getMetaData().getDatabase(databaseName);
ShardingSphereDatabase database = metaData.getDatabase(databaseName);
findGeneratedKey().ifPresent(optional -> generatedValues.addAll(optional.getGeneratedValues()));
currentResultSet = executor.executeQuery(metaDataContexts.getMetaData(), database, queryContext, createDriverExecutionPrepareEngine(database), this, columnLabelAndIndexMap,
currentResultSet = executor.executeQuery(metaData, database, queryContext, createDriverExecutionPrepareEngine(database), this, columnLabelAndIndexMap,
(StatementReplayCallback<PreparedStatement>) this::replay);
if (currentResultSet instanceof ShardingSphereResultSet) {
columnLabelAndIndexMap = ((ShardingSphereResultSet) currentResultSet).getColumnLabelAndIndexMap();
Expand All @@ -222,8 +222,8 @@ public ResultSet executeQuery() throws SQLException {
// CHECKSTYLE:OFF
} catch (final RuntimeException ex) {
// CHECKSTYLE:ON
handleExceptionInTransaction(connection, metaDataContexts);
throw SQLExceptionTransformEngine.toSQLException(ex, metaDataContexts.getMetaData().getDatabase(databaseName).getProtocolType());
handleExceptionInTransaction(connection, metaData);
throw SQLExceptionTransformEngine.toSQLException(ex, metaData.getDatabase(databaseName).getProtocolType());
} finally {
batchPreparedStatementExecutor.clear();
clearParameters();
Expand All @@ -243,7 +243,7 @@ private void resetParameters() throws SQLException {
}

private DriverExecutionPrepareEngine<JDBCExecutionUnit, Connection> createDriverExecutionPrepareEngine(final ShardingSphereDatabase database) {
int maxConnectionsSizePerQuery = metaDataContexts.getMetaData().getProps().<Integer>getValue(ConfigurationPropertyKey.MAX_CONNECTIONS_SIZE_PER_QUERY);
int maxConnectionsSizePerQuery = metaData.getProps().<Integer>getValue(ConfigurationPropertyKey.MAX_CONNECTIONS_SIZE_PER_QUERY);
return new DriverExecutionPrepareEngine<>(JDBCDriverType.PREPARED_STATEMENT, maxConnectionsSizePerQuery, connection.getDatabaseConnectionManager(), statementManager, statementOption,
database.getRuleMetaData().getRules(), database.getResourceMetaData().getStorageUnits());
}
Expand All @@ -258,8 +258,8 @@ public int executeUpdate() throws SQLException {
clearPrevious();
QueryContext queryContext = createQueryContext();
handleAutoCommit(queryContext.getSqlStatementContext().getSqlStatement());
ShardingSphereDatabase database = metaDataContexts.getMetaData().getDatabase(databaseName);
final int result = executor.executeUpdate(metaDataContexts.getMetaData(), database, queryContext, createDriverExecutionPrepareEngine(database),
ShardingSphereDatabase database = metaData.getDatabase(databaseName);
final int result = executor.executeUpdate(metaData, database, queryContext, createDriverExecutionPrepareEngine(database),
(statement, sql) -> ((PreparedStatement) statement).executeUpdate(), null, (StatementReplayCallback<PreparedStatement>) this::replay);
for (Statement each : executor.getStatements()) {
statements.add((PreparedStatement) each);
Expand All @@ -270,8 +270,8 @@ public int executeUpdate() throws SQLException {
// CHECKSTYLE:OFF
} catch (final RuntimeException ex) {
// CHECKSTYLE:ON
handleExceptionInTransaction(connection, metaDataContexts);
throw SQLExceptionTransformEngine.toSQLException(ex, metaDataContexts.getMetaData().getDatabase(databaseName).getProtocolType());
handleExceptionInTransaction(connection, metaData);
throw SQLExceptionTransformEngine.toSQLException(ex, metaData.getDatabase(databaseName).getProtocolType());
} finally {
clearBatch();
}
Expand All @@ -287,9 +287,9 @@ public boolean execute() throws SQLException {
clearPrevious();
QueryContext queryContext = createQueryContext();
handleAutoCommit(queryContext.getSqlStatementContext().getSqlStatement());
ShardingSphereDatabase database = metaDataContexts.getMetaData().getDatabase(databaseName);
ShardingSphereDatabase database = metaData.getDatabase(databaseName);
final boolean result = executor.executeAdvance(
metaDataContexts.getMetaData(), database, queryContext, createDriverExecutionPrepareEngine(database), (statement, sql) -> ((PreparedStatement) statement).execute(),
metaData, database, queryContext, createDriverExecutionPrepareEngine(database), (statement, sql) -> ((PreparedStatement) statement).execute(),
null, (StatementReplayCallback<PreparedStatement>) this::replay);
for (Statement each : executor.getStatements()) {
statements.add((PreparedStatement) each);
Expand All @@ -300,8 +300,8 @@ public boolean execute() throws SQLException {
// CHECKSTYLE:OFF
} catch (final RuntimeException ex) {
// CHECKSTYLE:ON
handleExceptionInTransaction(connection, metaDataContexts);
throw SQLExceptionTransformEngine.toSQLException(ex, metaDataContexts.getMetaData().getDatabase(databaseName).getProtocolType());
handleExceptionInTransaction(connection, metaData);
throw SQLExceptionTransformEngine.toSQLException(ex, metaData.getDatabase(databaseName).getProtocolType());
} finally {
clearBatch();
}
Expand Down Expand Up @@ -351,11 +351,11 @@ private List<QueryResult> getQueryResults(final List<ResultSet> resultSets) thro
}

private ExecutionContext createExecutionContext(final QueryContext queryContext) {
RuleMetaData globalRuleMetaData = metaDataContexts.getMetaData().getGlobalRuleMetaData();
ShardingSphereDatabase currentDatabase = metaDataContexts.getMetaData().getDatabase(databaseName);
RuleMetaData globalRuleMetaData = metaData.getGlobalRuleMetaData();
ShardingSphereDatabase currentDatabase = metaData.getDatabase(databaseName);
SQLAuditEngine.audit(queryContext.getSqlStatementContext(), queryContext.getParameters(), globalRuleMetaData, currentDatabase, null, queryContext.getHintValueContext());
ExecutionContext result = kernelProcessor.generateExecutionContext(
queryContext, currentDatabase, globalRuleMetaData, metaDataContexts.getMetaData().getProps(), connection.getDatabaseConnectionManager().getConnectionContext());
queryContext, currentDatabase, globalRuleMetaData, metaData.getProps(), connection.getDatabaseConnectionManager().getConnectionContext());
findGeneratedKey().ifPresent(optional -> generatedValues.addAll(optional.getGeneratedValues()));
return result;
}
Expand All @@ -374,8 +374,8 @@ private QueryContext createQueryContext() {
}

private MergedResult mergeQuery(final List<QueryResult> queryResults, final SQLStatementContext sqlStatementContext) throws SQLException {
MergeEngine mergeEngine = new MergeEngine(metaDataContexts.getMetaData().getGlobalRuleMetaData(), metaDataContexts.getMetaData().getDatabase(databaseName),
metaDataContexts.getMetaData().getProps(), connection.getDatabaseConnectionManager().getConnectionContext());
MergeEngine mergeEngine = new MergeEngine(metaData.getGlobalRuleMetaData(), metaData.getDatabase(databaseName),
metaData.getProps(), connection.getDatabaseConnectionManager().getConnectionContext());
return mergeEngine.merge(queryResults, sqlStatementContext);
}

Expand Down Expand Up @@ -424,7 +424,7 @@ public ResultSet getGeneratedKeys() throws SQLException {
}

private String getGeneratedKeysColumnName(final String columnName) {
return metaDataContexts.getMetaData().getDatabase(databaseName).getProtocolType() instanceof MySQLDatabaseType ? "GENERATED_KEY" : columnName;
return metaData.getDatabase(databaseName).getProtocolType() instanceof MySQLDatabaseType ? "GENERATED_KEY" : columnName;
}

@Override
Expand Down Expand Up @@ -452,8 +452,8 @@ public int[] executeBatch() throws SQLException {
// CHECKSTYLE:OFF
} catch (final RuntimeException ex) {
// CHECKSTYLE:ON
handleExceptionInTransaction(connection, metaDataContexts);
throw SQLExceptionTransformEngine.toSQLException(ex, metaDataContexts.getMetaData().getDatabase(databaseName).getProtocolType());
handleExceptionInTransaction(connection, metaData);
throw SQLExceptionTransformEngine.toSQLException(ex, metaData.getDatabase(databaseName).getProtocolType());
} finally {
clearBatch();
}
Expand All @@ -474,10 +474,10 @@ private int[] doExecuteBatch(final BatchPreparedStatementExecutor batchExecutor)
}

private void initBatchPreparedStatementExecutor(final BatchPreparedStatementExecutor batchExecutor) throws SQLException {
DriverExecutionPrepareEngine<JDBCExecutionUnit, Connection> prepareEngine = new DriverExecutionPrepareEngine<>(JDBCDriverType.PREPARED_STATEMENT, metaDataContexts.getMetaData().getProps()
DriverExecutionPrepareEngine<JDBCExecutionUnit, Connection> prepareEngine = new DriverExecutionPrepareEngine<>(JDBCDriverType.PREPARED_STATEMENT, metaData.getProps()
.<Integer>getValue(ConfigurationPropertyKey.MAX_CONNECTIONS_SIZE_PER_QUERY), connection.getDatabaseConnectionManager(), statementManager, statementOption,
metaDataContexts.getMetaData().getDatabase(databaseName).getRuleMetaData().getRules(),
metaDataContexts.getMetaData().getDatabase(databaseName).getResourceMetaData().getStorageUnits());
metaData.getDatabase(databaseName).getRuleMetaData().getRules(),
metaData.getDatabase(databaseName).getResourceMetaData().getStorageUnits());
List<ExecutionUnit> executionUnits = new ArrayList<>(batchExecutor.getBatchExecutionUnits().size());
for (BatchExecutionUnit each : batchExecutor.getBatchExecutionUnits()) {
ExecutionUnit executionUnit = each.getExecutionUnit();
Expand Down Expand Up @@ -523,7 +523,7 @@ public int getResultSetHoldability() {

@Override
public boolean isAccumulate() {
for (DataNodeRuleAttribute each : metaDataContexts.getMetaData().getDatabase(databaseName).getRuleMetaData().getAttributes(DataNodeRuleAttribute.class)) {
for (DataNodeRuleAttribute each : metaData.getDatabase(databaseName).getRuleMetaData().getAttributes(DataNodeRuleAttribute.class)) {
if (each.isNeedAccumulate(sqlStatementContext.getTablesContext().getTableNames())) {
return true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ void setUp() {
SQLExecutorExceptionHandler.setExceptionThrown(true);
ShardingSphereConnection connection = new ShardingSphereConnection("foo_db", mockContextManager());
String processId = new UUID(ThreadLocalRandom.current().nextLong(), ThreadLocalRandom.current().nextLong()).toString().replace("-", "");
executor = new BatchPreparedStatementExecutor(connection.getContextManager().getMetaDataContexts(),
executor = new BatchPreparedStatementExecutor(connection.getContextManager().getMetaDataContexts().getMetaData(),
new JDBCExecutor(executorEngine, connection.getDatabaseConnectionManager().getConnectionContext()), "foo_db", processId);
when(sqlStatementContext.getTablesContext()).thenReturn(mock(TablesContext.class));
}
Expand Down

0 comments on commit 08b9f4e

Please sign in to comment.