Skip to content

Commit

Permalink
Replace value expressions with insertValueContexts
Browse files Browse the repository at this point in the history
  • Loading branch information
FlyingZC committed Dec 1, 2023
1 parent 28814e9 commit 73b1ec0
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,8 @@ private ShardingCondition createShardingCondition(final String tableName, final
List<Integer> parameterMarkerIndexes = each instanceof ParameterMarkerExpressionSegment
? Collections.singletonList(((ParameterMarkerExpressionSegment) each).getParameterMarkerIndex())
: Collections.emptyList();
result.getValues().add(new ListShardingConditionValue<>(shardingColumn.get(), tableName, Collections.singletonList(getShardingValue((SimpleExpressionSegment) each, params)),
Object shardingValue = getShardingValue((SimpleExpressionSegment) each, params);
result.getValues().add(new ListShardingConditionValue<>(shardingColumn.get(), tableName, Collections.singletonList(shardingValue),
parameterMarkerIndexes));
} else if (each instanceof CommonExpressionSegment) {
generateShardingCondition((CommonExpressionSegment) each, result, shardingColumn.get(), tableName);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import lombok.RequiredArgsConstructor;
import org.apache.shardingsphere.infra.binder.context.segment.insert.keygen.GeneratedKeyContext;
import org.apache.shardingsphere.infra.binder.context.segment.insert.values.InsertValueContext;
import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereColumn;
import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.SetAssignmentSegment;
Expand Down Expand Up @@ -47,14 +48,14 @@ public final class GeneratedKeyContextEngine {
* Create generate key context.
*
* @param insertColumnNames insert column names
* @param valueExpressions value expressions
* @param insertValueContexts value expressions
* @param params SQL parameters
* @return generate key context
*/
public Optional<GeneratedKeyContext> createGenerateKeyContext(final List<String> insertColumnNames, final List<List<ExpressionSegment>> valueExpressions, final List<Object> params) {
public Optional<GeneratedKeyContext> createGenerateKeyContext(final List<String> insertColumnNames, final List<InsertValueContext> insertValueContexts, final List<Object> params) {
String tableName = Optional.ofNullable(insertStatement.getTable()).map(optional -> optional.getTableName().getIdentifier().getValue()).orElse("");
return findGenerateKeyColumn(tableName).map(optional -> containsGenerateKey(insertColumnNames, optional)
? findGeneratedKey(insertColumnNames, valueExpressions, params, optional)
? findGeneratedKey(insertColumnNames, insertValueContexts, params, optional)
: new GeneratedKeyContext(optional, true));
}

Expand Down Expand Up @@ -89,19 +90,20 @@ private int getValueCountForPerGroup() {
return 0;
}

private GeneratedKeyContext findGeneratedKey(final List<String> insertColumnNames, final List<List<ExpressionSegment>> valueExpressions,
private GeneratedKeyContext findGeneratedKey(final List<String> insertColumnNames, final List<InsertValueContext> insertValueContexts,
final List<Object> params, final String generateKeyColumnName) {
GeneratedKeyContext result = new GeneratedKeyContext(generateKeyColumnName, false);
for (ExpressionSegment each : findGenerateKeyExpressions(insertColumnNames, valueExpressions, generateKeyColumnName)) {
if (each instanceof ParameterMarkerExpressionSegment) {
for (InsertValueContext each : insertValueContexts) {
ExpressionSegment expression = each.getValueExpressions().get(findGenerateKeyIndex(insertColumnNames, generateKeyColumnName.toLowerCase()));
if (expression instanceof ParameterMarkerExpressionSegment) {
if (params.isEmpty()) {
continue;
}
if (null != params.get(((ParameterMarkerExpressionSegment) each).getParameterMarkerIndex())) {
result.getGeneratedValues().add((Comparable<?>) params.get(((ParameterMarkerExpressionSegment) each).getParameterMarkerIndex()));
if (null != params.get(((ParameterMarkerExpressionSegment) expression).getParameterMarkerIndex())) {
result.getGeneratedValues().add((Comparable<?>) params.get(((ParameterMarkerExpressionSegment) expression).getParameterMarkerIndex()));
}
} else if (each instanceof LiteralExpressionSegment) {
result.getGeneratedValues().add((Comparable<?>) ((LiteralExpressionSegment) each).getLiterals());
} else if (expression instanceof LiteralExpressionSegment) {
result.getGeneratedValues().add((Comparable<?>) ((LiteralExpressionSegment) expression).getLiterals());
}
}
return result;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ public InsertStatementContext(final ShardingSphereMetaData metaData, final List<
ShardingSphereSchema schema = getSchema(metaData, defaultDatabaseName);
columnNames = containsInsertColumns() ? insertColumnNames
: Optional.ofNullable(sqlStatement.getTable()).map(optional -> schema.getVisibleColumnNames(optional.getTableName().getIdentifier().getValue())).orElseGet(Collections::emptyList);
generatedKeyContext = new GeneratedKeyContextEngine(sqlStatement, schema).createGenerateKeyContext(insertColumnNames, getAllValueExpressions(sqlStatement), params).orElse(null);
generatedKeyContext = new GeneratedKeyContextEngine(sqlStatement, schema).createGenerateKeyContext(insertColumnNames, insertValueContexts, params).orElse(null);
}

private ShardingSphereSchema getSchema(final ShardingSphereMetaData metaData, final String defaultDatabaseName) {
Expand Down Expand Up @@ -276,6 +276,6 @@ public void setUpParameters(final List<Object> params) {
insertSelectContext = getInsertSelectContext(metaData, params, parametersOffset, defaultDatabaseName).orElse(null);
onDuplicateKeyUpdateValueContext = getOnDuplicateKeyUpdateValueContext(params, parametersOffset).orElse(null);
ShardingSphereSchema schema = getSchema(metaData, defaultDatabaseName);
generatedKeyContext = new GeneratedKeyContextEngine(getSqlStatement(), schema).createGenerateKeyContext(insertColumnNames, valueExpressions, params).orElse(null);
generatedKeyContext = new GeneratedKeyContextEngine(getSqlStatement(), schema).createGenerateKeyContext(insertColumnNames, insertValueContexts, params).orElse(null);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.shardingsphere.infra.binder.context.segment.insert.keygen.engine;

import org.apache.shardingsphere.infra.binder.context.segment.insert.keygen.GeneratedKeyContext;
import org.apache.shardingsphere.infra.binder.context.segment.insert.values.InsertValueContext;
import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereColumn;
import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema;
import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereTable;
Expand Down Expand Up @@ -124,9 +125,10 @@ private void assertCreateGenerateKeyContextWhenCreateWithGenerateKeyColumnConfig
insertStatement.setTable(new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("tbl"))));
insertStatement.setInsertColumns(new InsertColumnsSegment(0, 0, Collections.singletonList(new ColumnSegment(0, 0, new IdentifierValue("id")))));
List<ExpressionSegment> expressionSegments = Collections.singletonList(new LiteralExpressionSegment(0, 0, 1));
InsertValueContext insertValueContext = new InsertValueContext(expressionSegments, Collections.emptyList(), 0);
insertStatement.getValues().add(new InsertValuesSegment(0, 0, expressionSegments));
Optional<GeneratedKeyContext> actual = new GeneratedKeyContextEngine(insertStatement, schema)
.createGenerateKeyContext(Collections.singletonList("id"), Collections.singletonList(expressionSegments), Collections.singletonList(1));
.createGenerateKeyContext(Collections.singletonList("id"), Collections.singletonList(insertValueContext), Collections.singletonList(1));
assertTrue(actual.isPresent());
assertThat(actual.get().getGeneratedValues().size(), is(1));
}
Expand Down Expand Up @@ -163,9 +165,10 @@ private void assertCreateGenerateKeyContextWhenFind(final InsertStatement insert
insertStatement.getValues().add(new InsertValuesSegment(0, 0, Collections.singletonList(new LiteralExpressionSegment(1, 2, 100))));
insertStatement.getValues().add(new InsertValuesSegment(0, 0, Collections.singletonList(new LiteralExpressionSegment(1, 2, "value"))));
insertStatement.getValues().add(new InsertValuesSegment(0, 0, Collections.singletonList(new CommonExpressionSegment(1, 2, "ignored value"))));
List<List<ExpressionSegment>> valueExpressions = insertStatement.getValues().stream().map(InsertValuesSegment::getValues).collect(Collectors.toList());
List<InsertValueContext> insertValueContexts = insertStatement.getValues().stream()
.map(each -> new InsertValueContext(each.getValues(), Collections.emptyList(), 0)).collect(Collectors.toList());
Optional<GeneratedKeyContext> actual = new GeneratedKeyContextEngine(insertStatement, schema)
.createGenerateKeyContext(Collections.singletonList("id"), valueExpressions, Collections.singletonList(1));
.createGenerateKeyContext(Collections.singletonList("id"), insertValueContexts, Collections.singletonList(1));
assertTrue(actual.isPresent());
assertThat(actual.get().getGeneratedValues().size(), is(3));
Iterator<Comparable<?>> generatedValuesIterator = actual.get().getGeneratedValues().iterator();
Expand Down

0 comments on commit 73b1ec0

Please sign in to comment.