Skip to content

Commit

Permalink
Refactor get column to get column segment in encrypt condition
Browse files Browse the repository at this point in the history
  • Loading branch information
FlyingZC committed Dec 19, 2024
1 parent 055041f commit e3901e0
Show file tree
Hide file tree
Showing 10 changed files with 49 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.shardingsphere.encrypt.rewrite.condition;

import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.column.ColumnSegment;

import java.util.Map;

/**
Expand All @@ -25,11 +27,11 @@
public interface EncryptCondition {

/**
* Get column name.
* Get column segment.
*
* @return column name
* @return column segment
*/
String getColumnName();
ColumnSegment getColumnSegment();

/**
* Get table name.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
import org.apache.shardingsphere.infra.exception.core.ShardingSpherePreconditions;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema;
import org.apache.shardingsphere.sql.parser.statement.core.extractor.ColumnExtractor;
import org.apache.shardingsphere.sql.parser.statement.core.extractor.ExpressionExtractor;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.column.ColumnSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.BetweenExpression;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.BinaryOperationExpression;
Expand All @@ -41,8 +43,7 @@
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.subquery.SubqueryExpressionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.predicate.AndPredicate;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.predicate.WhereSegment;
import org.apache.shardingsphere.sql.parser.statement.core.extractor.ColumnExtractor;
import org.apache.shardingsphere.sql.parser.statement.core.extractor.ExpressionExtractor;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.bound.ColumnSegmentBoundInfo;

import java.util.Collection;
import java.util.HashSet;
Expand Down Expand Up @@ -122,7 +123,8 @@ private void addEncryptConditions(final Collection<EncryptCondition> encryptCond
return;
}
for (ColumnSegment each : ColumnExtractor.extract(expression)) {
String tableName = expressionTableNames.getOrDefault(each.getExpression(), "");
ColumnSegmentBoundInfo columnBoundInfo = each.getColumnBoundInfo();
String tableName = columnBoundInfo.getOriginalTable().getValue().isEmpty() ? expressionTableNames.getOrDefault(each.getExpression(), "") : columnBoundInfo.getOriginalTable().getValue();
Optional<EncryptTable> encryptTable = rule.findEncryptTable(tableName);
if (encryptTable.isPresent() && encryptTable.get().isEncryptColumn(each.getIdentifier().getValue())) {
createEncryptCondition(expression, tableName).ifPresent(encryptConditions::add);
Expand Down Expand Up @@ -184,8 +186,8 @@ private Optional<EncryptCondition> createCompareEncryptCondition(final String ta
}

private EncryptBinaryCondition createEncryptBinaryOperationCondition(final String tableName, final BinaryOperationExpression expression, final ExpressionSegment compareRightValue) {
String columnName = ((ColumnSegment) expression.getLeft()).getIdentifier().getValue();
return new EncryptBinaryCondition(columnName, tableName, expression.getOperator(), compareRightValue.getStartIndex(), expression.getStopIndex(), compareRightValue);
ColumnSegment columnSegment = (ColumnSegment) expression.getLeft();
return new EncryptBinaryCondition(columnSegment, tableName, expression.getOperator(), compareRightValue.getStartIndex(), expression.getStopIndex(), compareRightValue);
}

private static Optional<EncryptCondition> createInEncryptCondition(final String tableName, final InExpression inExpression, final ExpressionSegment inRightValue) {
Expand All @@ -201,7 +203,7 @@ private static Optional<EncryptCondition> createInEncryptCondition(final String
if (expressionSegments.isEmpty()) {
return Optional.empty();
}
String columnName = ((ColumnSegment) inExpression.getLeft()).getIdentifier().getValue();
return Optional.of(new EncryptInCondition(columnName, tableName, inRightValue.getStartIndex(), inRightValue.getStopIndex(), expressionSegments));
ColumnSegment columnSegment = (ColumnSegment) inExpression.getLeft();
return Optional.of(new EncryptInCondition(columnSegment, tableName, inRightValue.getStartIndex(), inRightValue.getStopIndex(), expressionSegments));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import lombok.Getter;
import lombok.ToString;
import org.apache.shardingsphere.encrypt.rewrite.condition.EncryptCondition;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.column.ColumnSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.ExpressionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.FunctionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.simple.LiteralExpressionSegment;
Expand All @@ -37,7 +38,7 @@
@ToString
public final class EncryptBinaryCondition implements EncryptCondition {

private final String columnName;
private final ColumnSegment columnSegment;

private final String tableName;

Expand All @@ -53,8 +54,9 @@ public final class EncryptBinaryCondition implements EncryptCondition {

private final Map<Integer, Object> positionValueMap = new LinkedHashMap<>();

public EncryptBinaryCondition(final String columnName, final String tableName, final String operator, final int startIndex, final int stopIndex, final ExpressionSegment expressionSegment) {
this.columnName = columnName;
public EncryptBinaryCondition(final ColumnSegment columnSegment, final String tableName, final String operator, final int startIndex, final int stopIndex,
final ExpressionSegment expressionSegment) {
this.columnSegment = columnSegment;
this.tableName = tableName;
this.operator = operator;
this.startIndex = startIndex;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import lombok.Getter;
import lombok.ToString;
import org.apache.shardingsphere.encrypt.rewrite.condition.EncryptCondition;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.column.ColumnSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.ExpressionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.simple.LiteralExpressionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.simple.ParameterMarkerExpressionSegment;
Expand All @@ -37,7 +38,7 @@
@ToString
public final class EncryptInCondition implements EncryptCondition {

private final String columnName;
private final ColumnSegment columnSegment;

private final String tableName;

Expand All @@ -49,8 +50,8 @@ public final class EncryptInCondition implements EncryptCondition {

private final Map<Integer, Object> positionValueMap = new LinkedHashMap<>();

public EncryptInCondition(final String columnName, final String tableName, final int startIndex, final int stopIndex, final List<ExpressionSegment> expressionSegments) {
this.columnName = columnName;
public EncryptInCondition(final ColumnSegment columnSegment, final String tableName, final int startIndex, final int stopIndex, final List<ExpressionSegment> expressionSegments) {
this.columnSegment = columnSegment;
this.tableName = tableName;
this.startIndex = startIndex;
this.stopIndex = stopIndex;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,16 @@ private Collection<EncryptCondition> createEncryptConditions(final EncryptRule r
return createEncryptConditions(rule, sqlRewriteContext, sqlStatementContext);
}

private Collection<EncryptCondition> createEncryptConditions(final EncryptRule rule, final SQLRewriteContext sqlRewriteContext, final SQLStatementContext sqlStatementContext) {
private Collection<EncryptCondition> createEncryptConditions(final EncryptRule rule, final SQLRewriteContext sqlRewriteContext,
final SQLStatementContext sqlStatementContext) {
if (!(sqlStatementContext instanceof WhereAvailable)) {
return Collections.emptyList();
}
Collection<SelectStatementContext> allSubqueryContexts = SQLStatementContextExtractor.getAllSubqueryContexts(sqlStatementContext);
Collection<WhereSegment> whereSegments = SQLStatementContextExtractor.getWhereSegments((WhereAvailable) sqlStatementContext, allSubqueryContexts);
Collection<ColumnSegment> columnSegments = SQLStatementContextExtractor.getColumnSegments((WhereAvailable) sqlStatementContext, allSubqueryContexts);
return new EncryptConditionEngine(
rule, sqlRewriteContext.getDatabase()).createEncryptConditions(whereSegments, columnSegments, sqlStatementContext, sqlRewriteContext.getDatabase().getName());
return new EncryptConditionEngine(rule, sqlRewriteContext.getDatabase()).createEncryptConditions(whereSegments, columnSegments, sqlStatementContext,
sqlRewriteContext.getDatabase().getName());
}

private void rewriteParameters(final SQLRewriteContext sqlRewriteContext, final Collection<ParameterRewriter> parameterRewriters) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ public void rewrite(final ParameterBuilder paramBuilder, final SQLStatementConte
}

private List<Object> getEncryptedValues(final String schemaName, final EncryptCondition encryptCondition, final List<Object> originalValues) {
String tableName = encryptCondition.getTableName();
String columnName = encryptCondition.getColumnName();
String tableName = encryptCondition.getColumnSegment().getColumnBoundInfo().getOriginalTable().getValue();
String columnName = encryptCondition.getColumnSegment().getIdentifier().getValue();
EncryptTable encryptTable = rule.getEncryptTable(tableName);
EncryptColumn encryptColumn = encryptTable.getEncryptColumn(columnName);
if (encryptCondition instanceof EncryptBinaryCondition && "LIKE".equals(((EncryptBinaryCondition) encryptCondition).getOperator()) && encryptColumn.getLikeQuery().isPresent()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

import lombok.RequiredArgsConstructor;
import lombok.Setter;
import org.apache.shardingsphere.infra.rewrite.sql.token.common.generator.aware.DatabaseAware;
import org.apache.shardingsphere.encrypt.rewrite.aware.EncryptConditionsAware;
import org.apache.shardingsphere.encrypt.rewrite.condition.EncryptCondition;
import org.apache.shardingsphere.encrypt.rule.EncryptRule;
Expand All @@ -28,6 +27,7 @@
import org.apache.shardingsphere.infra.binder.context.statement.dml.InsertStatementContext;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.infra.rewrite.sql.token.common.generator.CollectionSQLTokenGenerator;
import org.apache.shardingsphere.infra.rewrite.sql.token.common.generator.aware.DatabaseAware;
import org.apache.shardingsphere.infra.rewrite.sql.token.common.generator.aware.ParametersAware;
import org.apache.shardingsphere.infra.rewrite.sql.token.common.pojo.SQLToken;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import lombok.RequiredArgsConstructor;
import lombok.Setter;
import org.apache.shardingsphere.encrypt.exception.metadata.MissingMatchedEncryptQueryAlgorithmException;
import org.apache.shardingsphere.infra.rewrite.sql.token.common.generator.aware.DatabaseAware;
import org.apache.shardingsphere.encrypt.rewrite.aware.EncryptConditionsAware;
import org.apache.shardingsphere.encrypt.rewrite.condition.EncryptCondition;
import org.apache.shardingsphere.encrypt.rewrite.condition.EncryptConditionValues;
Expand All @@ -40,6 +39,7 @@
import org.apache.shardingsphere.infra.database.core.type.DatabaseTypeRegistry;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.infra.rewrite.sql.token.common.generator.CollectionSQLTokenGenerator;
import org.apache.shardingsphere.infra.rewrite.sql.token.common.generator.aware.DatabaseAware;
import org.apache.shardingsphere.infra.rewrite.sql.token.common.generator.aware.ParametersAware;
import org.apache.shardingsphere.infra.rewrite.sql.token.common.pojo.SQLToken;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.FunctionSegment;
Expand All @@ -49,6 +49,7 @@
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;

/**
* Predicate right value token generator for encrypt.
Expand Down Expand Up @@ -77,7 +78,8 @@ public Collection<SQLToken> generateSQLTokens(final SQLStatementContext sqlState
String schemaName = ((TableAvailable) sqlStatementContext).getTablesContext().getSchemaName()
.orElseGet(() -> new DatabaseTypeRegistry(sqlStatementContext.getDatabaseType()).getDefaultSchemaName(database.getName()));
for (EncryptCondition each : encryptConditions) {
rule.findEncryptTable(each.getTableName()).ifPresent(optional -> result.add(generateSQLToken(schemaName, optional, each)));
Optional<EncryptTable> encryptTable = rule.findEncryptTable(each.getTableName());
encryptTable.ifPresent(optional -> result.add(generateSQLToken(schemaName, optional, each)));
}
return result;
}
Expand All @@ -98,15 +100,16 @@ private SQLToken generateSQLToken(final String schemaName, final EncryptTable en
}

private List<Object> getEncryptedValues(final String schemaName, final EncryptTable encryptTable, final EncryptCondition encryptCondition, final List<Object> originalValues) {
EncryptColumn encryptColumn = encryptTable.getEncryptColumn(encryptCondition.getColumnName());
EncryptColumn encryptColumn = encryptTable.getEncryptColumn(encryptCondition.getColumnSegment().getIdentifier().getValue());
if (encryptCondition instanceof EncryptBinaryCondition && "LIKE".equalsIgnoreCase(((EncryptBinaryCondition) encryptCondition).getOperator())) {
LikeQueryColumnItem likeQueryColumnItem = encryptColumn.getLikeQuery()
.orElseThrow(() -> new MissingMatchedEncryptQueryAlgorithmException(encryptTable.getTable(), encryptCondition.getColumnName(), "LIKE"));
return likeQueryColumnItem.encrypt(database.getName(), schemaName, encryptCondition.getTableName(), encryptCondition.getColumnName(), originalValues);
.orElseThrow(() -> new MissingMatchedEncryptQueryAlgorithmException(encryptTable.getTable(), encryptCondition.getColumnSegment().getIdentifier().getValue(), "LIKE"));
return likeQueryColumnItem.encrypt(database.getName(), schemaName, encryptCondition.getTableName(), encryptCondition.getColumnSegment().getIdentifier().getValue(), originalValues);
}
return encryptColumn.getAssistedQuery()
.map(optional -> optional.encrypt(database.getName(), schemaName, encryptCondition.getTableName(), encryptCondition.getColumnName(), originalValues))
.orElseGet(() -> encryptColumn.getCipher().encrypt(database.getName(), schemaName, encryptCondition.getTableName(), encryptCondition.getColumnName(), originalValues));
.map(optional -> optional.encrypt(database.getName(), schemaName, encryptCondition.getTableName(), encryptCondition.getColumnSegment().getIdentifier().getValue(), originalValues))
.orElseGet(() -> encryptColumn.getCipher().encrypt(database.getName(), schemaName, encryptCondition.getTableName(), encryptCondition.getColumnSegment().getIdentifier().getValue(),
originalValues));
}

private Map<Integer, Object> getPositionValues(final Collection<Integer> valuePositions, final List<Object> encryptValues) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@

package org.apache.shardingsphere.encrypt.rewrite.condition.impl;

import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.column.ColumnSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.ExpressionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.FunctionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.simple.LiteralExpressionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.simple.ParameterMarkerExpressionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.value.identifier.IdentifierValue;
import org.junit.jupiter.api.Test;

import java.util.Collections;
Expand All @@ -34,14 +36,14 @@ class EncryptBinaryConditionTest {

@Test
void assertNewInstanceWithParameterMarkerExpression() {
EncryptBinaryCondition actual = new EncryptBinaryCondition("col", null, null, 0, 0, new ParameterMarkerExpressionSegment(0, 0, 1));
EncryptBinaryCondition actual = new EncryptBinaryCondition(new ColumnSegment(0, 0, new IdentifierValue("col")), null, null, 0, 0, new ParameterMarkerExpressionSegment(0, 0, 1));
assertThat(actual.getPositionIndexMap(), is(Collections.singletonMap(0, 1)));
assertTrue(actual.getPositionValueMap().isEmpty());
}

@Test
void assertNewInstanceWithLiteralExpression() {
EncryptBinaryCondition actual = new EncryptBinaryCondition("col", null, null, 0, 0, new LiteralExpressionSegment(0, 0, "foo"));
EncryptBinaryCondition actual = new EncryptBinaryCondition(new ColumnSegment(0, 0, new IdentifierValue("col")), null, null, 0, 0, new LiteralExpressionSegment(0, 0, "foo"));
assertTrue(actual.getPositionIndexMap().isEmpty());
assertThat(actual.getPositionValueMap(), is(Collections.singletonMap(0, "foo")));
}
Expand All @@ -52,7 +54,7 @@ void assertNewInstanceWithConcatFunctionExpression() {
functionSegment.getParameters().add(new LiteralExpressionSegment(0, 0, "foo"));
functionSegment.getParameters().add(new ParameterMarkerExpressionSegment(0, 0, 1));
functionSegment.getParameters().add(mock(ExpressionSegment.class));
EncryptBinaryCondition actual = new EncryptBinaryCondition("col", null, null, 0, 0, functionSegment);
EncryptBinaryCondition actual = new EncryptBinaryCondition(new ColumnSegment(0, 0, new IdentifierValue("col")), null, null, 0, 0, functionSegment);
assertThat(actual.getPositionIndexMap(), is(Collections.singletonMap(1, 1)));
assertThat(actual.getPositionValueMap(), is(Collections.singletonMap(0, "foo")));
}
Expand All @@ -63,7 +65,7 @@ void assertNewInstanceWithNotConcatFunctionExpression() {
functionSegment.getParameters().add(new LiteralExpressionSegment(0, 0, "foo"));
functionSegment.getParameters().add(new ParameterMarkerExpressionSegment(0, 0, 1));
functionSegment.getParameters().add(mock(ExpressionSegment.class));
EncryptBinaryCondition actual = new EncryptBinaryCondition("col", null, null, 0, 0, functionSegment);
EncryptBinaryCondition actual = new EncryptBinaryCondition(new ColumnSegment(0, 0, new IdentifierValue("col")), null, null, 0, 0, functionSegment);
assertTrue(actual.getPositionIndexMap().isEmpty());
assertTrue(actual.getPositionValueMap().isEmpty());
}
Expand Down
Loading

0 comments on commit e3901e0

Please sign in to comment.