Skip to content

Commit

Permalink
Support sql bind for select with current select projection reference (#…
Browse files Browse the repository at this point in the history
…34151)

* Support sql bind for select with current select projection reference

* update release note

* fix unit test

* fix unit test

* fix unit test

* fix unit test
  • Loading branch information
strongduanmu authored Dec 26, 2024
1 parent f4afcc3 commit bca6f32
Show file tree
Hide file tree
Showing 12 changed files with 97 additions and 18 deletions.
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
1. SQL Parser: Support MySQL update with statement parse - [#34126](https://github.com/apache/shardingsphere/pull/34126)
1. SQL Binder: Remove TablesContext#findTableNames method and implement select order by, group by bind logic - [#34123](https://github.com/apache/shardingsphere/pull/34123)
1. SQL Binder: Support select with statement sql bind and add bind test case - [#34141](https://github.com/apache/shardingsphere/pull/34141)
1. SQL Binder: Support sql bind for select with current select projection reference - [#34151](https://github.com/apache/shardingsphere/pull/34151)

### Bug Fixes

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.shardingsphere.infra.binder.engine.segment.expression.type;

import com.cedarsoftware.util.CaseInsensitiveMap.CaseInsensitiveString;
import com.cedarsoftware.util.CaseInsensitiveSet;
import com.google.common.base.Strings;
import com.google.common.collect.Multimap;
import lombok.AccessLevel;
Expand All @@ -43,7 +44,6 @@
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
Expand All @@ -56,7 +56,7 @@
@NoArgsConstructor(access = AccessLevel.PRIVATE)
public final class ColumnSegmentBinder {

private static final Collection<String> EXCLUDE_BIND_COLUMNS = new LinkedHashSet<>(Arrays.asList(
private static final Collection<String> EXCLUDE_BIND_COLUMNS = new CaseInsensitiveSet<>(Arrays.asList(
"ROWNUM", "ROW_NUMBER", "ROWNUM_", "ROWID", "SYSDATE", "SYSTIMESTAMP", "CURRENT_TIMESTAMP", "LOCALTIMESTAMP", "UID", "USER", "NEXTVAL", "LEVEL"));

private static final Map<SegmentType, String> SEGMENT_TYPE_MESSAGES = Maps.of(SegmentType.PROJECTION, "field list", SegmentType.JOIN_ON, "on clause", SegmentType.JOIN_USING, "from clause",
Expand All @@ -77,7 +77,7 @@ public final class ColumnSegmentBinder {
public static ColumnSegment bind(final ColumnSegment segment, final SegmentType parentSegmentType, final SQLStatementBinderContext binderContext,
final Multimap<CaseInsensitiveString, TableSegmentBinderContext> tableBinderContexts,
final Multimap<CaseInsensitiveString, TableSegmentBinderContext> outerTableBinderContexts) {
if (EXCLUDE_BIND_COLUMNS.contains(segment.getIdentifier().getValue().toUpperCase())) {
if (EXCLUDE_BIND_COLUMNS.contains(segment.getIdentifier().getValue())) {
return segment;
}
ColumnSegment result = copy(segment);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,13 @@
import org.apache.shardingsphere.infra.binder.engine.segment.SegmentType;
import org.apache.shardingsphere.infra.binder.engine.segment.expression.ExpressionSegmentBinder;
import org.apache.shardingsphere.infra.binder.engine.segment.from.context.TableSegmentBinderContext;
import org.apache.shardingsphere.infra.binder.engine.segment.from.context.type.SimpleTableSegmentBinderContext;
import org.apache.shardingsphere.infra.binder.engine.segment.projection.type.ColumnProjectionSegmentBinder;
import org.apache.shardingsphere.infra.binder.engine.segment.projection.type.ShorthandProjectionSegmentBinder;
import org.apache.shardingsphere.infra.binder.engine.segment.projection.type.SubqueryProjectionSegmentBinder;
import org.apache.shardingsphere.infra.binder.engine.statement.SQLStatementBinderContext;
import org.apache.shardingsphere.infra.binder.engine.util.SubqueryTableBindUtils;
import org.apache.shardingsphere.infra.exception.kernel.metadata.ColumnNotFoundException;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.ExpressionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.AggregationDistinctProjectionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.AggregationProjectionSegment;
Expand All @@ -39,8 +42,9 @@
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.ShorthandProjectionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.SubqueryProjectionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.table.TableSegment;
import org.apache.shardingsphere.sql.parser.statement.core.value.identifier.IdentifierValue;

import java.util.stream.Collectors;
import java.util.Collection;

/**
* Projections segment binder.
Expand All @@ -63,16 +67,30 @@ public static ProjectionsSegment bind(final ProjectionsSegment segment, final SQ
final Multimap<CaseInsensitiveString, TableSegmentBinderContext> outerTableBinderContexts) {
ProjectionsSegment result = new ProjectionsSegment(segment.getStartIndex(), segment.getStopIndex());
result.setDistinctRow(segment.isDistinctRow());
result.getProjections().addAll(segment.getProjections().stream()
.map(each -> bind(each, binderContext, boundTableSegment, tableBinderContexts, outerTableBinderContexts)).collect(Collectors.toList()));
for (ProjectionSegment each : segment.getProjections()) {
Multimap<CaseInsensitiveString, TableSegmentBinderContext> currentTableBinderContexts = createCurrentTableBinderContexts(binderContext, result.getProjections());
result.getProjections().add(bind(binderContext, boundTableSegment, currentTableBinderContexts, tableBinderContexts, outerTableBinderContexts, each));
}
return result;
}

private static ProjectionSegment bind(final SQLStatementBinderContext binderContext, final TableSegment boundTableSegment,
final Multimap<CaseInsensitiveString, TableSegmentBinderContext> currentTableBinderContexts,
final Multimap<CaseInsensitiveString, TableSegmentBinderContext> tableBinderContexts,
final Multimap<CaseInsensitiveString, TableSegmentBinderContext> outerTableBinderContexts,
final ProjectionSegment projectionSegment) {
try {
return bind(projectionSegment, binderContext, boundTableSegment, tableBinderContexts, outerTableBinderContexts);
} catch (final ColumnNotFoundException ignored) {
return bind(projectionSegment, binderContext, boundTableSegment, currentTableBinderContexts, outerTableBinderContexts);
}
}

private static ProjectionSegment bind(final ProjectionSegment projectionSegment, final SQLStatementBinderContext binderContext, final TableSegment boundTableSegment,
final Multimap<CaseInsensitiveString, TableSegmentBinderContext> tableBinderContexts,
final Multimap<CaseInsensitiveString, TableSegmentBinderContext> outerTableBinderContexts) {
if (projectionSegment instanceof ColumnProjectionSegment) {
return ColumnProjectionSegmentBinder.bind((ColumnProjectionSegment) projectionSegment, binderContext, tableBinderContexts);
return ColumnProjectionSegmentBinder.bind((ColumnProjectionSegment) projectionSegment, binderContext, tableBinderContexts, outerTableBinderContexts);
}
if (projectionSegment instanceof ShorthandProjectionSegment) {
return ShorthandProjectionSegmentBinder.bind((ShorthandProjectionSegment) projectionSegment, boundTableSegment, tableBinderContexts);
Expand Down Expand Up @@ -125,4 +143,12 @@ private static AggregationProjectionSegment bindAggregationProjection(final Aggr
aggregationSegment.getAliasSegment().ifPresent(result::setAlias);
return result;
}

private static Multimap<CaseInsensitiveString, TableSegmentBinderContext> createCurrentTableBinderContexts(final SQLStatementBinderContext binderContext,
final Collection<ProjectionSegment> projections) {
Multimap<CaseInsensitiveString, TableSegmentBinderContext> result = LinkedHashMultimap.create();
Collection<ProjectionSegment> subqueryProjections = SubqueryTableBindUtils.createSubqueryProjections(projections, new IdentifierValue(""), binderContext.getSqlStatement().getDatabaseType());
result.put(new CaseInsensitiveString(""), new SimpleTableSegmentBinderContext(subqueryProjections));
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package org.apache.shardingsphere.infra.binder.engine.segment.projection.type;

import com.cedarsoftware.util.CaseInsensitiveMap.CaseInsensitiveString;
import com.google.common.collect.LinkedHashMultimap;
import com.google.common.collect.Multimap;
import lombok.AccessLevel;
import lombok.NoArgsConstructor;
Expand All @@ -41,11 +40,13 @@ public final class ColumnProjectionSegmentBinder {
* @param segment table segment
* @param binderContext SQL statement binder context
* @param tableBinderContexts table binder contexts
* @param outerTableBinderContexts outer table binder contexts
* @return bound column projection segment
*/
public static ColumnProjectionSegment bind(final ColumnProjectionSegment segment,
final SQLStatementBinderContext binderContext, final Multimap<CaseInsensitiveString, TableSegmentBinderContext> tableBinderContexts) {
ColumnSegment boundColumn = ColumnSegmentBinder.bind(segment.getColumn(), SegmentType.PROJECTION, binderContext, tableBinderContexts, LinkedHashMultimap.create());
public static ColumnProjectionSegment bind(final ColumnProjectionSegment segment, final SQLStatementBinderContext binderContext,
final Multimap<CaseInsensitiveString, TableSegmentBinderContext> tableBinderContexts,
final Multimap<CaseInsensitiveString, TableSegmentBinderContext> outerTableBinderContexts) {
ColumnSegment boundColumn = ColumnSegmentBinder.bind(segment.getColumn(), SegmentType.PROJECTION, binderContext, tableBinderContexts, outerTableBinderContexts);
ColumnProjectionSegment result = new ColumnProjectionSegment(boundColumn);
segment.getAliasSegment().ifPresent(result::setAlias);
result.setVisible(segment.isVisible());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.apache.shardingsphere.infra.executor.sql.hook.SQLExecutionHook;
import org.apache.shardingsphere.infra.executor.sql.process.ProcessEngine;
import org.apache.shardingsphere.infra.metadata.database.resource.ResourceMetaData;
import org.apache.shardingsphere.infra.metadata.database.resource.unit.StorageUnit;
import org.apache.shardingsphere.sql.parser.statement.core.statement.SQLStatement;

import java.sql.SQLException;
Expand Down Expand Up @@ -76,12 +77,17 @@ public final Collection<T> execute(final Collection<JDBCExecutionUnit> execution
*/
private T execute(final JDBCExecutionUnit jdbcExecutionUnit, final boolean isTrunkThread, final String processId) throws SQLException {
SQLExecutorExceptionHandler.setExceptionThrown(isExceptionThrown);
DatabaseType storageType = resourceMetaData.getStorageUnits().get(jdbcExecutionUnit.getExecutionUnit().getDataSourceName()).getStorageType();
ConnectionProperties connectionProps = resourceMetaData.getStorageUnits().get(jdbcExecutionUnit.getExecutionUnit().getDataSourceName()).getConnectionProperties();
String dataSourceName = jdbcExecutionUnit.getExecutionUnit().getDataSourceName();
// TODO use metadata to replace storageUnits to support multiple logic databases
StorageUnit storageUnit = resourceMetaData.getStorageUnits().containsKey(dataSourceName)
? resourceMetaData.getStorageUnits().get(dataSourceName)
: resourceMetaData.getStorageUnits().values().iterator().next();
DatabaseType storageType = storageUnit.getStorageType();
ConnectionProperties connectionProps = storageUnit.getConnectionProperties();
SQLExecutionHook sqlExecutionHook = new SPISQLExecutionHook();
try {
SQLUnit sqlUnit = jdbcExecutionUnit.getExecutionUnit().getSqlUnit();
sqlExecutionHook.start(jdbcExecutionUnit.getExecutionUnit().getDataSourceName(), sqlUnit.getSql(), sqlUnit.getParameters(), connectionProps, isTrunkThread);
sqlExecutionHook.start(dataSourceName, sqlUnit.getSql(), sqlUnit.getParameters(), connectionProps, isTrunkThread);
T result = executeSQL(sqlUnit.getSql(), jdbcExecutionUnit.getStorageResource(), jdbcExecutionUnit.getConnectionMode(), storageType);
sqlExecutionHook.finishSuccess();
processEngine.completeSQLUnitExecution(jdbcExecutionUnit, processId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ protected List<ExecutionGroup<T>> group(final String databaseName, final String
@SuppressWarnings("unchecked")
private ExecutionGroup<T> createExecutionGroup(final String dataSourceName, final List<ExecutionUnit> executionUnits, final C connection, final ConnectionMode connectionMode) throws SQLException {
List<T> inputs = new LinkedList<>();
DatabaseType databaseType = storageUnits.get(dataSourceName).getStorageType();
// TODO use metadata to replace storageUnits to support multiple logic databases
DatabaseType databaseType = storageUnits.containsKey(dataSourceName) ? storageUnits.get(dataSourceName).getStorageType() : storageUnits.values().iterator().next().getStorageType();
for (ExecutionUnit each : executionUnits) {
inputs.add((T) sqlExecutionUnitBuilder.build(each, statementManager, connection, connectionMode, option, databaseType));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
import org.apache.shardingsphere.infra.executor.sql.execute.engine.ConnectionMode;
import org.apache.shardingsphere.infra.executor.sql.prepare.driver.jdbc.StatementOption;
import org.apache.shardingsphere.infra.metadata.database.resource.ResourceMetaData;
import org.apache.shardingsphere.infra.metadata.database.resource.unit.StorageUnit;
import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData;
import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereColumn;
Expand Down Expand Up @@ -155,6 +156,7 @@ void assertExecuteWithMultiInsertOnDuplicateKey() throws SQLException {
private ConnectionSession mockConnectionSession() throws SQLException {
ConnectionSession result = mock(ConnectionSession.class, RETURNS_DEEP_STUBS);
when(result.getCurrentDatabaseName()).thenReturn("foo_db");
when(result.getUsedDatabaseName()).thenReturn("foo_db");
Connection connection = mock(Connection.class, RETURNS_DEEP_STUBS);
when(connection.getMetaData().getURL()).thenReturn("jdbc:mysql://127.0.0.1/db");
Statement statement = mock(Statement.class);
Expand All @@ -171,11 +173,12 @@ private ConnectionSession mockConnectionSession() throws SQLException {

private ContextManager mockContextManager() {
ContextManager result = mock(ContextManager.class, RETURNS_DEEP_STUBS);
when(result.getMetaDataContexts().getMetaData().getDatabase("foo_db").getResourceMetaData().getAllInstanceDataSourceNames()).thenReturn(Collections.singletonList("foo_ds"));
ResourceMetaData resourceMetaData = mock(ResourceMetaData.class, RETURNS_DEEP_STUBS);
when(result.getMetaDataContexts().getMetaData().getDatabase("foo_db").getResourceMetaData()).thenReturn(resourceMetaData);
when(resourceMetaData.getAllInstanceDataSourceNames()).thenReturn(Collections.singletonList("foo_ds"));
StorageUnit storageUnit = mock(StorageUnit.class, RETURNS_DEEP_STUBS);
when(storageUnit.getStorageType()).thenReturn(TypedSPILoader.getService(DatabaseType.class, "FIXTURE"));
when(result.getMetaDataContexts().getMetaData().getDatabase("foo_db").getResourceMetaData().getStorageUnits())
.thenReturn(Collections.singletonMap("foo_ds", storageUnit));
when(resourceMetaData.getStorageUnits()).thenReturn(Collections.singletonMap("foo_ds", storageUnit));
when(result.getMetaDataContexts().getMetaData().getDatabase("foo_db").getProtocolType()).thenReturn(TypedSPILoader.getService(DatabaseType.class, "MySQL"));
when(result.getMetaDataContexts().getMetaData().getDatabase("foo_db").getRuleMetaData())
.thenReturn(new RuleMetaData(Collections.emptyList()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ private ConnectionSession mockConnectionSession() throws SQLException {
ConnectionSession result = mock(ConnectionSession.class);
when(result.getConnectionContext()).thenReturn(new ConnectionContext(Collections::emptySet));
when(result.getCurrentDatabaseName()).thenReturn("foo_db");
when(result.getUsedDatabaseName()).thenReturn("foo_db");
ConnectionContext connectionContext = mockConnectionContext();
when(result.getConnectionContext()).thenReturn(connectionContext);
ProxyDatabaseConnectionManager databaseConnectionManager = mock(ProxyDatabaseConnectionManager.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ private ConnectionSession mockConnectionSession() throws SQLException {
SQLStatementContext sqlStatementContext = mock(InsertStatementContext.class);
when(sqlStatementContext.getSqlStatement()).thenReturn(parserEngine.parse(SQL, false));
when(result.getCurrentDatabaseName()).thenReturn("foo_db");
when(result.getUsedDatabaseName()).thenReturn("foo_db");
ConnectionContext connectionContext = new ConnectionContext(Collections::emptySet);
connectionContext.setCurrentDatabaseName("foo_db");
when(result.getConnectionContext()).thenReturn(connectionContext);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ private ContextManager mockContextManager() {
private ConnectionSession mockConnectionSession() {
ConnectionSession result = mock(ConnectionSession.class);
when(result.getCurrentDatabaseName()).thenReturn("db");
when(result.getUsedDatabaseName()).thenReturn("db");
when(result.getDatabaseConnectionManager()).thenReturn(databaseConnectionManager);
when(result.getStatementManager()).thenReturn(backendStatement);
ConnectionContext connectionContext = new ConnectionContext(Collections::emptySet);
Expand Down
Loading

0 comments on commit bca6f32

Please sign in to comment.