Skip to content

Commit

Permalink
Support sql bind for select with current select projection reference
Browse files Browse the repository at this point in the history
  • Loading branch information
strongduanmu committed Dec 25, 2024
1 parent f4afcc3 commit c5d644b
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.shardingsphere.infra.binder.engine.segment.expression.type;

import com.cedarsoftware.util.CaseInsensitiveMap.CaseInsensitiveString;
import com.cedarsoftware.util.CaseInsensitiveSet;
import com.google.common.base.Strings;
import com.google.common.collect.Multimap;
import lombok.AccessLevel;
Expand All @@ -43,7 +44,6 @@
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
Expand All @@ -56,7 +56,7 @@
@NoArgsConstructor(access = AccessLevel.PRIVATE)
public final class ColumnSegmentBinder {

private static final Collection<String> EXCLUDE_BIND_COLUMNS = new LinkedHashSet<>(Arrays.asList(
private static final Collection<String> EXCLUDE_BIND_COLUMNS = new CaseInsensitiveSet<>(Arrays.asList(
"ROWNUM", "ROW_NUMBER", "ROWNUM_", "ROWID", "SYSDATE", "SYSTIMESTAMP", "CURRENT_TIMESTAMP", "LOCALTIMESTAMP", "UID", "USER", "NEXTVAL", "LEVEL"));

private static final Map<SegmentType, String> SEGMENT_TYPE_MESSAGES = Maps.of(SegmentType.PROJECTION, "field list", SegmentType.JOIN_ON, "on clause", SegmentType.JOIN_USING, "from clause",
Expand All @@ -77,7 +77,7 @@ public final class ColumnSegmentBinder {
public static ColumnSegment bind(final ColumnSegment segment, final SegmentType parentSegmentType, final SQLStatementBinderContext binderContext,
final Multimap<CaseInsensitiveString, TableSegmentBinderContext> tableBinderContexts,
final Multimap<CaseInsensitiveString, TableSegmentBinderContext> outerTableBinderContexts) {
if (EXCLUDE_BIND_COLUMNS.contains(segment.getIdentifier().getValue().toUpperCase())) {
if (EXCLUDE_BIND_COLUMNS.contains(segment.getIdentifier().getValue())) {
return segment;
}
ColumnSegment result = copy(segment);
Expand Down Expand Up @@ -116,7 +116,7 @@ private static Collection<TableSegmentBinderContext> getTableSegmentBinderContex
if (!binderContext.getJoinTableProjectionSegments().isEmpty() && isNeedUseJoinTableProjectionBind(segment, parentSegmentType, binderContext)) {
return Collections.singleton(new SimpleTableSegmentBinderContext(binderContext.getJoinTableProjectionSegments()));
}
return tableBinderContexts.values();
return tableBinderContexts.values().isEmpty() ? outerTableBinderContexts.values() : tableBinderContexts.values();
}

private static Collection<TableSegmentBinderContext> getTableBinderContextByOwner(final String owner, final Multimap<CaseInsensitiveString, TableSegmentBinderContext> tableBinderContexts,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,13 @@
import org.apache.shardingsphere.infra.binder.engine.segment.SegmentType;
import org.apache.shardingsphere.infra.binder.engine.segment.expression.ExpressionSegmentBinder;
import org.apache.shardingsphere.infra.binder.engine.segment.from.context.TableSegmentBinderContext;
import org.apache.shardingsphere.infra.binder.engine.segment.from.context.type.SimpleTableSegmentBinderContext;
import org.apache.shardingsphere.infra.binder.engine.segment.projection.type.ColumnProjectionSegmentBinder;
import org.apache.shardingsphere.infra.binder.engine.segment.projection.type.ShorthandProjectionSegmentBinder;
import org.apache.shardingsphere.infra.binder.engine.segment.projection.type.SubqueryProjectionSegmentBinder;
import org.apache.shardingsphere.infra.binder.engine.statement.SQLStatementBinderContext;
import org.apache.shardingsphere.infra.binder.engine.util.SubqueryTableBindUtils;
import org.apache.shardingsphere.infra.exception.kernel.metadata.ColumnNotFoundException;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.ExpressionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.AggregationDistinctProjectionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.AggregationProjectionSegment;
Expand All @@ -39,8 +42,9 @@
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.ShorthandProjectionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.SubqueryProjectionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.table.TableSegment;
import org.apache.shardingsphere.sql.parser.statement.core.value.identifier.IdentifierValue;

import java.util.stream.Collectors;
import java.util.Collection;

/**
* Projections segment binder.
Expand All @@ -63,16 +67,30 @@ public static ProjectionsSegment bind(final ProjectionsSegment segment, final SQ
final Multimap<CaseInsensitiveString, TableSegmentBinderContext> outerTableBinderContexts) {
ProjectionsSegment result = new ProjectionsSegment(segment.getStartIndex(), segment.getStopIndex());
result.setDistinctRow(segment.isDistinctRow());
result.getProjections().addAll(segment.getProjections().stream()
.map(each -> bind(each, binderContext, boundTableSegment, tableBinderContexts, outerTableBinderContexts)).collect(Collectors.toList()));
for (ProjectionSegment each : segment.getProjections()) {
Multimap<CaseInsensitiveString, TableSegmentBinderContext> currentTableBinderContexts = createCurrentTableBinderContexts(binderContext, result.getProjections());
result.getProjections().add(bind(binderContext, boundTableSegment, currentTableBinderContexts, tableBinderContexts, outerTableBinderContexts, each));
}
return result;
}

private static ProjectionSegment bind(final SQLStatementBinderContext binderContext, final TableSegment boundTableSegment,
final Multimap<CaseInsensitiveString, TableSegmentBinderContext> currentTableBinderContexts,
final Multimap<CaseInsensitiveString, TableSegmentBinderContext> tableBinderContexts,
final Multimap<CaseInsensitiveString, TableSegmentBinderContext> outerTableBinderContexts,
final ProjectionSegment projectionSegment) {
try {
return bind(projectionSegment, binderContext, boundTableSegment, tableBinderContexts, outerTableBinderContexts);
} catch (final ColumnNotFoundException ignored) {
return bind(projectionSegment, binderContext, boundTableSegment, currentTableBinderContexts, outerTableBinderContexts);
}
}

private static ProjectionSegment bind(final ProjectionSegment projectionSegment, final SQLStatementBinderContext binderContext, final TableSegment boundTableSegment,
final Multimap<CaseInsensitiveString, TableSegmentBinderContext> tableBinderContexts,
final Multimap<CaseInsensitiveString, TableSegmentBinderContext> outerTableBinderContexts) {
if (projectionSegment instanceof ColumnProjectionSegment) {
return ColumnProjectionSegmentBinder.bind((ColumnProjectionSegment) projectionSegment, binderContext, tableBinderContexts);
return ColumnProjectionSegmentBinder.bind((ColumnProjectionSegment) projectionSegment, binderContext, tableBinderContexts, outerTableBinderContexts);
}
if (projectionSegment instanceof ShorthandProjectionSegment) {
return ShorthandProjectionSegmentBinder.bind((ShorthandProjectionSegment) projectionSegment, boundTableSegment, tableBinderContexts);
Expand Down Expand Up @@ -125,4 +143,12 @@ private static AggregationProjectionSegment bindAggregationProjection(final Aggr
aggregationSegment.getAliasSegment().ifPresent(result::setAlias);
return result;
}

private static Multimap<CaseInsensitiveString, TableSegmentBinderContext> createCurrentTableBinderContexts(final SQLStatementBinderContext binderContext,
final Collection<ProjectionSegment> projections) {
Multimap<CaseInsensitiveString, TableSegmentBinderContext> result = LinkedHashMultimap.create();
Collection<ProjectionSegment> subqueryProjections = SubqueryTableBindUtils.createSubqueryProjections(projections, new IdentifierValue(""), binderContext.getSqlStatement().getDatabaseType());
result.put(new CaseInsensitiveString(""), new SimpleTableSegmentBinderContext(subqueryProjections));
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package org.apache.shardingsphere.infra.binder.engine.segment.projection.type;

import com.cedarsoftware.util.CaseInsensitiveMap.CaseInsensitiveString;
import com.google.common.collect.LinkedHashMultimap;
import com.google.common.collect.Multimap;
import lombok.AccessLevel;
import lombok.NoArgsConstructor;
Expand All @@ -41,11 +40,13 @@ public final class ColumnProjectionSegmentBinder {
* @param segment table segment
* @param binderContext SQL statement binder context
* @param tableBinderContexts table binder contexts
* @param outerTableBinderContexts outer table binder contexts
* @return bound column projection segment
*/
public static ColumnProjectionSegment bind(final ColumnProjectionSegment segment,
final SQLStatementBinderContext binderContext, final Multimap<CaseInsensitiveString, TableSegmentBinderContext> tableBinderContexts) {
ColumnSegment boundColumn = ColumnSegmentBinder.bind(segment.getColumn(), SegmentType.PROJECTION, binderContext, tableBinderContexts, LinkedHashMultimap.create());
public static ColumnProjectionSegment bind(final ColumnProjectionSegment segment, final SQLStatementBinderContext binderContext,
final Multimap<CaseInsensitiveString, TableSegmentBinderContext> tableBinderContexts,
final Multimap<CaseInsensitiveString, TableSegmentBinderContext> outerTableBinderContexts) {
ColumnSegment boundColumn = ColumnSegmentBinder.bind(segment.getColumn(), SegmentType.PROJECTION, binderContext, tableBinderContexts, outerTableBinderContexts);
ColumnProjectionSegment result = new ColumnProjectionSegment(boundColumn);
segment.getAliasSegment().ifPresent(result::setAlias);
result.setVisible(segment.isVisible());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.apache.shardingsphere.infra.executor.sql.hook.SQLExecutionHook;
import org.apache.shardingsphere.infra.executor.sql.process.ProcessEngine;
import org.apache.shardingsphere.infra.metadata.database.resource.ResourceMetaData;
import org.apache.shardingsphere.infra.metadata.database.resource.unit.StorageUnit;
import org.apache.shardingsphere.sql.parser.statement.core.statement.SQLStatement;

import java.sql.SQLException;
Expand Down Expand Up @@ -76,12 +77,17 @@ public final Collection<T> execute(final Collection<JDBCExecutionUnit> execution
*/
private T execute(final JDBCExecutionUnit jdbcExecutionUnit, final boolean isTrunkThread, final String processId) throws SQLException {
SQLExecutorExceptionHandler.setExceptionThrown(isExceptionThrown);
DatabaseType storageType = resourceMetaData.getStorageUnits().get(jdbcExecutionUnit.getExecutionUnit().getDataSourceName()).getStorageType();
ConnectionProperties connectionProps = resourceMetaData.getStorageUnits().get(jdbcExecutionUnit.getExecutionUnit().getDataSourceName()).getConnectionProperties();
String dataSourceName = jdbcExecutionUnit.getExecutionUnit().getDataSourceName();
// TODO use metadata to replace storageUnits to support multiple logic databases
StorageUnit storageUnit = resourceMetaData.getStorageUnits().containsKey(dataSourceName)
? resourceMetaData.getStorageUnits().get(dataSourceName)
: resourceMetaData.getStorageUnits().values().iterator().next();
DatabaseType storageType = storageUnit.getStorageType();
ConnectionProperties connectionProps = storageUnit.getConnectionProperties();
SQLExecutionHook sqlExecutionHook = new SPISQLExecutionHook();
try {
SQLUnit sqlUnit = jdbcExecutionUnit.getExecutionUnit().getSqlUnit();
sqlExecutionHook.start(jdbcExecutionUnit.getExecutionUnit().getDataSourceName(), sqlUnit.getSql(), sqlUnit.getParameters(), connectionProps, isTrunkThread);
sqlExecutionHook.start(dataSourceName, sqlUnit.getSql(), sqlUnit.getParameters(), connectionProps, isTrunkThread);
T result = executeSQL(sqlUnit.getSql(), jdbcExecutionUnit.getStorageResource(), jdbcExecutionUnit.getConnectionMode(), storageType);
sqlExecutionHook.finishSuccess();
processEngine.completeSQLUnitExecution(jdbcExecutionUnit, processId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ protected List<ExecutionGroup<T>> group(final String databaseName, final String
@SuppressWarnings("unchecked")
private ExecutionGroup<T> createExecutionGroup(final String dataSourceName, final List<ExecutionUnit> executionUnits, final C connection, final ConnectionMode connectionMode) throws SQLException {
List<T> inputs = new LinkedList<>();
DatabaseType databaseType = storageUnits.get(dataSourceName).getStorageType();
// TODO use metadata to replace storageUnits to support multiple logic databases
DatabaseType databaseType = storageUnits.containsKey(dataSourceName) ? storageUnits.get(dataSourceName).getStorageType() : storageUnits.values().iterator().next().getStorageType();
for (ExecutionUnit each : executionUnits) {
inputs.add((T) sqlExecutionUnitBuilder.build(each, statementManager, connection, connectionMode, option, databaseType));
}
Expand Down
37 changes: 37 additions & 0 deletions test/it/binder/src/test/resources/cases/dml/select.xml
Original file line number Diff line number Diff line change
Expand Up @@ -293,4 +293,41 @@
</simple-table>
</from>
</select>

<select sql-case-id="select_with_current_select_projection_reference">
<projections start-index="7" stop-index="58">
<column-projection name="order_id" start-index="7" stop-index="25" alias="orderId">
<column-bound>
<original-database name="foo_db_1" />
<original-schema name="foo_db_1" />
<original-table name="t_order" />
<original-column name="order_id" start-delimiter="`" end-delimiter="`" />
</column-bound>
</column-projection>
<subquery-projection start-index="28" stop-index="43" alias="tempOrderId" text="(SELECT orderId)">
<subquery>
<select>
<projections start-index="36" stop-index="42">
<column-projection name="orderId" start-index="36" stop-index="42">
<column-bound>
<original-database name="foo_db_1" />
<original-schema name="foo_db_1" />
<original-table name="t_order" />
<original-column name="order_id" start-delimiter="`" end-delimiter="`" />
</column-bound>
</column-projection>
</projections>
</select>
</subquery>
</subquery-projection>
</projections>
<from start-index="60" stop-index="71">
<simple-table name="t_order" start-index="65" stop-index="71">
<table-bound>
<original-database name="foo_db_1" />
<original-schema name="foo_db_1" />
</table-bound>
</simple-table>
</from>
</select>
</sql-parser-test-cases>
1 change: 1 addition & 0 deletions test/it/binder/src/test/resources/sqls/dml/select.xml
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,5 @@
<sql-case id="select_with_shorthand_projection" value="SELECT * FROM t_order o" db-types="MySQL"/>
<sql-case id="select_with_group_by_order_by" value="SELECT order_id, COUNT(1) count FROM t_order o GROUP BY order_id HAVING count > 1 ORDER BY order_id" db-types="MySQL"/>
<sql-case id="select_with_with_clause" value="WITH t_order_tmp AS (SELECT * FROM t_order o) SELECT * FROM t_order_tmp" db-types="MySQL"/>
<sql-case id="select_with_current_select_projection_reference" value="SELECT order_id AS orderId, (SELECT orderId) AS tempOrderId FROM t_order" db-types="MySQL"/>
</sql-cases>

0 comments on commit c5d644b

Please sign in to comment.