Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor ShardingSpherePreparedStatement #31406

Merged
merged 13 commits into from
May 26, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
import org.apache.shardingsphere.infra.connection.kernel.KernelProcessor;
import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
import org.apache.shardingsphere.infra.database.mysql.type.MySQLDatabaseType;
import org.apache.shardingsphere.infra.exception.core.ShardingSpherePreconditions;
import org.apache.shardingsphere.infra.exception.dialect.SQLExceptionTransformEngine;
import org.apache.shardingsphere.infra.exception.kernel.syntax.EmptySQLException;
import org.apache.shardingsphere.infra.executor.audit.SQLAuditEngine;
Expand Down Expand Up @@ -227,14 +226,15 @@ public ResultSet executeQuery() throws SQLException {
}
clearPrevious();
QueryContext queryContext = createQueryContext();
handleAutoCommit(queryContext);
handleAutoCommit(queryContext.getSqlStatementContext().getSqlStatement());
ShardingSphereDatabase database = metaDataContexts.getMetaData().getDatabase(databaseName);
String trafficInstanceId = getInstanceIdAndSet(queryContext).orElse(null);
if (null != trafficInstanceId) {
JDBCExecutionUnit executionUnit = createTrafficExecutionUnit(trafficInstanceId, queryContext);
currentResultSet = executor.getTrafficExecutor().execute(executionUnit, (statement, sql) -> ((PreparedStatement) statement).executeQuery());
currentResultSet = executor.getTrafficExecutor().execute(connection.getProcessId(), databaseName,
trafficInstanceId, queryContext, createDriverExecutionPrepareEngine(database), (statement, sql) -> ((PreparedStatement) statement).executeQuery());
return currentResultSet;
}
if (decide(queryContext, metaDataContexts.getMetaData().getDatabase(databaseName), metaDataContexts.getMetaData().getGlobalRuleMetaData())) {
if (decide(queryContext, database, metaDataContexts.getMetaData().getGlobalRuleMetaData())) {
currentResultSet = executeFederationQuery(queryContext);
return currentResultSet;
}
Expand Down Expand Up @@ -266,19 +266,19 @@ private boolean decide(final QueryContext queryContext, final ShardingSphereData
return executor.getSqlFederationEngine().decide(queryContext.getSqlStatementContext(), queryContext.getParameters(), database, globalRuleMetaData);
}

private void handleAutoCommit(final QueryContext queryContext) throws SQLException {
if (AutoCommitUtils.needOpenTransaction(queryContext.getSqlStatementContext().getSqlStatement())) {
private void handleAutoCommit(final SQLStatement sqlStatement) throws SQLException {
if (AutoCommitUtils.needOpenTransaction(sqlStatement)) {
connection.handleAutoCommit();
}
}

private JDBCExecutionUnit createTrafficExecutionUnit(final String trafficInstanceId, final QueryContext queryContext) throws SQLException {
DriverExecutionPrepareEngine<JDBCExecutionUnit, Connection> prepareEngine = createDriverExecutionPrepareEngine();
ShardingSphereDatabase database = metaDataContexts.getMetaData().getDatabase(databaseName);
DriverExecutionPrepareEngine<JDBCExecutionUnit, Connection> prepareEngine = createDriverExecutionPrepareEngine(database);
ExecutionUnit executionUnit = new ExecutionUnit(trafficInstanceId, new SQLUnit(queryContext.getSql(), queryContext.getParameters()));
ExecutionGroupContext<JDBCExecutionUnit> context =
prepareEngine.prepare(new RouteContext(), Collections.singleton(executionUnit), new ExecutionGroupReportContext(connection.getProcessId(), databaseName, new Grantee("", "")));
ShardingSpherePreconditions.checkState(!context.getInputGroups().isEmpty() && !context.getInputGroups().iterator().next().getInputs().isEmpty(), EmptyTrafficExecutionUnitException::new);
return context.getInputGroups().iterator().next().getInputs().iterator().next();
return context.getInputGroups().stream().flatMap(each -> each.getInputs().stream()).findFirst().orElseThrow(EmptyTrafficExecutionUnitException::new);
}

private Optional<String> getInstanceIdAndSet(final QueryContext queryContext) {
Expand Down Expand Up @@ -319,17 +319,17 @@ private List<QueryResult> executeQuery0(final ExecutionContext executionContext)
}

private ResultSet executeFederationQuery(final QueryContext queryContext) {
PreparedStatementExecuteQueryCallback callback = new PreparedStatementExecuteQueryCallback(metaDataContexts.getMetaData().getDatabase(databaseName).getProtocolType(),
metaDataContexts.getMetaData().getDatabase(databaseName).getResourceMetaData(), sqlStatement, SQLExecutorExceptionHandler.isExceptionThrown());
ShardingSphereDatabase database = metaDataContexts.getMetaData().getDatabase(databaseName);
PreparedStatementExecuteQueryCallback callback = new PreparedStatementExecuteQueryCallback(database.getProtocolType(),
database.getResourceMetaData(), sqlStatement, SQLExecutorExceptionHandler.isExceptionThrown());
SQLFederationContext context = new SQLFederationContext(false, queryContext, metaDataContexts.getMetaData(), connection.getProcessId());
return executor.getSqlFederationEngine().executeQuery(createDriverExecutionPrepareEngine(), callback, context);
return executor.getSqlFederationEngine().executeQuery(createDriverExecutionPrepareEngine(database), callback, context);
}

private DriverExecutionPrepareEngine<JDBCExecutionUnit, Connection> createDriverExecutionPrepareEngine() {
private DriverExecutionPrepareEngine<JDBCExecutionUnit, Connection> createDriverExecutionPrepareEngine(final ShardingSphereDatabase database) {
int maxConnectionsSizePerQuery = metaDataContexts.getMetaData().getProps().<Integer>getValue(ConfigurationPropertyKey.MAX_CONNECTIONS_SIZE_PER_QUERY);
return new DriverExecutionPrepareEngine<>(JDBCDriverType.PREPARED_STATEMENT, maxConnectionsSizePerQuery, connection.getDatabaseConnectionManager(), statementManager,
statementOption, metaDataContexts.getMetaData().getDatabase(databaseName).getRuleMetaData().getRules(),
metaDataContexts.getMetaData().getDatabase(databaseName).getResourceMetaData().getStorageUnits());
return new DriverExecutionPrepareEngine<>(JDBCDriverType.PREPARED_STATEMENT, maxConnectionsSizePerQuery, connection.getDatabaseConnectionManager(), statementManager, statementOption,
database.getRuleMetaData().getRules(), database.getResourceMetaData().getStorageUnits());
}

@Override
Expand All @@ -341,11 +341,12 @@ public int executeUpdate() throws SQLException {
}
clearPrevious();
QueryContext queryContext = createQueryContext();
handleAutoCommit(queryContext);
handleAutoCommit(queryContext.getSqlStatementContext().getSqlStatement());
String trafficInstanceId = getInstanceIdAndSet(queryContext).orElse(null);
if (null != trafficInstanceId) {
JDBCExecutionUnit executionUnit = createTrafficExecutionUnit(trafficInstanceId, queryContext);
return executor.getTrafficExecutor().execute(executionUnit, (statement, sql) -> ((PreparedStatement) statement).executeUpdate());
ShardingSphereDatabase database = metaDataContexts.getMetaData().getDatabase(databaseName);
return executor.getTrafficExecutor().execute(connection.getProcessId(), databaseName,
trafficInstanceId, queryContext, createDriverExecutionPrepareEngine(database), (statement, sql) -> ((PreparedStatement) statement).executeUpdate());
}
executionContext = createExecutionContext(queryContext);
if (hasRawExecutionRule()) {
Expand Down Expand Up @@ -405,11 +406,12 @@ public boolean execute() throws SQLException {
}
clearPrevious();
QueryContext queryContext = createQueryContext();
handleAutoCommit(queryContext);
handleAutoCommit(queryContext.getSqlStatementContext().getSqlStatement());
String trafficInstanceId = getInstanceIdAndSet(queryContext).orElse(null);
if (null != trafficInstanceId) {
JDBCExecutionUnit executionUnit = createTrafficExecutionUnit(trafficInstanceId, queryContext);
boolean result = executor.getTrafficExecutor().execute(executionUnit, (statement, sql) -> ((PreparedStatement) statement).execute());
ShardingSphereDatabase database = metaDataContexts.getMetaData().getDatabase(databaseName);
boolean result = executor.getTrafficExecutor().execute(connection.getProcessId(), databaseName,
trafficInstanceId, queryContext, createDriverExecutionPrepareEngine(database), (statement, sql) -> ((PreparedStatement) statement).execute());
currentResultSet = executor.getTrafficExecutor().getResultSet();
return result;
}
Expand Down Expand Up @@ -484,7 +486,8 @@ protected Optional<Boolean> getSaneResult(final SQLStatement sqlStatement, final
}

private ExecutionGroupContext<JDBCExecutionUnit> createExecutionGroupContext(final ExecutionContext executionContext) throws SQLException {
DriverExecutionPrepareEngine<JDBCExecutionUnit, Connection> prepareEngine = createDriverExecutionPrepareEngine();
ShardingSphereDatabase database = metaDataContexts.getMetaData().getDatabase(databaseName);
DriverExecutionPrepareEngine<JDBCExecutionUnit, Connection> prepareEngine = createDriverExecutionPrepareEngine(database);
return prepareEngine.prepare(executionContext.getRouteContext(), executionContext.getExecutionUnits(),
new ExecutionGroupReportContext(connection.getProcessId(), databaseName, new Grantee("", "")));
}
Expand Down
Loading