Skip to content

Commit

Permalink
fix unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
strongduanmu committed Dec 16, 2024
1 parent 4ed7986 commit dab374c
Show file tree
Hide file tree
Showing 21 changed files with 76 additions and 73 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public static CombineSegment bind(final CombineSegment segment, final SQLStateme
private static SubquerySegment bindSubquerySegment(final SubquerySegment segment, final SQLStatementBinderContext binderContext,
final Multimap<CaseInsensitiveMap.CaseInsensitiveString, TableSegmentBinderContext> outerTableBinderContexts) {
SubquerySegment result = new SubquerySegment(segment.getStartIndex(), segment.getStopIndex(), segment.getText());
SQLStatementBinderContext subqueryBinderContext = new SQLStatementBinderContext(segment.getSelect(), binderContext.getMetaData(), binderContext.getCurrentDatabaseName());
SQLStatementBinderContext subqueryBinderContext = new SQLStatementBinderContext(binderContext.getMetaData(), binderContext.getCurrentDatabaseName(), segment.getSelect());
subqueryBinderContext.getExternalTableBinderContexts().putAll(binderContext.getExternalTableBinderContexts());
result.setSelect(new SelectStatementBinder(outerTableBinderContexts).bind(segment.getSelect(), subqueryBinderContext));
return result;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ private static Optional<ColumnSegment> findInputColumnSegment(final ColumnSegmen
}
}
if (!isFindInputColumn) {
result = findInputColumnSegmentByVariables(segment, binderContext.getVariableNames()).orElse(null);
result = findInputColumnSegmentByVariables(segment, binderContext.getSqlStatement().getVariableNames()).orElse(null);
isFindInputColumn = null != result;
}
if (!isFindInputColumn) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public final class SubquerySegmentBinder {
*/
public static SubquerySegment bind(final SubquerySegment segment, final SQLStatementBinderContext binderContext,
final Multimap<CaseInsensitiveString, TableSegmentBinderContext> outerTableBinderContexts) {
SQLStatementBinderContext selectBinderContext = new SQLStatementBinderContext(segment.getSelect(), binderContext.getMetaData(), binderContext.getCurrentDatabaseName());
SQLStatementBinderContext selectBinderContext = new SQLStatementBinderContext(binderContext.getMetaData(), binderContext.getCurrentDatabaseName(), segment.getSelect());
selectBinderContext.getExternalTableBinderContexts().putAll(binderContext.getExternalTableBinderContexts());
SelectStatement boundSelectStatement = new SelectStatementBinder(outerTableBinderContexts).bind(segment.getSelect(), selectBinderContext);
return new SubquerySegment(segment.getStartIndex(), segment.getStopIndex(), boundSelectStatement, segment.getText());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ public static JoinTableSegment bind(final JoinTableSegment segment, final SQLSta
result.setDerivedUsing(bindUsingColumns(derivedUsingColumns, tableBinderContexts));
result.getDerivedUsing().forEach(each -> binderContext.getUsingColumnNames().add(each.getIdentifier().getValue()));
}
result.getDerivedJoinTableProjectionSegments().addAll(getDerivedJoinTableProjectionSegments(result, binderContext.getDatabaseType(), usingColumnsByNaturalJoin, tableBinderContexts));
result.getDerivedJoinTableProjectionSegments()
.addAll(getDerivedJoinTableProjectionSegments(result, binderContext.getSqlStatement().getDatabaseType(), usingColumnsByNaturalJoin, tableBinderContexts));
binderContext.getJoinTableProjectionSegments().addAll(result.getDerivedJoinTableProjectionSegments());
return result;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.bound.TableSegmentBoundInfo;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.table.SimpleTableSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.table.TableNameSegment;
import org.apache.shardingsphere.sql.parser.statement.core.statement.ddl.CreateTableStatement;
import org.apache.shardingsphere.sql.parser.statement.core.value.identifier.IdentifierValue;

import java.util.Collection;
Expand Down Expand Up @@ -101,7 +102,7 @@ private static void fillPivotColumnNamesInBinderContext(final SimpleTableSegment
}

private static IdentifierValue getDatabaseName(final SimpleTableSegment segment, final SQLStatementBinderContext binderContext) {
DialectDatabaseMetaData dialectDatabaseMetaData = new DatabaseTypeRegistry(binderContext.getDatabaseType()).getDialectDatabaseMetaData();
DialectDatabaseMetaData dialectDatabaseMetaData = new DatabaseTypeRegistry(binderContext.getSqlStatement().getDatabaseType()).getDialectDatabaseMetaData();
Optional<OwnerSegment> owner = dialectDatabaseMetaData.getDefaultSchema().isPresent() ? segment.getOwner().flatMap(OwnerSegment::getOwner) : segment.getOwner();
return new IdentifierValue(owner.map(optional -> optional.getIdentifier().getValue()).orElse(binderContext.getCurrentDatabaseName()));
}
Expand All @@ -111,15 +112,18 @@ private static IdentifierValue getSchemaName(final SimpleTableSegment segment, f
return segment.getOwner().get().getIdentifier();
}
// TODO getSchemaName according to search path
DatabaseType databaseType = binderContext.getDatabaseType();
DatabaseType databaseType = binderContext.getSqlStatement().getDatabaseType();
if ((databaseType instanceof PostgreSQLDatabaseType || databaseType instanceof OpenGaussDatabaseType) && SYSTEM_CATALOG_TABLES.contains(segment.getTableName().getIdentifier().getValue())) {
return new IdentifierValue(PG_CATALOG);
}
return new IdentifierValue(new DatabaseTypeRegistry(databaseType).getDefaultSchemaName(binderContext.getCurrentDatabaseName()));
}

private static void checkTableExists(final SQLStatementBinderContext binderContext, final String databaseName, final String schemaName, final String tableName) {
if ("dual".equalsIgnoreCase(tableName)) {
if (binderContext.getSqlStatement() instanceof CreateTableStatement) {
return;
}
if ("DUAL".equalsIgnoreCase(tableName)) {
return;
}
if (SystemSchemaManager.isSystemTable(schemaName, tableName)) {
Expand Down Expand Up @@ -147,7 +151,7 @@ private static SimpleTableSegmentBinderContext createSimpleTableSegmentBinderCon
final IdentifierValue databaseName, final IdentifierValue schemaName,
final SQLStatementBinderContext binderContext, final IdentifierValue tableName) {
Collection<ProjectionSegment> projectionSegments = new LinkedList<>();
QuoteCharacter quoteCharacter = new DatabaseTypeRegistry(binderContext.getDatabaseType()).getDialectDatabaseMetaData().getQuoteCharacter();
QuoteCharacter quoteCharacter = new DatabaseTypeRegistry(binderContext.getSqlStatement().getDatabaseType()).getDialectDatabaseMetaData().getQuoteCharacter();
for (ShardingSphereColumn each : schema.getTable(tableName.getValue()).getAllColumns()) {
ColumnProjectionSegment columnProjectionSegment = new ColumnProjectionSegment(createColumnSegment(segment, databaseName, schemaName, each, quoteCharacter, tableName));
columnProjectionSegment.setVisible(each.isVisible());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,15 @@ public static SubqueryTableSegment bind(final SubqueryTableSegment segment, fina
final Multimap<CaseInsensitiveString, TableSegmentBinderContext> tableBinderContexts,
final Multimap<CaseInsensitiveString, TableSegmentBinderContext> outerTableBinderContexts) {
fillPivotColumnNamesInBinderContext(segment, binderContext);
SQLStatementBinderContext subqueryBinderContext = new SQLStatementBinderContext(segment.getSubquery().getSelect(), binderContext.getMetaData(), binderContext.getCurrentDatabaseName());
SQLStatementBinderContext subqueryBinderContext = new SQLStatementBinderContext(binderContext.getMetaData(), binderContext.getCurrentDatabaseName(), segment.getSubquery().getSelect());
subqueryBinderContext.getExternalTableBinderContexts().putAll(binderContext.getExternalTableBinderContexts());
SelectStatement boundSubSelect = new SelectStatementBinder(outerTableBinderContexts).bind(segment.getSubquery().getSelect(), subqueryBinderContext);
SubquerySegment boundSubquerySegment = new SubquerySegment(segment.getSubquery().getStartIndex(), segment.getSubquery().getStopIndex(), boundSubSelect, segment.getSubquery().getText());
IdentifierValue subqueryTableName = segment.getAliasSegment().map(AliasSegment::getIdentifier).orElseGet(() -> new IdentifierValue(""));
SubqueryTableSegment result = new SubqueryTableSegment(segment.getStartIndex(), segment.getStopIndex(), boundSubquerySegment);
segment.getAliasSegment().ifPresent(result::setAlias);
tableBinderContexts.put(new CaseInsensitiveString(subqueryTableName.getValue()), new SimpleTableSegmentBinderContext(
SubqueryTableBindUtils.createSubqueryProjections(boundSubSelect.getProjections().getProjections(), subqueryTableName, binderContext.getDatabaseType())));
SubqueryTableBindUtils.createSubqueryProjections(boundSubSelect.getProjections().getProjections(), subqueryTableName, binderContext.getSqlStatement().getDatabaseType())));
return result;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import org.apache.shardingsphere.infra.binder.engine.segment.from.context.TableSegmentBinderContext;
import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
import org.apache.shardingsphere.infra.metadata.ShardingSphereMetaData;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.ProjectionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.statement.SQLStatement;
Expand All @@ -43,9 +42,7 @@ public final class SQLStatementBinderContext {

private final String currentDatabaseName;

private final DatabaseType databaseType;

private final Collection<String> variableNames;
private final SQLStatement sqlStatement;

private final Collection<String> usingColumnNames = new CaseInsensitiveSet<>();

Expand All @@ -54,11 +51,4 @@ public final class SQLStatementBinderContext {
private final Multimap<CaseInsensitiveString, TableSegmentBinderContext> externalTableBinderContexts = LinkedHashMultimap.create();

private final Collection<String> pivotColumnNames = new CaseInsensitiveSet<>();

public SQLStatementBinderContext(final SQLStatement sqlStatement, final ShardingSphereMetaData metaData, final String currentDatabaseName) {
this.metaData = metaData;
this.currentDatabaseName = currentDatabaseName;
databaseType = sqlStatement.getDatabaseType();
variableNames = new CaseInsensitiveSet<>(sqlStatement.getVariableNames());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public final class DDLStatementBindEngine {
* @return bound DDL statement
*/
public DDLStatement bind(final DDLStatement statement) {
SQLStatementBinderContext binderContext = new SQLStatementBinderContext(statement, metaData, currentDatabaseName);
SQLStatementBinderContext binderContext = new SQLStatementBinderContext(metaData, currentDatabaseName, statement);
if (statement instanceof CursorStatement) {
return new CursorStatementBinder().bind((CursorStatement) statement, binderContext);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public final class DMLStatementBindEngine {
* @return bound DML statement
*/
public DMLStatement bind(final DMLStatement statement) {
SQLStatementBinderContext binderContext = new SQLStatementBinderContext(statement, metaData, currentDatabaseName);
SQLStatementBinderContext binderContext = new SQLStatementBinderContext(metaData, currentDatabaseName, statement);
if (statement instanceof SelectStatement) {
return new SelectStatementBinder().bind((SelectStatement) statement, binderContext);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.OwnerSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.bound.ColumnSegmentBoundInfo;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.bound.TableSegmentBoundInfo;
import org.apache.shardingsphere.sql.parser.statement.core.statement.dml.SelectStatement;
import org.apache.shardingsphere.sql.parser.statement.core.value.identifier.IdentifierValue;
import org.junit.jupiter.api.Test;

Expand All @@ -45,6 +46,7 @@
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

class ColumnSegmentBinderTest {

Expand All @@ -62,8 +64,9 @@ void assertBindWithMultiTablesJoinAndNoOwner() {
new IdentifierValue("t_order_item"), new IdentifierValue("item_id")));
tableBinderContexts.put(new CaseInsensitiveString("t_order_item"), new SimpleTableSegmentBinderContext(Collections.singleton(new ColumnProjectionSegment(boundItemIdColumn))));
ColumnSegment columnSegment = new ColumnSegment(0, 0, new IdentifierValue("order_id"));
SQLStatementBinderContext binderContext =
new SQLStatementBinderContext(mock(ShardingSphereMetaData.class), "foo_db", databaseType, Collections.emptySet());
SelectStatement selectStatement = mock(SelectStatement.class);
when(selectStatement.getDatabaseType()).thenReturn(databaseType);
SQLStatementBinderContext binderContext = new SQLStatementBinderContext(mock(ShardingSphereMetaData.class), "foo_db", selectStatement);
ColumnSegment actual = ColumnSegmentBinder.bind(columnSegment, SegmentType.JOIN_ON, binderContext, tableBinderContexts, LinkedHashMultimap.create());
assertNotNull(actual.getColumnBoundInfo());
assertNull(actual.getOtherUsingColumnBoundInfo());
Expand All @@ -84,8 +87,9 @@ void assertBindFromOuterTable() {
boundOrderItemStatusColumn.setColumnBoundInfo(new ColumnSegmentBoundInfo(new TableSegmentBoundInfo(new IdentifierValue("foo_db"), new IdentifierValue("foo_schema")),
new IdentifierValue("t_order_item"), new IdentifierValue("status")));
outerTableBinderContexts.put(new CaseInsensitiveString("t_order_item"), new SimpleTableSegmentBinderContext(Collections.singleton(new ColumnProjectionSegment(boundOrderItemStatusColumn))));
SQLStatementBinderContext binderContext =
new SQLStatementBinderContext(mock(ShardingSphereMetaData.class), "foo_db", databaseType, Collections.emptySet());
SelectStatement selectStatement = mock(SelectStatement.class);
when(selectStatement.getDatabaseType()).thenReturn(databaseType);
SQLStatementBinderContext binderContext = new SQLStatementBinderContext(mock(ShardingSphereMetaData.class), "foo_db", selectStatement);
ColumnSegment columnSegment = new ColumnSegment(0, 0, new IdentifierValue("status"));
ColumnSegment actual = ColumnSegmentBinder.bind(columnSegment, SegmentType.PROJECTION, binderContext, LinkedHashMultimap.create(), outerTableBinderContexts);
assertNotNull(actual.getColumnBoundInfo());
Expand All @@ -107,8 +111,9 @@ void assertBindWithSameTableAliasAndSameProjection() {
boundOrderItemColumn.setColumnBoundInfo(new ColumnSegmentBoundInfo(new TableSegmentBoundInfo(new IdentifierValue("foo_db"), new IdentifierValue("foo_schema")),
new IdentifierValue("t_order_item"), new IdentifierValue("status")));
tableBinderContexts.put(new CaseInsensitiveString("temp"), new SimpleTableSegmentBinderContext(Collections.singleton(new ColumnProjectionSegment(boundOrderItemColumn))));
SQLStatementBinderContext binderContext =
new SQLStatementBinderContext(mock(ShardingSphereMetaData.class), "foo_db", databaseType, Collections.emptySet());
SelectStatement selectStatement = mock(SelectStatement.class);
when(selectStatement.getDatabaseType()).thenReturn(databaseType);
SQLStatementBinderContext binderContext = new SQLStatementBinderContext(mock(ShardingSphereMetaData.class), "foo_db", selectStatement);
ColumnSegment columnSegment = new ColumnSegment(0, 0, new IdentifierValue("status"));
columnSegment.setOwner(new OwnerSegment(0, 0, new IdentifierValue("temp")));
assertThrows(AmbiguousColumnException.class, () -> ColumnSegmentBinder.bind(columnSegment, SegmentType.PROJECTION, binderContext, tableBinderContexts, LinkedHashMultimap.create()));
Expand All @@ -125,8 +130,9 @@ void assertBindWithSameTableAliasAndDifferentProjection() {
boundOrderItemColumn.setColumnBoundInfo(new ColumnSegmentBoundInfo(new TableSegmentBoundInfo(new IdentifierValue("foo_db"), new IdentifierValue("foo_schema")),
new IdentifierValue("t_order_item"), new IdentifierValue("status")));
tableBinderContexts.put(new CaseInsensitiveString("temp"), new SimpleTableSegmentBinderContext(Collections.singleton(new ColumnProjectionSegment(boundOrderItemColumn))));
SQLStatementBinderContext binderContext =
new SQLStatementBinderContext(mock(ShardingSphereMetaData.class), "foo_db", databaseType, Collections.emptySet());
SelectStatement selectStatement = mock(SelectStatement.class);
when(selectStatement.getDatabaseType()).thenReturn(databaseType);
SQLStatementBinderContext binderContext = new SQLStatementBinderContext(mock(ShardingSphereMetaData.class), "foo_db", selectStatement);
ColumnSegment columnSegment = new ColumnSegment(0, 0, new IdentifierValue("status"));
columnSegment.setOwner(new OwnerSegment(0, 0, new IdentifierValue("temp")));
ColumnSegment actual = ColumnSegmentBinder.bind(columnSegment, SegmentType.PROJECTION, binderContext, tableBinderContexts, LinkedHashMultimap.create());
Expand All @@ -151,8 +157,9 @@ void assertBindOwner() {
tableBinderContexts.put(new CaseInsensitiveString("t_order_item"), new SimpleTableSegmentBinderContext(Collections.singleton(new ColumnProjectionSegment(boundItemIdColumn))));
ColumnSegment columnSegment = new ColumnSegment(0, 0, new IdentifierValue("order_id"));
columnSegment.setOwner(new OwnerSegment(0, 0, new IdentifierValue("t_order")));
SQLStatementBinderContext binderContext =
new SQLStatementBinderContext(mock(ShardingSphereMetaData.class), "foo_db", databaseType, Collections.emptySet());
SelectStatement selectStatement = mock(SelectStatement.class);
when(selectStatement.getDatabaseType()).thenReturn(databaseType);
SQLStatementBinderContext binderContext = new SQLStatementBinderContext(mock(ShardingSphereMetaData.class), "foo_db", selectStatement);
ColumnSegment actual = ColumnSegmentBinder.bind(columnSegment, SegmentType.JOIN_ON, binderContext, tableBinderContexts, LinkedHashMultimap.create());
assertTrue(actual.getOwner().isPresent());
assertTrue(actual.getOwner().get().getTableBoundInfo().isPresent());
Expand Down
Loading

0 comments on commit dab374c

Please sign in to comment.