Skip to content

Commit

Permalink
Refactor ColumnSegmentBoundInfo and correct pass to SQLBindEngine dat…
Browse files Browse the repository at this point in the history
…abaseName (#34023)

* Refactor ColumnSegmentBoundInfo and correct pass to SQLBindEngine databaseName

* Refactor ColumnSegmentBoundInfo and correct pass to SQLBindEngine databaseName

* fix unit test

* fix unit test

* fix unit test

* fix unit test
  • Loading branch information
strongduanmu authored Dec 12, 2024
1 parent 5439c67 commit e42f4e5
Show file tree
Hide file tree
Showing 30 changed files with 112 additions and 77 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ private boolean containsEncryptProjectionInCombineSegment(final EncryptRule rule

private ColumnSegmentBoundInfo getColumnSegmentBoundInfo(final Projection projection) {
return projection instanceof ColumnProjection
? new ColumnSegmentBoundInfo(null, null, ((ColumnProjection) projection).getOriginalTable(), ((ColumnProjection) projection).getOriginalColumn())
? new ColumnSegmentBoundInfo(null, ((ColumnProjection) projection).getOriginalTable(), ((ColumnProjection) projection).getOriginalColumn())
: new ColumnSegmentBoundInfo(new IdentifierValue(projection.getColumnLabel()));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.BinaryOperationExpression;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.predicate.WhereSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.bound.ColumnSegmentBoundInfo;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.bound.TableSegmentBoundInfo;
import org.apache.shardingsphere.sql.parser.statement.core.value.identifier.IdentifierValue;
import org.junit.jupiter.api.Test;

Expand Down Expand Up @@ -66,9 +67,11 @@ void assertCheckWithDifferentEncryptorsInJoinCondition() {

private SQLStatementContext mockSelectStatementContextWithDifferentEncryptorsInJoinCondition() {
ColumnSegment leftColumn = new ColumnSegment(0, 0, new IdentifierValue("user_name"));
leftColumn.setColumnBoundInfo(new ColumnSegmentBoundInfo(new IdentifierValue("foo_db"), new IdentifierValue("foo_schema"), new IdentifierValue("t_user"), new IdentifierValue("user_name")));
leftColumn.setColumnBoundInfo(new ColumnSegmentBoundInfo(new TableSegmentBoundInfo(new IdentifierValue("foo_db"), new IdentifierValue("foo_schema")), new IdentifierValue("t_user"),
new IdentifierValue("user_name")));
ColumnSegment rightColumn = new ColumnSegment(0, 0, new IdentifierValue("user_id"));
rightColumn.setColumnBoundInfo(new ColumnSegmentBoundInfo(new IdentifierValue("foo_db"), new IdentifierValue("foo_schema"), new IdentifierValue("t_user"), new IdentifierValue("user_id")));
rightColumn.setColumnBoundInfo(
new ColumnSegmentBoundInfo(new TableSegmentBoundInfo(new IdentifierValue("foo_db"), new IdentifierValue("foo_schema")), new IdentifierValue("t_user"), new IdentifierValue("user_id")));
SelectStatementContext result = mock(SelectStatementContext.class);
when(result.getJoinConditions()).thenReturn(Collections.singleton(new BinaryOperationExpression(0, 0, leftColumn, rightColumn, "=", "")));
return result;
Expand All @@ -82,8 +85,8 @@ void assertCheckWithNotMatchedLikeQueryEncryptor() {

private SQLStatementContext mockSelectStatementContextWithLike() {
ColumnSegment columnSegment = new ColumnSegment(0, 0, new IdentifierValue("user_name"));
columnSegment.setColumnBoundInfo(
new ColumnSegmentBoundInfo(new IdentifierValue("foo_db"), new IdentifierValue("foo_schema"), new IdentifierValue("t_user"), new IdentifierValue("user_name")));
columnSegment.setColumnBoundInfo(new ColumnSegmentBoundInfo(new TableSegmentBoundInfo(new IdentifierValue("foo_db"), new IdentifierValue("foo_schema")), new IdentifierValue("t_user"),
new IdentifierValue("user_name")));
SelectStatementContext result = mock(SelectStatementContext.class, RETURNS_DEEP_STUBS);
when(result.getTablesContext().findTableNames(Collections.singleton(columnSegment), null)).thenReturn(Collections.singletonMap("user_name", "t_user"));
when(result.getColumnSegments()).thenReturn(Collections.singleton(columnSegment));
Expand All @@ -100,8 +103,8 @@ void assertCheckSuccess() {

private SQLStatementContext mockSelectStatementContextWithEqual() {
ColumnSegment columnSegment = new ColumnSegment(0, 0, new IdentifierValue("user_name"));
columnSegment.setColumnBoundInfo(
new ColumnSegmentBoundInfo(new IdentifierValue("foo_db"), new IdentifierValue("foo_schema"), new IdentifierValue("t_user"), new IdentifierValue("user_name")));
columnSegment.setColumnBoundInfo(new ColumnSegmentBoundInfo(new TableSegmentBoundInfo(new IdentifierValue("foo_db"), new IdentifierValue("foo_schema")), new IdentifierValue("t_user"),
new IdentifierValue("user_name")));
SelectStatementContext result = mock(SelectStatementContext.class, RETURNS_DEEP_STUBS);
when(result.getTablesContext().findTableNames(Collections.singleton(columnSegment), null)).thenReturn(Collections.singletonMap("user_name", "t_user"));
when(result.getColumnSegments()).thenReturn(Collections.singleton(columnSegment));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.combine.CombineSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.ShorthandProjectionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.bound.ColumnSegmentBoundInfo;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.bound.TableSegmentBoundInfo;
import org.apache.shardingsphere.sql.parser.statement.core.value.identifier.IdentifierValue;
import org.junit.jupiter.api.Test;

Expand Down Expand Up @@ -91,16 +92,16 @@ private SelectStatementContext mockSelectStatementContext() {
when(combineSegment.getLeft().getStartIndex()).thenReturn(0);
when(combineSegment.getRight().getStartIndex()).thenReturn(1);
when(result.getSqlStatement().getCombine()).thenReturn(Optional.of(combineSegment));
ColumnProjection leftColumn1 = new ColumnProjection(new IdentifierValue("f"), new IdentifierValue("foo_col_1"), null, databaseType,
null, null, new ColumnSegmentBoundInfo(new IdentifierValue(""), new IdentifierValue(""), new IdentifierValue("foo_tbl"), new IdentifierValue("foo_col_1")));
ColumnProjection leftColumn2 = new ColumnProjection(new IdentifierValue("f"), new IdentifierValue("foo_col_2"), null, databaseType,
null, null, new ColumnSegmentBoundInfo(new IdentifierValue(""), new IdentifierValue(""), new IdentifierValue("foo_tbl"), new IdentifierValue("foo_col_2")));
ColumnProjection leftColumn1 = new ColumnProjection(new IdentifierValue("f"), new IdentifierValue("foo_col_1"), null, databaseType, null, null,
new ColumnSegmentBoundInfo(new TableSegmentBoundInfo(new IdentifierValue(""), new IdentifierValue("")), new IdentifierValue("foo_tbl"), new IdentifierValue("foo_col_1")));
ColumnProjection leftColumn2 = new ColumnProjection(new IdentifierValue("f"), new IdentifierValue("foo_col_2"), null, databaseType, null, null,
new ColumnSegmentBoundInfo(new TableSegmentBoundInfo(new IdentifierValue(""), new IdentifierValue("")), new IdentifierValue("foo_tbl"), new IdentifierValue("foo_col_2")));
SelectStatementContext leftSelectStatementContext = mock(SelectStatementContext.class, RETURNS_DEEP_STUBS);
when(leftSelectStatementContext.getProjectionsContext().getExpandProjections()).thenReturn(Arrays.asList(leftColumn1, leftColumn2));
ColumnProjection rightColumn1 = new ColumnProjection(new IdentifierValue("b"), new IdentifierValue("bar_col_1"), null, databaseType,
null, null, new ColumnSegmentBoundInfo(new IdentifierValue(""), new IdentifierValue(""), new IdentifierValue("bar_tbl"), new IdentifierValue("bar_col_1")));
ColumnProjection rightColumn2 = new ColumnProjection(new IdentifierValue("b"), new IdentifierValue("bar_col_2"), null, databaseType,
null, null, new ColumnSegmentBoundInfo(new IdentifierValue(""), new IdentifierValue(""), new IdentifierValue("bar_tbl"), new IdentifierValue("bar_col_2")));
ColumnProjection rightColumn1 = new ColumnProjection(new IdentifierValue("b"), new IdentifierValue("bar_col_1"), null, databaseType, null, null,
new ColumnSegmentBoundInfo(new TableSegmentBoundInfo(new IdentifierValue(""), new IdentifierValue("")), new IdentifierValue("bar_tbl"), new IdentifierValue("bar_col_1")));
ColumnProjection rightColumn2 = new ColumnProjection(new IdentifierValue("b"), new IdentifierValue("bar_col_2"), null, databaseType, null, null,
new ColumnSegmentBoundInfo(new TableSegmentBoundInfo(new IdentifierValue(""), new IdentifierValue("")), new IdentifierValue("bar_tbl"), new IdentifierValue("bar_col_2")));
SelectStatementContext rightSelectStatementContext = mock(SelectStatementContext.class, RETURNS_DEEP_STUBS);
when(rightSelectStatementContext.getProjectionsContext().getExpandProjections()).thenReturn(Arrays.asList(rightColumn1, rightColumn2));
Map<Integer, SelectStatementContext> subqueryContexts = new LinkedHashMap<>(2, 1F);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData;
import org.apache.shardingsphere.infra.spi.type.typed.TypedSPILoader;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.bound.ColumnSegmentBoundInfo;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.bound.TableSegmentBoundInfo;
import org.apache.shardingsphere.sql.parser.statement.core.value.identifier.IdentifierValue;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
Expand Down Expand Up @@ -107,8 +108,9 @@ void assertGetValueWithoutEncryptColumn() throws SQLException {

@Test
void assertGetValueWithEncryptColumn() throws SQLException {
ColumnProjection columnProjection = new ColumnProjection(new IdentifierValue("foo_tbl"), new IdentifierValue("foo_col"), new IdentifierValue("foo_alias"), databaseType,
null, null, new ColumnSegmentBoundInfo(new IdentifierValue("foo_db"), new IdentifierValue("foo_schema"), new IdentifierValue("foo_tbl"), new IdentifierValue("foo_col")));
ColumnProjection columnProjection =
new ColumnProjection(new IdentifierValue("foo_tbl"), new IdentifierValue("foo_col"), new IdentifierValue("foo_alias"), databaseType, null, null, new ColumnSegmentBoundInfo(
new TableSegmentBoundInfo(new IdentifierValue("foo_db"), new IdentifierValue("foo_schema")), new IdentifierValue("foo_tbl"), new IdentifierValue("foo_col")));
when(selectStatementContext.findColumnProjection(1)).thenReturn(Optional.of(columnProjection));
when(selectStatementContext.getTablesContext().getSchemaName()).thenReturn(Optional.of("foo_schema"));
EncryptAlgorithm encryptAlgorithm = mock(EncryptAlgorithm.class);
Expand All @@ -122,8 +124,9 @@ void assertGetValueWithEncryptColumn() throws SQLException {

@Test
void assertGetValueFailed() throws SQLException {
ColumnProjection columnProjection = new ColumnProjection(new IdentifierValue("foo_tbl"), new IdentifierValue("foo_col"), new IdentifierValue("foo_alias"), databaseType,
null, null, new ColumnSegmentBoundInfo(new IdentifierValue("foo_db"), new IdentifierValue("foo_schema"), new IdentifierValue("foo_tbl"), new IdentifierValue("foo_col")));
ColumnProjection columnProjection =
new ColumnProjection(new IdentifierValue("foo_tbl"), new IdentifierValue("foo_col"), new IdentifierValue("foo_alias"), databaseType, null, null, new ColumnSegmentBoundInfo(
new TableSegmentBoundInfo(new IdentifierValue("foo_db"), new IdentifierValue("foo_schema")), new IdentifierValue("foo_tbl"), new IdentifierValue("foo_col")));
when(selectStatementContext.findColumnProjection(1)).thenReturn(Optional.of(columnProjection));
when(selectStatementContext.getTablesContext().getSchemaName()).thenReturn(Optional.of("foo_schema"));
EncryptAlgorithm encryptAlgorithm = mock(EncryptAlgorithm.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.apache.shardingsphere.infra.binder.context.segment.select.projection.impl.ColumnProjection;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.column.ColumnSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.bound.ColumnSegmentBoundInfo;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.bound.TableSegmentBoundInfo;
import org.apache.shardingsphere.sql.parser.statement.core.value.identifier.IdentifierValue;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
Expand Down Expand Up @@ -59,13 +60,13 @@ void assertInsertSelectIsSame() {

private ColumnProjection getSelectProjection(final String pwd, final IdentifierValue databaseValue, final IdentifierValue schemaValue) {
return new ColumnProjection(new IdentifierValue("table2"), new IdentifierValue(pwd), null, null, null, null,
new ColumnSegmentBoundInfo(databaseValue, schemaValue, new IdentifierValue("table2"), new IdentifierValue(pwd)));
new ColumnSegmentBoundInfo(new TableSegmentBoundInfo(databaseValue, schemaValue), new IdentifierValue("table2"), new IdentifierValue(pwd)));
}

private ColumnSegment getInsertColumnSegment(final IdentifierValue databaseValue, final IdentifierValue schemaValue, final String tableName, final String columnName) {
ColumnSegment result = mock(ColumnSegment.class);
when(result.getColumnBoundInfo())
.thenReturn(new ColumnSegmentBoundInfo(databaseValue, schemaValue, new IdentifierValue(tableName), new IdentifierValue(columnName)));
.thenReturn(new ColumnSegmentBoundInfo(new TableSegmentBoundInfo(databaseValue, schemaValue), new IdentifierValue(tableName), new IdentifierValue(columnName)));
return result;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.ProjectionsSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.predicate.WhereSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.bound.ColumnSegmentBoundInfo;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.bound.TableSegmentBoundInfo;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.table.SimpleTableSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.table.TableNameSegment;
import org.apache.shardingsphere.sql.parser.statement.core.statement.dml.InsertStatement;
Expand Down Expand Up @@ -128,10 +129,10 @@ private static InsertStatement createInsertSelectStatement(final boolean contain
InsertStatement result = new MySQLInsertStatement();
result.setTable(new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("t_user"))));
ColumnSegment userIdColumn = new ColumnSegment(0, 0, new IdentifierValue("user_id"));
userIdColumn.setColumnBoundInfo(new ColumnSegmentBoundInfo(new IdentifierValue("foo_db"), new IdentifierValue("foo_db"), new IdentifierValue("t_user"),
userIdColumn.setColumnBoundInfo(new ColumnSegmentBoundInfo(new TableSegmentBoundInfo(new IdentifierValue("foo_db"), new IdentifierValue("foo_db")), new IdentifierValue("t_user"),
new IdentifierValue("user_id")));
ColumnSegment userNameColumn = new ColumnSegment(0, 0, new IdentifierValue("user_name"));
userNameColumn.setColumnBoundInfo(new ColumnSegmentBoundInfo(new IdentifierValue("foo_db"), new IdentifierValue("foo_db"),
userNameColumn.setColumnBoundInfo(new ColumnSegmentBoundInfo(new TableSegmentBoundInfo(new IdentifierValue("foo_db"), new IdentifierValue("foo_db")),
new IdentifierValue("t_user"), new IdentifierValue("user_name")));
List<ColumnSegment> insertColumns = Arrays.asList(userIdColumn, userNameColumn);
if (containsInsertColumns) {
Expand All @@ -145,7 +146,7 @@ private static InsertStatement createInsertSelectStatement(final boolean contain
ProjectionsSegment projections = new ProjectionsSegment(0, 0);
projections.getProjections().add(new ColumnProjectionSegment(userIdColumn));
ColumnSegment statusColumn = new ColumnSegment(0, 0, new IdentifierValue("status"));
statusColumn.setColumnBoundInfo(new ColumnSegmentBoundInfo(new IdentifierValue("foo_db"), new IdentifierValue("foo_db"), new IdentifierValue("t_user"),
statusColumn.setColumnBoundInfo(new ColumnSegmentBoundInfo(new TableSegmentBoundInfo(new IdentifierValue("foo_db"), new IdentifierValue("foo_db")), new IdentifierValue("t_user"),
new IdentifierValue("status")));
projections.getProjections().add(new ColumnProjectionSegment(statusColumn));
selectStatement.setProjections(projections);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.AliasSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.OwnerSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.bound.ColumnSegmentBoundInfo;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.bound.TableSegmentBoundInfo;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.table.SimpleTableSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.table.TableNameSegment;
import org.apache.shardingsphere.sql.parser.statement.core.value.identifier.IdentifierValue;
Expand Down Expand Up @@ -80,7 +81,7 @@ void assertGenerateSQLTokensWhenOwnerMatchTableAlias() {
SimpleTableSegment doctorTable = new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("doctor")));
doctorTable.setAlias(new AliasSegment(0, 0, new IdentifierValue("a")));
ColumnSegment column = new ColumnSegment(0, 0, new IdentifierValue("mobile"));
column.setColumnBoundInfo(new ColumnSegmentBoundInfo(new IdentifierValue("foo_db"), new IdentifierValue("foo_db"), new IdentifierValue("doctor"),
column.setColumnBoundInfo(new ColumnSegmentBoundInfo(new TableSegmentBoundInfo(new IdentifierValue("foo_db"), new IdentifierValue("foo_db")), new IdentifierValue("doctor"),
new IdentifierValue("mobile")));
column.setOwner(new OwnerSegment(0, 0, new IdentifierValue("a")));
ProjectionsSegment projections = mock(ProjectionsSegment.class);
Expand All @@ -102,7 +103,7 @@ void assertGenerateSQLTokensWhenOwnerMatchTableAliasForSameTable() {
SimpleTableSegment doctorTable = new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("doctor")));
doctorTable.setAlias(new AliasSegment(0, 0, new IdentifierValue("a")));
ColumnSegment column = new ColumnSegment(0, 0, new IdentifierValue("mobile"));
column.setColumnBoundInfo(new ColumnSegmentBoundInfo(new IdentifierValue("foo_db"), new IdentifierValue("foo_db"), new IdentifierValue("doctor"),
column.setColumnBoundInfo(new ColumnSegmentBoundInfo(new TableSegmentBoundInfo(new IdentifierValue("foo_db"), new IdentifierValue("foo_db")), new IdentifierValue("doctor"),
new IdentifierValue("mobile")));
column.setOwner(new OwnerSegment(0, 0, new IdentifierValue("a")));
ProjectionsSegment projections = mock(ProjectionsSegment.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ private static ColumnSegmentBoundInfo createColumnSegmentBoundInfo(final ColumnS
: segmentOriginalTable;
IdentifierValue segmentOriginalColumn = segment.getColumnBoundInfo().getOriginalColumn();
IdentifierValue originalColumn = Optional.ofNullable(inputColumnSegment).map(optional -> optional.getColumnBoundInfo().getOriginalColumn()).orElse(segmentOriginalColumn);
return new ColumnSegmentBoundInfo(originalDatabase, originalSchema, originalTable, originalColumn);
return new ColumnSegmentBoundInfo(new TableSegmentBoundInfo(originalDatabase, originalSchema), originalTable, originalColumn);
}

/**
Expand Down
Loading

0 comments on commit e42f4e5

Please sign in to comment.