Skip to content

Commit

Permalink
Add metaData as field for DriverExecutor (#31521)
Browse files Browse the repository at this point in the history
  • Loading branch information
terrymanu authored Jun 1, 2024
1 parent f7e60f3 commit ae8eb76
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@
import org.apache.shardingsphere.infra.metadata.user.Grantee;
import org.apache.shardingsphere.infra.rule.attribute.raw.RawExecutionRuleAttribute;
import org.apache.shardingsphere.infra.session.query.QueryContext;
import org.apache.shardingsphere.mode.metadata.MetaDataContexts;
import org.apache.shardingsphere.sql.parser.sql.common.statement.SQLStatement;
import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.DMLStatement;
import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.SelectStatement;
Expand Down Expand Up @@ -95,6 +94,8 @@ public final class DriverExecutor implements AutoCloseable {

private final ShardingSphereConnection connection;

private final ShardingSphereMetaData metaData;

private final DriverJDBCExecutor regularExecutor;

private final RawExecutor rawExecutor;
Expand All @@ -115,21 +116,20 @@ public final class DriverExecutor implements AutoCloseable {

public DriverExecutor(final ShardingSphereConnection connection) {
this.connection = connection;
MetaDataContexts metaDataContexts = connection.getContextManager().getMetaDataContexts();
metaData = connection.getContextManager().getMetaDataContexts().getMetaData();
ExecutorEngine executorEngine = connection.getContextManager().getExecutorEngine();
JDBCExecutor jdbcExecutor = new JDBCExecutor(executorEngine, connection.getDatabaseConnectionManager().getConnectionContext());
regularExecutor = new DriverJDBCExecutor(connection.getDatabaseName(), connection.getContextManager(), jdbcExecutor);
rawExecutor = new RawExecutor(executorEngine, connection.getDatabaseConnectionManager().getConnectionContext());
String schemaName = new DatabaseTypeRegistry(metaDataContexts.getMetaData().getDatabase(connection.getDatabaseName()).getProtocolType()).getDefaultSchemaName(connection.getDatabaseName());
String schemaName = new DatabaseTypeRegistry(metaData.getDatabase(connection.getDatabaseName()).getProtocolType()).getDefaultSchemaName(connection.getDatabaseName());
trafficExecutor = new TrafficExecutor();
sqlFederationEngine = new SQLFederationEngine(connection.getDatabaseName(), schemaName, metaDataContexts.getMetaData(), metaDataContexts.getStatistics(), jdbcExecutor);
sqlFederationEngine = new SQLFederationEngine(connection.getDatabaseName(), schemaName, metaData, connection.getContextManager().getMetaDataContexts().getStatistics(), jdbcExecutor);
kernelProcessor = new KernelProcessor();
}

/**
* Execute query.
*
* @param metaData meta data
* @param database database
* @param queryContext query context
* @param prepareEngine prepare engine
Expand All @@ -140,7 +140,7 @@ public DriverExecutor(final ShardingSphereConnection connection) {
* @throws SQLException SQL exception
*/
@SuppressWarnings("rawtypes")
public ResultSet executeQuery(final ShardingSphereMetaData metaData, final ShardingSphereDatabase database, final QueryContext queryContext,
public ResultSet executeQuery(final ShardingSphereDatabase database, final QueryContext queryContext,
final DriverExecutionPrepareEngine<JDBCExecutionUnit, Connection> prepareEngine, final Statement statement,
final Map<String, Integer> columnLabelAndIndexMap, final StatementReplayCallback statementReplayCallback) throws SQLException {
Optional<String> trafficInstanceId = connection.getTrafficInstanceId(metaData.getGlobalRuleMetaData().getSingleRule(TrafficRule.class), queryContext);
Expand All @@ -152,8 +152,8 @@ public ResultSet executeQuery(final ShardingSphereMetaData metaData, final Shard
return sqlFederationEngine.executeQuery(
prepareEngine, getExecuteQueryCallback(database, queryContext, prepareEngine.getType()), new SQLFederationContext(false, queryContext, metaData, connection.getProcessId()));
}
List<QueryResult> queryResults = executePushDownQuery(metaData, database, queryContext, prepareEngine, statementReplayCallback);
MergedResult mergedResult = mergeQuery(metaData, database, queryResults, queryContext.getSqlStatementContext());
List<QueryResult> queryResults = executePushDownQuery(database, queryContext, prepareEngine, statementReplayCallback);
MergedResult mergedResult = mergeQuery(database, queryResults, queryContext.getSqlStatementContext());
boolean selectContainsEnhancedTable = queryContext.getSqlStatementContext() instanceof SelectStatementContext
&& ((SelectStatementContext) queryContext.getSqlStatementContext()).isContainsEnhancedTable();
List<ResultSet> resultSets = getResultSets();
Expand All @@ -176,12 +176,12 @@ private ExecuteQueryCallback getExecuteQueryCallback(final ShardingSphereDatabas
}

@SuppressWarnings({"rawtypes", "unchecked"})
private List<QueryResult> executePushDownQuery(final ShardingSphereMetaData metaData, final ShardingSphereDatabase database, final QueryContext queryContext,
private List<QueryResult> executePushDownQuery(final ShardingSphereDatabase database, final QueryContext queryContext,
final DriverExecutionPrepareEngine<JDBCExecutionUnit, Connection> prepareEngine,
final StatementReplayCallback statementReplayCallback) throws SQLException {
ExecutionContext executionContext = createExecutionContext(metaData, database, queryContext);
ExecutionContext executionContext = createExecutionContext(database, queryContext);
if (hasRawExecutionRule(database)) {
return rawExecutor.execute(createRawExecutionGroupContext(metaData, database, executionContext),
return rawExecutor.execute(createRawExecutionGroupContext(database, executionContext),
queryContext, new RawSQLExecutorCallback()).stream().map(QueryResult.class::cast).collect(Collectors.toList());
}
ExecutionGroupContext<JDBCExecutionUnit> executionGroupContext = prepareEngine.prepare(executionContext.getRouteContext(), executionContext.getExecutionUnits(),
Expand Down Expand Up @@ -216,7 +216,7 @@ private Collection<List<Object>> getParameterSets(final ExecutionGroup<JDBCExecu
return result;
}

private ExecutionContext createExecutionContext(final ShardingSphereMetaData metaData, final ShardingSphereDatabase database, final QueryContext queryContext) throws SQLException {
private ExecutionContext createExecutionContext(final ShardingSphereDatabase database, final QueryContext queryContext) throws SQLException {
clearStatements();
RuleMetaData globalRuleMetaData = metaData.getGlobalRuleMetaData();
SQLAuditEngine.audit(queryContext.getSqlStatementContext(), queryContext.getParameters(), globalRuleMetaData, database, null, queryContext.getHintValueContext());
Expand All @@ -230,15 +230,13 @@ private void clearStatements() throws SQLException {
statements.clear();
}

private ExecutionGroupContext<RawSQLExecutionUnit> createRawExecutionGroupContext(final ShardingSphereMetaData metaData,
final ShardingSphereDatabase database, final ExecutionContext executionContext) throws SQLException {
private ExecutionGroupContext<RawSQLExecutionUnit> createRawExecutionGroupContext(final ShardingSphereDatabase database, final ExecutionContext executionContext) throws SQLException {
int maxConnectionsSizePerQuery = metaData.getProps().<Integer>getValue(ConfigurationPropertyKey.MAX_CONNECTIONS_SIZE_PER_QUERY);
return new RawExecutionPrepareEngine(maxConnectionsSizePerQuery, database.getRuleMetaData().getRules()).prepare(
executionContext.getRouteContext(), executionContext.getExecutionUnits(), new ExecutionGroupReportContext(connection.getProcessId(), database.getName(), new Grantee("", "")));
}

private MergedResult mergeQuery(final ShardingSphereMetaData metaData, final ShardingSphereDatabase database,
final List<QueryResult> queryResults, final SQLStatementContext sqlStatementContext) throws SQLException {
private MergedResult mergeQuery(final ShardingSphereDatabase database, final List<QueryResult> queryResults, final SQLStatementContext sqlStatementContext) throws SQLException {
MergeEngine mergeEngine = new MergeEngine(metaData.getGlobalRuleMetaData(), database, metaData.getProps(), connection.getDatabaseConnectionManager().getConnectionContext());
return mergeEngine.merge(queryResults, sqlStatementContext);
}
Expand All @@ -256,7 +254,6 @@ private List<ResultSet> getResultSets() throws SQLException {
/**
* Execute update.
*
* @param metaData meta data
* @param database database
* @param queryContext query context
* @param prepareEngine prepare engine
Expand All @@ -267,20 +264,20 @@ private List<ResultSet> getResultSets() throws SQLException {
* @throws SQLException SQL exception
*/
@SuppressWarnings("rawtypes")
public int executeUpdate(final ShardingSphereMetaData metaData, final ShardingSphereDatabase database, final QueryContext queryContext,
public int executeUpdate(final ShardingSphereDatabase database, final QueryContext queryContext,
final DriverExecutionPrepareEngine<JDBCExecutionUnit, Connection> prepareEngine, final TrafficExecutorCallback<Integer> trafficCallback,
final ExecuteUpdateCallback updateCallback, final StatementReplayCallback statementReplayCallback) throws SQLException {
Optional<String> trafficInstanceId = connection.getTrafficInstanceId(metaData.getGlobalRuleMetaData().getSingleRule(TrafficRule.class), queryContext);
if (trafficInstanceId.isPresent()) {
return trafficExecutor.execute(connection.getProcessId(), database.getName(), trafficInstanceId.get(), queryContext, prepareEngine, trafficCallback);
}
ExecutionContext executionContext = createExecutionContext(metaData, database, queryContext);
ExecutionContext executionContext = createExecutionContext(database, queryContext);
return database.getRuleMetaData().getAttributes(RawExecutionRuleAttribute.class).isEmpty()
? executeUpdate(database, updateCallback, queryContext.getSqlStatementContext(), executionContext, prepareEngine,
isNeedImplicitCommitTransaction(connection,
queryContext.getSqlStatementContext().getSqlStatement(), executionContext.getExecutionUnits().size() > 1),
statementReplayCallback)
: accumulate(rawExecutor.execute(createRawExecutionGroupContext(metaData, database, executionContext), queryContext, new RawSQLExecutorCallback()));
: accumulate(rawExecutor.execute(createRawExecutionGroupContext(database, executionContext), queryContext, new RawSQLExecutorCallback()));
}

@SuppressWarnings("rawtypes")
Expand Down Expand Up @@ -360,7 +357,6 @@ private int accumulate(final Collection<ExecuteResult> results) {
/**
* Execute advance.
*
* @param metaData meta data
* @param database database
* @param queryContext query context
* @param prepareEngine prepare engine
Expand All @@ -371,7 +367,7 @@ private int accumulate(final Collection<ExecuteResult> results) {
* @throws SQLException SQL exception
*/
@SuppressWarnings("rawtypes")
public boolean executeAdvance(final ShardingSphereMetaData metaData, final ShardingSphereDatabase database, final QueryContext queryContext,
public boolean executeAdvance(final ShardingSphereDatabase database, final QueryContext queryContext,
final DriverExecutionPrepareEngine<JDBCExecutionUnit, Connection> prepareEngine, final TrafficExecutorCallback<Boolean> trafficCallback,
final ExecuteCallback executeCallback, final StatementReplayCallback statementReplayCallback) throws SQLException {
Optional<String> trafficInstanceId = connection.getTrafficInstanceId(metaData.getGlobalRuleMetaData().getSingleRule(TrafficRule.class), queryContext);
Expand All @@ -385,9 +381,9 @@ public boolean executeAdvance(final ShardingSphereMetaData metaData, final Shard
prepareEngine, getExecuteQueryCallback(database, queryContext, prepareEngine.getType()), new SQLFederationContext(false, queryContext, metaData, connection.getProcessId()));
return null != resultSet;
}
ExecutionContext executionContext = createExecutionContext(metaData, database, queryContext);
ExecutionContext executionContext = createExecutionContext(database, queryContext);
if (!database.getRuleMetaData().getAttributes(RawExecutionRuleAttribute.class).isEmpty()) {
Collection<ExecuteResult> results = rawExecutor.execute(createRawExecutionGroupContext(metaData, database, executionContext), queryContext, new RawSQLExecutorCallback());
Collection<ExecuteResult> results = rawExecutor.execute(createRawExecutionGroupContext(database, executionContext), queryContext, new RawSQLExecutorCallback());
return results.iterator().next() instanceof QueryResult;
}
boolean isNeedImplicitCommitTransaction = isNeedImplicitCommitTransaction(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ public ResultSet executeQuery() throws SQLException {
handleAutoCommit(queryContext.getSqlStatementContext().getSqlStatement());
ShardingSphereDatabase database = metaData.getDatabase(databaseName);
findGeneratedKey().ifPresent(optional -> generatedValues.addAll(optional.getGeneratedValues()));
currentResultSet = executor.executeQuery(metaData, database, queryContext, createDriverExecutionPrepareEngine(database), this, columnLabelAndIndexMap,
currentResultSet = executor.executeQuery(database, queryContext, createDriverExecutionPrepareEngine(database), this, columnLabelAndIndexMap,
(StatementReplayCallback<PreparedStatement>) this::replay);
if (currentResultSet instanceof ShardingSphereResultSet) {
columnLabelAndIndexMap = ((ShardingSphereResultSet) currentResultSet).getColumnLabelAndIndexMap();
Expand Down Expand Up @@ -260,7 +260,7 @@ public int executeUpdate() throws SQLException {
QueryContext queryContext = createQueryContext();
handleAutoCommit(queryContext.getSqlStatementContext().getSqlStatement());
ShardingSphereDatabase database = metaData.getDatabase(databaseName);
final int result = executor.executeUpdate(metaData, database, queryContext, createDriverExecutionPrepareEngine(database),
final int result = executor.executeUpdate(database, queryContext, createDriverExecutionPrepareEngine(database),
(statement, sql) -> ((PreparedStatement) statement).executeUpdate(), null, (StatementReplayCallback<PreparedStatement>) this::replay);
for (Statement each : executor.getStatements()) {
statements.add((PreparedStatement) each);
Expand Down Expand Up @@ -289,8 +289,7 @@ public boolean execute() throws SQLException {
QueryContext queryContext = createQueryContext();
handleAutoCommit(queryContext.getSqlStatementContext().getSqlStatement());
ShardingSphereDatabase database = metaData.getDatabase(databaseName);
final boolean result = executor.executeAdvance(
metaData, database, queryContext, createDriverExecutionPrepareEngine(database), (statement, sql) -> ((PreparedStatement) statement).execute(),
final boolean result = executor.executeAdvance(database, queryContext, createDriverExecutionPrepareEngine(database), (statement, sql) -> ((PreparedStatement) statement).execute(),
null, (StatementReplayCallback<PreparedStatement>) this::replay);
for (Statement each : executor.getStatements()) {
statements.add((PreparedStatement) each);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ public ResultSet executeQuery(final String sql) throws SQLException {
connection.getDatabaseConnectionManager().getConnectionContext().setCurrentDatabase(databaseName);
ShardingSphereDatabase database = metaData.getDatabase(databaseName);
sqlStatementContext = queryContext.getSqlStatementContext();
currentResultSet = executor.executeQuery(metaData, database, queryContext, createDriverExecutionPrepareEngine(database), this, null,
currentResultSet = executor.executeQuery(database, queryContext, createDriverExecutionPrepareEngine(database), this, null,
(StatementReplayCallback<Statement>) (statements, parameterSets) -> replay(statements));
statements.addAll(executor.getStatements());
return currentResultSet;
Expand Down Expand Up @@ -218,8 +218,7 @@ private int executeUpdate(final String sql, final ExecuteUpdateCallback updateCa
ShardingSphereDatabase database = metaData.getDatabase(databaseName);
sqlStatementContext = queryContext.getSqlStatementContext();
clearStatements();
int result = executor.executeUpdate(
metaData, database, queryContext, createDriverExecutionPrepareEngine(database), trafficCallback, updateCallback,
int result = executor.executeUpdate(database, queryContext, createDriverExecutionPrepareEngine(database), trafficCallback, updateCallback,
(StatementReplayCallback<Statement>) (statements, parameterSets) -> replay(statements));
statements.addAll(executor.getStatements());
replay(statements);
Expand Down Expand Up @@ -288,7 +287,7 @@ private boolean execute0(final String sql, final ExecuteCallback executeCallback
ShardingSphereDatabase database = metaData.getDatabase(databaseName);
sqlStatementContext = queryContext.getSqlStatementContext();
clearStatements();
boolean result = executor.executeAdvance(metaData, database, queryContext, createDriverExecutionPrepareEngine(database), trafficCallback,
boolean result = executor.executeAdvance(database, queryContext, createDriverExecutionPrepareEngine(database), trafficCallback,
executeCallback, (StatementReplayCallback<Statement>) (statements, parameterSets) -> replay(statements));
statements.addAll(executor.getStatements());
return result;
Expand Down

0 comments on commit ae8eb76

Please sign in to comment.