Skip to content

Commit

Permalink
Added Sequence generation support
Browse files Browse the repository at this point in the history
  • Loading branch information
mipo256 committed Oct 26, 2024
1 parent 52fdadd commit 8ddf054
Show file tree
Hide file tree
Showing 37 changed files with 603 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ private <S> Object setIdAndCascadingProperties(DbAction.WithEntity<S> action, @N
PersistentPropertyPathAccessor<S> propertyAccessor = converter.getPropertyAccessor(persistentEntity,
originalEntity);

if (IdValueSource.GENERATED.equals(action.getIdValueSource())) {
if (IdValueSource.isGeneratedByDatabased(action.getIdValueSource())) {
propertyAccessor.setProperty(persistentEntity.getRequiredIdProperty(), generatedId);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,16 @@

import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.LongStream;

import org.springframework.dao.EmptyResultDataAccessException;
import org.springframework.dao.OptimisticLockingFailureException;
Expand All @@ -37,6 +44,7 @@
import org.springframework.data.relational.core.query.Query;
import org.springframework.data.relational.core.sql.LockMode;
import org.springframework.data.relational.core.sql.SqlIdentifier;
import org.springframework.data.util.Pair;
import org.springframework.jdbc.core.RowMapper;
import org.springframework.jdbc.core.namedparam.MapSqlParameterSource;
import org.springframework.jdbc.core.namedparam.NamedParameterJdbcOperations;
Expand All @@ -60,6 +68,7 @@
* @author Radim Tlusty
* @author Chirag Tailor
* @author Diego Krupitza
* @author Mikhail Polivakha
* @since 1.1
*/
public class DefaultDataAccessStrategy implements DataAccessStrategy {
Expand Down Expand Up @@ -102,31 +111,35 @@ public DefaultDataAccessStrategy(SqlGeneratorSource sqlGeneratorSource, Relation
@Override
public <T> Object insert(T instance, Class<T> domainType, Identifier identifier, IdValueSource idValueSource) {

SqlIdentifierParameterSource parameterSource = sqlParametersFactory.forInsert(instance, domainType, identifier,
idValueSource);
RelationalPersistentEntity<?> persistentEntity = context.getRequiredPersistentEntity(domainType);

Optional<Long> idFromSequence = getIdFromSequenceIfAnyDefined(idValueSource, persistentEntity);

SqlIdentifierParameterSource parameterSource = idFromSequence
.map(it -> sqlParametersFactory.forInsert(instance, domainType, identifier, it))
.orElseGet(() -> sqlParametersFactory.forInsert(instance, domainType, identifier, idValueSource));

String insertSql = sql(domainType).getInsert(parameterSource.getIdentifiers());

return insertStrategyFactory.insertStrategy(idValueSource, getIdColumn(domainType)).execute(insertSql,
parameterSource);
}
Object idAfterExecute = insertStrategyFactory.insertStrategy(idValueSource, getIdColumn(domainType))
.execute(insertSql, parameterSource);

return idFromSequence.map(it -> (Object) it).orElse(idAfterExecute);
}

@Override
public <T> Object[] insert(List<InsertSubject<T>> insertSubjects, Class<T> domainType, IdValueSource idValueSource) {

Assert.notEmpty(insertSubjects, "Batch insert must contain at least one InsertSubject");
SqlIdentifierParameterSource[] sqlParameterSources = insertSubjects.stream()
.map(insertSubject -> sqlParametersFactory.forInsert(insertSubject.getInstance(), domainType,
insertSubject.getIdentifier(), idValueSource))
.toArray(SqlIdentifierParameterSource[]::new);

String insertSql = sql(domainType).getInsert(sqlParameterSources[0].getIdentifiers());
if (IdValueSource.SEQUENCE.equals(idValueSource)) {
return executeBatchInsertWithSequenceAsIdSource(insertSubjects, domainType, idValueSource);
} else {
return executeBatchInsert(insertSubjects, domainType, idValueSource);
}
}

return insertStrategyFactory.batchInsertStrategy(idValueSource, getIdColumn(domainType)).execute(insertSql,
sqlParameterSources);
}

@Override
@Override
public <S> boolean update(S instance, Class<S> domainType) {

SqlIdentifierParameterSource parameterSource = sqlParametersFactory.forUpdate(instance, domainType);
Expand Down Expand Up @@ -446,4 +459,70 @@ private Class<?> getBaseType(PersistentPropertyPath<RelationalPersistentProperty
return baseProperty.getOwner().getType();
}

private <T> Object[] executeBatchInsert(List<InsertSubject<T>> insertSubjects, Class<T> domainType, IdValueSource idValueSource) {
SqlIdentifierParameterSource[] sqlParameterSources = insertSubjects
.stream()
.map(insertSubject -> sqlParametersFactory.forInsert(
insertSubject.getInstance(), domainType,
insertSubject.getIdentifier(), idValueSource)
)
.toArray(SqlIdentifierParameterSource[]::new);

String insertSql = sql(domainType).getInsert(sqlParameterSources[0].getIdentifiers());

return insertStrategyFactory.batchInsertStrategy(idValueSource, getIdColumn(domainType))
.execute(insertSql, sqlParameterSources);
}

private <T> Object[] executeBatchInsertWithSequenceAsIdSource(List<InsertSubject<T>> insertSubjects, Class<T> domainType, IdValueSource idValueSource) {
List<Pair<Long, SqlIdentifierParameterSource>> sqlParameterSources = createBatchParameterSourcesWithSequence(insertSubjects, domainType,
context.getPersistentEntity(domainType).getIdTargetSequence()
);

String insertSql = sql(domainType).getInsert(sqlParameterSources.get(0).getSecond().getIdentifiers());

insertStrategyFactory.batchInsertStrategy(idValueSource, getIdColumn(domainType))
.execute(insertSql, sqlParameterSources.stream()
.map(Pair::getSecond)
.toArray(SqlIdentifierParameterSource[]::new));

return sqlParameterSources.stream().map(Pair::getFirst).toArray(Object[]::new);
}

private <T> List<Pair<Long, SqlIdentifierParameterSource>> createBatchParameterSourcesWithSequence(List<InsertSubject<T>> insertSubjects, Class<T> domainType, Optional<String> idTargetSequence) {
List<Pair<Long, SqlIdentifierParameterSource>> sqlParameterSources;
int subjectsSize = insertSubjects.size();

List<Long> generatedIds = getMultipleIdsFromSequence(idTargetSequence.get(), subjectsSize);

sqlParameterSources = IntStream
.range(0, subjectsSize)
.mapToObj(index -> {
InsertSubject<T> subject = insertSubjects.get(index);
Long generatedId = generatedIds.get(index);
return Pair.of(generatedId, sqlParametersFactory.forInsert(
subject.getInstance(), domainType,
subject.getIdentifier(), generatedId
));
})
.collect(Collectors.toList());
return sqlParameterSources;
}

private Optional<Long> getIdFromSequenceIfAnyDefined(IdValueSource idValueSource, RelationalPersistentEntity<?> persistentEntity) {
if (IdValueSource.SEQUENCE.equals(idValueSource) && persistentEntity.getIdTargetSequence().isPresent()) {
String nextSequenceValueSelect = insertStrategyFactory.getDialect().nextValueFromSequenceSelect(persistentEntity.getIdTargetSequence().get());
return Optional.of(operations.queryForObject(nextSequenceValueSelect, Map.of(), (rs, rowNum) -> rs.getLong(1)));
}
return Optional.empty();
}

private List<Long> getMultipleIdsFromSequence(String sequenceName, Integer requiredIds) {
String nextSequenceValueSelect = insertStrategyFactory.getDialect().nextValueFromSequenceSelect(sequenceName);

return IntStream.range(0, requiredIds)
.mapToObj(operand -> operations.queryForObject(nextSequenceValueSelect, Map.of(), (rs, rowNum) -> rs.getLong(1)))
.collect(Collectors.toList());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
*
* @author Chirag Tailor
* @author Jens Schauder
* @author Mikhail Polivakha
* @since 2.4
*/
public class InsertStrategyFactory {
Expand Down Expand Up @@ -102,4 +103,7 @@ public Object[] execute(String sql, SqlParameterSource[] sqlParameterSources) {
}
}

public Dialect getDialect() {
return this.dialect;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.util.Set;

import org.springframework.data.relational.core.sql.SqlIdentifier;
import org.springframework.data.util.Pair;
import org.springframework.jdbc.core.namedparam.AbstractSqlParameterSource;

/**
Expand All @@ -35,9 +36,11 @@
*/
class SqlIdentifierParameterSource extends AbstractSqlParameterSource {

private final Set<SqlIdentifier> identifiers = new HashSet<>();
private final Set<SqlIdentifier> sqlIdentifiers = new HashSet<>();
private final Map<String, Object> namesToValues = new HashMap<>();

private Pair<SqlIdentifier, Object> idToValue;

@Override
public boolean hasValue(String paramName) {
return namesToValues.containsKey(paramName);
Expand All @@ -54,17 +57,31 @@ public String[] getParameterNames() {
}

Set<SqlIdentifier> getIdentifiers() {
return Collections.unmodifiableSet(identifiers);
return Collections.unmodifiableSet(sqlIdentifiers);
}

void addValue(SqlIdentifier name, Object value) {
addValue(name, value, Integer.MIN_VALUE);
}

void addValue(SqlIdentifier identifier, Object value, int sqlType) {
void addValue(SqlIdentifier sqlIdentifier, Object value, int sqlType) {

sqlIdentifiers.add(sqlIdentifier);
String name = prepareSqlIdentifierName(sqlIdentifier);
namesToValues.put(name, value);
registerSqlType(name, sqlType);
}

/**
* Adds an Id of the record
* @param sqlIdentifier
* @param value
* @param sqlType
*/
void addId(SqlIdentifier sqlIdentifier, Object value, int sqlType) {

identifiers.add(identifier);
String name = BindParameterNameSanitizer.sanitize(identifier.getReference());
sqlIdentifiers.add(sqlIdentifier);
String name = prepareSqlIdentifierName(sqlIdentifier);
namesToValues.put(name, value);
registerSqlType(name, sqlType);
}
Expand All @@ -73,11 +90,15 @@ void addAll(SqlIdentifierParameterSource others) {

for (SqlIdentifier identifier : others.getIdentifiers()) {

String name = BindParameterNameSanitizer.sanitize( identifier.getReference());
String name = prepareSqlIdentifierName(identifier);
addValue(identifier, others.getValue(name), others.getSqlType(name));
}
}

private static String prepareSqlIdentifierName(SqlIdentifier sqlIdentifier) {
return BindParameterNameSanitizer.sanitize(sqlIdentifier.getReference());
}

int size() {
return namesToValues.size();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,28 @@
*/
package org.springframework.data.jdbc.core.convert;

import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.SQLType;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Predicate;

import org.springframework.data.jdbc.core.mapping.JdbcValue;
import org.springframework.data.jdbc.support.JdbcUtil;
import org.springframework.data.mapping.PersistentProperty;
import org.springframework.data.mapping.PersistentPropertyAccessor;
import org.springframework.data.relational.core.conversion.IdValueSource;
import org.springframework.data.relational.core.dialect.Dialect;
import org.springframework.data.relational.core.mapping.RelationalMappingContext;
import org.springframework.data.relational.core.mapping.RelationalPersistentEntity;
import org.springframework.data.relational.core.mapping.RelationalPersistentProperty;
import org.springframework.data.relational.core.sql.SqlIdentifier;
import org.springframework.jdbc.core.RowMapper;
import org.springframework.jdbc.core.namedparam.NamedParameterJdbcOperations;
import org.springframework.jdbc.core.namedparam.SqlParameterSource;
import org.springframework.jdbc.support.JdbcUtils;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
Expand All @@ -46,13 +53,15 @@
public class SqlParametersFactory {
private final RelationalMappingContext context;
private final JdbcConverter converter;
private final Dialect dialect;

/**
* @since 3.1
*/
public SqlParametersFactory(RelationalMappingContext context, JdbcConverter converter) {
private final NamedParameterJdbcOperations operations;

public SqlParametersFactory(RelationalMappingContext context, JdbcConverter converter, Dialect dialect, NamedParameterJdbcOperations operations) {
this.context = context;
this.converter = converter;
this.dialect = dialect;
this.operations = operations;
}

/**
Expand All @@ -70,18 +79,38 @@ public SqlParametersFactory(RelationalMappingContext context, JdbcConverter conv
<T> SqlIdentifierParameterSource forInsert(T instance, Class<T> domainType, Identifier identifier,
IdValueSource idValueSource) {

RelationalPersistentEntity<T> persistentEntity = getRequiredPersistentEntity(domainType);

Object idValue = null;

if (IdValueSource.PROVIDED.equals(idValueSource)) {
idValue = persistentEntity.getIdentifierAccessor(instance).getRequiredIdentifier();
}
return forInsert(instance, domainType, identifier, idValue);
}

/**
* Creates the parameters for a SQL insert operation. That method is different from its sibling
* {@link #forInsert(Object, Class, Identifier, IdValueSource) forInsert method} in the sense, that
* this method is invoked when we actually know the id to be added to the {@link SqlParameterSource paarameter source}.
* It might be null, meaning, that we know for sure the id should be coming from the database, or
* it could be not null, meaning, that we've got the id from some source (user provided by himself,
* or we have queried the sequence for instance)
*/
<T> SqlIdentifierParameterSource forInsert(T instance, Class<T> domainType, Identifier identifier,
@Nullable Object id) {

RelationalPersistentEntity<T> persistentEntity = getRequiredPersistentEntity(domainType);
SqlIdentifierParameterSource parameterSource = getParameterSource(instance, persistentEntity, "",
PersistentProperty::isIdProperty);

identifier.forEach((name, value, type) -> addConvertedPropertyValue(parameterSource, name, value, type));

if (IdValueSource.PROVIDED.equals(idValueSource)) {

RelationalPersistentProperty idProperty = persistentEntity.getRequiredIdProperty();
Object idValue = persistentEntity.getIdentifierAccessor(instance).getRequiredIdentifier();
addConvertedPropertyValue(parameterSource, idProperty, idValue, idProperty.getColumnName());
}
RelationalPersistentProperty idProperty = persistentEntity.getIdProperty();
Optional
.ofNullable(id)
.filter(it -> idProperty != null)
.ifPresent(it -> addConvertedPropertyValue(parameterSource, idProperty, it, idProperty.getColumnName()));
return parameterSource;
}

Expand Down Expand Up @@ -178,6 +207,13 @@ private void addConvertedPropertyValue(SqlIdentifierParameterSource parameterSou
converter.getTargetSqlType(property));
}

private void addConvertedIdPropertyValue(SqlIdentifierParameterSource parameterSource,
RelationalPersistentProperty property, @Nullable Object value, SqlIdentifier name) {

addConvertedValue(parameterSource, value, name, converter.getColumnType(property),
converter.getTargetSqlType(property));
}

private void addConvertedPropertyValue(SqlIdentifierParameterSource parameterSource, SqlIdentifier name, Object value,
Class<?> javaType) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ public static DataAccessStrategy createCombinedAccessStrategy(RelationalMappingC
NamespaceStrategy namespaceStrategy, Dialect dialect) {

SqlGeneratorSource sqlGeneratorSource = new SqlGeneratorSource(context, converter, dialect);
SqlParametersFactory sqlParametersFactory = new SqlParametersFactory(context, converter);
SqlParametersFactory sqlParametersFactory = new SqlParametersFactory(context, converter, dialect, operations);
InsertStrategyFactory insertStrategyFactory = new InsertStrategyFactory(operations, dialect);

DataAccessStrategy defaultDataAccessStrategy = new DataAccessStrategyFactory( //
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ public DataAccessStrategy dataAccessStrategyBean(NamedParameterJdbcOperations op

SqlGeneratorSource sqlGeneratorSource = new SqlGeneratorSource(context, jdbcConverter, dialect);
DataAccessStrategyFactory factory = new DataAccessStrategyFactory(sqlGeneratorSource, jdbcConverter, operations,
new SqlParametersFactory(context, jdbcConverter),
new SqlParametersFactory(context, jdbcConverter, dialect, operations),
new InsertStrategyFactory(operations, dialect));

return factory.create();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ public void afterPropertiesSet() {

SqlGeneratorSource sqlGeneratorSource = new SqlGeneratorSource(this.mappingContext, this.converter,
this.dialect);
SqlParametersFactory sqlParametersFactory = new SqlParametersFactory(this.mappingContext, this.converter);
SqlParametersFactory sqlParametersFactory = new SqlParametersFactory(this.mappingContext, this.converter, this.dialect, this.operations);
InsertStrategyFactory insertStrategyFactory = new InsertStrategyFactory(this.operations, this.dialect);

DataAccessStrategyFactory factory = new DataAccessStrategyFactory(sqlGeneratorSource, this.converter,
Expand Down
Loading

0 comments on commit 8ddf054

Please sign in to comment.