diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index 37fad67b25538..5e23492b00dc3 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -39,6 +39,7 @@ 1. SQL Parser: Support MySQL update with statement parse - [#34126](https://github.com/apache/shardingsphere/pull/34126) 1. SQL Binder: Remove TablesContext#findTableNames method and implement select order by, group by bind logic - [#34123](https://github.com/apache/shardingsphere/pull/34123) 1. SQL Binder: Support select with statement sql bind and add bind test case - [#34141](https://github.com/apache/shardingsphere/pull/34141) +1. SQL Binder: Support sql bind for select with current select projection reference - [#34151](https://github.com/apache/shardingsphere/pull/34151) ### Bug Fixes diff --git a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/engine/segment/expression/type/ColumnSegmentBinder.java b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/engine/segment/expression/type/ColumnSegmentBinder.java index 6020cc3f1bbdf..d4f9cdeecbcb3 100644 --- a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/engine/segment/expression/type/ColumnSegmentBinder.java +++ b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/engine/segment/expression/type/ColumnSegmentBinder.java @@ -18,6 +18,7 @@ package org.apache.shardingsphere.infra.binder.engine.segment.expression.type; import com.cedarsoftware.util.CaseInsensitiveMap.CaseInsensitiveString; +import com.cedarsoftware.util.CaseInsensitiveSet; import com.google.common.base.Strings; import com.google.common.collect.Multimap; import lombok.AccessLevel; @@ -43,7 +44,6 @@ import java.util.Arrays; import java.util.Collection; import java.util.Collections; -import java.util.LinkedHashSet; import java.util.List; import java.util.ListIterator; import java.util.Map; @@ -56,7 +56,7 @@ @NoArgsConstructor(access = AccessLevel.PRIVATE) public final class ColumnSegmentBinder { - private static final Collection EXCLUDE_BIND_COLUMNS = new LinkedHashSet<>(Arrays.asList( + private static final Collection EXCLUDE_BIND_COLUMNS = new CaseInsensitiveSet<>(Arrays.asList( "ROWNUM", "ROW_NUMBER", "ROWNUM_", "ROWID", "SYSDATE", "SYSTIMESTAMP", "CURRENT_TIMESTAMP", "LOCALTIMESTAMP", "UID", "USER", "NEXTVAL", "LEVEL")); private static final Map SEGMENT_TYPE_MESSAGES = Maps.of(SegmentType.PROJECTION, "field list", SegmentType.JOIN_ON, "on clause", SegmentType.JOIN_USING, "from clause", @@ -77,7 +77,7 @@ public final class ColumnSegmentBinder { public static ColumnSegment bind(final ColumnSegment segment, final SegmentType parentSegmentType, final SQLStatementBinderContext binderContext, final Multimap tableBinderContexts, final Multimap outerTableBinderContexts) { - if (EXCLUDE_BIND_COLUMNS.contains(segment.getIdentifier().getValue().toUpperCase())) { + if (EXCLUDE_BIND_COLUMNS.contains(segment.getIdentifier().getValue())) { return segment; } ColumnSegment result = copy(segment); diff --git a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/engine/segment/projection/ProjectionsSegmentBinder.java b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/engine/segment/projection/ProjectionsSegmentBinder.java index 83557beb1f937..9b3c92a48577b 100644 --- a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/engine/segment/projection/ProjectionsSegmentBinder.java +++ b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/engine/segment/projection/ProjectionsSegmentBinder.java @@ -25,10 +25,13 @@ import org.apache.shardingsphere.infra.binder.engine.segment.SegmentType; import org.apache.shardingsphere.infra.binder.engine.segment.expression.ExpressionSegmentBinder; import org.apache.shardingsphere.infra.binder.engine.segment.from.context.TableSegmentBinderContext; +import org.apache.shardingsphere.infra.binder.engine.segment.from.context.type.SimpleTableSegmentBinderContext; import org.apache.shardingsphere.infra.binder.engine.segment.projection.type.ColumnProjectionSegmentBinder; import org.apache.shardingsphere.infra.binder.engine.segment.projection.type.ShorthandProjectionSegmentBinder; import org.apache.shardingsphere.infra.binder.engine.segment.projection.type.SubqueryProjectionSegmentBinder; import org.apache.shardingsphere.infra.binder.engine.statement.SQLStatementBinderContext; +import org.apache.shardingsphere.infra.binder.engine.util.SubqueryTableBindUtils; +import org.apache.shardingsphere.infra.exception.kernel.metadata.ColumnNotFoundException; import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.ExpressionSegment; import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.AggregationDistinctProjectionSegment; import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.AggregationProjectionSegment; @@ -39,8 +42,9 @@ import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.ShorthandProjectionSegment; import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.SubqueryProjectionSegment; import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.table.TableSegment; +import org.apache.shardingsphere.sql.parser.statement.core.value.identifier.IdentifierValue; -import java.util.stream.Collectors; +import java.util.Collection; /** * Projections segment binder. @@ -63,16 +67,30 @@ public static ProjectionsSegment bind(final ProjectionsSegment segment, final SQ final Multimap outerTableBinderContexts) { ProjectionsSegment result = new ProjectionsSegment(segment.getStartIndex(), segment.getStopIndex()); result.setDistinctRow(segment.isDistinctRow()); - result.getProjections().addAll(segment.getProjections().stream() - .map(each -> bind(each, binderContext, boundTableSegment, tableBinderContexts, outerTableBinderContexts)).collect(Collectors.toList())); + for (ProjectionSegment each : segment.getProjections()) { + Multimap currentTableBinderContexts = createCurrentTableBinderContexts(binderContext, result.getProjections()); + result.getProjections().add(bind(binderContext, boundTableSegment, currentTableBinderContexts, tableBinderContexts, outerTableBinderContexts, each)); + } return result; } + private static ProjectionSegment bind(final SQLStatementBinderContext binderContext, final TableSegment boundTableSegment, + final Multimap currentTableBinderContexts, + final Multimap tableBinderContexts, + final Multimap outerTableBinderContexts, + final ProjectionSegment projectionSegment) { + try { + return bind(projectionSegment, binderContext, boundTableSegment, tableBinderContexts, outerTableBinderContexts); + } catch (final ColumnNotFoundException ignored) { + return bind(projectionSegment, binderContext, boundTableSegment, currentTableBinderContexts, outerTableBinderContexts); + } + } + private static ProjectionSegment bind(final ProjectionSegment projectionSegment, final SQLStatementBinderContext binderContext, final TableSegment boundTableSegment, final Multimap tableBinderContexts, final Multimap outerTableBinderContexts) { if (projectionSegment instanceof ColumnProjectionSegment) { - return ColumnProjectionSegmentBinder.bind((ColumnProjectionSegment) projectionSegment, binderContext, tableBinderContexts); + return ColumnProjectionSegmentBinder.bind((ColumnProjectionSegment) projectionSegment, binderContext, tableBinderContexts, outerTableBinderContexts); } if (projectionSegment instanceof ShorthandProjectionSegment) { return ShorthandProjectionSegmentBinder.bind((ShorthandProjectionSegment) projectionSegment, boundTableSegment, tableBinderContexts); @@ -125,4 +143,12 @@ private static AggregationProjectionSegment bindAggregationProjection(final Aggr aggregationSegment.getAliasSegment().ifPresent(result::setAlias); return result; } + + private static Multimap createCurrentTableBinderContexts(final SQLStatementBinderContext binderContext, + final Collection projections) { + Multimap result = LinkedHashMultimap.create(); + Collection subqueryProjections = SubqueryTableBindUtils.createSubqueryProjections(projections, new IdentifierValue(""), binderContext.getSqlStatement().getDatabaseType()); + result.put(new CaseInsensitiveString(""), new SimpleTableSegmentBinderContext(subqueryProjections)); + return result; + } } diff --git a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/engine/segment/projection/type/ColumnProjectionSegmentBinder.java b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/engine/segment/projection/type/ColumnProjectionSegmentBinder.java index 371105638533c..3dd3cff32cca2 100644 --- a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/engine/segment/projection/type/ColumnProjectionSegmentBinder.java +++ b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/engine/segment/projection/type/ColumnProjectionSegmentBinder.java @@ -18,7 +18,6 @@ package org.apache.shardingsphere.infra.binder.engine.segment.projection.type; import com.cedarsoftware.util.CaseInsensitiveMap.CaseInsensitiveString; -import com.google.common.collect.LinkedHashMultimap; import com.google.common.collect.Multimap; import lombok.AccessLevel; import lombok.NoArgsConstructor; @@ -41,11 +40,13 @@ public final class ColumnProjectionSegmentBinder { * @param segment table segment * @param binderContext SQL statement binder context * @param tableBinderContexts table binder contexts + * @param outerTableBinderContexts outer table binder contexts * @return bound column projection segment */ - public static ColumnProjectionSegment bind(final ColumnProjectionSegment segment, - final SQLStatementBinderContext binderContext, final Multimap tableBinderContexts) { - ColumnSegment boundColumn = ColumnSegmentBinder.bind(segment.getColumn(), SegmentType.PROJECTION, binderContext, tableBinderContexts, LinkedHashMultimap.create()); + public static ColumnProjectionSegment bind(final ColumnProjectionSegment segment, final SQLStatementBinderContext binderContext, + final Multimap tableBinderContexts, + final Multimap outerTableBinderContexts) { + ColumnSegment boundColumn = ColumnSegmentBinder.bind(segment.getColumn(), SegmentType.PROJECTION, binderContext, tableBinderContexts, outerTableBinderContexts); ColumnProjectionSegment result = new ColumnProjectionSegment(boundColumn); segment.getAliasSegment().ifPresent(result::setAlias); result.setVisible(segment.isVisible()); diff --git a/infra/executor/src/main/java/org/apache/shardingsphere/infra/executor/sql/execute/engine/driver/jdbc/JDBCExecutorCallback.java b/infra/executor/src/main/java/org/apache/shardingsphere/infra/executor/sql/execute/engine/driver/jdbc/JDBCExecutorCallback.java index f2451cd99be68..df54c8aa51575 100644 --- a/infra/executor/src/main/java/org/apache/shardingsphere/infra/executor/sql/execute/engine/driver/jdbc/JDBCExecutorCallback.java +++ b/infra/executor/src/main/java/org/apache/shardingsphere/infra/executor/sql/execute/engine/driver/jdbc/JDBCExecutorCallback.java @@ -29,6 +29,7 @@ import org.apache.shardingsphere.infra.executor.sql.hook.SQLExecutionHook; import org.apache.shardingsphere.infra.executor.sql.process.ProcessEngine; import org.apache.shardingsphere.infra.metadata.database.resource.ResourceMetaData; +import org.apache.shardingsphere.infra.metadata.database.resource.unit.StorageUnit; import org.apache.shardingsphere.sql.parser.statement.core.statement.SQLStatement; import java.sql.SQLException; @@ -76,12 +77,17 @@ public final Collection execute(final Collection execution */ private T execute(final JDBCExecutionUnit jdbcExecutionUnit, final boolean isTrunkThread, final String processId) throws SQLException { SQLExecutorExceptionHandler.setExceptionThrown(isExceptionThrown); - DatabaseType storageType = resourceMetaData.getStorageUnits().get(jdbcExecutionUnit.getExecutionUnit().getDataSourceName()).getStorageType(); - ConnectionProperties connectionProps = resourceMetaData.getStorageUnits().get(jdbcExecutionUnit.getExecutionUnit().getDataSourceName()).getConnectionProperties(); + String dataSourceName = jdbcExecutionUnit.getExecutionUnit().getDataSourceName(); + // TODO use metadata to replace storageUnits to support multiple logic databases + StorageUnit storageUnit = resourceMetaData.getStorageUnits().containsKey(dataSourceName) + ? resourceMetaData.getStorageUnits().get(dataSourceName) + : resourceMetaData.getStorageUnits().values().iterator().next(); + DatabaseType storageType = storageUnit.getStorageType(); + ConnectionProperties connectionProps = storageUnit.getConnectionProperties(); SQLExecutionHook sqlExecutionHook = new SPISQLExecutionHook(); try { SQLUnit sqlUnit = jdbcExecutionUnit.getExecutionUnit().getSqlUnit(); - sqlExecutionHook.start(jdbcExecutionUnit.getExecutionUnit().getDataSourceName(), sqlUnit.getSql(), sqlUnit.getParameters(), connectionProps, isTrunkThread); + sqlExecutionHook.start(dataSourceName, sqlUnit.getSql(), sqlUnit.getParameters(), connectionProps, isTrunkThread); T result = executeSQL(sqlUnit.getSql(), jdbcExecutionUnit.getStorageResource(), jdbcExecutionUnit.getConnectionMode(), storageType); sqlExecutionHook.finishSuccess(); processEngine.completeSQLUnitExecution(jdbcExecutionUnit, processId); diff --git a/infra/executor/src/main/java/org/apache/shardingsphere/infra/executor/sql/prepare/driver/DriverExecutionPrepareEngine.java b/infra/executor/src/main/java/org/apache/shardingsphere/infra/executor/sql/prepare/driver/DriverExecutionPrepareEngine.java index 17dd4851ea175..463404ec0e312 100644 --- a/infra/executor/src/main/java/org/apache/shardingsphere/infra/executor/sql/prepare/driver/DriverExecutionPrepareEngine.java +++ b/infra/executor/src/main/java/org/apache/shardingsphere/infra/executor/sql/prepare/driver/DriverExecutionPrepareEngine.java @@ -102,7 +102,8 @@ protected List> group(final String databaseName, final String @SuppressWarnings("unchecked") private ExecutionGroup createExecutionGroup(final String dataSourceName, final List executionUnits, final C connection, final ConnectionMode connectionMode) throws SQLException { List inputs = new LinkedList<>(); - DatabaseType databaseType = storageUnits.get(dataSourceName).getStorageType(); + // TODO use metadata to replace storageUnits to support multiple logic databases + DatabaseType databaseType = storageUnits.containsKey(dataSourceName) ? storageUnits.get(dataSourceName).getStorageType() : storageUnits.values().iterator().next().getStorageType(); for (ExecutionUnit each : executionUnits) { inputs.add((T) sqlExecutionUnitBuilder.build(each, statementManager, connection, connectionMode, option, databaseType)); } diff --git a/proxy/frontend/type/mysql/src/test/java/org/apache/shardingsphere/proxy/frontend/mysql/command/query/text/query/MySQLMultiStatementsHandlerTest.java b/proxy/frontend/type/mysql/src/test/java/org/apache/shardingsphere/proxy/frontend/mysql/command/query/text/query/MySQLMultiStatementsHandlerTest.java index 7d3e282e0796e..3fb046a90bdef 100644 --- a/proxy/frontend/type/mysql/src/test/java/org/apache/shardingsphere/proxy/frontend/mysql/command/query/text/query/MySQLMultiStatementsHandlerTest.java +++ b/proxy/frontend/type/mysql/src/test/java/org/apache/shardingsphere/proxy/frontend/mysql/command/query/text/query/MySQLMultiStatementsHandlerTest.java @@ -21,6 +21,7 @@ import org.apache.shardingsphere.infra.database.core.type.DatabaseType; import org.apache.shardingsphere.infra.executor.sql.execute.engine.ConnectionMode; import org.apache.shardingsphere.infra.executor.sql.prepare.driver.jdbc.StatementOption; +import org.apache.shardingsphere.infra.metadata.database.resource.ResourceMetaData; import org.apache.shardingsphere.infra.metadata.database.resource.unit.StorageUnit; import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData; import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereColumn; @@ -155,6 +156,7 @@ void assertExecuteWithMultiInsertOnDuplicateKey() throws SQLException { private ConnectionSession mockConnectionSession() throws SQLException { ConnectionSession result = mock(ConnectionSession.class, RETURNS_DEEP_STUBS); when(result.getCurrentDatabaseName()).thenReturn("foo_db"); + when(result.getUsedDatabaseName()).thenReturn("foo_db"); Connection connection = mock(Connection.class, RETURNS_DEEP_STUBS); when(connection.getMetaData().getURL()).thenReturn("jdbc:mysql://127.0.0.1/db"); Statement statement = mock(Statement.class); @@ -171,11 +173,12 @@ private ConnectionSession mockConnectionSession() throws SQLException { private ContextManager mockContextManager() { ContextManager result = mock(ContextManager.class, RETURNS_DEEP_STUBS); - when(result.getMetaDataContexts().getMetaData().getDatabase("foo_db").getResourceMetaData().getAllInstanceDataSourceNames()).thenReturn(Collections.singletonList("foo_ds")); + ResourceMetaData resourceMetaData = mock(ResourceMetaData.class, RETURNS_DEEP_STUBS); + when(result.getMetaDataContexts().getMetaData().getDatabase("foo_db").getResourceMetaData()).thenReturn(resourceMetaData); + when(resourceMetaData.getAllInstanceDataSourceNames()).thenReturn(Collections.singletonList("foo_ds")); StorageUnit storageUnit = mock(StorageUnit.class, RETURNS_DEEP_STUBS); when(storageUnit.getStorageType()).thenReturn(TypedSPILoader.getService(DatabaseType.class, "FIXTURE")); - when(result.getMetaDataContexts().getMetaData().getDatabase("foo_db").getResourceMetaData().getStorageUnits()) - .thenReturn(Collections.singletonMap("foo_ds", storageUnit)); + when(resourceMetaData.getStorageUnits()).thenReturn(Collections.singletonMap("foo_ds", storageUnit)); when(result.getMetaDataContexts().getMetaData().getDatabase("foo_db").getProtocolType()).thenReturn(TypedSPILoader.getService(DatabaseType.class, "MySQL")); when(result.getMetaDataContexts().getMetaData().getDatabase("foo_db").getRuleMetaData()) .thenReturn(new RuleMetaData(Collections.emptyList())); diff --git a/proxy/frontend/type/opengauss/src/test/java/org/apache/shardingsphere/proxy/frontend/opengauss/command/query/extended/bind/OpenGaussComBatchBindExecutorTest.java b/proxy/frontend/type/opengauss/src/test/java/org/apache/shardingsphere/proxy/frontend/opengauss/command/query/extended/bind/OpenGaussComBatchBindExecutorTest.java index bfe13d152b8d6..88f942dd632c0 100644 --- a/proxy/frontend/type/opengauss/src/test/java/org/apache/shardingsphere/proxy/frontend/opengauss/command/query/extended/bind/OpenGaussComBatchBindExecutorTest.java +++ b/proxy/frontend/type/opengauss/src/test/java/org/apache/shardingsphere/proxy/frontend/opengauss/command/query/extended/bind/OpenGaussComBatchBindExecutorTest.java @@ -106,6 +106,7 @@ private ConnectionSession mockConnectionSession() throws SQLException { ConnectionSession result = mock(ConnectionSession.class); when(result.getConnectionContext()).thenReturn(new ConnectionContext(Collections::emptySet)); when(result.getCurrentDatabaseName()).thenReturn("foo_db"); + when(result.getUsedDatabaseName()).thenReturn("foo_db"); ConnectionContext connectionContext = mockConnectionContext(); when(result.getConnectionContext()).thenReturn(connectionContext); ProxyDatabaseConnectionManager databaseConnectionManager = mock(ProxyDatabaseConnectionManager.class); diff --git a/proxy/frontend/type/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/PostgreSQLAggregatedBatchedStatementsCommandExecutorTest.java b/proxy/frontend/type/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/PostgreSQLAggregatedBatchedStatementsCommandExecutorTest.java index cb813d381c6d9..d2f3445f4bd9c 100644 --- a/proxy/frontend/type/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/PostgreSQLAggregatedBatchedStatementsCommandExecutorTest.java +++ b/proxy/frontend/type/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/PostgreSQLAggregatedBatchedStatementsCommandExecutorTest.java @@ -114,6 +114,7 @@ private ConnectionSession mockConnectionSession() throws SQLException { SQLStatementContext sqlStatementContext = mock(InsertStatementContext.class); when(sqlStatementContext.getSqlStatement()).thenReturn(parserEngine.parse(SQL, false)); when(result.getCurrentDatabaseName()).thenReturn("foo_db"); + when(result.getUsedDatabaseName()).thenReturn("foo_db"); ConnectionContext connectionContext = new ConnectionContext(Collections::emptySet); connectionContext.setCurrentDatabaseName("foo_db"); when(result.getConnectionContext()).thenReturn(connectionContext); diff --git a/proxy/frontend/type/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/PostgreSQLBatchedStatementsExecutorTest.java b/proxy/frontend/type/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/PostgreSQLBatchedStatementsExecutorTest.java index 81943995dae02..5d7f1aa257b9d 100644 --- a/proxy/frontend/type/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/PostgreSQLBatchedStatementsExecutorTest.java +++ b/proxy/frontend/type/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/query/extended/PostgreSQLBatchedStatementsExecutorTest.java @@ -154,6 +154,7 @@ private ContextManager mockContextManager() { private ConnectionSession mockConnectionSession() { ConnectionSession result = mock(ConnectionSession.class); when(result.getCurrentDatabaseName()).thenReturn("db"); + when(result.getUsedDatabaseName()).thenReturn("db"); when(result.getDatabaseConnectionManager()).thenReturn(databaseConnectionManager); when(result.getStatementManager()).thenReturn(backendStatement); ConnectionContext connectionContext = new ConnectionContext(Collections::emptySet); diff --git a/test/it/binder/src/test/resources/cases/dml/select.xml b/test/it/binder/src/test/resources/cases/dml/select.xml index 8caabdcb4d972..da40c0135de82 100644 --- a/test/it/binder/src/test/resources/cases/dml/select.xml +++ b/test/it/binder/src/test/resources/cases/dml/select.xml @@ -293,4 +293,41 @@ + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/test/it/binder/src/test/resources/sqls/dml/select.xml b/test/it/binder/src/test/resources/sqls/dml/select.xml index cca532f1c5f25..ced2c64e8b23a 100644 --- a/test/it/binder/src/test/resources/sqls/dml/select.xml +++ b/test/it/binder/src/test/resources/sqls/dml/select.xml @@ -20,4 +20,5 @@ +