Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor ShardingSphereStatement #31436

Merged
merged 3 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ public ResultSet executeQuery() throws SQLException {
currentResultSet = advancedResultSet.get();
return currentResultSet;
}
executionContext = createExecutionContext(queryContext);
ExecutionContext executionContext = createExecutionContext(queryContext);
currentResultSet = doExecuteQuery(executionContext);
return currentResultSet;
// CHECKSTYLE:OFF
Expand All @@ -245,12 +245,12 @@ public ResultSet executeQuery() throws SQLException {

private ShardingSphereResultSet doExecuteQuery(final ExecutionContext executionContext) throws SQLException {
List<QueryResult> queryResults = executeQuery0(executionContext);
MergedResult mergedResult = mergeQuery(queryResults, executionContext.getSqlStatementContext());
MergedResult mergedResult = mergeQuery(queryResults, sqlStatementContext);
List<ResultSet> resultSets = getResultSets();
if (null == columnLabelAndIndexMap) {
columnLabelAndIndexMap = ShardingSphereResultSetUtils.createColumnLabelAndIndexMap(sqlStatementContext, selectContainsEnhancedTable, resultSets.get(0).getMetaData());
}
return new ShardingSphereResultSet(resultSets, mergedResult, this, selectContainsEnhancedTable, executionContext.getSqlStatementContext(), columnLabelAndIndexMap);
return new ShardingSphereResultSet(resultSets, mergedResult, this, selectContainsEnhancedTable, sqlStatementContext, columnLabelAndIndexMap);
}

private List<QueryResult> executeQuery0(final ExecutionContext executionContext) throws SQLException {
Expand Down Expand Up @@ -303,7 +303,7 @@ public int executeUpdate() throws SQLException {
if (updatedCount.isPresent()) {
return updatedCount.get();
}
executionContext = createExecutionContext(queryContext);
ExecutionContext executionContext = createExecutionContext(queryContext);
if (hasRawExecutionRule()) {
Collection<ExecuteResult> results =
executor.getRawExecutor().execute(createRawExecutionGroupContext(executionContext), executionContext.getQueryContext(), new RawSQLExecutorCallback());
Expand Down Expand Up @@ -368,7 +368,7 @@ public boolean execute() throws SQLException {
if (advancedResult.isPresent()) {
return advancedResult.get();
}
executionContext = createExecutionContext(queryContext);
ExecutionContext executionContext = createExecutionContext(queryContext);
if (hasRawExecutionRule()) {
Collection<ExecuteResult> results =
executor.getRawExecutor().execute(createRawExecutionGroupContext(executionContext), executionContext.getQueryContext(), new RawSQLExecutorCallback());
Expand All @@ -386,13 +386,13 @@ public boolean execute() throws SQLException {
}

private boolean executeWithExecutionContext(final ExecutionContext executionContext) throws SQLException {
return isNeedImplicitCommitTransaction(connection, executionContext.getSqlStatementContext().getSqlStatement(), executionContext.getExecutionUnits().size() > 1)
return isNeedImplicitCommitTransaction(connection, sqlStatementContext.getSqlStatement(), executionContext.getExecutionUnits().size() > 1)
? executeWithImplicitCommitTransaction(() -> useDriverToExecute(executionContext), connection, metaDataContexts.getMetaData().getDatabase(databaseName).getProtocolType())
: useDriverToExecute(executionContext);
}

private int executeUpdateWithExecutionContext(final ExecutionContext executionContext) throws SQLException {
return isNeedImplicitCommitTransaction(connection, executionContext.getSqlStatementContext().getSqlStatement(), executionContext.getExecutionUnits().size() > 1)
return isNeedImplicitCommitTransaction(connection, sqlStatementContext.getSqlStatement(), executionContext.getExecutionUnits().size() > 1)
? executeWithImplicitCommitTransaction(() -> useDriverToExecuteUpdate(executionContext), connection, metaDataContexts.getMetaData().getDatabase(databaseName).getProtocolType())
: useDriverToExecuteUpdate(executionContext);
}
Expand Down Expand Up @@ -437,12 +437,11 @@ public ResultSet getResultSet() throws SQLException {
if (advancedResultSet.isPresent()) {
return advancedResultSet.get();
}
if (executionContext.getSqlStatementContext() instanceof SelectStatementContext || executionContext.getSqlStatementContext().getSqlStatement() instanceof DALStatement) {
if (sqlStatementContext instanceof SelectStatementContext || sqlStatementContext.getSqlStatement() instanceof DALStatement) {
List<ResultSet> resultSets = getResultSets();
if (resultSets.isEmpty()) {
return currentResultSet;
}
SQLStatementContext sqlStatementContext = executionContext.getSqlStatementContext();
MergedResult mergedResult = mergeQuery(getQueryResults(resultSets), sqlStatementContext);
if (null == columnLabelAndIndexMap) {
columnLabelAndIndexMap = ShardingSphereResultSetUtils.createColumnLabelAndIndexMap(sqlStatementContext, selectContainsEnhancedTable, resultSets.get(0).getMetaData());
Expand Down Expand Up @@ -478,7 +477,7 @@ private ExecutionContext createExecutionContext(final QueryContext queryContext)
SQLAuditEngine.audit(queryContext.getSqlStatementContext(), queryContext.getParameters(), globalRuleMetaData, currentDatabase, null, queryContext.getHintValueContext());
ExecutionContext result = kernelProcessor.generateExecutionContext(
queryContext, currentDatabase, globalRuleMetaData, metaDataContexts.getMetaData().getProps(), connection.getDatabaseConnectionManager().getConnectionContext());
findGeneratedKey(result).ifPresent(optional -> generatedValues.addAll(optional.getGeneratedValues()));
findGeneratedKey().ifPresent(optional -> generatedValues.addAll(optional.getGeneratedValues()));
return result;
}

Expand Down Expand Up @@ -537,18 +536,16 @@ private void clearPrevious() {
generatedValues.clear();
}

private Optional<GeneratedKeyContext> findGeneratedKey(final ExecutionContext executionContext) {
return executionContext.getSqlStatementContext() instanceof InsertStatementContext
? ((InsertStatementContext) executionContext.getSqlStatementContext()).getGeneratedKeyContext()
: Optional.empty();
private Optional<GeneratedKeyContext> findGeneratedKey() {
return sqlStatementContext instanceof InsertStatementContext ? ((InsertStatementContext) sqlStatementContext).getGeneratedKeyContext() : Optional.empty();
}

@Override
public ResultSet getGeneratedKeys() throws SQLException {
if (null != currentBatchGeneratedKeysResultSet) {
return currentBatchGeneratedKeysResultSet;
}
Optional<GeneratedKeyContext> generatedKey = findGeneratedKey(executionContext);
Optional<GeneratedKeyContext> generatedKey = findGeneratedKey();
if (generatedKey.isPresent() && statementOption.isReturnGeneratedKeys() && !generatedValues.isEmpty()) {
return new GeneratedKeysResultSet(getGeneratedKeysColumnName(generatedKey.get().getColumnName()), generatedValues.iterator(), this);
}
Expand Down Expand Up @@ -599,7 +596,7 @@ public int[] executeBatch() throws SQLException {

private int[] doExecuteBatch(final BatchPreparedStatementExecutor batchExecutor) throws SQLException {
initBatchPreparedStatementExecutor(batchExecutor);
int[] result = batchExecutor.executeBatch(executionContext.getSqlStatementContext());
int[] result = batchExecutor.executeBatch(sqlStatementContext);
if (statementOption.isReturnGeneratedKeys() && generatedValues.isEmpty()) {
List<Statement> batchPreparedStatementExecutorStatements = batchExecutor.getStatements();
for (Statement statement : batchPreparedStatementExecutorStatements) {
Expand Down Expand Up @@ -662,7 +659,7 @@ public int getResultSetHoldability() {
@Override
public boolean isAccumulate() {
for (DataNodeRuleAttribute each : metaDataContexts.getMetaData().getDatabase(databaseName).getRuleMetaData().getAttributes(DataNodeRuleAttribute.class)) {
if (each.isNeedAccumulate(executionContext.getSqlStatementContext().getTablesContext().getTableNames())) {
if (each.isNeedAccumulate(sqlStatementContext.getTablesContext().getTableNames())) {
return true;
}
}
Expand Down
Loading
Loading