Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minimal refactor of oracle dml statement parameter parse #28462

Merged
merged 3 commits into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import org.apache.shardingsphere.infra.binder.context.segment.insert.values.InsertValueContext;
import org.apache.shardingsphere.infra.binder.context.segment.insert.values.expression.DerivedLiteralExpressionSegment;
import org.apache.shardingsphere.infra.binder.context.segment.insert.values.expression.DerivedParameterMarkerExpressionSegment;
import org.apache.shardingsphere.infra.binder.context.segment.insert.values.expression.DerivedSimpleExpressionSegment;
import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.binder.context.statement.dml.InsertStatementContext;
import org.apache.shardingsphere.infra.database.core.type.DatabaseTypeRegistry;
Expand All @@ -41,9 +40,11 @@
import org.apache.shardingsphere.infra.rewrite.sql.token.pojo.generic.InsertValuesToken;
import org.apache.shardingsphere.infra.rewrite.sql.token.pojo.generic.UseDefaultInsertColumnsToken;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.InsertValuesSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.LiteralExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.ParameterMarkerExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.common.value.identifier.IdentifierValue;

import java.util.Collection;
import java.util.Iterator;
Expand Down Expand Up @@ -143,14 +144,15 @@ private void encryptToken(final InsertValue insertValueToken, final String schem
int columnIndex = useDefaultInsertColumnsToken
.map(optional -> ((UseDefaultInsertColumnsToken) optional).getColumns().indexOf(columnName)).orElseGet(() -> insertStatementContext.getColumnNames().indexOf(columnName));
Object originalValue = insertValueContext.getLiteralValue(columnIndex).orElse(null);
setCipherColumn(schemaName, tableName, encryptColumn, insertValueToken, insertValueContext.getValueExpressions().get(columnIndex), columnIndex, originalValue);
ExpressionSegment valueExpression = insertValueContext.getValueExpressions().get(columnIndex);
setCipherColumn(schemaName, tableName, encryptColumn, insertValueToken, valueExpression, columnIndex, originalValue);
int indexDelta = 1;
if (encryptColumn.getAssistedQuery().isPresent()) {
addAssistedQueryColumn(schemaName, tableName, encryptColumn, insertValueContext, insertValueToken, columnIndex, indexDelta, originalValue);
addAssistedQueryColumn(schemaName, tableName, encryptColumn, insertValueToken, valueExpression, columnIndex, indexDelta, originalValue);
indexDelta++;
}
if (encryptColumn.getLikeQuery().isPresent()) {
addLikeQueryColumn(schemaName, tableName, encryptColumn, insertValueContext, insertValueToken, columnIndex, indexDelta, originalValue);
addLikeQueryColumn(schemaName, tableName, encryptColumn, insertValueToken, valueExpression, columnIndex, indexDelta, originalValue);
}
}
}
Expand All @@ -163,31 +165,45 @@ private void setCipherColumn(final String schemaName, final String tableName, fi
}
}

