Skip to content

Commit

Permalink
Refactor SQLAuditEngine (#31529)
Browse files Browse the repository at this point in the history
* Refactor ShardingSphereStatement

* Rename DriverExecutor.getResultSet()

* Refactor SQLAuditEngine

* Refactor SQLAuditEngine
  • Loading branch information
terrymanu authored Jun 2, 2024
1 parent ce5b355 commit f7d7ce2
Show file tree
Hide file tree
Showing 9 changed files with 24 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,35 +19,33 @@

import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.executor.audit.SQLAuditor;
import org.apache.shardingsphere.infra.hint.HintValueContext;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData;
import org.apache.shardingsphere.infra.metadata.user.Grantee;
import org.apache.shardingsphere.infra.session.query.QueryContext;
import org.apache.shardingsphere.sharding.api.config.strategy.audit.ShardingAuditStrategyConfiguration;
import org.apache.shardingsphere.sharding.constant.ShardingOrder;
import org.apache.shardingsphere.sharding.rule.ShardingRule;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;

/**
* Sharding SQL auditor.
*/
public final class ShardingSQLAuditor implements SQLAuditor<ShardingRule> {

@Override
public void audit(final SQLStatementContext sqlStatementContext, final List<Object> params, final Grantee grantee, final RuleMetaData globalRuleMetaData,
final ShardingSphereDatabase database, final ShardingRule rule, final HintValueContext hintValueContext) {
Collection<ShardingAuditStrategyConfiguration> auditStrategies = getShardingAuditStrategies(sqlStatementContext, rule);
public void audit(final QueryContext queryContext, final Grantee grantee, final RuleMetaData globalRuleMetaData, final ShardingSphereDatabase database, final ShardingRule rule) {
Collection<ShardingAuditStrategyConfiguration> auditStrategies = getShardingAuditStrategies(queryContext.getSqlStatementContext(), rule);
if (auditStrategies.isEmpty()) {
return;
}
Collection<String> disableAuditNames = hintValueContext.getDisableAuditNames();
Collection<String> disableAuditNames = queryContext.getHintValueContext().getDisableAuditNames();
for (ShardingAuditStrategyConfiguration auditStrategy : auditStrategies) {
for (String auditorName : auditStrategy.getAuditorNames()) {
if (!auditStrategy.isAllowHintDisable() || !disableAuditNames.contains(auditorName.toLowerCase())) {
rule.getAuditors().get(auditorName).check(sqlStatementContext, params, grantee, globalRuleMetaData, database);
rule.getAuditors().get(auditorName).check(queryContext.getSqlStatementContext(), queryContext.getParameters(), grantee, globalRuleMetaData, database);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.shardingsphere.sharding.auditor;

import org.apache.shardingsphere.infra.binder.context.statement.CommonSQLStatementContext;
import org.apache.shardingsphere.infra.session.query.QueryContext;
import org.apache.shardingsphere.sharding.exception.audit.DMLWithoutShardingKeyException;
import org.apache.shardingsphere.infra.hint.HintValueContext;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
Expand Down Expand Up @@ -81,15 +82,15 @@ void setUp() {
@Test
void assertCheckSuccess() {
RuleMetaData globalRuleMetaData = mock(RuleMetaData.class);
new ShardingSQLAuditor().audit(sqlStatementContext, Collections.emptyList(), grantee, globalRuleMetaData, databases.get("foo_db"), rule, hintValueContext);
new ShardingSQLAuditor().audit(new QueryContext(sqlStatementContext, "", Collections.emptyList(), hintValueContext), grantee, globalRuleMetaData, databases.get("foo_db"), rule);
verify(rule.getAuditors().get("auditor_1")).check(sqlStatementContext, Collections.emptyList(), grantee, globalRuleMetaData, databases.get("foo_db"));
}

@Test
void assertCheckSuccessByDisableAuditNames() {
when(auditStrategy.isAllowHintDisable()).thenReturn(true);
RuleMetaData globalRuleMetaData = mock(RuleMetaData.class);
new ShardingSQLAuditor().audit(sqlStatementContext, Collections.emptyList(), grantee, globalRuleMetaData, databases.get("foo_db"), rule, hintValueContext);
new ShardingSQLAuditor().audit(new QueryContext(sqlStatementContext, "", Collections.emptyList(), hintValueContext), grantee, globalRuleMetaData, databases.get("foo_db"), rule);
verify(rule.getAuditors().get("auditor_1"), times(0)).check(sqlStatementContext, Collections.emptyList(), grantee, globalRuleMetaData, databases.get("foo_db"));
}

Expand All @@ -98,8 +99,8 @@ void assertCheckFailed() {
ShardingAuditAlgorithm auditAlgorithm = rule.getAuditors().get("auditor_1");
RuleMetaData globalRuleMetaData = mock(RuleMetaData.class);
doThrow(new DMLWithoutShardingKeyException()).when(auditAlgorithm).check(sqlStatementContext, Collections.emptyList(), grantee, globalRuleMetaData, databases.get("foo_db"));
DMLWithoutShardingKeyException ex = assertThrows(DMLWithoutShardingKeyException.class,
() -> new ShardingSQLAuditor().audit(sqlStatementContext, Collections.emptyList(), grantee, globalRuleMetaData, databases.get("foo_db"), rule, hintValueContext));
DMLWithoutShardingKeyException ex = assertThrows(DMLWithoutShardingKeyException.class, () -> new ShardingSQLAuditor().audit(
new QueryContext(sqlStatementContext, "", Collections.emptyList(), hintValueContext), grantee, globalRuleMetaData, databases.get("foo_db"), rule));
assertThat(ex.getMessage(), is("Not allow DML operation without sharding conditions."));
verify(rule.getAuditors().get("auditor_1")).check(sqlStatementContext, Collections.emptyList(), grantee, globalRuleMetaData, databases.get("foo_db"));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,15 @@

import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.hint.HintValueContext;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData;
import org.apache.shardingsphere.infra.metadata.user.Grantee;
import org.apache.shardingsphere.infra.rule.ShardingSphereRule;
import org.apache.shardingsphere.infra.session.query.QueryContext;
import org.apache.shardingsphere.infra.spi.type.ordered.OrderedSPILoader;

import java.util.Collection;
import java.util.LinkedList;
import java.util.List;
import java.util.Map.Entry;

/**
Expand All @@ -41,22 +39,19 @@ public final class SQLAuditEngine {
/**
* Audit SQL.
*
* @param sqlStatementContext SQL statement context
* @param params SQL parameters
* @param queryContext query context
* @param globalRuleMetaData global rule meta data
* @param database database
* @param grantee grantee
* @param hintValueContext hint value context
*/
@SuppressWarnings({"rawtypes", "unchecked"})
public static void audit(final SQLStatementContext sqlStatementContext, final List<Object> params,
final RuleMetaData globalRuleMetaData, final ShardingSphereDatabase database, final Grantee grantee, final HintValueContext hintValueContext) {
public static void audit(final QueryContext queryContext, final RuleMetaData globalRuleMetaData, final ShardingSphereDatabase database, final Grantee grantee) {
Collection<ShardingSphereRule> rules = new LinkedList<>(globalRuleMetaData.getRules());
if (null != database) {
rules.addAll(database.getRuleMetaData().getRules());
}
for (Entry<ShardingSphereRule, SQLAuditor> entry : OrderedSPILoader.getServices(SQLAuditor.class, rules).entrySet()) {
entry.getValue().audit(sqlStatementContext, params, grantee, globalRuleMetaData, database, entry.getKey(), hintValueContext);
entry.getValue().audit(queryContext, grantee, globalRuleMetaData, database, entry.getKey());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,14 @@

package org.apache.shardingsphere.infra.executor.audit;

import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.hint.HintValueContext;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData;
import org.apache.shardingsphere.infra.metadata.user.Grantee;
import org.apache.shardingsphere.infra.rule.ShardingSphereRule;
import org.apache.shardingsphere.infra.session.query.QueryContext;
import org.apache.shardingsphere.infra.spi.annotation.SingletonSPI;
import org.apache.shardingsphere.infra.spi.type.ordered.OrderedSPI;

import java.util.List;

/**
* SQL auditor.
*
Expand All @@ -39,14 +36,11 @@ public interface SQLAuditor<T extends ShardingSphereRule> extends OrderedSPI<T>
/**
* Audit SQL.
*
* @param sqlStatementContext SQL statement context
* @param params SQL parameters
* @param queryContext query context
* @param grantee grantee
* @param globalRuleMetaData global rule meta data
* @param database current database
* @param rule rule
* @param hintValueContext hint value context
*/
void audit(SQLStatementContext sqlStatementContext, List<Object> params, Grantee grantee, RuleMetaData globalRuleMetaData,
ShardingSphereDatabase database, T rule, HintValueContext hintValueContext);
void audit(QueryContext queryContext, Grantee grantee, RuleMetaData globalRuleMetaData, ShardingSphereDatabase database, T rule);
}
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ private Collection<List<Object>> getParameterSets(final ExecutionGroup<JDBCExecu
private ExecutionContext createExecutionContext(final ShardingSphereDatabase database, final QueryContext queryContext) throws SQLException {
clearStatements();
RuleMetaData globalRuleMetaData = metaData.getGlobalRuleMetaData();
SQLAuditEngine.audit(queryContext.getSqlStatementContext(), queryContext.getParameters(), globalRuleMetaData, database, null, queryContext.getHintValueContext());
SQLAuditEngine.audit(queryContext, globalRuleMetaData, database, null);
return kernelProcessor.generateExecutionContext(queryContext, database, globalRuleMetaData, metaData.getProps(), connection.getDatabaseConnectionManager().getConnectionContext());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,8 @@ public ShardingSpherePreparedStatement(final ShardingSphereConnection connection
this(connection, sql, resultSetType, resultSetConcurrency, resultSetHoldability, false, null);
}

private ShardingSpherePreparedStatement(final ShardingSphereConnection connection, final String sql,
final int resultSetType, final int resultSetConcurrency, final int resultSetHoldability, final boolean returnGeneratedKeys,
final String[] columns) throws SQLException {
private ShardingSpherePreparedStatement(final ShardingSphereConnection connection, final String sql, final int resultSetType,
final int resultSetConcurrency, final int resultSetHoldability, final boolean returnGeneratedKeys, final String[] columns) throws SQLException {
ShardingSpherePreconditions.checkNotEmpty(sql, () -> new EmptySQLException().toSQLException());
this.connection = connection;
metaData = connection.getContextManager().getMetaDataContexts().getMetaData();
Expand Down Expand Up @@ -351,7 +350,7 @@ private List<QueryResult> getQueryResults(final List<ResultSet> resultSets) thro
private ExecutionContext createExecutionContext(final QueryContext queryContext) {
RuleMetaData globalRuleMetaData = metaData.getGlobalRuleMetaData();
ShardingSphereDatabase currentDatabase = metaData.getDatabase(databaseName);
SQLAuditEngine.audit(queryContext.getSqlStatementContext(), queryContext.getParameters(), globalRuleMetaData, currentDatabase, null, queryContext.getHintValueContext());
SQLAuditEngine.audit(queryContext, globalRuleMetaData, currentDatabase, null);
ExecutionContext result = kernelProcessor.generateExecutionContext(
queryContext, currentDatabase, globalRuleMetaData, metaData.getProps(), connection.getDatabaseConnectionManager().getConnectionContext());
findGeneratedKey().ifPresent(optional -> generatedValues.addAll(optional.getGeneratedValues()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,7 @@ public static ProxyBackendHandler newInstance(final DatabaseType databaseType, f
ShardingSphereDatabase database = ProxyContext.getInstance().getContextManager().getMetaDataContexts().getMetaData().getDatabase(databaseName);
ShardingSpherePreconditions.checkState(new AuthorityChecker(authorityRule, connectionSession.getGrantee()).isAuthorized(databaseName),
() -> new UnknownDatabaseException(databaseName));
SQLAuditEngine.audit(sqlStatementContext, queryContext.getParameters(), ProxyContext.getInstance().getContextManager().getMetaDataContexts().getMetaData().getGlobalRuleMetaData(),
database, connectionSession.getGrantee(), queryContext.getHintValueContext());
SQLAuditEngine.audit(queryContext, ProxyContext.getInstance().getContextManager().getMetaDataContexts().getMetaData().getGlobalRuleMetaData(), database, connectionSession.getGrantee());
backendHandler = DatabaseAdminBackendHandlerFactory.newInstance(databaseType, sqlStatementContext, connectionSession);
return backendHandler.orElseGet(() -> DatabaseBackendHandlerFactory.newInstance(queryContext, connectionSession, preferPreparedStatement));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ private Map<String, List<ExecutionUnit>> buildDataSourcesToExecutionUnits(final
private ExecutionContext createExecutionContext(final QueryContext queryContext) {
RuleMetaData globalRuleMetaData = metaDataContexts.getMetaData().getGlobalRuleMetaData();
ShardingSphereDatabase currentDatabase = metaDataContexts.getMetaData().getDatabase(connectionSession.getDatabaseName());
SQLAuditEngine.audit(queryContext.getSqlStatementContext(), queryContext.getParameters(), globalRuleMetaData, currentDatabase, null, queryContext.getHintValueContext());
SQLAuditEngine.audit(queryContext, globalRuleMetaData, currentDatabase, null);
return kernelProcessor.generateExecutionContext(queryContext, currentDatabase, globalRuleMetaData, metaDataContexts.getMetaData().getProps(), connectionSession.getConnectionContext());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ private QueryContext createQueryContext(final SQLStatementContext sqlStatementCo
private ExecutionContext createExecutionContext(final QueryContext queryContext) {
RuleMetaData globalRuleMetaData = metaDataContexts.getMetaData().getGlobalRuleMetaData();
ShardingSphereDatabase currentDatabase = metaDataContexts.getMetaData().getDatabase(connectionSession.getDatabaseName());
SQLAuditEngine.audit(queryContext.getSqlStatementContext(), queryContext.getParameters(), globalRuleMetaData, currentDatabase, null, queryContext.getHintValueContext());
SQLAuditEngine.audit(queryContext, globalRuleMetaData, currentDatabase, null);
return kernelProcessor.generateExecutionContext(queryContext, currentDatabase, globalRuleMetaData, metaDataContexts.getMetaData().getProps(), connectionSession.getConnectionContext());
}

Expand Down

0 comments on commit f7d7ce2

Please sign in to comment.