Skip to content

Commit

Permalink
Minimal refactor of oracle dml statement parameter parse (#28462)
Browse files Browse the repository at this point in the history
* Minimal refactor of oracle dml statement parameter parse

* fix build error

* fix sql parser test case
  • Loading branch information
strongduanmu authored Sep 19, 2023
1 parent ea0e3cf commit bbc2b19
Show file tree
Hide file tree
Showing 8 changed files with 87 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import org.apache.shardingsphere.infra.binder.context.segment.insert.values.InsertValueContext;
import org.apache.shardingsphere.infra.binder.context.segment.insert.values.expression.DerivedLiteralExpressionSegment;
import org.apache.shardingsphere.infra.binder.context.segment.insert.values.expression.DerivedParameterMarkerExpressionSegment;
import org.apache.shardingsphere.infra.binder.context.segment.insert.values.expression.DerivedSimpleExpressionSegment;
import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.binder.context.statement.dml.InsertStatementContext;
import org.apache.shardingsphere.infra.database.core.type.DatabaseTypeRegistry;
Expand All @@ -41,9 +40,11 @@
import org.apache.shardingsphere.infra.rewrite.sql.token.pojo.generic.InsertValuesToken;
import org.apache.shardingsphere.infra.rewrite.sql.token.pojo.generic.UseDefaultInsertColumnsToken;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.InsertValuesSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.LiteralExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.ParameterMarkerExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.common.value.identifier.IdentifierValue;

import java.util.Collection;
import java.util.Iterator;
Expand Down Expand Up @@ -143,14 +144,15 @@ private void encryptToken(final InsertValue insertValueToken, final String schem
int columnIndex = useDefaultInsertColumnsToken
.map(optional -> ((UseDefaultInsertColumnsToken) optional).getColumns().indexOf(columnName)).orElseGet(() -> insertStatementContext.getColumnNames().indexOf(columnName));
Object originalValue = insertValueContext.getLiteralValue(columnIndex).orElse(null);
setCipherColumn(schemaName, tableName, encryptColumn, insertValueToken, insertValueContext.getValueExpressions().get(columnIndex), columnIndex, originalValue);
ExpressionSegment valueExpression = insertValueContext.getValueExpressions().get(columnIndex);
setCipherColumn(schemaName, tableName, encryptColumn, insertValueToken, valueExpression, columnIndex, originalValue);
int indexDelta = 1;
if (encryptColumn.getAssistedQuery().isPresent()) {
addAssistedQueryColumn(schemaName, tableName, encryptColumn, insertValueContext, insertValueToken, columnIndex, indexDelta, originalValue);
addAssistedQueryColumn(schemaName, tableName, encryptColumn, insertValueToken, valueExpression, columnIndex, indexDelta, originalValue);
indexDelta++;
}
if (encryptColumn.getLikeQuery().isPresent()) {
addLikeQueryColumn(schemaName, tableName, encryptColumn, insertValueContext, insertValueToken, columnIndex, indexDelta, originalValue);
addLikeQueryColumn(schemaName, tableName, encryptColumn, insertValueToken, valueExpression, columnIndex, indexDelta, originalValue);
}
}
}
Expand All @@ -163,31 +165,45 @@ private void setCipherColumn(final String schemaName, final String tableName, fi
}
}

