diff --git a/features/sharding/core/src/main/java/org/apache/shardingsphere/sharding/route/engine/condition/engine/InsertClauseShardingConditionEngine.java b/features/sharding/core/src/main/java/org/apache/shardingsphere/sharding/route/engine/condition/engine/InsertClauseShardingConditionEngine.java index c6b2b889b0bdf..d0ae2515d2d6b 100644 --- a/features/sharding/core/src/main/java/org/apache/shardingsphere/sharding/route/engine/condition/engine/InsertClauseShardingConditionEngine.java +++ b/features/sharding/core/src/main/java/org/apache/shardingsphere/sharding/route/engine/condition/engine/InsertClauseShardingConditionEngine.java @@ -137,7 +137,8 @@ private ShardingCondition createShardingCondition(final String tableName, final List parameterMarkerIndexes = each instanceof ParameterMarkerExpressionSegment ? Collections.singletonList(((ParameterMarkerExpressionSegment) each).getParameterMarkerIndex()) : Collections.emptyList(); - result.getValues().add(new ListShardingConditionValue<>(shardingColumn.get(), tableName, Collections.singletonList(getShardingValue((SimpleExpressionSegment) each, params)), + Object shardingValue = getShardingValue((SimpleExpressionSegment) each, params); + result.getValues().add(new ListShardingConditionValue<>(shardingColumn.get(), tableName, Collections.singletonList(shardingValue), parameterMarkerIndexes)); } else if (each instanceof CommonExpressionSegment) { generateShardingCondition((CommonExpressionSegment) each, result, shardingColumn.get(), tableName); diff --git a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/segment/insert/keygen/engine/GeneratedKeyContextEngine.java b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/segment/insert/keygen/engine/GeneratedKeyContextEngine.java index fa7921f4cdc29..356b713910dc0 100644 --- a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/segment/insert/keygen/engine/GeneratedKeyContextEngine.java +++ b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/segment/insert/keygen/engine/GeneratedKeyContextEngine.java @@ -19,6 +19,7 @@ import lombok.RequiredArgsConstructor; import org.apache.shardingsphere.infra.binder.context.segment.insert.keygen.GeneratedKeyContext; +import org.apache.shardingsphere.infra.binder.context.segment.insert.values.InsertValueContext; import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereColumn; import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema; import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.SetAssignmentSegment; @@ -47,14 +48,14 @@ public final class GeneratedKeyContextEngine { * Create generate key context. * * @param insertColumnNames insert column names - * @param valueExpressions value expressions + * @param insertValueContexts value expressions * @param params SQL parameters * @return generate key context */ - public Optional createGenerateKeyContext(final List insertColumnNames, final List> valueExpressions, final List params) { + public Optional createGenerateKeyContext(final List insertColumnNames, final List insertValueContexts, final List params) { String tableName = Optional.ofNullable(insertStatement.getTable()).map(optional -> optional.getTableName().getIdentifier().getValue()).orElse(""); return findGenerateKeyColumn(tableName).map(optional -> containsGenerateKey(insertColumnNames, optional) - ? findGeneratedKey(insertColumnNames, valueExpressions, params, optional) + ? findGeneratedKey(insertColumnNames, insertValueContexts, params, optional) : new GeneratedKeyContext(optional, true)); } @@ -89,19 +90,20 @@ private int getValueCountForPerGroup() { return 0; } - private GeneratedKeyContext findGeneratedKey(final List insertColumnNames, final List> valueExpressions, + private GeneratedKeyContext findGeneratedKey(final List insertColumnNames, final List insertValueContexts, final List params, final String generateKeyColumnName) { GeneratedKeyContext result = new GeneratedKeyContext(generateKeyColumnName, false); - for (ExpressionSegment each : findGenerateKeyExpressions(insertColumnNames, valueExpressions, generateKeyColumnName)) { - if (each instanceof ParameterMarkerExpressionSegment) { + for (InsertValueContext each : insertValueContexts) { + ExpressionSegment expression = each.getValueExpressions().get(findGenerateKeyIndex(insertColumnNames, generateKeyColumnName.toLowerCase())); + if (expression instanceof ParameterMarkerExpressionSegment) { if (params.isEmpty()) { continue; } - if (null != params.get(((ParameterMarkerExpressionSegment) each).getParameterMarkerIndex())) { - result.getGeneratedValues().add((Comparable) params.get(((ParameterMarkerExpressionSegment) each).getParameterMarkerIndex())); + if (null != params.get(((ParameterMarkerExpressionSegment) expression).getParameterMarkerIndex())) { + result.getGeneratedValues().add((Comparable) params.get(((ParameterMarkerExpressionSegment) expression).getParameterMarkerIndex())); } - } else if (each instanceof LiteralExpressionSegment) { - result.getGeneratedValues().add((Comparable) ((LiteralExpressionSegment) each).getLiterals()); + } else if (expression instanceof LiteralExpressionSegment) { + result.getGeneratedValues().add((Comparable) ((LiteralExpressionSegment) expression).getLiterals()); } } return result; diff --git a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/statement/dml/InsertStatementContext.java b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/statement/dml/InsertStatementContext.java index 2099a80a8f22d..90e3e08e55aad 100644 --- a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/statement/dml/InsertStatementContext.java +++ b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/statement/dml/InsertStatementContext.java @@ -96,7 +96,7 @@ public InsertStatementContext(final ShardingSphereMetaData metaData, final List< ShardingSphereSchema schema = getSchema(metaData, defaultDatabaseName); columnNames = containsInsertColumns() ? insertColumnNames : Optional.ofNullable(sqlStatement.getTable()).map(optional -> schema.getVisibleColumnNames(optional.getTableName().getIdentifier().getValue())).orElseGet(Collections::emptyList); - generatedKeyContext = new GeneratedKeyContextEngine(sqlStatement, schema).createGenerateKeyContext(insertColumnNames, getAllValueExpressions(sqlStatement), params).orElse(null); + generatedKeyContext = new GeneratedKeyContextEngine(sqlStatement, schema).createGenerateKeyContext(insertColumnNames, insertValueContexts, params).orElse(null); } private ShardingSphereSchema getSchema(final ShardingSphereMetaData metaData, final String defaultDatabaseName) { @@ -276,6 +276,6 @@ public void setUpParameters(final List params) { insertSelectContext = getInsertSelectContext(metaData, params, parametersOffset, defaultDatabaseName).orElse(null); onDuplicateKeyUpdateValueContext = getOnDuplicateKeyUpdateValueContext(params, parametersOffset).orElse(null); ShardingSphereSchema schema = getSchema(metaData, defaultDatabaseName); - generatedKeyContext = new GeneratedKeyContextEngine(getSqlStatement(), schema).createGenerateKeyContext(insertColumnNames, valueExpressions, params).orElse(null); + generatedKeyContext = new GeneratedKeyContextEngine(getSqlStatement(), schema).createGenerateKeyContext(insertColumnNames, insertValueContexts, params).orElse(null); } } diff --git a/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/context/segment/insert/keygen/engine/GeneratedKeyContextEngineTest.java b/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/context/segment/insert/keygen/engine/GeneratedKeyContextEngineTest.java index 8e072bcb791e8..7b2c47bff73c4 100644 --- a/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/context/segment/insert/keygen/engine/GeneratedKeyContextEngineTest.java +++ b/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/context/segment/insert/keygen/engine/GeneratedKeyContextEngineTest.java @@ -18,6 +18,7 @@ package org.apache.shardingsphere.infra.binder.context.segment.insert.keygen.engine; import org.apache.shardingsphere.infra.binder.context.segment.insert.keygen.GeneratedKeyContext; +import org.apache.shardingsphere.infra.binder.context.segment.insert.values.InsertValueContext; import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereColumn; import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema; import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereTable; @@ -124,9 +125,10 @@ private void assertCreateGenerateKeyContextWhenCreateWithGenerateKeyColumnConfig insertStatement.setTable(new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("tbl")))); insertStatement.setInsertColumns(new InsertColumnsSegment(0, 0, Collections.singletonList(new ColumnSegment(0, 0, new IdentifierValue("id"))))); List expressionSegments = Collections.singletonList(new LiteralExpressionSegment(0, 0, 1)); + InsertValueContext insertValueContext = new InsertValueContext(expressionSegments, Collections.emptyList(), 0); insertStatement.getValues().add(new InsertValuesSegment(0, 0, expressionSegments)); Optional actual = new GeneratedKeyContextEngine(insertStatement, schema) - .createGenerateKeyContext(Collections.singletonList("id"), Collections.singletonList(expressionSegments), Collections.singletonList(1)); + .createGenerateKeyContext(Collections.singletonList("id"), Collections.singletonList(insertValueContext), Collections.singletonList(1)); assertTrue(actual.isPresent()); assertThat(actual.get().getGeneratedValues().size(), is(1)); } @@ -163,9 +165,10 @@ private void assertCreateGenerateKeyContextWhenFind(final InsertStatement insert insertStatement.getValues().add(new InsertValuesSegment(0, 0, Collections.singletonList(new LiteralExpressionSegment(1, 2, 100)))); insertStatement.getValues().add(new InsertValuesSegment(0, 0, Collections.singletonList(new LiteralExpressionSegment(1, 2, "value")))); insertStatement.getValues().add(new InsertValuesSegment(0, 0, Collections.singletonList(new CommonExpressionSegment(1, 2, "ignored value")))); - List> valueExpressions = insertStatement.getValues().stream().map(InsertValuesSegment::getValues).collect(Collectors.toList()); + List insertValueContexts = insertStatement.getValues().stream() + .map(each -> new InsertValueContext(each.getValues(), Collections.emptyList(), 0)).collect(Collectors.toList()); Optional actual = new GeneratedKeyContextEngine(insertStatement, schema) - .createGenerateKeyContext(Collections.singletonList("id"), valueExpressions, Collections.singletonList(1)); + .createGenerateKeyContext(Collections.singletonList("id"), insertValueContexts, Collections.singletonList(1)); assertTrue(actual.isPresent()); assertThat(actual.get().getGeneratedValues().size(), is(3)); Iterator> generatedValuesIterator = actual.get().getGeneratedValues().iterator();