private void addAssistedQueryColumn(final String schemaName, final String tableName, final EncryptColumn encryptColumn,
final InsertValueContext insertValueContext, final InsertValue insertValueToken, final int columnIndex, final int indexDelta, final Object originalValue) {
private void addAssistedQueryColumn(final String schemaName, final String tableName, final EncryptColumn encryptColumn, final InsertValue insertValueToken,
final ExpressionSegment valueExpression, final int columnIndex, final int indexDelta, final Object originalValue) {
Optional<AssistedQueryColumnItem> assistedQueryColumnItem = encryptColumn.getAssistedQuery();
Preconditions.checkState(assistedQueryColumnItem.isPresent());
Object derivedValue = assistedQueryColumnItem.get().encrypt(databaseName, schemaName, tableName, encryptColumn.getName(), originalValue);
addDerivedColumn(insertValueContext, insertValueToken, columnIndex, indexDelta, derivedValue);
addDerivedColumn(insertValueToken, valueExpression, columnIndex, indexDelta, derivedValue, assistedQueryColumnItem.get().getName());
}

private void addLikeQueryColumn(final String schemaName, final String tableName, final EncryptColumn encryptColumn,
final InsertValueContext insertValueContext, final InsertValue insertValueToken, final int columnIndex, final int indexDelta, final Object originalValue) {
private void addLikeQueryColumn(final String schemaName, final String tableName, final EncryptColumn encryptColumn, final InsertValue insertValueToken,
final ExpressionSegment valueExpression, final int columnIndex, final int indexDelta, final Object originalValue) {
Optional<LikeQueryColumnItem> likeQueryColumnItem = encryptColumn.getLikeQuery();
Preconditions.checkState(likeQueryColumnItem.isPresent());
Object derivedValue = likeQueryColumnItem.get().encrypt(databaseName, schemaName, tableName, encryptColumn.getName(), originalValue);
addDerivedColumn(insertValueContext, insertValueToken, columnIndex, indexDelta, derivedValue);
addDerivedColumn(insertValueToken, valueExpression, columnIndex, indexDelta, derivedValue, likeQueryColumnItem.get().getName());
}

private void addDerivedColumn(final InsertValueContext insertValueContext, final InsertValue insertValueToken, final int columnIndex, final int indexDelta, final Object derivedValue) {
DerivedSimpleExpressionSegment derivedExpressionSegment = isAddLiteralExpressionSegment(insertValueContext, columnIndex)
? new DerivedLiteralExpressionSegment(derivedValue)
: new DerivedParameterMarkerExpressionSegment(getParameterIndexCount(insertValueToken));
insertValueToken.getValues().add(columnIndex + indexDelta, derivedExpressionSegment);
private void addDerivedColumn(final InsertValue insertValueToken, final ExpressionSegment valueExpression, final int columnIndex, final int indexDelta, final Object derivedValue,
final String derivedColumnName) {
ExpressionSegment derivedExpression;
if (valueExpression instanceof LiteralExpressionSegment) {
derivedExpression = new DerivedLiteralExpressionSegment(derivedValue);
} else if (valueExpression instanceof ParameterMarkerExpressionSegment) {
derivedExpression = new DerivedParameterMarkerExpressionSegment(getParameterIndexCount(insertValueToken));
} else if (valueExpression instanceof ColumnSegment) {
derivedExpression = createColumnSegment((ColumnSegment) valueExpression, derivedColumnName);
} else {
derivedExpression = valueExpression;
}
insertValueToken.getValues().add(columnIndex + indexDelta, derivedExpression);
}

private boolean isAddLiteralExpressionSegment(final InsertValueContext insertValueContext, final int columnIndex) {
return insertValueContext.getParameters().isEmpty() || insertValueContext.getValueExpressions().get(columnIndex) instanceof LiteralExpressionSegment;
private ColumnSegment createColumnSegment(final ColumnSegment originalColumn, final String columnName) {
ColumnSegment result = new ColumnSegment(originalColumn.getStartIndex(), originalColumn.getStopIndex(), new IdentifierValue(columnName, originalColumn.getIdentifier().getQuoteCharacter()));
result.setNestedObjectAttributes(originalColumn.getNestedObjectAttributes());
originalColumn.getOwner().ifPresent(result::setOwner);
result.setColumnBoundedInfo(originalColumn.getColumnBoundedInfo());
result.setOtherUsingColumnBoundedInfo(originalColumn.getOtherUsingColumnBoundedInfo());
result.setVariable(originalColumn.isVariable());
return result;
}

private int getParameterIndexCount(final InsertValue insertValueToken) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@
import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.BitValueLiteralsContext;
import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.BooleanLiteralsContext;
import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.BooleanPrimaryContext;
import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.CastFunctionContext;
import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.CaseExpressionContext;
import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.CaseWhenContext;
import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.CastFunctionContext;
import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.CharFunctionContext;
import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.ColumnNameContext;
import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.ColumnNamesContext;
Expand All @@ -45,8 +45,8 @@
import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.DateTimeLiteralsContext;
import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.DatetimeExprContext;
import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.ExprContext;
import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.ExtractFunctionContext;
import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.ExprListContext;
import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.ExtractFunctionContext;
import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.FeatureFunctionContext;
import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.FirstOrLastValueFunctionContext;
import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.FormatFunctionContext;
Expand Down Expand Up @@ -79,8 +79,8 @@
import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.TableNameContext;
import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.TableNamesContext;
import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.ToDateFunctionContext;
import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.TrimFunctionContext;
import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.TranslateFunctionContext;
import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.TrimFunctionContext;
import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.TypeNameContext;
import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.UnreservedWordContext;
import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.ViewNameContext;
Expand Down Expand Up @@ -180,11 +180,13 @@
@Getter
public abstract class OracleStatementVisitor extends OracleStatementBaseVisitor<ASTNode> {

private final Collection<ParameterMarkerSegment> parameterMarkerSegments = new LinkedList<>();
private final Collection<ParameterMarkerSegment> globalParameterMarkerSegments = new LinkedList<>();

private final Collection<ParameterMarkerSegment> statementParameterMarkerSegments = new LinkedList<>();

@Override
public final ASTNode visitParameterMarker(final ParameterMarkerContext ctx) {
return new ParameterMarkerValue(parameterMarkerSegments.size(), ParameterMarkerType.QUESTION);
return new ParameterMarkerValue(globalParameterMarkerSegments.size(), ParameterMarkerType.QUESTION);
}

@Override
Expand Down Expand Up @@ -537,7 +539,8 @@ private ASTNode createExpressionSegment(final ASTNode astNode, final ParserRuleC
ParameterMarkerValue parameterMarker = (ParameterMarkerValue) astNode;
ParameterMarkerExpressionSegment segment = new ParameterMarkerExpressionSegment(context.start.getStartIndex(), context.stop.getStopIndex(),
parameterMarker.getValue(), parameterMarker.getType());
parameterMarkerSegments.add(segment);
globalParameterMarkerSegments.add(segment);
statementParameterMarkerSegments.add(segment);
return segment;
}
if (astNode instanceof SubquerySegment) {
Expand All @@ -559,7 +562,8 @@ public final ASTNode visitSimpleExpr(final SimpleExprContext ctx) {
if (null != ctx.parameterMarker()) {
ParameterMarkerValue parameterMarker = (ParameterMarkerValue) visit(ctx.parameterMarker());
ParameterMarkerExpressionSegment segment = new ParameterMarkerExpressionSegment(startIndex, stopIndex, parameterMarker.getValue(), parameterMarker.getType());
parameterMarkerSegments.add(segment);
globalParameterMarkerSegments.add(segment);
statementParameterMarkerSegments.add(segment);
return segment;
}
if (null != ctx.literals()) {
Expand Down Expand Up @@ -1172,4 +1176,15 @@ public final ASTNode visitDataTypeLength(final DataTypeLengthContext ctx) {
protected String getOriginalText(final ParserRuleContext ctx) {
return ctx.start.getInputStream().getText(new Interval(ctx.start.getStartIndex(), ctx.stop.getStopIndex()));
}

/**
* Pop all statement parameter marker segments.
*
* @return all statement parameter marker segments
*/
protected Collection<ParameterMarkerSegment> popAllStatementParameterMarkerSegments() {
Collection<ParameterMarkerSegment> result = new LinkedList<>(statementParameterMarkerSegments);
statementParameterMarkerSegments.clear();
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.apache.shardingsphere.sql.parser.api.ASTNode;
import org.apache.shardingsphere.sql.parser.api.visitor.statement.type.DALStatementVisitor;
import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.AlterResourceCostContext;
import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.ExecuteContext;
import org.apache.shardingsphere.sql.parser.autogen.OracleStatementParser.ExplainContext;
import org.apache.shardingsphere.sql.parser.oracle.visitor.statement.OracleStatementVisitor;
import org.apache.shardingsphere.sql.parser.sql.common.statement.SQLStatement;
Expand All @@ -40,7 +41,8 @@ public ASTNode visitAlterResourceCost(final AlterResourceCostContext ctx) {
public ASTNode visitExplain(final ExplainContext ctx) {
OracleExplainStatement result = new OracleExplainStatement();
OracleDMLStatementVisitor visitor = new OracleDMLStatementVisitor();
visitor.getParameterMarkerSegments().addAll(getParameterMarkerSegments());
getGlobalParameterMarkerSegments().addAll(visitor.getGlobalParameterMarkerSegments());
getStatementParameterMarkerSegments().addAll(visitor.getStatementParameterMarkerSegments());
if (null != ctx.insert()) {
result.setStatement((SQLStatement) visitor.visit(ctx.insert()));
} else if (null != ctx.delete()) {
Expand All @@ -50,7 +52,7 @@ public ASTNode visitExplain(final ExplainContext ctx) {
} else if (null != ctx.select()) {
result.setStatement((SQLStatement) visitor.visit(ctx.select()));
}
result.addParameterMarkerSegments(getParameterMarkerSegments());
result.addParameterMarkerSegments(ctx.getParent() instanceof ExecuteContext ? getGlobalParameterMarkerSegments() : popAllStatementParameterMarkerSegments());
return result;
}
}
Loading