Skip to content

Commit

Permalink
Refactor BaseDMLE2EIT and insert select statement parse logic (#28457)
Browse files Browse the repository at this point in the history
  • Loading branch information
strongduanmu authored Sep 18, 2023
1 parent 9b4939b commit 3cbd841
Show file tree
Hide file tree
Showing 15 changed files with 61 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ public boolean isGenerateSQLToken(final SQLStatementContext sqlStatementContext)

@Override
public UseDefaultInsertColumnsToken generateSQLToken(final InsertStatementContext insertStatementContext) {
String tableName = insertStatementContext.getSqlStatement().getTable().getTableName().getIdentifier().getValue();
String tableName = Optional.ofNullable(insertStatementContext.getSqlStatement().getTable()).map(optional -> optional.getTableName().getIdentifier().getValue()).orElse("");
Optional<UseDefaultInsertColumnsToken> previousSQLToken = findInsertColumnsToken();
if (previousSQLToken.isPresent()) {
processPreviousSQLToken(previousSQLToken.get(), insertStatementContext, tableName);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public final class GeneratedKeyContextEngine {
* @return generate key context
*/
public Optional<GeneratedKeyContext> createGenerateKeyContext(final List<String> insertColumnNames, final List<List<ExpressionSegment>> valueExpressions, final List<Object> params) {
String tableName = insertStatement.getTable().getTableName().getIdentifier().getValue();
String tableName = Optional.ofNullable(insertStatement.getTable()).map(optional -> optional.getTableName().getIdentifier().getValue()).orElse("");
return findGenerateKeyColumn(tableName).map(optional -> containsGenerateKey(insertColumnNames, optional)
? findGeneratedKey(insertColumnNames, valueExpressions, params, optional)
: new GeneratedKeyContext(optional, true));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ public InsertStatementContext(final ShardingSphereMetaData metaData, final List<
onDuplicateKeyUpdateValueContext = getOnDuplicateKeyUpdateValueContext(params, parametersOffset).orElse(null);
tablesContext = new TablesContext(getAllSimpleTableSegments(), getDatabaseType());
ShardingSphereSchema schema = getSchema(metaData, defaultDatabaseName);
columnNames = containsInsertColumns() ? insertColumnNames : schema.getVisibleColumnNames(sqlStatement.getTable().getTableName().getIdentifier().getValue().toLowerCase());
columnNames = containsInsertColumns() ? insertColumnNames
: Optional.ofNullable(sqlStatement.getTable()).map(optional -> schema.getVisibleColumnNames(optional.getTableName().getIdentifier().getValue())).orElseGet(Collections::emptyList);
generatedKeyContext = new GeneratedKeyContextEngine(sqlStatement, schema).createGenerateKeyContext(insertColumnNames, getAllValueExpressions(sqlStatement), params).orElse(null);
}

Expand Down Expand Up @@ -166,7 +167,7 @@ public List<List<Object>> getGroupedParameters() {
for (InsertValueContext each : insertValueContexts) {
result.add(each.getParameters());
}
if (null != insertSelectContext) {
if (null != insertSelectContext && !insertSelectContext.getParameters().isEmpty()) {
result.add(insertSelectContext.getParameters());
}
return result;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.Map;
import java.util.Optional;

/**
* Select statement binder.
Expand All @@ -54,7 +55,7 @@ private InsertStatement bind(final InsertStatement sqlStatement, final ShardingS
SQLStatementBinderContext statementBinderContext = new SQLStatementBinderContext(metaData, defaultDatabaseName, sqlStatement.getDatabaseType(), sqlStatement.getVariableNames());
statementBinderContext.getExternalTableBinderContexts().putAll(externalTableBinderContexts);
Map<String, TableSegmentBinderContext> tableBinderContexts = new LinkedHashMap<>();
result.setTable(SimpleTableSegmentBinder.bind(sqlStatement.getTable(), statementBinderContext, tableBinderContexts));
Optional.ofNullable(sqlStatement.getTable()).ifPresent(optional -> result.setTable(SimpleTableSegmentBinder.bind(optional, statementBinderContext, tableBinderContexts)));
if (sqlStatement.getInsertColumns().isPresent() && !sqlStatement.getInsertColumns().get().getColumns().isEmpty()) {
result.setInsertColumns(InsertColumnsSegmentBinder.bind(sqlStatement.getInsertColumns().get(), statementBinderContext, tableBinderContexts));
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.shardingsphere.infra.metadata.database.resource.ResourceMetaData;
import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData;
import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema;
import org.apache.shardingsphere.sql.parser.sql.common.enums.ParameterMarkerType;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.AssignmentSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.ColumnAssignmentSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.InsertValuesSegment;
Expand Down Expand Up @@ -161,15 +162,16 @@ void assertGetGroupedParametersWithOnDuplicateParameters() {
void assertInsertSelect() {
InsertStatement insertStatement = new MySQLInsertStatement();
SelectStatement selectStatement = new MySQLSelectStatement();
selectStatement.addParameterMarkerSegments(Collections.singleton(new ParameterMarkerExpressionSegment(0, 0, 0, ParameterMarkerType.QUESTION)));
selectStatement.setProjections(new ProjectionsSegment(0, 0));
SubquerySegment insertSelect = new SubquerySegment(0, 0, selectStatement);
insertStatement.setInsertSelect(insertSelect);
insertStatement.setTable(new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("tbl"))));
InsertStatementContext actual = createInsertStatementContext(Collections.singletonList("param"), insertStatement);
actual.setUpParameters(Collections.singletonList("param"));
assertThat(actual.getInsertSelectContext().getParameterCount(), is(0));
assertThat(actual.getInsertSelectContext().getParameterCount(), is(1));
assertThat(actual.getGroupedParameters().size(), is(1));
assertThat(actual.getGroupedParameters().iterator().next(), is(Collections.emptyList()));
assertThat(actual.getGroupedParameters().iterator().next(), is(Collections.singletonList("param")));
}

private void setUpInsertValues(final InsertStatement insertStatement) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,18 @@ public SQLRewriteContext(final ShardingSphereDatabase database, final SQLStateme
if (!hintValueContext.isSkipSQLRewrite()) {
addSQLTokenGenerators(new DefaultTokenGeneratorBuilder(sqlStatementContext).getSQLTokenGenerators());
}
parameterBuilder = sqlStatementContext instanceof InsertStatementContext && null == ((InsertStatementContext) sqlStatementContext).getInsertSelectContext()
? new GroupedParameterBuilder(
((InsertStatementContext) sqlStatementContext).getGroupedParameters(), ((InsertStatementContext) sqlStatementContext).getOnDuplicateKeyUpdateParameters())
parameterBuilder = containsInsertValues(sqlStatementContext)
? new GroupedParameterBuilder(((InsertStatementContext) sqlStatementContext).getGroupedParameters(), ((InsertStatementContext) sqlStatementContext).getOnDuplicateKeyUpdateParameters())
: new StandardParameterBuilder(params);
}

private boolean containsInsertValues(final SQLStatementContext sqlStatementContext) {
if (!(sqlStatementContext instanceof InsertStatementContext)) {
return false;
}
return null == ((InsertStatementContext) sqlStatementContext).getInsertSelectContext();
}

/**
* Add SQL token generators.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ void assertRewriteWithStandardParameterBuilderWhenNeedAggregateRewrite() {
void assertRewriteWithGroupedParameterBuilderForBroadcast() {
InsertStatementContext statementContext = mock(InsertStatementContext.class, RETURNS_DEEP_STUBS);
when(((TableAvailable) statementContext).getTablesContext().getDatabaseName().isPresent()).thenReturn(false);
when(statementContext.getInsertSelectContext()).thenReturn(null);
when(statementContext.getGroupedParameters()).thenReturn(Collections.singletonList(Collections.singletonList(1)));
when(statementContext.getOnDuplicateKeyUpdateParameters()).thenReturn(Collections.emptyList());
SQLRewriteContext sqlRewriteContext =
new SQLRewriteContext(mockDatabase(), statementContext, "INSERT INTO tbl VALUES (?)", Collections.singletonList(1), mock(ConnectionContext.class), new HintValueContext());
RouteUnit routeUnit = new RouteUnit(new RouteMapper("ds", "ds_0"), Collections.singletonList(new RouteMapper("tbl", "tbl_0")));
Expand All @@ -107,7 +109,9 @@ void assertRewriteWithGroupedParameterBuilderForBroadcast() {
void assertRewriteWithGroupedParameterBuilderForRouteWithSameDataNode() {
InsertStatementContext statementContext = mock(InsertStatementContext.class, RETURNS_DEEP_STUBS);
when(((TableAvailable) statementContext).getTablesContext().getDatabaseName().isPresent()).thenReturn(false);
when(statementContext.getInsertSelectContext()).thenReturn(null);
when(statementContext.getGroupedParameters()).thenReturn(Collections.singletonList(Collections.singletonList(1)));
when(statementContext.getOnDuplicateKeyUpdateParameters()).thenReturn(Collections.emptyList());
SQLRewriteContext sqlRewriteContext =
new SQLRewriteContext(mockDatabase(), statementContext, "INSERT INTO tbl VALUES (?)", Collections.singletonList(1), mock(ConnectionContext.class), new HintValueContext());
RouteUnit routeUnit = new RouteUnit(new RouteMapper("ds", "ds_0"), Collections.singletonList(new RouteMapper("tbl", "tbl_0")));
Expand All @@ -127,7 +131,9 @@ void assertRewriteWithGroupedParameterBuilderForRouteWithSameDataNode() {
void assertRewriteWithGroupedParameterBuilderForRouteWithEmptyDataNode() {
InsertStatementContext statementContext = mock(InsertStatementContext.class, RETURNS_DEEP_STUBS);
when(((TableAvailable) statementContext).getTablesContext().getDatabaseName().isPresent()).thenReturn(false);
when(statementContext.getInsertSelectContext()).thenReturn(null);
when(statementContext.getGroupedParameters()).thenReturn(Collections.singletonList(Collections.singletonList(1)));
when(statementContext.getOnDuplicateKeyUpdateParameters()).thenReturn(Collections.emptyList());
SQLRewriteContext sqlRewriteContext =
new SQLRewriteContext(mockDatabase(), statementContext, "INSERT INTO tbl VALUES (?)", Collections.singletonList(1), mock(ConnectionContext.class), new HintValueContext());
RouteUnit routeUnit = new RouteUnit(new RouteMapper("ds", "ds_0"), Collections.singletonList(new RouteMapper("tbl", "tbl_0")));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,9 @@
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ListExpression;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.MatchAgainstExpression;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.NotExpression;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ValuesExpression;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.RowExpression;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.UnaryOperationExpression;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ValuesExpression;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.complex.CommonExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.LiteralExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.ParameterMarkerExpressionSegment;
Expand Down Expand Up @@ -1361,6 +1361,7 @@ public ASTNode visitInsert(final InsertContext ctx) {
@Override
public ASTNode visitInsertSelectClause(final InsertSelectClauseContext ctx) {
MySQLInsertStatement result = new MySQLInsertStatement();
result.setInsertSelect(createInsertSelectSegment(ctx));
if (null != ctx.LP_()) {
if (null != ctx.fields()) {
result.setInsertColumns(new InsertColumnsSegment(ctx.LP_().getSymbol().getStartIndex(), ctx.RP_().getSymbol().getStopIndex(), createInsertColumns(ctx.fields())));
Expand All @@ -1370,12 +1371,12 @@ public ASTNode visitInsertSelectClause(final InsertSelectClauseContext ctx) {
} else {
result.setInsertColumns(new InsertColumnsSegment(ctx.start.getStartIndex() - 1, ctx.start.getStartIndex() - 1, Collections.emptyList()));
}
result.setInsertSelect(createInsertSelectSegment(ctx));
return result;
}

private SubquerySegment createInsertSelectSegment(final InsertSelectClauseContext ctx) {
MySQLSelectStatement selectStatement = (MySQLSelectStatement) visit(ctx.select());
selectStatement.getParameterMarkerSegments().addAll(getParameterMarkerSegments());
return new SubquerySegment(ctx.select().start.getStartIndex(), ctx.select().stop.getStopIndex(), selectStatement);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -751,6 +751,13 @@ public ASTNode visitQualifiedName(final QualifiedNameContext ctx) {
@Override
public ASTNode visitInsertRest(final InsertRestContext ctx) {
OpenGaussInsertStatement result = new OpenGaussInsertStatement();
ValuesClauseContext valuesClause = ctx.select().selectNoParens().selectClauseN().simpleSelect().valuesClause();
if (null == valuesClause) {
OpenGaussSelectStatement selectStatement = (OpenGaussSelectStatement) visit(ctx.select());
result.setInsertSelect(new SubquerySegment(ctx.select().start.getStartIndex(), ctx.select().stop.getStopIndex(), selectStatement));
} else {
result.getValues().addAll(createInsertValuesSegments(valuesClause));
}
if (null == ctx.insertColumnList()) {
result.setInsertColumns(new InsertColumnsSegment(ctx.start.getStartIndex() - 1, ctx.start.getStartIndex() - 1, Collections.emptyList()));
} else {
Expand All @@ -759,13 +766,6 @@ public ASTNode visitInsertRest(final InsertRestContext ctx) {
InsertColumnsSegment insertColumnsSegment = new InsertColumnsSegment(insertColumns.start.getStartIndex() - 1, insertColumns.stop.getStopIndex() + 1, columns.getValue());
result.setInsertColumns(insertColumnsSegment);
}
ValuesClauseContext valuesClause = ctx.select().selectNoParens().selectClauseN().simpleSelect().valuesClause();
if (null == valuesClause) {
OpenGaussSelectStatement selectStatement = (OpenGaussSelectStatement) visit(ctx.select());
result.setInsertSelect(new SubquerySegment(ctx.select().start.getStartIndex(), ctx.select().stop.getStopIndex(), selectStatement));
} else {
result.getValues().addAll(createInsertValuesSegments(valuesClause));
}
return result;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,7 @@ private Collection<InsertValuesSegment> createInsertValuesSegments(final Assignm
@Override
public ASTNode visitInsertMultiTable(final InsertMultiTableContext ctx) {
OracleInsertStatement result = new OracleInsertStatement();
result.setInsertSelect(new SubquerySegment(ctx.selectSubquery().start.getStartIndex(), ctx.selectSubquery().stop.getStopIndex(), (OracleSelectStatement) visit(ctx.selectSubquery())));
result.setMultiTableInsertType(null != ctx.conditionalInsertClause() && null != ctx.conditionalInsertClause().FIRST() ? MultiTableInsertType.FIRST : MultiTableInsertType.ALL);
List<MultiTableElementContext> multiTableElementContexts = ctx.multiTableElement();
if (null != multiTableElementContexts && !multiTableElementContexts.isEmpty()) {
Expand All @@ -336,9 +337,6 @@ public ASTNode visitInsertMultiTable(final InsertMultiTableContext ctx) {
} else {
result.setMultiTableConditionalIntoSegment((MultiTableConditionalIntoSegment) visit(ctx.conditionalInsertClause()));
}
OracleSelectStatement subquery = (OracleSelectStatement) visit(ctx.selectSubquery());
SubquerySegment subquerySegment = new SubquerySegment(ctx.selectSubquery().start.getStartIndex(), ctx.selectSubquery().stop.getStopIndex(), subquery);
result.setInsertSelect(subquerySegment);
result.addParameterMarkerSegments(getParameterMarkerSegments());
return result;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,13 @@ public ASTNode visitQualifiedName(final QualifiedNameContext ctx) {
@Override
public ASTNode visitInsertRest(final InsertRestContext ctx) {
PostgreSQLInsertStatement result = new PostgreSQLInsertStatement();
ValuesClauseContext valuesClause = ctx.select().selectNoParens().selectClauseN().simpleSelect().valuesClause();
if (null == valuesClause) {
PostgreSQLSelectStatement selectStatement = (PostgreSQLSelectStatement) visit(ctx.select());
result.setInsertSelect(new SubquerySegment(ctx.select().start.getStartIndex(), ctx.select().stop.getStopIndex(), selectStatement));
} else {
result.getValues().addAll(createInsertValuesSegments(valuesClause));
}
if (null == ctx.insertColumnList()) {
result.setInsertColumns(new InsertColumnsSegment(ctx.start.getStartIndex() - 1, ctx.start.getStartIndex() - 1, Collections.emptyList()));
} else {
Expand All @@ -764,13 +771,6 @@ public ASTNode visitInsertRest(final InsertRestContext ctx) {
InsertColumnsSegment insertColumnsSegment = new InsertColumnsSegment(insertColumns.start.getStartIndex() - 1, insertColumns.stop.getStopIndex() + 1, columns.getValue());
result.setInsertColumns(insertColumnsSegment);
}
ValuesClauseContext valuesClause = ctx.select().selectNoParens().selectClauseN().simpleSelect().valuesClause();
if (null == valuesClause) {
PostgreSQLSelectStatement selectStatement = (PostgreSQLSelectStatement) visit(ctx.select());
result.setInsertSelect(new SubquerySegment(ctx.select().start.getStartIndex(), ctx.select().stop.getStopIndex(), selectStatement));
} else {
result.getValues().addAll(createInsertValuesSegments(valuesClause));
}
return result;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,13 @@ void tearDown() {
}

protected final void assertDataSet(final AssertionTestParameter testParam, final SingleE2EContainerComposer containerComposer, final int actualUpdateCount) throws SQLException {
assertThat("Only support single table for DML.", containerComposer.getDataSet().getMetaDataList().size(), is(1));
assertThat(actualUpdateCount, is(containerComposer.getDataSet().getUpdateCount()));
DataSetMetaData expectedDataSetMetaData = containerComposer.getDataSet().getMetaDataList().get(0);
for (DataSetMetaData each : containerComposer.getDataSet().getMetaDataList()) {
assertDataSet(testParam, containerComposer, each);
}
}

private void assertDataSet(final AssertionTestParameter testParam, final SingleE2EContainerComposer containerComposer, final DataSetMetaData expectedDataSetMetaData) throws SQLException {
for (String each : InlineExpressionParserFactory.newInstance().splitAndEvaluate(expectedDataSetMetaData.getDataNodes())) {
DataNode dataNode = new DataNode(each);
DataSource dataSource = containerComposer.getActualDataSourceMap().get(dataNode.getDataSourceName());
Expand Down
Loading

0 comments on commit 3cbd841

Please sign in to comment.