private void addAssistedQueryColumn(final String schemaName, final String tableName, final EncryptColumn encryptColumn,
final InsertValueContext insertValueContext, final InsertValue insertValueToken, final int columnIndex, final int indexDelta, final Object originalValue) {
private void addAssistedQueryColumn(final String schemaName, final String tableName, final EncryptColumn encryptColumn, final InsertValue insertValueToken,
final ExpressionSegment valueExpression, final int columnIndex, final int indexDelta, final Object originalValue) {
Optional<AssistedQueryColumnItem> assistedQueryColumnItem = encryptColumn.getAssistedQuery();
Preconditions.checkState(assistedQueryColumnItem.isPresent());
Object derivedValue = assistedQueryColumnItem.get().encrypt(databaseName, schemaName, tableName, encryptColumn.getName(), originalValue);
addDerivedColumn(insertValueContext, insertValueToken, columnIndex, indexDelta, derivedValue);
addDerivedColumn(insertValueToken, valueExpression, columnIndex, indexDelta, derivedValue, assistedQueryColumnItem.get().getName());
}

private void addLikeQueryColumn(final String schemaName, final String tableName, final EncryptColumn encryptColumn,
final InsertValueContext insertValueContext, final InsertValue insertValueToken, final int columnIndex, final int indexDelta, final Object originalValue) {
private void addLikeQueryColumn(final String schemaName, final String tableName, final EncryptColumn encryptColumn, final InsertValue insertValueToken,
final ExpressionSegment valueExpression, final int columnIndex, final int indexDelta, final Object originalValue) {
Optional<LikeQueryColumnItem> likeQueryColumnItem = encryptColumn.getLikeQuery();
Preconditions.checkState(likeQueryColumnItem.isPresent());
Object derivedValue = likeQueryColumnItem.get().encrypt(databaseName, schemaName, tableName, encryptColumn.getName(), originalValue);
addDerivedColumn(insertValueContext, insertValueToken, columnIndex, indexDelta, derivedValue);
addDerivedColumn(insertValueToken, valueExpression, columnIndex, indexDelta, derivedValue, likeQueryColumnItem.get().getName());
}

private void addDerivedColumn(final InsertValueContext insertValueContext, final InsertValue insertValueToken, final int columnIndex, final int indexDelta, final Object derivedValue) {
DerivedSimpleExpressionSegment derivedExpressionSegment = isAddLiteralExpressionSegment(insertValueContext, columnIndex)
? new DerivedLiteralExpressionSegment(derivedValue)
: new DerivedParameterMarkerExpressionSegment(getParameterIndexCount(insertValueToken));
insertValueToken.getValues().add(columnIndex + indexDelta, derivedExpressionSegment);
private void addDerivedColumn(final InsertValue insertValueToken, final ExpressionSegment valueExpression, final int columnIndex, final int indexDelta, final Object derivedValue,
final String derivedColumnName) {
ExpressionSegment derivedExpression;
if (valueExpression instanceof LiteralExpressionSegment) {
derivedExpression = new DerivedLiteralExpressionSegment(derivedValue);
} else if (valueExpression instanceof ParameterMarkerExpressionSegment) {
derivedExpression = new DerivedParameterMarkerExpressionSegment(getParameterIndexCount(insertValueToken));
} else if (valueExpression instanceof ColumnSegment) {
derivedExpression = createColumnSegment((ColumnSegment) valueExpression, derivedColumnName);
} else {
derivedExpression = valueExpression;
}
insertValueToken.getValues().add(columnIndex + indexDelta, derivedExpression);
}

private boolean isAddLiteralExpressionSegment(final InsertValueContext insertValueContext, final int columnIndex) {
return insertValueContext.getParameters().isEmpty() || insertValueContext.getValueExpressions().get(columnIndex) instanceof LiteralExpressionSegment;
private ColumnSegment createColumnSegment(final ColumnSegment originalColumn, final String columnName) {
ColumnSegment result = new ColumnSegment(originalColumn.getStartIndex(), originalColumn.getStopIndex(), new IdentifierValue(columnName, originalColumn.getIdentifier().getQuoteCharacter()));
result.setNestedObjectAttributes(originalColumn.getNestedObjectAttributes());
originalColumn.getOwner().ifPresent(result::setOwner);
result.setColumnBoundedInfo(originalColumn.getColumnBoundedInfo());
result.setOtherUsingColumnBoundedInfo(originalColumn.getOtherUsingColumnBoundedInfo());
result.setVariable(originalColumn.isVariable());
return result;
}

private int getParameterIndexCount(final InsertValue insertValueToken) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@
import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.BitValueLiteralsContext;
import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.BooleanLiteralsContext;
import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.BooleanPrimaryContext;
import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.CastFunctionContext;
import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.CaseExpressionContext;
import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.CaseWhenContext;
import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.CastFunctionContext;
import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.CharFunctionContext;
import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.ColumnNameContext;
import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.ColumnNamesContext;
Expand All @@ -45,8 +45,8 @@
import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.DateTimeLiteralsContext;
import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.DatetimeExprContext;
import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.ExprContext;
import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.ExtractFunctionContext;
import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.ExprListContext;
import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.ExtractFunctionContext;
import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.FeatureFunctionContext;
import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.FirstOrLastValueFunctionContext;
import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.FormatFunctionContext;
Expand Down Expand Up @@ -79,8 +79,8 @@
import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.TableNameContext;
import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.TableNamesContext;
import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.ToDateFunctionContext;
import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.TrimFunctionContext;
import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.TranslateFunctionContext;
import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.TrimFunctionContext;
import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.TypeNameContext;
import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.UnreservedWordContext;
import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.ViewNameContext;
Expand Down Expand Up @@ -180,11 +180,13 @@
@Getter
public abstract class OracleStatementVisitor extends OracleStatementBaseVisitor<ASTNode> {

private final Collection<ParameterMarkerSegment> parameterMarkerSegments = new LinkedList<>();
private final Collection<ParameterMarkerSegment> globalParameterMarkerSegments = new LinkedList<>();

private final Collection<ParameterMarkerSegment> statementParameterMarkerSegments = new LinkedList<>();

@Override
public final ASTNode visitParameterMarker(final ParameterMarkerContext ctx) {
return new ParameterMarkerValue(parameterMarkerSegments.size(), ParameterMarkerType.QUESTION);
return new ParameterMarkerValue(globalParameterMarkerSegments.size(), ParameterMarkerType.QUESTION);
}

@Override
Expand Down Expand Up @@ -537,7 +539,8 @@ private ASTNode createExpressionSegment(final ASTNode astNode, final ParserRuleC
ParameterMarkerValue parameterMarker = (ParameterMarkerValue) astNode;
ParameterMarkerExpressionSegment segment = new ParameterMarkerExpressionSegment(context.start.getStartIndex(), context.stop.getStopIndex(),
parameterMarker.getValue(), parameterMarker.getType());
parameterMarkerSegments.add(segment);
globalParameterMarkerSegments.add(segment);
statementParameterMarkerSegments.add(segment);
return segment;
}
if (astNode instanceof SubquerySegment) {
Expand All @@ -559,7 +562,8 @@ public final ASTNode visitSimpleExpr(final SimpleExprContext ctx) {
if (null != ctx.parameterMarker()) {
ParameterMarkerValue parameterMarker = (ParameterMarkerValue) visit(ctx.parameterMarker());
ParameterMarkerExpressionSegment segment = new ParameterMarkerExpressionSegment(startIndex, stopIndex, parameterMarker.getValue(), parameterMarker.getType());
parameterMarkerSegments.add(segment);
globalParameterMarkerSegments.add(segment);
statementParameterMarkerSegments.add(segment);
return segment;
}
if (null != ctx.literals()) {
Expand Down Expand Up @@ -1172,4 +1176,15 @@ public final ASTNode visitDataTypeLength(final DataTypeLengthContext ctx) {
protected String getOriginalText(final ParserRuleContext ctx) {
return ctx.start.getInputStream().getText(new Interval(ctx.start.getStartIndex(), ctx.stop.getStopIndex()));
}

/**
* Pop all statement parameter marker segments.
*
* @return all statement parameter marker segments
*/
protected Collection<ParameterMarkerSegment> popAllStatementParameterMarkerSegments() {
Collection<ParameterMarkerSegment> result = new LinkedList<>(statementParameterMarkerSegments);
statementParameterMarkerSegments.clear();
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.apache.shardingsphere.sql.parser.api.ASTNode;
import org.apache.shardingsphere.sql.parser.api.visitor.statement.type.DALStatementVisitor;
import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.AlterResourceCostContext;
import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.ExecuteContext;
import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.ExplainContext;
import org.apache.shardingsphere.sql.parser.oracle.visitor.statement.OracleStatementVisitor;
import org.apache.shardingsphere.sql.parser.sql.common.statement.SQLStatement;
Expand All @@ -40,7 +41,8 @@ public ASTNode visitAlterResourceCost(final AlterResourceCostContext ctx) {
public ASTNode visitExplain(final ExplainContext ctx) {
OracleExplainStatement result = new OracleExplainStatement();
OracleDMLStatementVisitor visitor = new OracleDMLStatementVisitor();
visitor.getParameterMarkerSegments().addAll(getParameterMarkerSegments());
getGlobalParameterMarkerSegments().addAll(visitor.getGlobalParameterMarkerSegments());
getStatementParameterMarkerSegments().addAll(visitor.getStatementParameterMarkerSegments());
if (null != ctx.insert()) {
result.setStatement((SQLStatement) visitor.visit(ctx.insert()));
} else if (null != ctx.delete()) {
Expand All @@ -50,7 +52,7 @@ public ASTNode visitExplain(final ExplainContext ctx) {
} else if (null != ctx.select()) {
result.setStatement((SQLStatement) visitor.visit(ctx.select()));
}
result.addParameterMarkerSegments(getParameterMarkerSegments());
result.addParameterMarkerSegments(ctx.getParent() instanceof ExecuteContext ? getGlobalParameterMarkerSegments() : popAllStatementParameterMarkerSegments());
return result;
}
}
Loading

0 comments on commit bbc2b19

Please sign in to comment.