Skip to content

Commit

Permalink
Remove TablesContext#findTableNames method and use sql bind info to r…
Browse files Browse the repository at this point in the history
…eplace
  • Loading branch information
strongduanmu committed Dec 23, 2024
1 parent 1f3af26 commit 711f357
Show file tree
Hide file tree
Showing 27 changed files with 447 additions and 296 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,7 @@
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.order.item.ColumnOrderByItemSegment;

import java.util.Collection;
import java.util.Collections;
import java.util.LinkedList;
import java.util.Map;
import java.util.Optional;

/**
Expand Down Expand Up @@ -64,7 +62,7 @@ private boolean containsOrderByItem(final SelectStatementContext sqlStatementCon
public void check(final EncryptRule rule, final ShardingSphereDatabase database, final ShardingSphereSchema currentSchema, final SelectStatementContext sqlStatementContext) {
for (OrderByItem each : getOrderByItems(sqlStatementContext)) {
if (each.getSegment() instanceof ColumnOrderByItemSegment) {
checkColumnOrderByItem(rule, currentSchema, sqlStatementContext, ((ColumnOrderByItemSegment) each.getSegment()).getColumn());
checkColumnOrderByItem(rule, ((ColumnOrderByItemSegment) each.getSegment()).getColumn());
}
}
}
Expand All @@ -80,10 +78,8 @@ private Collection<OrderByItem> getOrderByItems(final SelectStatementContext sql
return result;
}

private void checkColumnOrderByItem(final EncryptRule rule, final ShardingSphereSchema schema, final SelectStatementContext sqlStatementContext, final ColumnSegment columnSegment) {
Map<String, String> columnTableNames = sqlStatementContext.getTablesContext().findTableNames(Collections.singleton(columnSegment), schema);
String tableName = columnTableNames.getOrDefault(columnSegment.getExpression(), "");
Optional<EncryptTable> encryptTable = rule.findEncryptTable(tableName);
private void checkColumnOrderByItem(final EncryptRule rule, final ColumnSegment columnSegment) {
Optional<EncryptTable> encryptTable = rule.findEncryptTable(columnSegment.getColumnBoundInfo().getOriginalTable().getValue());
String columnName = columnSegment.getIdentifier().getValue();
ShardingSpherePreconditions.checkState(!encryptTable.isPresent() || !encryptTable.get().isEncryptColumn(columnName), () -> new UnsupportedEncryptSQLException("ORDER BY"));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import org.apache.shardingsphere.infra.binder.context.extractor.SQLStatementContextExtractor;
import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.binder.context.statement.dml.SelectStatementContext;
import org.apache.shardingsphere.infra.binder.context.type.TableAvailable;
import org.apache.shardingsphere.infra.binder.context.type.WhereAvailable;
import org.apache.shardingsphere.infra.checker.SupportedSQLChecker;
import org.apache.shardingsphere.infra.exception.core.ShardingSpherePreconditions;
Expand All @@ -40,7 +39,6 @@
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.predicate.WhereSegment;

import java.util.Collection;
import java.util.Map;
import java.util.Optional;

/**
Expand All @@ -60,13 +58,12 @@ public void check(final EncryptRule rule, final ShardingSphereDatabase database,
Collection<BinaryOperationExpression> joinConditions = SQLStatementContextExtractor.getJoinConditions((WhereAvailable) sqlStatementContext, allSubqueryContexts);
ShardingSpherePreconditions.checkState(JoinConditionsEncryptorComparator.isSame(joinConditions, rule),
() -> new UnsupportedSQLOperationException("Can not use different encryptor in join condition"));
check(rule, currentSchema, (WhereAvailable) sqlStatementContext);
check(rule, (WhereAvailable) sqlStatementContext);
}

private void check(final EncryptRule rule, final ShardingSphereSchema schema, final WhereAvailable sqlStatementContext) {
Map<String, String> columnExpressionTableNames = ((TableAvailable) sqlStatementContext).getTablesContext().findTableNames(sqlStatementContext.getColumnSegments(), schema);
private void check(final EncryptRule rule, final WhereAvailable sqlStatementContext) {
for (ColumnSegment each : sqlStatementContext.getColumnSegments()) {
Optional<EncryptTable> encryptTable = rule.findEncryptTable(columnExpressionTableNames.getOrDefault(each.getExpression(), ""));
Optional<EncryptTable> encryptTable = rule.findEncryptTable(each.getColumnBoundInfo().getOriginalTable().getValue());
String columnName = each.getIdentifier().getValue();
if (encryptTable.isPresent() && encryptTable.get().isEncryptColumn(columnName) && includesLike(sqlStatementContext.getWhereSegments(), each)) {
String tableName = encryptTable.get().getTable();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,8 @@
import org.apache.shardingsphere.encrypt.rule.EncryptRule;
import org.apache.shardingsphere.encrypt.rule.table.EncryptTable;
import org.apache.shardingsphere.infra.annotation.HighFrequencyInvocation;
import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.binder.context.type.TableAvailable;
import org.apache.shardingsphere.infra.database.core.type.DatabaseTypeRegistry;
import org.apache.shardingsphere.infra.exception.core.ShardingSpherePreconditions;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema;
import org.apache.shardingsphere.sql.parser.statement.core.extractor.ColumnExtractor;
import org.apache.shardingsphere.sql.parser.statement.core.extractor.ExpressionExtractor;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.column.ColumnSegment;
Expand All @@ -43,13 +39,11 @@
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.subquery.SubqueryExpressionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.predicate.AndPredicate;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.predicate.WhereSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.bound.ColumnSegmentBoundInfo;

import java.util.Collection;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.TreeSet;
Expand Down Expand Up @@ -89,44 +83,36 @@ public final class EncryptConditionEngine {
* Create encrypt conditions.
*
* @param whereSegments where segments
* @param columnSegments column segments
* @param sqlStatementContext SQL statement context
* @param databaseName database name
* @return encrypt conditions
*/
public Collection<EncryptCondition> createEncryptConditions(final Collection<WhereSegment> whereSegments, final Collection<ColumnSegment> columnSegments,
final SQLStatementContext sqlStatementContext, final String databaseName) {
public Collection<EncryptCondition> createEncryptConditions(final Collection<WhereSegment> whereSegments) {
Collection<EncryptCondition> result = new LinkedList<>();
String defaultSchema = new DatabaseTypeRegistry(sqlStatementContext.getDatabaseType()).getDefaultSchemaName(databaseName);
ShardingSphereSchema schema = ((TableAvailable) sqlStatementContext).getTablesContext().getSchemaName().map(database::getSchema).orElseGet(() -> database.getSchema(defaultSchema));
Map<String, String> expressionTableNames = ((TableAvailable) sqlStatementContext).getTablesContext().findTableNames(columnSegments, schema);
for (WhereSegment each : whereSegments) {
Collection<AndPredicate> andPredicates = ExpressionExtractor.extractAndPredicates(each.getExpr());
for (AndPredicate predicate : andPredicates) {
addEncryptConditions(result, predicate.getPredicates(), expressionTableNames);
addEncryptConditions(result, predicate.getPredicates());
}
}
return result;
}

private void addEncryptConditions(final Collection<EncryptCondition> encryptConditions, final Collection<ExpressionSegment> predicates, final Map<String, String> expressionTableNames) {
private void addEncryptConditions(final Collection<EncryptCondition> encryptConditions, final Collection<ExpressionSegment> predicates) {
Collection<Integer> stopIndexes = new HashSet<>(predicates.size(), 1F);
for (ExpressionSegment each : predicates) {
if (stopIndexes.add(each.getStopIndex())) {
addEncryptConditions(encryptConditions, each, expressionTableNames);
addEncryptConditions(encryptConditions, each);
}
}
}

private void addEncryptConditions(final Collection<EncryptCondition> encryptConditions, final ExpressionSegment expression, final Map<String, String> expressionTableNames) {
private void addEncryptConditions(final Collection<EncryptCondition> encryptConditions, final ExpressionSegment expression) {
if (!findNotContainsNullLiteralsExpression(expression).isPresent()) {
return;
}
for (ColumnSegment each : ColumnExtractor.extract(expression)) {
ColumnSegmentBoundInfo columnBoundInfo = each.getColumnBoundInfo();
String tableName = columnBoundInfo.getOriginalTable().getValue().isEmpty() ? expressionTableNames.getOrDefault(each.getExpression(), "") : columnBoundInfo.getOriginalTable().getValue();
String tableName = each.getColumnBoundInfo().getOriginalTable().getValue();
Optional<EncryptTable> encryptTable = rule.findEncryptTable(tableName);
if (encryptTable.isPresent() && encryptTable.get().isEncryptColumn(each.getIdentifier().getValue())) {
if (encryptTable.isPresent() && encryptTable.get().isEncryptColumn(each.getColumnBoundInfo().getOriginalColumn().getValue())) {
createEncryptCondition(expression, tableName).ifPresent(encryptConditions::add);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
import org.apache.shardingsphere.infra.rewrite.parameter.rewriter.ParameterRewritersBuilder;
import org.apache.shardingsphere.infra.rewrite.sql.token.common.generator.builder.SQLTokenGeneratorBuilder;
import org.apache.shardingsphere.infra.route.context.RouteContext;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.column.ColumnSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.predicate.WhereSegment;

import java.util.Collection;
Expand Down Expand Up @@ -93,9 +92,7 @@ private Collection<EncryptCondition> createEncryptConditions(final EncryptRule r
}
Collection<SelectStatementContext> allSubqueryContexts = SQLStatementContextExtractor.getAllSubqueryContexts(sqlStatementContext);
Collection<WhereSegment> whereSegments = SQLStatementContextExtractor.getWhereSegments((WhereAvailable) sqlStatementContext, allSubqueryContexts);
Collection<ColumnSegment> columnSegments = SQLStatementContextExtractor.getColumnSegments((WhereAvailable) sqlStatementContext, allSubqueryContexts);
return new EncryptConditionEngine(rule, sqlRewriteContext.getDatabase()).createEncryptConditions(whereSegments, columnSegments, sqlStatementContext,
sqlRewriteContext.getDatabase().getName());
return new EncryptConditionEngine(rule, sqlRewriteContext.getDatabase()).createEncryptConditions(whereSegments);
}

private void rewriteParameters(final SQLRewriteContext sqlRewriteContext, final Collection<ParameterRewriter> parameterRewriters) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,7 @@
import org.apache.shardingsphere.infra.binder.context.statement.dml.SelectStatementContext;
import org.apache.shardingsphere.infra.database.core.metadata.database.enums.QuoteCharacter;
import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema;
import org.apache.shardingsphere.infra.rewrite.sql.token.common.generator.CollectionSQLTokenGenerator;
import org.apache.shardingsphere.infra.rewrite.sql.token.common.generator.aware.SchemaMetaDataAware;
import org.apache.shardingsphere.infra.rewrite.sql.token.common.pojo.SQLToken;
import org.apache.shardingsphere.infra.rewrite.sql.token.common.pojo.generic.SubstitutableColumnNameToken;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.column.ColumnSegment;
Expand All @@ -42,7 +40,6 @@
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedList;
import java.util.Map;
import java.util.Optional;

/**
Expand All @@ -51,14 +48,10 @@
@HighFrequencyInvocation
@RequiredArgsConstructor
@Setter
public final class EncryptGroupByItemTokenGenerator implements CollectionSQLTokenGenerator<SelectStatementContext>, SchemaMetaDataAware {
public final class EncryptGroupByItemTokenGenerator implements CollectionSQLTokenGenerator<SelectStatementContext> {

private final EncryptRule rule;

private Map<String, ShardingSphereSchema> schemas;

private ShardingSphereSchema defaultSchema;

@Override
public boolean isGenerateSQLToken(final SQLStatementContext sqlStatementContext) {
return sqlStatementContext instanceof SelectStatementContext && containsGroupByItem((SelectStatementContext) sqlStatementContext);
Expand All @@ -79,21 +72,18 @@ private boolean containsGroupByItem(final SelectStatementContext sqlStatementCon
@Override
public Collection<SQLToken> generateSQLTokens(final SelectStatementContext sqlStatementContext) {
Collection<SQLToken> result = new LinkedList<>();
ShardingSphereSchema schema = sqlStatementContext.getTablesContext().getSchemaName().map(schemas::get).orElseGet(() -> defaultSchema);
for (OrderByItem each : getGroupByItems(sqlStatementContext)) {
if (each.getSegment() instanceof ColumnOrderByItemSegment) {
ColumnSegment columnSegment = ((ColumnOrderByItemSegment) each.getSegment()).getColumn();
Map<String, String> columnTableNames = sqlStatementContext.getTablesContext().findTableNames(Collections.singleton(columnSegment), schema);
generateSQLToken(columnSegment, columnTableNames, sqlStatementContext.getDatabaseType()).ifPresent(result::add);
generateSQLToken(columnSegment, sqlStatementContext.getDatabaseType()).ifPresent(result::add);
}
}
return result;
}

private Optional<SubstitutableColumnNameToken> generateSQLToken(final ColumnSegment columnSegment, final Map<String, String> columnTableNames, final DatabaseType databaseType) {
String tableName = columnTableNames.getOrDefault(columnSegment.getExpression(), "");
Optional<EncryptTable> encryptTable = rule.findEncryptTable(tableName);
String columnName = columnSegment.getIdentifier().getValue();
private Optional<SubstitutableColumnNameToken> generateSQLToken(final ColumnSegment columnSegment, final DatabaseType databaseType) {
Optional<EncryptTable> encryptTable = rule.findEncryptTable(columnSegment.getColumnBoundInfo().getOriginalTable().getValue());
String columnName = columnSegment.getColumnBoundInfo().getOriginalColumn().getValue();
if (!encryptTable.isPresent() || !encryptTable.get().isEncryptColumn(columnName)) {
return Optional.empty();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ private SQLStatementContext mockSelectStatementContextWithLike() {
columnSegment.setColumnBoundInfo(new ColumnSegmentBoundInfo(new TableSegmentBoundInfo(new IdentifierValue("foo_db"), new IdentifierValue("foo_schema")), new IdentifierValue("t_user"),
new IdentifierValue("user_name")));
SelectStatementContext result = mock(SelectStatementContext.class, RETURNS_DEEP_STUBS);
when(result.getTablesContext().findTableNames(Collections.singleton(columnSegment), null)).thenReturn(Collections.singletonMap("user_name", "t_user"));
when(result.getColumnSegments()).thenReturn(Collections.singleton(columnSegment));
when(result.getWhereSegments()).thenReturn(Collections.singleton(new WhereSegment(0, 0, new BinaryOperationExpression(0, 0, columnSegment, columnSegment, "LIKE", ""))));
when(result.getSubqueryContexts()).thenReturn(Collections.emptyMap());
Expand All @@ -106,7 +105,6 @@ private SQLStatementContext mockSelectStatementContextWithEqual() {
columnSegment.setColumnBoundInfo(new ColumnSegmentBoundInfo(new TableSegmentBoundInfo(new IdentifierValue("foo_db"), new IdentifierValue("foo_schema")), new IdentifierValue("t_user"),
new IdentifierValue("user_name")));
SelectStatementContext result = mock(SelectStatementContext.class, RETURNS_DEEP_STUBS);
when(result.getTablesContext().findTableNames(Collections.singleton(columnSegment), null)).thenReturn(Collections.singletonMap("user_name", "t_user"));
when(result.getColumnSegments()).thenReturn(Collections.singleton(columnSegment));
when(result.getWhereSegments()).thenReturn(Collections.singleton(new WhereSegment(0, 0, new BinaryOperationExpression(0, 0, columnSegment, columnSegment, "=", ""))));
when(result.getSubqueryContexts()).thenReturn(Collections.emptyMap());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,6 @@ void assertGenerateSQLTokenFromGenerateNewSQLToken() {
private Collection<EncryptCondition> getEncryptConditions(final UpdateStatementContext updateStatementContext) {
ShardingSphereDatabase database = new ShardingSphereDatabase("foo_db", mock(), mock(), mock(), Collections.singleton(new ShardingSphereSchema("foo_db")));
return new EncryptConditionEngine(EncryptGeneratorFixtureBuilder.createEncryptRule(), database)
.createEncryptConditions(updateStatementContext.getWhereSegments(), updateStatementContext.getColumnSegments(), updateStatementContext, "foo_db");
.createEncryptConditions(updateStatementContext.getWhereSegments());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import org.apache.shardingsphere.infra.binder.context.statement.dml.SelectStatementContext;
import org.apache.shardingsphere.infra.database.core.metadata.database.enums.NullsOrderType;
import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema;
import org.apache.shardingsphere.infra.spi.type.typed.TypedSPILoader;
import org.apache.shardingsphere.sql.parser.statement.core.enums.OrderDirection;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.column.ColumnSegment;
Expand Down Expand Up @@ -56,7 +55,6 @@ class EncryptGroupByItemTokenGeneratorTest {
@BeforeEach
void setup() {
generator = new EncryptGroupByItemTokenGenerator(mockEncryptRule());
generator.setSchemas(Collections.singletonMap("test", mock(ShardingSphereSchema.class)));
}

private EncryptRule mockEncryptRule() {
Expand Down
Loading

0 comments on commit 711f357

Please sign in to comment.