Skip to content

Commit

Permalink
Support of customSqlUpdate to support FROM
Browse files Browse the repository at this point in the history
  • Loading branch information
kanha-gupta committed Oct 25, 2023
1 parent 6591722 commit ab34884
Show file tree
Hide file tree
Showing 5 changed files with 232 additions and 5 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.shardingsphere.sqlfederation.optimizer.converter.statement.update;

import org.apache.calcite.sql.SqlCall;
import org.apache.calcite.sql.SqlIdentifier;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlNodeList;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.SqlSelect;
import org.apache.calcite.sql.SqlSpecialOperator;
import org.apache.calcite.sql.SqlWriter;
import org.apache.calcite.sql.SqlWriter.Frame;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.util.ImmutableNullableList;
import org.apache.calcite.util.Pair;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.checkerframework.dataflow.qual.Pure;

import java.util.Iterator;
import java.util.List;

public class CustomSqlUpdate extends SqlCall {

public static final SqlSpecialOperator OPERATOR;

private SqlNode targetTable;

private SqlNodeList targetColumnList;

private SqlNodeList sourceExpressionList;

@Nullable
private SqlNode condition;

@Nullable
private SqlSelect sourceSelect;

@Nullable
private SqlIdentifier alias;

@Nullable
private SqlNode from;

public CustomSqlUpdate(final SqlParserPos pos, final SqlNode targetTable, final SqlNodeList targetColumnList, final SqlNodeList sourceExpressionList, final @Nullable SqlNode condition,
final @Nullable SqlSelect sourceSelect, final @Nullable SqlIdentifier alias, final @Nullable SqlNode from) {
super(pos);
this.targetTable = targetTable;
this.targetColumnList = targetColumnList;
this.sourceExpressionList = sourceExpressionList;
this.condition = condition;
this.sourceSelect = sourceSelect;
this.from = from;
assert sourceExpressionList.size() == targetColumnList.size();

this.alias = alias;
}

public SqlKind getKind() {
return SqlKind.UPDATE;
}

public SqlOperator getOperator() {
return OPERATOR;
}

public List<@Nullable SqlNode> getOperandList() {
return ImmutableNullableList.of(this.targetTable, this.targetColumnList, this.sourceExpressionList, this.condition, this.alias, this.from);
}

@Override
public void setOperand(final int i, final @Nullable SqlNode operand) {
switch (i) {
case 0:
assert operand instanceof SqlIdentifier;

this.targetTable = operand;
break;
case 1:
this.targetColumnList = (SqlNodeList) operand;
break;
case 2:
this.sourceExpressionList = (SqlNodeList) operand;
break;
case 3:
this.condition = operand;
break;
case 4:
this.sourceExpressionList = (SqlNodeList) operand;
break;
case 5:
this.alias = (SqlIdentifier) operand;
break;
case 6:
this.from = operand;
break;
default:
throw new AssertionError(i);
}

}

public SqlNode getTargetTable() {
return this.targetTable;
}

@Pure
public final @Nullable SqlNode getFrom() {
return this.from;
}

public void setFrom(final @Nullable SqlNode from) {
this.from = from;
}

@Pure
public @Nullable SqlIdentifier getAlias() {
return this.alias;
}

public void setAlias(final SqlIdentifier alias) {
this.alias = alias;
}

public SqlNodeList getTargetColumnList() {
return this.targetColumnList;
}

public SqlNodeList getSourceExpressionList() {
return this.sourceExpressionList;
}

public @Nullable SqlNode getCondition() {
return this.condition;
}

public @Nullable SqlSelect getSourceSelect() {
return this.sourceSelect;
}

public void setSourceSelect(final SqlSelect sourceSelect) {
this.sourceSelect = sourceSelect;
}

@Override
public void unparse(final SqlWriter writer, final int leftPrec, final int rightPrec) {
final Frame frame = writer.startList(SqlWriter.FrameTypeEnum.SELECT, "UPDATE", "");
int opLeft = this.getOperator().getLeftPrec();
int opRight = this.getOperator().getRightPrec();
this.targetTable.unparse(writer, opLeft, opRight);
SqlIdentifier alias = this.alias;
if (alias != null) {
writer.keyword("AS");
alias.unparse(writer, opLeft, opRight);
}
SqlWriter.Frame setFrame = writer.startList(SqlWriter.FrameTypeEnum.UPDATE_SET_LIST, "SET", "");
Iterator var9 = Pair.zip(this.getTargetColumnList(), this.getSourceExpressionList()).iterator();
while (var9.hasNext()) {
Pair<SqlNode, SqlNode> pair = (Pair) var9.next();
writer.sep(",");
SqlIdentifier id = (SqlIdentifier) pair.left;
id.unparse(writer, opLeft, opRight);
writer.keyword("=");
SqlNode sourceExp = (SqlNode) pair.right;
sourceExp.unparse(writer, opLeft, opRight);
}
writer.endList(setFrame);
SqlNode from = this.from;
if (from != null) {
writer.sep("FROM");
from.unparse(writer, opLeft, opRight);
}
SqlNode condition = this.condition;
if (condition != null) {
writer.sep("WHERE");
condition.unparse(writer, opLeft, opRight);
}
writer.endList(frame);
}

static {
OPERATOR = new SqlSpecialOperator("UPDATE", SqlKind.UPDATE);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlNodeList;
import org.apache.calcite.sql.SqlOrderBy;
import org.apache.calcite.sql.SqlUpdate;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.AssignmentSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.pagination.limit.LimitSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.TableSegment;
import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.UpdateStatement;
import org.apache.shardingsphere.sql.parser.sql.dialect.handler.dml.UpdateStatementHandler;
import org.apache.shardingsphere.sqlfederation.optimizer.converter.segment.expression.ExpressionConverter;
Expand All @@ -34,6 +34,7 @@
import org.apache.shardingsphere.sqlfederation.optimizer.converter.segment.limit.PaginationValueSQLConverter;
import org.apache.shardingsphere.sqlfederation.optimizer.converter.segment.orderby.OrderByConverter;
import org.apache.shardingsphere.sqlfederation.optimizer.converter.segment.where.WhereConverter;
import org.apache.shardingsphere.sqlfederation.optimizer.converter.segment.with.WithConverter;
import org.apache.shardingsphere.sqlfederation.optimizer.converter.statement.SQLStatementConverter;

import java.util.List;
Expand All @@ -47,7 +48,7 @@ public final class UpdateStatementConverter implements SQLStatementConverter<Upd

@Override
public SqlNode convert(final UpdateStatement updateStatement) {
SqlUpdate sqlUpdate = convertUpdate(updateStatement);
SqlNode sqlUpdate = convertUpdate(updateStatement);
SqlNodeList orderBy = UpdateStatementHandler.getOrderBySegment(updateStatement).flatMap(OrderByConverter::convert).orElse(SqlNodeList.EMPTY);
Optional<LimitSegment> limit = UpdateStatementHandler.getLimitSegment(updateStatement);
if (limit.isPresent()) {
Expand All @@ -58,16 +59,18 @@ public SqlNode convert(final UpdateStatement updateStatement) {
return orderBy.isEmpty() ? sqlUpdate : new SqlOrderBy(SqlParserPos.ZERO, sqlUpdate, orderBy, null, null);
}

private SqlUpdate convertUpdate(final UpdateStatement updateStatement) {
private SqlNode convertUpdate(final UpdateStatement updateStatement) {
SqlNode table = TableConverter.convert(updateStatement.getTable()).orElseThrow(IllegalStateException::new);
SqlNode from = convertTable(updateStatement.getAssignmentSegment().orElse(null).getFrom());
SqlNode condition = updateStatement.getWhere().flatMap(WhereConverter::convert).orElse(null);
SqlNodeList columns = new SqlNodeList(SqlParserPos.ZERO);
SqlNodeList expressions = new SqlNodeList(SqlParserPos.ZERO);
for (AssignmentSegment each : updateStatement.getAssignmentSegment().orElseThrow(IllegalStateException::new).getAssignments()) {
columns.addAll(convertColumn(each.getColumns()));
expressions.add(convertExpression(each.getValue()));
}
return new SqlUpdate(SqlParserPos.ZERO, table, columns, expressions, condition, null, null);
CustomSqlUpdate sqlUpdate = new CustomSqlUpdate(SqlParserPos.ZERO, table, columns, expressions, condition, null, null, from);
return UpdateStatementHandler.getWithSegment(updateStatement).flatMap(optional -> WithConverter.convert(optional, sqlUpdate)).orElse(sqlUpdate);
}

private List<SqlNode> convertColumn(final List<ColumnSegment> columnSegments) {
Expand All @@ -77,4 +80,8 @@ private List<SqlNode> convertColumn(final List<ColumnSegment> columnSegments) {
private SqlNode convertExpression(final ExpressionSegment expressionSegment) {
return ExpressionConverter.convert(expressionSegment).orElseThrow(IllegalStateException::new);
}

private SqlNode convertTable(final TableSegment tableSegment) {
return TableConverter.convert(tableSegment).orElse(null);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -995,9 +995,14 @@ public ASTNode visitUpdate(final UpdateContext ctx) {
@Override
public ASTNode visitSetAssignmentsClause(final SetAssignmentsClauseContext ctx) {
Collection<AssignmentSegment> assignments = new LinkedList<>();
TableSegment from;
for (AssignmentContext each : ctx.assignment()) {
assignments.add((AssignmentSegment) visit(each));
}
if (null != ctx.fromClause()) {
from = (TableSegment) visit(ctx.fromClause().tableReferences());
return new SetAssignmentSegment(ctx.getStart().getStartIndex(), ctx.getStop().getStopIndex(), assignments, from);
}
return new SetAssignmentSegment(ctx.getStart().getStartIndex(), ctx.getStop().getStopIndex(), assignments);
}

Expand All @@ -1017,7 +1022,6 @@ public ASTNode visitAssignment(final AssignmentContext ctx) {
columnSegments.add(column);
ExpressionSegment value = (ExpressionSegment) visit(ctx.assignmentValue());
AssignmentSegment result = new ColumnAssignmentSegment(ctx.getStart().getStartIndex(), ctx.getStop().getStopIndex(), columnSegments, value);
result.getColumns().add(column);
return result;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import org.apache.shardingsphere.sql.parser.sql.common.segment.SQLSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.TableSegment;

import java.util.Collection;

Expand All @@ -35,4 +36,17 @@ public final class SetAssignmentSegment implements SQLSegment {
private final int stopIndex;

private final Collection<AssignmentSegment> assignments;

private TableSegment from;

public SetAssignmentSegment(final int startIndex, final int stopIndex, final Collection<AssignmentSegment> assignments, final TableSegment from) {
this.startIndex = startIndex;
this.stopIndex = stopIndex;
this.assignments = assignments;
this.from = from;
}

public TableSegment getFrom() {
return from;
}
}
2 changes: 2 additions & 0 deletions test/it/optimizer/src/test/resources/converter/update.xml
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,6 @@
<!--<test-cases sql-case-id="update_with_translate_function" expected-sql="UPDATE &quot;translate_tab&quot; SET &quot;char_col&quot; = TRANSLATE(&quot;nchar_col&quot; USING 'CHAR_CS')" db-types="Oracle" />-->
<test-cases sql-case-id="update_with_dot_column_name" expected-sql="UPDATE &quot;employees&quot; SET &quot;salary&quot; = &quot;salary&quot; + 10 WHERE &quot;employee_id&quot; BETWEEN ASYMMETRIC 1 AND 10" db-types="Oracle" sql-case-types="LITERAL" />
<test-cases sql-case-id="update_with_dot_column_name" expected-sql="UPDATE &quot;employees&quot; SET &quot;salary&quot; = &quot;salary&quot; + ? WHERE &quot;employee_id&quot; BETWEEN ASYMMETRIC ? AND ?" db-types="Oracle" sql-case-types="PLACEHOLDER" />
<test-cases sql-case-id="update_with_with_clause" expected-sql="(WITH [cte] ([order_id], [user_id], [status]) AS (SELECT [order_id], [user_id], [status] FROM [t_order]) UPDATE [t_order] SET [status] = 1 FROM [t_order] AS [t] INNER JOIN [cte] AS [c] ON [t].[order_id] = [c].[order_id] WHERE [c].[order_id] = 1)" db-types="SQLServer" sql-case-types="LITERAL" />
<test-cases sql-case-id="update_with_with_clause" expected-sql="(WITH [cte] ([order_id], [user_id], [status]) AS (SELECT [order_id], [user_id], [status] FROM [t_order]) UPDATE [t_order] SET [status] = ? FROM [t_order] AS [t] INNER JOIN [cte] AS [c] ON [t].[order_id] = [c].[order_id] WHERE [c].[order_id] = ?)" db-types="SQLServer" sql-case-types="PLACEHOLDER" />
</sql-node-converter-test-cases>

0 comments on commit ab34884

Please sign in to comment.