From 3cbd8410aaa362a790a7dfb1bb479dd02e66ed4b Mon Sep 17 00:00:00 2001 From: Zhengqiang Duan Date: Mon, 18 Sep 2023 18:07:38 +0800 Subject: [PATCH] Refactor BaseDMLE2EIT and insert select statement parse logic (#28457) --- .../EncryptInsertDefaultColumnsTokenGenerator.java | 2 +- .../keygen/engine/GeneratedKeyContextEngine.java | 2 +- .../statement/dml/InsertStatementContext.java | 5 +++-- .../statement/dml/InsertStatementBinder.java | 3 ++- .../statement/dml/InsertStatementContextTest.java | 6 ++++-- .../infra/rewrite/context/SQLRewriteContext.java | 12 +++++++++--- .../rewrite/engine/RouteSQLRewriteEngineTest.java | 6 ++++++ .../visitor/statement/MySQLStatementVisitor.java | 5 +++-- .../statement/OpenGaussStatementVisitor.java | 14 +++++++------- .../statement/type/OracleDMLStatementVisitor.java | 4 +--- .../statement/PostgreSQLStatementVisitor.java | 14 +++++++------- .../test/e2e/engine/type/dml/BaseDMLE2EIT.java | 8 ++++++-- .../statement/dml/impl/InsertStatementAssert.java | 2 ++ .../parser/src/main/resources/case/dml/insert.xml | 10 +++++----- .../parser/src/main/resources/case/dml/replace.xml | 8 ++++---- 15 files changed, 61 insertions(+), 40 deletions(-) diff --git a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/insert/EncryptInsertDefaultColumnsTokenGenerator.java b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/insert/EncryptInsertDefaultColumnsTokenGenerator.java index eccaed000070a..c8e8e8e4ea447 100644 --- a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/insert/EncryptInsertDefaultColumnsTokenGenerator.java +++ b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/insert/EncryptInsertDefaultColumnsTokenGenerator.java @@ -59,7 +59,7 @@ public boolean isGenerateSQLToken(final SQLStatementContext sqlStatementContext) @Override public UseDefaultInsertColumnsToken generateSQLToken(final InsertStatementContext insertStatementContext) { - String tableName = insertStatementContext.getSqlStatement().getTable().getTableName().getIdentifier().getValue(); + String tableName = Optional.ofNullable(insertStatementContext.getSqlStatement().getTable()).map(optional -> optional.getTableName().getIdentifier().getValue()).orElse(""); Optional previousSQLToken = findInsertColumnsToken(); if (previousSQLToken.isPresent()) { processPreviousSQLToken(previousSQLToken.get(), insertStatementContext, tableName); diff --git a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/segment/insert/keygen/engine/GeneratedKeyContextEngine.java b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/segment/insert/keygen/engine/GeneratedKeyContextEngine.java index edcb1764881cf..fa7921f4cdc29 100644 --- a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/segment/insert/keygen/engine/GeneratedKeyContextEngine.java +++ b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/segment/insert/keygen/engine/GeneratedKeyContextEngine.java @@ -52,7 +52,7 @@ public final class GeneratedKeyContextEngine { * @return generate key context */ public Optional createGenerateKeyContext(final List insertColumnNames, final List> valueExpressions, final List params) { - String tableName = insertStatement.getTable().getTableName().getIdentifier().getValue(); + String tableName = Optional.ofNullable(insertStatement.getTable()).map(optional -> optional.getTableName().getIdentifier().getValue()).orElse(""); return findGenerateKeyColumn(tableName).map(optional -> containsGenerateKey(insertColumnNames, optional) ? findGeneratedKey(insertColumnNames, valueExpressions, params, optional) : new GeneratedKeyContext(optional, true)); diff --git a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/statement/dml/InsertStatementContext.java b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/statement/dml/InsertStatementContext.java index fb5a29bb906ae..2099a80a8f22d 100644 --- a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/statement/dml/InsertStatementContext.java +++ b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/statement/dml/InsertStatementContext.java @@ -94,7 +94,8 @@ public InsertStatementContext(final ShardingSphereMetaData metaData, final List< onDuplicateKeyUpdateValueContext = getOnDuplicateKeyUpdateValueContext(params, parametersOffset).orElse(null); tablesContext = new TablesContext(getAllSimpleTableSegments(), getDatabaseType()); ShardingSphereSchema schema = getSchema(metaData, defaultDatabaseName); - columnNames = containsInsertColumns() ? insertColumnNames : schema.getVisibleColumnNames(sqlStatement.getTable().getTableName().getIdentifier().getValue().toLowerCase()); + columnNames = containsInsertColumns() ? insertColumnNames + : Optional.ofNullable(sqlStatement.getTable()).map(optional -> schema.getVisibleColumnNames(optional.getTableName().getIdentifier().getValue())).orElseGet(Collections::emptyList); generatedKeyContext = new GeneratedKeyContextEngine(sqlStatement, schema).createGenerateKeyContext(insertColumnNames, getAllValueExpressions(sqlStatement), params).orElse(null); } @@ -166,7 +167,7 @@ public List> getGroupedParameters() { for (InsertValueContext each : insertValueContexts) { result.add(each.getParameters()); } - if (null != insertSelectContext) { + if (null != insertSelectContext && !insertSelectContext.getParameters().isEmpty()) { result.add(insertSelectContext.getParameters()); } return result; diff --git a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/statement/dml/InsertStatementBinder.java b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/statement/dml/InsertStatementBinder.java index 8d7f8eedf0a3c..efdce0773da4b 100644 --- a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/statement/dml/InsertStatementBinder.java +++ b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/statement/dml/InsertStatementBinder.java @@ -36,6 +36,7 @@ import java.util.LinkedHashMap; import java.util.LinkedList; import java.util.Map; +import java.util.Optional; /** * Select statement binder. @@ -54,7 +55,7 @@ private InsertStatement bind(final InsertStatement sqlStatement, final ShardingS SQLStatementBinderContext statementBinderContext = new SQLStatementBinderContext(metaData, defaultDatabaseName, sqlStatement.getDatabaseType(), sqlStatement.getVariableNames()); statementBinderContext.getExternalTableBinderContexts().putAll(externalTableBinderContexts); Map tableBinderContexts = new LinkedHashMap<>(); - result.setTable(SimpleTableSegmentBinder.bind(sqlStatement.getTable(), statementBinderContext, tableBinderContexts)); + Optional.ofNullable(sqlStatement.getTable()).ifPresent(optional -> result.setTable(SimpleTableSegmentBinder.bind(optional, statementBinderContext, tableBinderContexts))); if (sqlStatement.getInsertColumns().isPresent() && !sqlStatement.getInsertColumns().get().getColumns().isEmpty()) { result.setInsertColumns(InsertColumnsSegmentBinder.bind(sqlStatement.getInsertColumns().get(), statementBinderContext, tableBinderContexts)); } else { diff --git a/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/context/statement/dml/InsertStatementContextTest.java b/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/context/statement/dml/InsertStatementContextTest.java index 71787a68a5b63..0d34002d2ec96 100644 --- a/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/context/statement/dml/InsertStatementContextTest.java +++ b/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/context/statement/dml/InsertStatementContextTest.java @@ -24,6 +24,7 @@ import org.apache.shardingsphere.infra.metadata.database.resource.ResourceMetaData; import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData; import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema; +import org.apache.shardingsphere.sql.parser.sql.common.enums.ParameterMarkerType; import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.AssignmentSegment; import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.ColumnAssignmentSegment; import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.InsertValuesSegment; @@ -161,15 +162,16 @@ void assertGetGroupedParametersWithOnDuplicateParameters() { void assertInsertSelect() { InsertStatement insertStatement = new MySQLInsertStatement(); SelectStatement selectStatement = new MySQLSelectStatement(); + selectStatement.addParameterMarkerSegments(Collections.singleton(new ParameterMarkerExpressionSegment(0, 0, 0, ParameterMarkerType.QUESTION))); selectStatement.setProjections(new ProjectionsSegment(0, 0)); SubquerySegment insertSelect = new SubquerySegment(0, 0, selectStatement); insertStatement.setInsertSelect(insertSelect); insertStatement.setTable(new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("tbl")))); InsertStatementContext actual = createInsertStatementContext(Collections.singletonList("param"), insertStatement); actual.setUpParameters(Collections.singletonList("param")); - assertThat(actual.getInsertSelectContext().getParameterCount(), is(0)); + assertThat(actual.getInsertSelectContext().getParameterCount(), is(1)); assertThat(actual.getGroupedParameters().size(), is(1)); - assertThat(actual.getGroupedParameters().iterator().next(), is(Collections.emptyList())); + assertThat(actual.getGroupedParameters().iterator().next(), is(Collections.singletonList("param"))); } private void setUpInsertValues(final InsertStatement insertStatement) { diff --git a/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/context/SQLRewriteContext.java b/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/context/SQLRewriteContext.java index 3ea33d8eddeab..7625e227500a1 100644 --- a/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/context/SQLRewriteContext.java +++ b/infra/rewrite/src/main/java/org/apache/shardingsphere/infra/rewrite/context/SQLRewriteContext.java @@ -69,12 +69,18 @@ public SQLRewriteContext(final ShardingSphereDatabase database, final SQLStateme if (!hintValueContext.isSkipSQLRewrite()) { addSQLTokenGenerators(new DefaultTokenGeneratorBuilder(sqlStatementContext).getSQLTokenGenerators()); } - parameterBuilder = sqlStatementContext instanceof InsertStatementContext && null == ((InsertStatementContext) sqlStatementContext).getInsertSelectContext() - ? new GroupedParameterBuilder( - ((InsertStatementContext) sqlStatementContext).getGroupedParameters(), ((InsertStatementContext) sqlStatementContext).getOnDuplicateKeyUpdateParameters()) + parameterBuilder = containsInsertValues(sqlStatementContext) + ? new GroupedParameterBuilder(((InsertStatementContext) sqlStatementContext).getGroupedParameters(), ((InsertStatementContext) sqlStatementContext).getOnDuplicateKeyUpdateParameters()) : new StandardParameterBuilder(params); } + private boolean containsInsertValues(final SQLStatementContext sqlStatementContext) { + if (!(sqlStatementContext instanceof InsertStatementContext)) { + return false; + } + return null == ((InsertStatementContext) sqlStatementContext).getInsertSelectContext(); + } + /** * Add SQL token generators. * diff --git a/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/RouteSQLRewriteEngineTest.java b/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/RouteSQLRewriteEngineTest.java index d1b6b6be82244..9cb4985b2be8d 100644 --- a/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/RouteSQLRewriteEngineTest.java +++ b/infra/rewrite/src/test/java/org/apache/shardingsphere/infra/rewrite/engine/RouteSQLRewriteEngineTest.java @@ -89,7 +89,9 @@ void assertRewriteWithStandardParameterBuilderWhenNeedAggregateRewrite() { void assertRewriteWithGroupedParameterBuilderForBroadcast() { InsertStatementContext statementContext = mock(InsertStatementContext.class, RETURNS_DEEP_STUBS); when(((TableAvailable) statementContext).getTablesContext().getDatabaseName().isPresent()).thenReturn(false); + when(statementContext.getInsertSelectContext()).thenReturn(null); when(statementContext.getGroupedParameters()).thenReturn(Collections.singletonList(Collections.singletonList(1))); + when(statementContext.getOnDuplicateKeyUpdateParameters()).thenReturn(Collections.emptyList()); SQLRewriteContext sqlRewriteContext = new SQLRewriteContext(mockDatabase(), statementContext, "INSERT INTO tbl VALUES (?)", Collections.singletonList(1), mock(ConnectionContext.class), new HintValueContext()); RouteUnit routeUnit = new RouteUnit(new RouteMapper("ds", "ds_0"), Collections.singletonList(new RouteMapper("tbl", "tbl_0"))); @@ -107,7 +109,9 @@ void assertRewriteWithGroupedParameterBuilderForBroadcast() { void assertRewriteWithGroupedParameterBuilderForRouteWithSameDataNode() { InsertStatementContext statementContext = mock(InsertStatementContext.class, RETURNS_DEEP_STUBS); when(((TableAvailable) statementContext).getTablesContext().getDatabaseName().isPresent()).thenReturn(false); + when(statementContext.getInsertSelectContext()).thenReturn(null); when(statementContext.getGroupedParameters()).thenReturn(Collections.singletonList(Collections.singletonList(1))); + when(statementContext.getOnDuplicateKeyUpdateParameters()).thenReturn(Collections.emptyList()); SQLRewriteContext sqlRewriteContext = new SQLRewriteContext(mockDatabase(), statementContext, "INSERT INTO tbl VALUES (?)", Collections.singletonList(1), mock(ConnectionContext.class), new HintValueContext()); RouteUnit routeUnit = new RouteUnit(new RouteMapper("ds", "ds_0"), Collections.singletonList(new RouteMapper("tbl", "tbl_0"))); @@ -127,7 +131,9 @@ void assertRewriteWithGroupedParameterBuilderForRouteWithSameDataNode() { void assertRewriteWithGroupedParameterBuilderForRouteWithEmptyDataNode() { InsertStatementContext statementContext = mock(InsertStatementContext.class, RETURNS_DEEP_STUBS); when(((TableAvailable) statementContext).getTablesContext().getDatabaseName().isPresent()).thenReturn(false); + when(statementContext.getInsertSelectContext()).thenReturn(null); when(statementContext.getGroupedParameters()).thenReturn(Collections.singletonList(Collections.singletonList(1))); + when(statementContext.getOnDuplicateKeyUpdateParameters()).thenReturn(Collections.emptyList()); SQLRewriteContext sqlRewriteContext = new SQLRewriteContext(mockDatabase(), statementContext, "INSERT INTO tbl VALUES (?)", Collections.singletonList(1), mock(ConnectionContext.class), new HintValueContext()); RouteUnit routeUnit = new RouteUnit(new RouteMapper("ds", "ds_0"), Collections.singletonList(new RouteMapper("tbl", "tbl_0"))); diff --git a/parser/sql/dialect/mysql/src/main/java/org/apache/shardingsphere/sql/parser/mysql/visitor/statement/MySQLStatementVisitor.java b/parser/sql/dialect/mysql/src/main/java/org/apache/shardingsphere/sql/parser/mysql/visitor/statement/MySQLStatementVisitor.java index fccea89e6fe3f..ae5ce5b7a28d8 100644 --- a/parser/sql/dialect/mysql/src/main/java/org/apache/shardingsphere/sql/parser/mysql/visitor/statement/MySQLStatementVisitor.java +++ b/parser/sql/dialect/mysql/src/main/java/org/apache/shardingsphere/sql/parser/mysql/visitor/statement/MySQLStatementVisitor.java @@ -177,9 +177,9 @@ import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ListExpression; import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.MatchAgainstExpression; import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.NotExpression; -import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ValuesExpression; import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.RowExpression; import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.UnaryOperationExpression; +import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ValuesExpression; import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.complex.CommonExpressionSegment; import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.LiteralExpressionSegment; import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.ParameterMarkerExpressionSegment; @@ -1361,6 +1361,7 @@ public ASTNode visitInsert(final InsertContext ctx) { @Override public ASTNode visitInsertSelectClause(final InsertSelectClauseContext ctx) { MySQLInsertStatement result = new MySQLInsertStatement(); + result.setInsertSelect(createInsertSelectSegment(ctx)); if (null != ctx.LP_()) { if (null != ctx.fields()) { result.setInsertColumns(new InsertColumnsSegment(ctx.LP_().getSymbol().getStartIndex(), ctx.RP_().getSymbol().getStopIndex(), createInsertColumns(ctx.fields()))); @@ -1370,12 +1371,12 @@ public ASTNode visitInsertSelectClause(final InsertSelectClauseContext ctx) { } else { result.setInsertColumns(new InsertColumnsSegment(ctx.start.getStartIndex() - 1, ctx.start.getStartIndex() - 1, Collections.emptyList())); } - result.setInsertSelect(createInsertSelectSegment(ctx)); return result; } private SubquerySegment createInsertSelectSegment(final InsertSelectClauseContext ctx) { MySQLSelectStatement selectStatement = (MySQLSelectStatement) visit(ctx.select()); + selectStatement.getParameterMarkerSegments().addAll(getParameterMarkerSegments()); return new SubquerySegment(ctx.select().start.getStartIndex(), ctx.select().stop.getStopIndex(), selectStatement); } diff --git a/parser/sql/dialect/opengauss/src/main/java/org/apache/shardingsphere/sql/parser/opengauss/visitor/statement/OpenGaussStatementVisitor.java b/parser/sql/dialect/opengauss/src/main/java/org/apache/shardingsphere/sql/parser/opengauss/visitor/statement/OpenGaussStatementVisitor.java index a06f67b24ec10..b185b4c3a9856 100644 --- a/parser/sql/dialect/opengauss/src/main/java/org/apache/shardingsphere/sql/parser/opengauss/visitor/statement/OpenGaussStatementVisitor.java +++ b/parser/sql/dialect/opengauss/src/main/java/org/apache/shardingsphere/sql/parser/opengauss/visitor/statement/OpenGaussStatementVisitor.java @@ -751,6 +751,13 @@ public ASTNode visitQualifiedName(final QualifiedNameContext ctx) { @Override public ASTNode visitInsertRest(final InsertRestContext ctx) { OpenGaussInsertStatement result = new OpenGaussInsertStatement(); + ValuesClauseContext valuesClause = ctx.select().selectNoParens().selectClauseN().simpleSelect().valuesClause(); + if (null == valuesClause) { + OpenGaussSelectStatement selectStatement = (OpenGaussSelectStatement) visit(ctx.select()); + result.setInsertSelect(new SubquerySegment(ctx.select().start.getStartIndex(), ctx.select().stop.getStopIndex(), selectStatement)); + } else { + result.getValues().addAll(createInsertValuesSegments(valuesClause)); + } if (null == ctx.insertColumnList()) { result.setInsertColumns(new InsertColumnsSegment(ctx.start.getStartIndex() - 1, ctx.start.getStartIndex() - 1, Collections.emptyList())); } else { @@ -759,13 +766,6 @@ public ASTNode visitInsertRest(final InsertRestContext ctx) { InsertColumnsSegment insertColumnsSegment = new InsertColumnsSegment(insertColumns.start.getStartIndex() - 1, insertColumns.stop.getStopIndex() + 1, columns.getValue()); result.setInsertColumns(insertColumnsSegment); } - ValuesClauseContext valuesClause = ctx.select().selectNoParens().selectClauseN().simpleSelect().valuesClause(); - if (null == valuesClause) { - OpenGaussSelectStatement selectStatement = (OpenGaussSelectStatement) visit(ctx.select()); - result.setInsertSelect(new SubquerySegment(ctx.select().start.getStartIndex(), ctx.select().stop.getStopIndex(), selectStatement)); - } else { - result.getValues().addAll(createInsertValuesSegments(valuesClause)); - } return result; } diff --git a/parser/sql/dialect/oracle/src/main/java/org/apache/shardingsphere/sql/parser/oracle/visitor/statement/type/OracleDMLStatementVisitor.java b/parser/sql/dialect/oracle/src/main/java/org/apache/shardingsphere/sql/parser/oracle/visitor/statement/type/OracleDMLStatementVisitor.java index b35aa4e345c81..39c6223ae87b0 100644 --- a/parser/sql/dialect/oracle/src/main/java/org/apache/shardingsphere/sql/parser/oracle/visitor/statement/type/OracleDMLStatementVisitor.java +++ b/parser/sql/dialect/oracle/src/main/java/org/apache/shardingsphere/sql/parser/oracle/visitor/statement/type/OracleDMLStatementVisitor.java @@ -326,6 +326,7 @@ private Collection createInsertValuesSegments(final Assignm @Override public ASTNode visitInsertMultiTable(final InsertMultiTableContext ctx) { OracleInsertStatement result = new OracleInsertStatement(); + result.setInsertSelect(new SubquerySegment(ctx.selectSubquery().start.getStartIndex(), ctx.selectSubquery().stop.getStopIndex(), (OracleSelectStatement) visit(ctx.selectSubquery()))); result.setMultiTableInsertType(null != ctx.conditionalInsertClause() && null != ctx.conditionalInsertClause().FIRST() ? MultiTableInsertType.FIRST : MultiTableInsertType.ALL); List multiTableElementContexts = ctx.multiTableElement(); if (null != multiTableElementContexts && !multiTableElementContexts.isEmpty()) { @@ -336,9 +337,6 @@ public ASTNode visitInsertMultiTable(final InsertMultiTableContext ctx) { } else { result.setMultiTableConditionalIntoSegment((MultiTableConditionalIntoSegment) visit(ctx.conditionalInsertClause())); } - OracleSelectStatement subquery = (OracleSelectStatement) visit(ctx.selectSubquery()); - SubquerySegment subquerySegment = new SubquerySegment(ctx.selectSubquery().start.getStartIndex(), ctx.selectSubquery().stop.getStopIndex(), subquery); - result.setInsertSelect(subquerySegment); result.addParameterMarkerSegments(getParameterMarkerSegments()); return result; } diff --git a/parser/sql/dialect/postgresql/src/main/java/org/apache/shardingsphere/sql/parser/postgresql/visitor/statement/PostgreSQLStatementVisitor.java b/parser/sql/dialect/postgresql/src/main/java/org/apache/shardingsphere/sql/parser/postgresql/visitor/statement/PostgreSQLStatementVisitor.java index 4cdde01be62ff..57e5005b2f56b 100644 --- a/parser/sql/dialect/postgresql/src/main/java/org/apache/shardingsphere/sql/parser/postgresql/visitor/statement/PostgreSQLStatementVisitor.java +++ b/parser/sql/dialect/postgresql/src/main/java/org/apache/shardingsphere/sql/parser/postgresql/visitor/statement/PostgreSQLStatementVisitor.java @@ -756,6 +756,13 @@ public ASTNode visitQualifiedName(final QualifiedNameContext ctx) { @Override public ASTNode visitInsertRest(final InsertRestContext ctx) { PostgreSQLInsertStatement result = new PostgreSQLInsertStatement(); + ValuesClauseContext valuesClause = ctx.select().selectNoParens().selectClauseN().simpleSelect().valuesClause(); + if (null == valuesClause) { + PostgreSQLSelectStatement selectStatement = (PostgreSQLSelectStatement) visit(ctx.select()); + result.setInsertSelect(new SubquerySegment(ctx.select().start.getStartIndex(), ctx.select().stop.getStopIndex(), selectStatement)); + } else { + result.getValues().addAll(createInsertValuesSegments(valuesClause)); + } if (null == ctx.insertColumnList()) { result.setInsertColumns(new InsertColumnsSegment(ctx.start.getStartIndex() - 1, ctx.start.getStartIndex() - 1, Collections.emptyList())); } else { @@ -764,13 +771,6 @@ public ASTNode visitInsertRest(final InsertRestContext ctx) { InsertColumnsSegment insertColumnsSegment = new InsertColumnsSegment(insertColumns.start.getStartIndex() - 1, insertColumns.stop.getStopIndex() + 1, columns.getValue()); result.setInsertColumns(insertColumnsSegment); } - ValuesClauseContext valuesClause = ctx.select().selectNoParens().selectClauseN().simpleSelect().valuesClause(); - if (null == valuesClause) { - PostgreSQLSelectStatement selectStatement = (PostgreSQLSelectStatement) visit(ctx.select()); - result.setInsertSelect(new SubquerySegment(ctx.select().start.getStartIndex(), ctx.select().stop.getStopIndex(), selectStatement)); - } else { - result.getValues().addAll(createInsertValuesSegments(valuesClause)); - } return result; } diff --git a/test/e2e/sql/src/test/java/org/apache/shardingsphere/test/e2e/engine/type/dml/BaseDMLE2EIT.java b/test/e2e/sql/src/test/java/org/apache/shardingsphere/test/e2e/engine/type/dml/BaseDMLE2EIT.java index c1dc45a7633df..e1a71a5cf6db2 100644 --- a/test/e2e/sql/src/test/java/org/apache/shardingsphere/test/e2e/engine/type/dml/BaseDMLE2EIT.java +++ b/test/e2e/sql/src/test/java/org/apache/shardingsphere/test/e2e/engine/type/dml/BaseDMLE2EIT.java @@ -86,9 +86,13 @@ void tearDown() { } protected final void assertDataSet(final AssertionTestParameter testParam, final SingleE2EContainerComposer containerComposer, final int actualUpdateCount) throws SQLException { - assertThat("Only support single table for DML.", containerComposer.getDataSet().getMetaDataList().size(), is(1)); assertThat(actualUpdateCount, is(containerComposer.getDataSet().getUpdateCount())); - DataSetMetaData expectedDataSetMetaData = containerComposer.getDataSet().getMetaDataList().get(0); + for (DataSetMetaData each : containerComposer.getDataSet().getMetaDataList()) { + assertDataSet(testParam, containerComposer, each); + } + } + + private void assertDataSet(final AssertionTestParameter testParam, final SingleE2EContainerComposer containerComposer, final DataSetMetaData expectedDataSetMetaData) throws SQLException { for (String each : InlineExpressionParserFactory.newInstance().splitAndEvaluate(expectedDataSetMetaData.getDataNodes())) { DataNode dataNode = new DataNode(each); DataSource dataSource = containerComposer.getActualDataSourceMap().get(dataNode.getDataSourceName()); diff --git a/test/it/parser/src/main/java/org/apache/shardingsphere/test/it/sql/parser/internal/asserts/statement/dml/impl/InsertStatementAssert.java b/test/it/parser/src/main/java/org/apache/shardingsphere/test/it/sql/parser/internal/asserts/statement/dml/impl/InsertStatementAssert.java index 2cdb42d37a182..1df0c82214a76 100644 --- a/test/it/parser/src/main/java/org/apache/shardingsphere/test/it/sql/parser/internal/asserts/statement/dml/impl/InsertStatementAssert.java +++ b/test/it/parser/src/main/java/org/apache/shardingsphere/test/it/sql/parser/internal/asserts/statement/dml/impl/InsertStatementAssert.java @@ -36,6 +36,7 @@ import org.apache.shardingsphere.test.it.sql.parser.internal.asserts.segment.insert.MultiTableInsertIntoClauseAssert; import org.apache.shardingsphere.test.it.sql.parser.internal.asserts.segment.insert.OnDuplicateKeyColumnsAssert; import org.apache.shardingsphere.test.it.sql.parser.internal.asserts.segment.output.OutputClauseAssert; +import org.apache.shardingsphere.test.it.sql.parser.internal.asserts.segment.parameter.ParameterMarkerAssert; import org.apache.shardingsphere.test.it.sql.parser.internal.asserts.segment.returning.ReturningClauseAssert; import org.apache.shardingsphere.test.it.sql.parser.internal.asserts.segment.set.SetClauseAssert; import org.apache.shardingsphere.test.it.sql.parser.internal.asserts.segment.table.TableAssert; @@ -119,6 +120,7 @@ private static void assertInsertSelectClause(final SQLCaseAssertContext assertCo assertFalse(actual.getInsertSelect().isPresent(), assertContext.getText("Actual insert select segment should not exist.")); } else { assertTrue(actual.getInsertSelect().isPresent(), assertContext.getText("Actual insert select segment should exist.")); + ParameterMarkerAssert.assertCount(assertContext, actual.getInsertSelect().get().getSelect().getParameterCount(), expected.getSelectTestCase().getParameters().size()); SelectStatementAssert.assertIs(assertContext, actual.getInsertSelect().get().getSelect(), expected.getSelectTestCase()); } } diff --git a/test/it/parser/src/main/resources/case/dml/insert.xml b/test/it/parser/src/main/resources/case/dml/insert.xml index a7b8ca3c729b2..34c2b6b610784 100644 --- a/test/it/parser/src/main/resources/case/dml/insert.xml +++ b/test/it/parser/src/main/resources/case/dml/insert.xml @@ -1493,7 +1493,7 @@ - @@ -1522,7 +1522,7 @@ - @@ -1557,7 +1557,7 @@ - @@ -1593,7 +1593,7 @@ - @@ -1627,7 +1627,7 @@ - diff --git a/test/it/parser/src/main/resources/case/dml/replace.xml b/test/it/parser/src/main/resources/case/dml/replace.xml index e3dccb901e72a..415a09c21c3a9 100644 --- a/test/it/parser/src/main/resources/case/dml/replace.xml +++ b/test/it/parser/src/main/resources/case/dml/replace.xml @@ -867,7 +867,7 @@ - @@ -896,7 +896,7 @@
- @@ -931,7 +931,7 @@ - @@ -967,7 +967,7 @@ -