Skip to content

Commit

Permalink
Remove TablesContext#findTableNames method and implement select order…
Browse files Browse the repository at this point in the history
… by, group by bind logic (#34123)

* Remove TablesContext#findTableNames method and use sql bind info to replace

* fix unit test

* optimize select statement binder

* update release note
  • Loading branch information
strongduanmu authored Dec 24, 2024
1 parent d187256 commit cf0b19a
Show file tree
Hide file tree
Showing 29 changed files with 498 additions and 298 deletions.
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
1. SQL Binder: Add sql bind logic for create table statement - [#34074](https://github.com/apache/shardingsphere/pull/34074)
1. SQL Binder: Support create index statement sql bind - [#34112](https://github.com/apache/shardingsphere/pull/34112)
1. SQL Parser: Support MySQL update with statement parse - [#34126](https://github.com/apache/shardingsphere/pull/34126)
1. SQL Binder: Remove TablesContext#findTableNames method and implement select order by, group by bind logic - [#34123](https://github.com/apache/shardingsphere/pull/34123)

### Bug Fixes

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,7 @@
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.order.item.ColumnOrderByItemSegment;

import java.util.Collection;
import java.util.Collections;
import java.util.LinkedList;
import java.util.Map;
import java.util.Optional;

/**
Expand Down Expand Up @@ -64,7 +62,7 @@ private boolean containsOrderByItem(final SelectStatementContext sqlStatementCon
public void check(final EncryptRule rule, final ShardingSphereDatabase database, final ShardingSphereSchema currentSchema, final SelectStatementContext sqlStatementContext) {
for (OrderByItem each : getOrderByItems(sqlStatementContext)) {
if (each.getSegment() instanceof ColumnOrderByItemSegment) {
checkColumnOrderByItem(rule, currentSchema, sqlStatementContext, ((ColumnOrderByItemSegment) each.getSegment()).getColumn());
checkColumnOrderByItem(rule, ((ColumnOrderByItemSegment) each.getSegment()).getColumn());
}
}
}
Expand All @@ -80,10 +78,8 @@ private Collection<OrderByItem> getOrderByItems(final SelectStatementContext sql
return result;
}

private void checkColumnOrderByItem(final EncryptRule rule, final ShardingSphereSchema schema, final SelectStatementContext sqlStatementContext, final ColumnSegment columnSegment) {
Map<String, String> columnTableNames = sqlStatementContext.getTablesContext().findTableNames(Collections.singleton(columnSegment), schema);
String tableName = columnTableNames.getOrDefault(columnSegment.getExpression(), "");
Optional<EncryptTable> encryptTable = rule.findEncryptTable(tableName);
private void checkColumnOrderByItem(final EncryptRule rule, final ColumnSegment columnSegment) {
Optional<EncryptTable> encryptTable = rule.findEncryptTable(columnSegment.getColumnBoundInfo().getOriginalTable().getValue());
String columnName = columnSegment.getIdentifier().getValue();
ShardingSpherePreconditions.checkState(!encryptTable.isPresent() || !encryptTable.get().isEncryptColumn(columnName), () -> new UnsupportedEncryptSQLException("ORDER BY"));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import org.apache.shardingsphere.infra.binder.context.extractor.SQLStatementContextExtractor;
import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.binder.context.statement.dml.SelectStatementContext;
import org.apache.shardingsphere.infra.binder.context.type.TableAvailable;
import org.apache.shardingsphere.infra.binder.context.type.WhereAvailable;
import org.apache.shardingsphere.infra.checker.SupportedSQLChecker;
import org.apache.shardingsphere.infra.exception.core.ShardingSpherePreconditions;
Expand All @@ -40,7 +39,6 @@
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.predicate.WhereSegment;

import java.util.Collection;
import java.util.Map;
import java.util.Optional;

/**
Expand All @@ -60,13 +58,12 @@ public void check(final EncryptRule rule, final ShardingSphereDatabase database,
Collection<BinaryOperationExpression> joinConditions = SQLStatementContextExtractor.getJoinConditions((WhereAvailable) sqlStatementContext, allSubqueryContexts);
ShardingSpherePreconditions.checkState(JoinConditionsEncryptorComparator.isSame(joinConditions, rule),
() -> new UnsupportedSQLOperationException("Can not use different encryptor in join condition"));
check(rule, currentSchema, (WhereAvailable) sqlStatementContext);
check(rule, (WhereAvailable) sqlStatementContext);
}

private void check(final EncryptRule rule, final ShardingSphereSchema schema, final WhereAvailable sqlStatementContext) {
Map<String, String> columnExpressionTableNames = ((TableAvailable) sqlStatementContext).getTablesContext().findTableNames(sqlStatementContext.getColumnSegments(), schema);
private void check(final EncryptRule rule, final WhereAvailable sqlStatementContext) {
for (ColumnSegment each : sqlStatementContext.getColumnSegments()) {
Optional<EncryptTable> encryptTable = rule.findEncryptTable(columnExpressionTableNames.getOrDefault(each.getExpression(), ""));
Optional<EncryptTable> encryptTable = rule.findEncryptTable(each.getColumnBoundInfo().getOriginalTable().getValue());
String columnName = each.getIdentifier().getValue();
if (encryptTable.isPresent() && encryptTable.get().isEncryptColumn(columnName) && includesLike(sqlStatementContext.getWhereSegments(), each)) {
String tableName = encryptTable.get().getTable();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,8 @@
import org.apache.shardingsphere.encrypt.rule.EncryptRule;
import org.apache.shardingsphere.encrypt.rule.table.EncryptTable;
import org.apache.shardingsphere.infra.annotation.HighFrequencyInvocation;
import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.binder.context.type.TableAvailable;
import org.apache.shardingsphere.infra.database.core.type.DatabaseTypeRegistry;
import org.apache.shardingsphere.infra.exception.core.ShardingSpherePreconditions;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema;
import org.apache.shardingsphere.sql.parser.statement.core.extractor.ColumnExtractor;
import org.apache.shardingsphere.sql.parser.statement.core.extractor.ExpressionExtractor;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.column.ColumnSegment;
Expand All @@ -43,13 +39,11 @@
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.subquery.SubqueryExpressionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.predicate.AndPredicate;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.predicate.WhereSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.bound.ColumnSegmentBoundInfo;

import java.util.Collection;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.TreeSet;
Expand Down Expand Up @@ -89,44 +83,36 @@ public final class EncryptConditionEngine {
* Create encrypt conditions.
*
* @param whereSegments where segments
* @param columnSegments column segments
* @param sqlStatementContext SQL statement context
* @param databaseName database name
* @return encrypt conditions
*/
public Collection<EncryptCondition> createEncryptConditions(final Collection<WhereSegment> whereSegments, final Collection<ColumnSegment> columnSegments,
final SQLStatementContext sqlStatementContext, final String databaseName) {
public Collection<EncryptCondition> createEncryptConditions(final Collection<WhereSegment> whereSegments) {
Collection<EncryptCondition> result = new LinkedList<>();
String defaultSchema = new DatabaseTypeRegistry(sqlStatementContext.getDatabaseType()).getDefaultSchemaName(databaseName);
ShardingSphereSchema schema = ((TableAvailable) sqlStatementContext).getTablesContext().getSchemaName().map(database::getSchema).orElseGet(() -> database.getSchema(defaultSchema));
Map<String, String> expressionTableNames = ((TableAvailable) sqlStatementContext).getTablesContext().findTableNames(columnSegments, schema);
for (WhereSegment each : whereSegments) {
Collection<AndPredicate> andPredicates = ExpressionExtractor.extractAndPredicates(each.getExpr());
for (AndPredicate predicate : andPredicates) {
addEncryptConditions(result, predicate.getPredicates(), expressionTableNames);
addEncryptConditions(result, predicate.getPredicates());
}
}
return result;
}

private void addEncryptConditions(final Collection<EncryptCondition> encryptConditions, final Collection<ExpressionSegment> predicates, final Map<String, String> expressionTableNames) {
private void addEncryptConditions(final Collection<EncryptCondition> encryptConditions, final Collection<ExpressionSegment> predicates) {
Collection<Integer> stopIndexes = new HashSet<>(predicates.size(), 1F);
for (ExpressionSegment each : predicates) {
if (stopIndexes.add(each.getStopIndex())) {
addEncryptConditions(encryptConditions, each, expressionTableNames);
addEncryptConditions(encryptConditions, each);
}
}
}

private void addEncryptConditions(final Collection<EncryptCondition> encryptConditions, final ExpressionSegment expression, final Map<String, String> expressionTableNames) {
private void addEncryptConditions(final Collection<EncryptCondition> encryptConditions, final ExpressionSegment expression) {
if (!findNotContainsNullLiteralsExpression(expression).isPresent()) {
return;
}
for (ColumnSegment each : ColumnExtractor.extract(expression)) {
ColumnSegmentBoundInfo columnBoundInfo = each.getColumnBoundInfo();
String tableName = columnBoundInfo.getOriginalTable().getValue().isEmpty() ? expressionTableNames.getOrDefault(each.getExpression(), "") : columnBoundInfo.getOriginalTable().getValue();
String tableName = each.getColumnBoundInfo().getOriginalTable().getValue();
Optional<EncryptTable> encryptTable = rule.findEncryptTable(tableName);
if (encryptTable.isPresent() && encryptTable.get().isEncryptColumn(each.getIdentifier().getValue())) {
if (encryptTable.isPresent() && encryptTable.get().isEncryptColumn(each.getColumnBoundInfo().getOriginalColumn().getValue())) {
createEncryptCondition(expression, tableName).ifPresent(encryptConditions::add);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
import org.apache.shardingsphere.infra.rewrite.parameter.rewriter.ParameterRewritersBuilder;
import org.apache.shardingsphere.infra.rewrite.sql.token.common.generator.builder.SQLTokenGeneratorBuilder;
import org.apache.shardingsphere.infra.route.context.RouteContext;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.column.ColumnSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.predicate.WhereSegment;

import java.util.Collection;
Expand Down Expand Up @@ -93,9 +92,7 @@ private Collection<EncryptCondition> createEncryptConditions(final EncryptRule r
}
Collection<SelectStatementContext> allSubqueryContexts = SQLStatementContextExtractor.getAllSubqueryContexts(sqlStatementContext);
Collection<WhereSegment> whereSegments = SQLStatementContextExtractor.getWhereSegments((WhereAvailable) sqlStatementContext, allSubqueryContexts);
Collection<ColumnSegment> columnSegments = SQLStatementContextExtractor.getColumnSegments((WhereAvailable) sqlStatementContext, allSubqueryContexts);
return new EncryptConditionEngine(rule, sqlRewriteContext.getDatabase()).createEncryptConditions(whereSegments, columnSegments, sqlStatementContext,
sqlRewriteContext.getDatabase().getName());
return new EncryptConditionEngine(rule, sqlRewriteContext.getDatabase()).createEncryptConditions(whereSegments);
}

private void rewriteParameters(final SQLRewriteContext sqlRewriteContext, final Collection<ParameterRewriter> parameterRewriters) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,7 @@
import org.apache.shardingsphere.infra.binder.context.statement.dml.SelectStatementContext;
import org.apache.shardingsphere.infra.database.core.metadata.database.enums.QuoteCharacter;
import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema;
import org.apache.shardingsphere.infra.rewrite.sql.token.common.generator.CollectionSQLTokenGenerator;
import org.apache.shardingsphere.infra.rewrite.sql.token.common.generator.aware.SchemaMetaDataAware;
import org.apache.shardingsphere.infra.rewrite.sql.token.common.pojo.SQLToken;
import org.apache.shardingsphere.infra.rewrite.sql.token.common.pojo.generic.SubstitutableColumnNameToken;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.column.ColumnSegment;
Expand All @@ -42,7 +40,6 @@
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedList;
import java.util.Map;
import java.util.Optional;

/**
Expand All @@ -51,14 +48,10 @@
@HighFrequencyInvocation
@RequiredArgsConstructor
@Setter
public final class EncryptGroupByItemTokenGenerator implements CollectionSQLTokenGenerator<SelectStatementContext>, SchemaMetaDataAware {
public final class EncryptGroupByItemTokenGenerator implements CollectionSQLTokenGenerator<SelectStatementContext> {

private final EncryptRule rule;

private Map<String, ShardingSphereSchema> schemas;

private ShardingSphereSchema defaultSchema;

@Override
public boolean isGenerateSQLToken(final SQLStatementContext sqlStatementContext) {
return sqlStatementContext instanceof SelectStatementContext && containsGroupByItem((SelectStatementContext) sqlStatementContext);
Expand All @@ -79,21 +72,18 @@ private boolean containsGroupByItem(final SelectStatementContext sqlStatementCon
@Override
public Collection<SQLToken> generateSQLTokens(final SelectStatementContext sqlStatementContext) {
Collection<SQLToken> result = new LinkedList<>();
ShardingSphereSchema schema = sqlStatementContext.getTablesContext().getSchemaName().map(schemas::get).orElseGet(() -> defaultSchema);
for (OrderByItem each : getGroupByItems(sqlStatementContext)) {
if (each.getSegment() instanceof ColumnOrderByItemSegment) {
ColumnSegment columnSegment = ((ColumnOrderByItemSegment) each.getSegment()).getColumn();
Map<String, String> columnTableNames = sqlStatementContext.getTablesContext().findTableNames(Collections.singleton(columnSegment), schema);
generateSQLToken(columnSegment, columnTableNames, sqlStatementContext.getDatabaseType()).ifPresent(result::add);
generateSQLToken(columnSegment, sqlStatementContext.getDatabaseType()).ifPresent(result::add);
}
}
return result;
}

private Optional<SubstitutableColumnNameToken> generateSQLToken(final ColumnSegment columnSegment, final Map<String, String> columnTableNames, final DatabaseType databaseType) {
String tableName = columnTableNames.getOrDefault(columnSegment.getExpression(), "");
Optional<EncryptTable> encryptTable = rule.findEncryptTable(tableName);
String columnName = columnSegment.getIdentifier().getValue();
private Optional<SubstitutableColumnNameToken> generateSQLToken(final ColumnSegment columnSegment, final DatabaseType databaseType) {
Optional<EncryptTable> encryptTable = rule.findEncryptTable(columnSegment.getColumnBoundInfo().getOriginalTable().getValue());
String columnName = columnSegment.getColumnBoundInfo().getOriginalColumn().getValue();
if (!encryptTable.isPresent() || !encryptTable.get().isEncryptColumn(columnName)) {
return Optional.empty();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.order.item.ColumnOrderByItemSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.AliasSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.OwnerSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.bound.ColumnSegmentBoundInfo;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.bound.TableSegmentBoundInfo;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.table.SimpleTableSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.table.TableNameSegment;
import org.apache.shardingsphere.sql.parser.statement.core.value.identifier.IdentifierValue;
Expand Down Expand Up @@ -116,6 +118,8 @@ private SelectStatementContext mockSelectStatementContext(final String tableName
simpleTableSegment.setAlias(new AliasSegment(0, 0, new IdentifierValue("a")));
ColumnSegment columnSegment = new ColumnSegment(0, 0, new IdentifierValue("foo_col"));
columnSegment.setOwner(new OwnerSegment(0, 0, new IdentifierValue("a")));
columnSegment.setColumnBoundInfo(
new ColumnSegmentBoundInfo(new TableSegmentBoundInfo(new IdentifierValue("foo_db"), new IdentifierValue("foo_db")), new IdentifierValue(tableName), new IdentifierValue("foo_col")));
SelectStatementContext result = mock(SelectStatementContext.class, RETURNS_DEEP_STUBS);
when(result.getDatabaseType()).thenReturn(databaseType);
ColumnOrderByItemSegment columnOrderByItemSegment = new ColumnOrderByItemSegment(columnSegment, OrderDirection.ASC, NullsOrderType.FIRST);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ private SQLStatementContext mockSelectStatementContextWithLike() {
columnSegment.setColumnBoundInfo(new ColumnSegmentBoundInfo(new TableSegmentBoundInfo(new IdentifierValue("foo_db"), new IdentifierValue("foo_schema")), new IdentifierValue("t_user"),
new IdentifierValue("user_name")));
SelectStatementContext result = mock(SelectStatementContext.class, RETURNS_DEEP_STUBS);
when(result.getTablesContext().findTableNames(Collections.singleton(columnSegment), null)).thenReturn(Collections.singletonMap("user_name", "t_user"));
when(result.getColumnSegments()).thenReturn(Collections.singleton(columnSegment));
when(result.getWhereSegments()).thenReturn(Collections.singleton(new WhereSegment(0, 0, new BinaryOperationExpression(0, 0, columnSegment, columnSegment, "LIKE", ""))));
when(result.getSubqueryContexts()).thenReturn(Collections.emptyMap());
Expand All @@ -106,7 +105,6 @@ private SQLStatementContext mockSelectStatementContextWithEqual() {
columnSegment.setColumnBoundInfo(new ColumnSegmentBoundInfo(new TableSegmentBoundInfo(new IdentifierValue("foo_db"), new IdentifierValue("foo_schema")), new IdentifierValue("t_user"),
new IdentifierValue("user_name")));
SelectStatementContext result = mock(SelectStatementContext.class, RETURNS_DEEP_STUBS);
when(result.getTablesContext().findTableNames(Collections.singleton(columnSegment), null)).thenReturn(Collections.singletonMap("user_name", "t_user"));
when(result.getColumnSegments()).thenReturn(Collections.singleton(columnSegment));
when(result.getWhereSegments()).thenReturn(Collections.singleton(new WhereSegment(0, 0, new BinaryOperationExpression(0, 0, columnSegment, columnSegment, "=", ""))));
when(result.getSubqueryContexts()).thenReturn(Collections.emptyMap());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,6 @@ void assertGenerateSQLTokenFromGenerateNewSQLToken() {
private Collection<EncryptCondition> getEncryptConditions(final UpdateStatementContext updateStatementContext) {
ShardingSphereDatabase database = new ShardingSphereDatabase("foo_db", mock(), mock(), mock(), Collections.singleton(new ShardingSphereSchema("foo_db")));
return new EncryptConditionEngine(EncryptGeneratorFixtureBuilder.createEncryptRule(), database)
.createEncryptConditions(updateStatementContext.getWhereSegments(), updateStatementContext.getColumnSegments(), updateStatementContext, "foo_db");
.createEncryptConditions(updateStatementContext.getWhereSegments());
}
}
Loading

0 comments on commit cf0b19a

Please sign in to comment.