Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support sql bind for select with current select projection reference #34151

Merged
merged 6 commits into from
Dec 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
1. SQL Parser: Support MySQL update with statement parse - [#34126](https://github.com/apache/shardingsphere/pull/34126)
1. SQL Binder: Remove TablesContext#findTableNames method and implement select order by, group by bind logic - [#34123](https://github.com/apache/shardingsphere/pull/34123)
1. SQL Binder: Support select with statement sql bind and add bind test case - [#34141](https://github.com/apache/shardingsphere/pull/34141)
1. SQL Binder: Support sql bind for select with current select projection reference - [#34151](https://github.com/apache/shardingsphere/pull/34151)

### Bug Fixes

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.shardingsphere.infra.binder.engine.segment.expression.type;

import com.cedarsoftware.util.CaseInsensitiveMap.CaseInsensitiveString;
import com.cedarsoftware.util.CaseInsensitiveSet;
import com.google.common.base.Strings;
import com.google.common.collect.Multimap;
import lombok.AccessLevel;
Expand All @@ -43,7 +44,6 @@
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
Expand All @@ -56,7 +56,7 @@
@NoArgsConstructor(access = AccessLevel.PRIVATE)
public final class ColumnSegmentBinder {

private static final Collection<String> EXCLUDE_BIND_COLUMNS = new LinkedHashSet<>(Arrays.asList(
private static final Collection<String> EXCLUDE_BIND_COLUMNS = new CaseInsensitiveSet<>(Arrays.asList(
"ROWNUM", "ROW_NUMBER", "ROWNUM_", "ROWID", "SYSDATE", "SYSTIMESTAMP", "CURRENT_TIMESTAMP", "LOCALTIMESTAMP", "UID", "USER", "NEXTVAL", "LEVEL"));

private static final Map<SegmentType, String> SEGMENT_TYPE_MESSAGES = Maps.of(SegmentType.PROJECTION, "field list", SegmentType.JOIN_ON, "on clause", SegmentType.JOIN_USING, "from clause",
Expand All @@ -77,7 +77,7 @@ public final class ColumnSegmentBinder {
public static ColumnSegment bind(final ColumnSegment segment, final SegmentType parentSegmentType, final SQLStatementBinderContext binderContext,
final Multimap<CaseInsensitiveString, TableSegmentBinderContext> tableBinderContexts,
final Multimap<CaseInsensitiveString, TableSegmentBinderContext> outerTableBinderContexts) {
if (EXCLUDE_BIND_COLUMNS.contains(segment.getIdentifier().getValue().toUpperCase())) {
if (EXCLUDE_BIND_COLUMNS.contains(segment.getIdentifier().getValue())) {
return segment;
}
ColumnSegment result = copy(segment);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,13 @@
import org.apache.shardingsphere.infra.binder.engine.segment.SegmentType;
import org.apache.shardingsphere.infra.binder.engine.segment.expression.ExpressionSegmentBinder;
import org.apache.shardingsphere.infra.binder.engine.segment.from.context.TableSegmentBinderContext;
import org.apache.shardingsphere.infra.binder.engine.segment.from.context.type.SimpleTableSegmentBinderContext;
import org.apache.shardingsphere.infra.binder.engine.segment.projection.type.ColumnProjectionSegmentBinder;
import org.apache.shardingsphere.infra.binder.engine.segment.projection.type.ShorthandProjectionSegmentBinder;
import org.apache.shardingsphere.infra.binder.engine.segment.projection.type.SubqueryProjectionSegmentBinder;
import org.apache.shardingsphere.infra.binder.engine.statement.SQLStatementBinderContext;
import org.apache.shardingsphere.infra.binder.engine.util.SubqueryTableBindUtils;
import org.apache.shardingsphere.infra.exception.kernel.metadata.ColumnNotFoundException;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.ExpressionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.AggregationDistinctProjectionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.AggregationProjectionSegment;
Expand All @@ -39,8 +42,9 @@
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.ShorthandProjectionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.SubqueryProjectionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.table.TableSegment;
import org.apache.shardingsphere.sql.parser.statement.core.value.identifier.IdentifierValue;

import java.util.stream.Collectors;
import java.util.Collection;

/**
* Projections segment binder.
Expand All @@ -63,16 +67,30 @@ public static ProjectionsSegment bind(final ProjectionsSegment segment, final SQ
final Multimap<CaseInsensitiveString, TableSegmentBinderContext> outerTableBinderContexts) {
ProjectionsSegment result = new ProjectionsSegment(segment.getStartIndex(), segment.getStopIndex());
result.setDistinctRow(segment.isDistinctRow());
result.getProjections().addAll(segment.getProjections().stream()
.map(each -> bind(each, binderContext, boundTableSegment, tableBinderContexts, outerTableBinderContexts)).collect(Collectors.toList()));
for (ProjectionSegment each : segment.getProjections()) {
Multimap<CaseInsensitiveString, TableSegmentBinderContext> currentTableBinderContexts = createCurrentTableBinderContexts(binderContext, result.getProjections());
result.getProjections().add(bind(binderContext, boundTableSegment, currentTableBinderContexts, tableBinderContexts, outerTableBinderContexts, each));
}
return result;
}

private static ProjectionSegment bind(final SQLStatementBinderContext binderContext, final TableSegment boundTableSegment,
final Multimap<CaseInsensitiveString, TableSegmentBinderContext> currentTableBinderContexts,
final Multimap<CaseInsensitiveString, TableSegmentBinderContext> tableBinderContexts,
final Multimap<CaseInsensitiveString, TableSegmentBinderContext> outerTableBinderContexts,
final ProjectionSegment projectionSegment) {
try {
return bind(projectionSegment, binderContext, boundTableSegment, tableBinderContexts, outerTableBinderContexts);
} catch (final ColumnNotFoundException ignored) {
return bind(projectionSegment, binderContext, boundTableSegment, currentTableBinderContexts, outerTableBinderContexts);
}
}

private static ProjectionSegment bind(final ProjectionSegment projectionSegment, final SQLStatementBinderContext binderContext, final TableSegment boundTableSegment,
final Multimap<CaseInsensitiveString, TableSegmentBinderContext> tableBinderContexts,
final Multimap<CaseInsensitiveString, TableSegmentBinderContext> outerTableBinderContexts) {
if (projectionSegment instanceof ColumnProjectionSegment) {
return ColumnProjectionSegmentBinder.bind((ColumnProjectionSegment) projectionSegment, binderContext, tableBinderContexts);
return ColumnProjectionSegmentBinder.bind((ColumnProjectionSegment) projectionSegment, binderContext, tableBinderContexts, outerTableBinderContexts);
}
if (projectionSegment instanceof ShorthandProjectionSegment) {
return ShorthandProjectionSegmentBinder.bind((ShorthandProjectionSegment) projectionSegment, boundTableSegment, tableBinderContexts);
Expand Down Expand Up @@ -125,4 +143,12 @@ private static AggregationProjectionSegment bindAggregationProjection(final Aggr
aggregationSegment.getAliasSegment().ifPresent(result::setAlias);
return result;
}

private static Multimap<CaseInsensitiveString, TableSegmentBinderContext> createCurrentTableBinderContexts(final SQLStatementBinderContext binderContext,
final Collection<ProjectionSegment> projections) {
Multimap<CaseInsensitiveString, TableSegmentBinderContext> result = LinkedHashMultimap.create();
Collection<ProjectionSegment> subqueryProjections = SubqueryTableBindUtils.createSubqueryProjections(projections, new IdentifierValue(""), binderContext.getSqlStatement().getDatabaseType());
result.put(new CaseInsensitiveString(""), new SimpleTableSegmentBinderContext(subqueryProjections));
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package org.apache.shardingsphere.infra.binder.engine.segment.projection.type;

import com.cedarsoftware.util.CaseInsensitiveMap.CaseInsensitiveString;
import com.google.common.collect.LinkedHashMultimap;
import com.google.common.collect.Multimap;
import lombok.AccessLevel;
import lombok.NoArgsConstructor;
Expand All @@ -41,11 +40,13 @@ public final class ColumnProjectionSegmentBinder {
* @param segment table segment
* @param binderContext SQL statement binder context
* @param tableBinderContexts table binder contexts
* @param outerTableBinderContexts outer table binder contexts
* @return bound column projection segment
*/
public static ColumnProjectionSegment bind(final ColumnProjectionSegment segment,
final SQLStatementBinderContext binderContext, final Multimap<CaseInsensitiveString, TableSegmentBinderContext> tableBinderContexts) {
ColumnSegment boundColumn = ColumnSegmentBinder.bind(segment.getColumn(), SegmentType.PROJECTION, binderContext, tableBinderContexts, LinkedHashMultimap.create());
public static ColumnProjectionSegment bind(final ColumnProjectionSegment segment, final SQLStatementBinderContext binderContext,
final Multimap<CaseInsensitiveString, TableSegmentBinderContext> tableBinderContexts,
final Multimap<CaseInsensitiveString, TableSegmentBinderContext> outerTableBinderContexts) {
ColumnSegment boundColumn = ColumnSegmentBinder.bind(segment.getColumn(), SegmentType.PROJECTION, binderContext, tableBinderContexts, outerTableBinderContexts);
ColumnProjectionSegment result = new ColumnProjectionSegment(boundColumn);
segment.getAliasSegment().ifPresent(result::setAlias);
result.setVisible(segment.isVisible());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.apache.shardingsphere.infra.executor.sql.hook.SQLExecutionHook;
import org.apache.shardingsphere.infra.executor.sql.process.ProcessEngine;
import org.apache.shardingsphere.infra.metadata.database.resource.ResourceMetaData;
import org.apache.shardingsphere.infra.metadata.database.resource.unit.StorageUnit;
import org.apache.shardingsphere.sql.parser.statement.core.statement.SQLStatement;

import java.sql.SQLException;
Expand Down Expand Up @@ -76,12 +77,17 @@ public final Collection<T> execute(final Collection<JDBCExecutionUnit> execution
*/
private T execute(final JDBCExecutionUnit jdbcExecutionUnit, final boolean isTrunkThread, final String processId) throws SQLException {
SQLExecutorExceptionHandler.setExceptionThrown(isExceptionThrown);
DatabaseType storageType = resourceMetaData.getStorageUnits().get(jdbcExecutionUnit.getExecutionUnit().getDataSourceName()).getStorageType();
ConnectionProperties connectionProps = resourceMetaData.getStorageUnits().get(jdbcExecutionUnit.getExecutionUnit().getDataSourceName()).getConnectionProperties();
String dataSourceName = jdbcExecutionUnit.getExecutionUnit().getDataSourceName();
// TODO use metadata to replace storageUnits to support multiple logic databases
StorageUnit storageUnit = resourceMetaData.getStorageUnits().containsKey(dataSourceName)
? resourceMetaData.getStorageUnits().get(dataSourceName)
: resourceMetaData.getStorageUnits().values().iterator().next();
DatabaseType storageType = storageUnit.getStorageType();
ConnectionProperties connectionProps = storageUnit.getConnectionProperties();
SQLExecutionHook sqlExecutionHook = new SPISQLExecutionHook();
try {
SQLUnit sqlUnit = jdbcExecutionUnit.getExecutionUnit().getSqlUnit();
sqlExecutionHook.start(jdbcExecutionUnit.getExecutionUnit().getDataSourceName(), sqlUnit.getSql(), sqlUnit.getParameters(), connectionProps, isTrunkThread);
sqlExecutionHook.start(dataSourceName, sqlUnit.getSql(), sqlUnit.getParameters(), connectionProps, isTrunkThread);
T result = executeSQL(sqlUnit.getSql(), jdbcExecutionUnit.getStorageResource(), jdbcExecutionUnit.getConnectionMode(), storageType);
sqlExecutionHook.finishSuccess();
processEngine.completeSQLUnitExecution(jdbcExecutionUnit, processId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ protected List<ExecutionGroup<T>> group(final String databaseName, final String
@SuppressWarnings("unchecked")
private ExecutionGroup<T> createExecutionGroup(final String dataSourceName, final List<ExecutionUnit> executionUnits, final C connection, final ConnectionMode connectionMode) throws SQLException {
List<T> inputs = new LinkedList<>();
DatabaseType databaseType = storageUnits.get(dataSourceName).getStorageType();
// TODO use metadata to replace storageUnits to support multiple logic databases
DatabaseType databaseType = storageUnits.containsKey(dataSourceName) ? storageUnits.get(dataSourceName).getStorageType() : storageUnits.values().iterator().next().getStorageType();
for (ExecutionUnit each : executionUnits) {
inputs.add((T) sqlExecutionUnitBuilder.build(each, statementManager, connection, connectionMode, option, databaseType));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
import org.apache.shardingsphere.infra.executor.sql.execute.engine.ConnectionMode;
import org.apache.shardingsphere.infra.executor.sql.prepare.driver.jdbc.StatementOption;
import org.apache.shardingsphere.infra.metadata.database.resource.ResourceMetaData;
import org.apache.shardingsphere.infra.metadata.database.resource.unit.StorageUnit;
import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData;
import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereColumn;
Expand Down Expand Up @@ -155,6 +156,7 @@ void assertExecuteWithMultiInsertOnDuplicateKey() throws SQLException {
private ConnectionSession mockConnectionSession() throws SQLException {
ConnectionSession result = mock(ConnectionSession.class, RETURNS_DEEP_STUBS);
when(result.getCurrentDatabaseName()).thenReturn("foo_db");
when(result.getUsedDatabaseName()).thenReturn("foo_db");
Connection connection = mock(Connection.class, RETURNS_DEEP_STUBS);
when(connection.getMetaData().getURL()).thenReturn("jdbc:mysql://127.0.0.1/db");
Statement statement = mock(Statement.class);
Expand All @@ -171,11 +173,12 @@ private ConnectionSession mockConnectionSession() throws SQLException {

private ContextManager mockContextManager() {
ContextManager result = mock(ContextManager.class, RETURNS_DEEP_STUBS);
when(result.getMetaDataContexts().getMetaData().getDatabase("foo_db").getResourceMetaData().getAllInstanceDataSourceNames()).thenReturn(Collections.singletonList("foo_ds"));
ResourceMetaData resourceMetaData = mock(ResourceMetaData.class, RETURNS_DEEP_STUBS);
when(result.getMetaDataContexts().getMetaData().getDatabase("foo_db").getResourceMetaData()).thenReturn(resourceMetaData);
when(resourceMetaData.getAllInstanceDataSourceNames()).thenReturn(Collections.singletonList("foo_ds"));
StorageUnit storageUnit = mock(StorageUnit.class, RETURNS_DEEP_STUBS);
when(storageUnit.getStorageType()).thenReturn(TypedSPILoader.getService(DatabaseType.class, "FIXTURE"));
when(result.getMetaDataContexts().getMetaData().getDatabase("foo_db").getResourceMetaData().getStorageUnits())
.thenReturn(Collections.singletonMap("foo_ds", storageUnit));
when(resourceMetaData.getStorageUnits()).thenReturn(Collections.singletonMap("foo_ds", storageUnit));
when(result.getMetaDataContexts().getMetaData().getDatabase("foo_db").getProtocolType()).thenReturn(TypedSPILoader.getService(DatabaseType.class, "MySQL"));
when(result.getMetaDataContexts().getMetaData().getDatabase("foo_db").getRuleMetaData())
.thenReturn(new RuleMetaData(Collections.emptyList()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ private ConnectionSession mockConnectionSession() throws SQLException {
ConnectionSession result = mock(ConnectionSession.class);
when(result.getConnectionContext()).thenReturn(new ConnectionContext(Collections::emptySet));
when(result.getCurrentDatabaseName()).thenReturn("foo_db");
when(result.getUsedDatabaseName()).thenReturn("foo_db");
ConnectionContext connectionContext = mockConnectionContext();
when(result.getConnectionContext()).thenReturn(connectionContext);
ProxyDatabaseConnectionManager databaseConnectionManager = mock(ProxyDatabaseConnectionManager.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ private ConnectionSession mockConnectionSession() throws SQLException {
SQLStatementContext sqlStatementContext = mock(InsertStatementContext.class);
when(sqlStatementContext.getSqlStatement()).thenReturn(parserEngine.parse(SQL, false));
when(result.getCurrentDatabaseName()).thenReturn("foo_db");
when(result.getUsedDatabaseName()).thenReturn("foo_db");
ConnectionContext connectionContext = new ConnectionContext(Collections::emptySet);
connectionContext.setCurrentDatabaseName("foo_db");
when(result.getConnectionContext()).thenReturn(connectionContext);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ private ContextManager mockContextManager() {
private ConnectionSession mockConnectionSession() {
ConnectionSession result = mock(ConnectionSession.class);
when(result.getCurrentDatabaseName()).thenReturn("db");
when(result.getUsedDatabaseName()).thenReturn("db");
when(result.getDatabaseConnectionManager()).thenReturn(databaseConnectionManager);
when(result.getStatementManager()).thenReturn(backendStatement);
ConnectionContext connectionContext = new ConnectionContext(Collections::emptySet);
Expand Down
Loading
Loading