Skip to content

Commit

Permalink
Support select with statement sql bind and add bind test case
Browse files Browse the repository at this point in the history
  • Loading branch information
strongduanmu committed Dec 24, 2024
1 parent 5a96091 commit ff550f8
Show file tree
Hide file tree
Showing 10 changed files with 414 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@

package org.apache.shardingsphere.encrypt.checker.sql;

import org.apache.shardingsphere.encrypt.checker.sql.projection.EncryptInsertSelectProjectionSupportedChecker;
import org.apache.shardingsphere.encrypt.checker.sql.projection.EncryptSelectProjectionSupportedChecker;
import org.apache.shardingsphere.encrypt.checker.sql.orderby.EncryptOrderByItemSupportedChecker;
import org.apache.shardingsphere.encrypt.checker.sql.predicate.EncryptPredicateColumnSupportedChecker;
import org.apache.shardingsphere.encrypt.checker.sql.projection.EncryptInsertSelectProjectionSupportedChecker;
import org.apache.shardingsphere.encrypt.checker.sql.projection.EncryptSelectProjectionSupportedChecker;
import org.apache.shardingsphere.encrypt.constant.EncryptOrder;
import org.apache.shardingsphere.encrypt.rule.EncryptRule;
import org.apache.shardingsphere.infra.checker.SupportedSQLChecker;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.shardingsphere.infra.binder.engine.segment.from.context.TableSegmentBinderContext;
import org.apache.shardingsphere.infra.binder.engine.segment.from.context.type.SimpleTableSegmentBinderContext;
import org.apache.shardingsphere.infra.binder.engine.statement.SQLStatementBinderContext;
import org.apache.shardingsphere.infra.binder.engine.util.SubqueryTableBindUtils;
import org.apache.shardingsphere.infra.database.core.metadata.database.DialectDatabaseMetaData;
import org.apache.shardingsphere.infra.database.core.metadata.database.enums.QuoteCharacter;
import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
Expand Down Expand Up @@ -150,6 +151,12 @@ private static SimpleTableSegmentBinderContext createSimpleTableBinderContext(fi
if (binderContext.getSqlStatement() instanceof CreateTableStatement) {
return new SimpleTableSegmentBinderContext(createProjectionSegments((CreateTableStatement) binderContext.getSqlStatement(), databaseName, schemaName, tableName));
}
CaseInsensitiveString caseInsensitiveTableName = new CaseInsensitiveString(tableName.getValue());
if (binderContext.getExternalTableBinderContexts().containsKey(caseInsensitiveTableName)) {
TableSegmentBinderContext tableSegmentBinderContext = binderContext.getExternalTableBinderContexts().get(caseInsensitiveTableName).iterator().next();
return new SimpleTableSegmentBinderContext(
SubqueryTableBindUtils.createSubqueryProjections(tableSegmentBinderContext.getProjectionSegments(), tableName, binderContext.getSqlStatement().getDatabaseType()));
}
return new SimpleTableSegmentBinderContext(Collections.emptyList());
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.shardingsphere.infra.binder.engine.segment.with;

import com.cedarsoftware.util.CaseInsensitiveMap.CaseInsensitiveString;
import com.google.common.collect.LinkedHashMultimap;
import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import org.apache.shardingsphere.infra.binder.engine.segment.from.context.type.SimpleTableSegmentBinderContext;
import org.apache.shardingsphere.infra.binder.engine.segment.from.type.SubqueryTableSegmentBinder;
import org.apache.shardingsphere.infra.binder.engine.statement.SQLStatementBinderContext;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.complex.CommonTableExpressionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.ColumnProjectionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.table.SubqueryTableSegment;

import java.util.stream.Collectors;

/**
* Common table expression segment binder.
*/
@NoArgsConstructor(access = AccessLevel.PRIVATE)
public final class CommonTableExpressionSegmentBinder {

/**
* Bind common table expression segment.
*
* @param segment common table expression segment
* @param binderContext SQL statement binder context
* @param recursive recursive
* @return bound common table expression segment
*/
public static CommonTableExpressionSegment bind(final CommonTableExpressionSegment segment, final SQLStatementBinderContext binderContext, final boolean recursive) {
if (recursive && segment.getAliasName().isPresent()) {
binderContext.getExternalTableBinderContexts().put(new CaseInsensitiveString(segment.getAliasName().get()),
new SimpleTableSegmentBinderContext(segment.getColumns().stream().map(ColumnProjectionSegment::new).collect(Collectors.toList())));
}
SubqueryTableSegment subqueryTableSegment = new SubqueryTableSegment(segment.getStartIndex(), segment.getStopIndex(), segment.getSubquery());
subqueryTableSegment.setAlias(segment.getAliasSegment());
SubqueryTableSegment boundSubquerySegment =
SubqueryTableSegmentBinder.bind(subqueryTableSegment, binderContext, LinkedHashMultimap.create(), binderContext.getExternalTableBinderContexts());
CommonTableExpressionSegment result = new CommonTableExpressionSegment(
segment.getStartIndex(), segment.getStopIndex(), boundSubquerySegment.getAliasSegment().orElse(null), boundSubquerySegment.getSubquery());
// TODO bind with columns
result.getColumns().addAll(segment.getColumns());
return result;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.shardingsphere.infra.binder.engine.segment.with;

import com.cedarsoftware.util.CaseInsensitiveMap;
import com.cedarsoftware.util.CaseInsensitiveMap.CaseInsensitiveString;
import com.google.common.base.Strings;
import com.google.common.collect.Multimap;
import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import org.apache.shardingsphere.infra.binder.engine.segment.from.context.TableSegmentBinderContext;
import org.apache.shardingsphere.infra.binder.engine.segment.from.context.type.SimpleTableSegmentBinderContext;
import org.apache.shardingsphere.infra.binder.engine.statement.SQLStatementBinderContext;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.column.ColumnSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.complex.CommonTableExpressionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.ColumnProjectionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.ProjectionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.ShorthandProjectionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.WithSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.bound.ColumnSegmentBoundInfo;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.bound.TableSegmentBoundInfo;
import org.apache.shardingsphere.sql.parser.statement.core.value.identifier.IdentifierValue;

import java.util.Collection;
import java.util.LinkedList;
import java.util.Map;
import java.util.Optional;

/**
* With segment binder.
*/
@NoArgsConstructor(access = AccessLevel.PRIVATE)
public final class WithSegmentBinder {

/**
* Bind with segment.
*
* @param segment with segment
* @param binderContext SQL statement binder context
* @param externalTableBinderContexts external table binder contexts
* @return bound with segment
*/
public static WithSegment bind(final WithSegment segment, final SQLStatementBinderContext binderContext,
final Multimap<CaseInsensitiveString, TableSegmentBinderContext> externalTableBinderContexts) {
Collection<CommonTableExpressionSegment> boundCommonTableExpressions = new LinkedList<>();
for (CommonTableExpressionSegment each : segment.getCommonTableExpressions()) {
CommonTableExpressionSegment boundCommonTableExpression = CommonTableExpressionSegmentBinder.bind(each, binderContext, segment.isRecursive());
boundCommonTableExpressions.add(boundCommonTableExpression);
if (segment.isRecursive() && each.getAliasName().isPresent()) {
externalTableBinderContexts.removeAll(new CaseInsensitiveString(each.getAliasName().get()));
}
bindWithColumns(each.getColumns(), boundCommonTableExpression);
each.getAliasName().ifPresent(optional -> externalTableBinderContexts.put(new CaseInsensitiveString(optional), createWithTableBinderContext(boundCommonTableExpression)));
}
return new WithSegment(segment.getStartIndex(), segment.getStopIndex(), boundCommonTableExpressions);
}

private static SimpleTableSegmentBinderContext createWithTableBinderContext(final CommonTableExpressionSegment commonTableExpressionSegment) {
return new SimpleTableSegmentBinderContext(commonTableExpressionSegment.getSubquery().getSelect().getProjections().getProjections());
}

private static void bindWithColumns(final Collection<ColumnSegment> columns, final CommonTableExpressionSegment boundCommonTableExpression) {
if (columns.isEmpty()) {
return;
}
Map<String, ColumnProjectionSegment> columnProjections = extractWithSubqueryColumnProjections(boundCommonTableExpression);
columns.forEach(each -> {
ColumnProjectionSegment projectionSegment = columnProjections.get(each.getIdentifier().getValue());
if (null != projectionSegment) {
each.setColumnBoundInfo(createColumnSegmentBoundInfo(each, projectionSegment.getColumn()));
}
});
}

private static Map<String, ColumnProjectionSegment> extractWithSubqueryColumnProjections(final CommonTableExpressionSegment boundCommonTableExpression) {
Map<String, ColumnProjectionSegment> result = new CaseInsensitiveMap<>();
Collection<ProjectionSegment> projections = boundCommonTableExpression.getSubquery().getSelect().getProjections().getProjections();
projections.forEach(each -> extractWithSubqueryColumnProjections(each, result));
return result;
}

private static void extractWithSubqueryColumnProjections(final ProjectionSegment projectionSegment, final Map<String, ColumnProjectionSegment> result) {
if (projectionSegment instanceof ColumnProjectionSegment) {
result.put(getColumnName((ColumnProjectionSegment) projectionSegment), (ColumnProjectionSegment) projectionSegment);
}
if (projectionSegment instanceof ShorthandProjectionSegment) {
((ShorthandProjectionSegment) projectionSegment).getActualProjectionSegments().forEach(eachProjection -> {
if (eachProjection instanceof ColumnProjectionSegment) {
result.put(getColumnName((ColumnProjectionSegment) eachProjection), (ColumnProjectionSegment) eachProjection);
}
});
}
}

private static String getColumnName(final ColumnProjectionSegment columnProjection) {
return columnProjection.getAliasName().orElse(columnProjection.getColumn().getIdentifier().getValue());
}

private static ColumnSegmentBoundInfo createColumnSegmentBoundInfo(final ColumnSegment segment, final ColumnSegment inputColumnSegment) {
IdentifierValue originalDatabase = null == inputColumnSegment ? null : inputColumnSegment.getColumnBoundInfo().getOriginalDatabase();
IdentifierValue originalSchema = null == inputColumnSegment ? null : inputColumnSegment.getColumnBoundInfo().getOriginalSchema();
IdentifierValue segmentOriginalTable = segment.getColumnBoundInfo().getOriginalTable();
IdentifierValue originalTable = Strings.isNullOrEmpty(segmentOriginalTable.getValue())
? Optional.ofNullable(inputColumnSegment).map(optional -> optional.getColumnBoundInfo().getOriginalTable()).orElse(segmentOriginalTable)
: segmentOriginalTable;
IdentifierValue segmentOriginalColumn = segment.getColumnBoundInfo().getOriginalColumn();
IdentifierValue originalColumn = Optional.ofNullable(inputColumnSegment).map(optional -> optional.getColumnBoundInfo().getOriginalColumn()).orElse(segmentOriginalColumn);
return new ColumnSegmentBoundInfo(new TableSegmentBoundInfo(originalDatabase, originalSchema), originalTable, originalColumn);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.apache.shardingsphere.infra.binder.engine.segment.predicate.HavingSegmentBinder;
import org.apache.shardingsphere.infra.binder.engine.segment.predicate.WhereSegmentBinder;
import org.apache.shardingsphere.infra.binder.engine.segment.projection.ProjectionsSegmentBinder;
import org.apache.shardingsphere.infra.binder.engine.segment.with.WithSegmentBinder;
import org.apache.shardingsphere.infra.binder.engine.statement.SQLStatementBinder;
import org.apache.shardingsphere.infra.binder.engine.statement.SQLStatementBinderContext;
import org.apache.shardingsphere.infra.binder.engine.util.SubqueryTableBindUtils;
Expand Down Expand Up @@ -59,6 +60,7 @@ public SelectStatementBinder() {
public SelectStatement bind(final SelectStatement sqlStatement, final SQLStatementBinderContext binderContext) {
SelectStatement result = copy(sqlStatement);
Multimap<CaseInsensitiveString, TableSegmentBinderContext> tableBinderContexts = LinkedHashMultimap.create();
sqlStatement.getWithSegment().ifPresent(optional -> result.setWithSegment(WithSegmentBinder.bind(optional, binderContext, binderContext.getExternalTableBinderContexts())));
Optional<TableSegment> boundTableSegment = sqlStatement.getFrom().map(optional -> TableSegmentBinder.bind(optional, binderContext, tableBinderContexts, outerTableBinderContexts));
boundTableSegment.ifPresent(result::setFrom);
result.setProjections(ProjectionsSegmentBinder.bind(sqlStatement.getProjections(), binderContext, boundTableSegment.orElse(null), tableBinderContexts, outerTableBinderContexts));
Expand All @@ -71,7 +73,6 @@ public SelectStatement bind(final SelectStatement sqlStatement, final SQLStateme
sqlStatement.getOrderBy().ifPresent(optional -> result.setOrderBy(
OrderBySegmentBinder.bind(optional, binderContext, currentTableBinderContexts, tableBinderContexts, outerTableBinderContexts)));
sqlStatement.getHaving().ifPresent(optional -> result.setHaving(HavingSegmentBinder.bind(optional, binderContext, currentTableBinderContexts, outerTableBinderContexts)));
// TODO support other segment bind in select statement
return result;
}

Expand All @@ -90,7 +91,6 @@ private SelectStatement copy(final SelectStatement sqlStatement) {
sqlStatement.getWindow().ifPresent(result::setWindow);
sqlStatement.getModelSegment().ifPresent(result::setModelSegment);
sqlStatement.getSubqueryType().ifPresent(result::setSubqueryType);
sqlStatement.getWithSegment().ifPresent(result::setWithSegment);
result.addParameterMarkerSegments(sqlStatement.getParameterMarkerSegments());
result.getCommentSegments().addAll(sqlStatement.getCommentSegments());
result.getVariableNames().addAll(sqlStatement.getVariableNames());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.apache.shardingsphere.infra.binder.engine.segment.from.TableSegmentBinder;
import org.apache.shardingsphere.infra.binder.engine.segment.from.context.TableSegmentBinderContext;
import org.apache.shardingsphere.infra.binder.engine.segment.predicate.WhereSegmentBinder;
import org.apache.shardingsphere.infra.binder.engine.segment.with.WithSegmentBinder;
import org.apache.shardingsphere.infra.binder.engine.statement.SQLStatementBinder;
import org.apache.shardingsphere.infra.binder.engine.statement.SQLStatementBinderContext;
import org.apache.shardingsphere.sql.parser.statement.core.statement.dml.UpdateStatement;
Expand All @@ -38,6 +39,7 @@ public final class UpdateStatementBinder implements SQLStatementBinder<UpdateSta
public UpdateStatement bind(final UpdateStatement sqlStatement, final SQLStatementBinderContext binderContext) {
UpdateStatement result = copy(sqlStatement);
Multimap<CaseInsensitiveString, TableSegmentBinderContext> tableBinderContexts = LinkedHashMultimap.create();
sqlStatement.getWithSegment().ifPresent(optional -> result.setWithSegment(WithSegmentBinder.bind(optional, binderContext, binderContext.getExternalTableBinderContexts())));
result.setTable(TableSegmentBinder.bind(sqlStatement.getTable(), binderContext, tableBinderContexts, LinkedHashMultimap.create()));
sqlStatement.getFrom().ifPresent(optional -> result.setFrom(TableSegmentBinder.bind(optional, binderContext, tableBinderContexts, LinkedHashMultimap.create())));
sqlStatement.getAssignmentSegment().ifPresent(optional -> result.setSetAssignment(AssignmentSegmentBinder.bind(optional, binderContext, tableBinderContexts, LinkedHashMultimap.create())));
Expand All @@ -50,7 +52,6 @@ private UpdateStatement copy(final UpdateStatement sqlStatement) {
UpdateStatement result = sqlStatement.getClass().getDeclaredConstructor().newInstance();
sqlStatement.getOrderBy().ifPresent(result::setOrderBy);
sqlStatement.getLimit().ifPresent(result::setLimit);
sqlStatement.getWithSegment().ifPresent(result::setWithSegment);
result.addParameterMarkerSegments(sqlStatement.getParameterMarkerSegments());
result.getCommentSegments().addAll(sqlStatement.getCommentSegments());
return result;
Expand Down
Loading

0 comments on commit ff550f8

Please sign in to comment.