Skip to content

Commit b44c9f0

Browse files
committed
Insert action includeId property should only be true when id value is not null and not primitive zero.
+ Moved this logic from SqlParametersFactory to WritingContext and updatede SqlParametersFactory#getInsert to use the includeId parameter.
1 parent f77545c commit b44c9f0

File tree

7 files changed

+145
-40
lines changed

7 files changed

+145
-40
lines changed

spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/JdbcAggregateChangeExecutionContext.java

+3-3
Original file line numberDiff line numberDiff line change
@@ -89,11 +89,11 @@ <T> void executeInsertRoot(DbAction.InsertRoot<T> insert) {
8989
T rootEntity = RelationalEntityVersionUtils.setVersionNumberOnEntity( //
9090
insert.getEntity(), initialVersion, persistentEntity, converter);
9191

92-
id = accessStrategy.insert(rootEntity, insert.getEntityType(), Identifier.empty(), false);
92+
id = accessStrategy.insert(rootEntity, insert.getEntityType(), Identifier.empty(), insert.isIncludeId());
9393

9494
setNewVersion(initialVersion);
9595
} else {
96-
id = accessStrategy.insert(insert.getEntity(), insert.getEntityType(), Identifier.empty(), false);
96+
id = accessStrategy.insert(insert.getEntity(), insert.getEntityType(), Identifier.empty(), insert.isIncludeId());
9797
}
9898

9999
add(new DbActionExecutionResult(insert, id));
@@ -102,7 +102,7 @@ <T> void executeInsertRoot(DbAction.InsertRoot<T> insert) {
102102
<T> void executeInsert(DbAction.Insert<T> insert) {
103103

104104
Identifier parentKeys = getParentKeys(insert, converter);
105-
Object id = accessStrategy.insert(insert.getEntity(), insert.getEntityType(), parentKeys, false);
105+
Object id = accessStrategy.insert(insert.getEntity(), insert.getEntityType(), parentKeys, insert.isIncludeId());
106106
add(new DbActionExecutionResult(insert, id));
107107
}
108108

spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DefaultDataAccessStrategy.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ public DefaultDataAccessStrategy(SqlGeneratorSource sqlGeneratorSource, Relation
9999
@Override
100100
public <T> Object insert(T instance, Class<T> domainType, Identifier identifier, boolean includeId) {
101101

102-
SqlIdentifierParameterSource parameterSource = sqlParametersFactory.getInsert(instance, domainType, identifier);
102+
SqlIdentifierParameterSource parameterSource = sqlParametersFactory.getInsert(instance, domainType, identifier, includeId);
103103

104104
String insertSql = sql(domainType).getInsert(parameterSource.getIdentifiers());
105105

@@ -111,7 +111,7 @@ public <T> Object[] insert(List<InsertSubject<T>> insertSubjects, Class<T> domai
111111

112112
Assert.notEmpty(insertSubjects, "Batch insert must contain at least one InsertSubject");
113113
SqlIdentifierParameterSource[] sqlParameterSources = insertSubjects.stream()
114-
.map(insertSubject -> sqlParametersFactory.getInsert(insertSubject.getInstance(), domainType, insertSubject.getIdentifier()))
114+
.map(insertSubject -> sqlParametersFactory.getInsert(insertSubject.getInstance(), domainType, insertSubject.getIdentifier(), includeId))
115115
.toArray(SqlIdentifierParameterSource[]::new);
116116

117117
String insertSql = sql(domainType).getInsert(sqlParameterSources[0].getIdentifiers());

spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/SqlParametersFactory.java

+3-26
Original file line numberDiff line numberDiff line change
@@ -32,17 +32,17 @@ public SqlParametersFactory(RelationalMappingContext context, JdbcConverter conv
3232
this.dialect = dialect;
3333
}
3434

35-
<T> SqlIdentifierParameterSource getInsert(T instance, Class<T> domainType, Identifier identifier) {
35+
<T> SqlIdentifierParameterSource getInsert(T instance, Class<T> domainType, Identifier identifier, boolean includeId) {
3636
RelationalPersistentEntity<T> persistentEntity = getRequiredPersistentEntity(domainType);
3737
SqlIdentifierParameterSource parameterSource = getParameterSource(instance, persistentEntity, "",
3838
PersistentProperty::isIdProperty, dialect.getIdentifierProcessing());
3939

4040
identifier.forEach((name, value, type) -> addConvertedPropertyValue(parameterSource, name, value, type));
4141

42-
Object idValue = getIdValueOrNull(instance, persistentEntity);
43-
if (idValue != null) {
42+
if (includeId) {
4443

4544
RelationalPersistentProperty idProperty = persistentEntity.getRequiredIdProperty();
45+
Object idValue = persistentEntity.getIdentifierAccessor(instance).getRequiredIdentifier();
4646
addConvertedPropertyValue(parameterSource, idProperty, idValue, idProperty.getColumnName());
4747
}
4848
return parameterSource;
@@ -98,29 +98,6 @@ static Predicate<RelationalPersistentProperty> includeAll() {
9898
return it -> false;
9999
}
100100
}
101-
102-
/**
103-
* Returns the id value if its not a primitive zero. Returns {@literal null} if the id value is null or a primitive
104-
* zero.
105-
*/
106-
@Nullable
107-
@SuppressWarnings("unchecked")
108-
private <S, ID> ID getIdValueOrNull(S instance, RelationalPersistentEntity<S> persistentEntity) {
109-
110-
ID idValue = (ID) persistentEntity.getIdentifierAccessor(instance).getIdentifier();
111-
112-
return isIdPropertyNullOrScalarZero(idValue, persistentEntity) ? null : idValue;
113-
}
114-
115-
private static <S, ID> boolean isIdPropertyNullOrScalarZero(@Nullable ID idValue,
116-
RelationalPersistentEntity<S> persistentEntity) {
117-
118-
RelationalPersistentProperty idProperty = persistentEntity.getIdProperty();
119-
return idValue == null //
120-
|| idProperty == null //
121-
|| (idProperty.getType() == int.class && idValue.equals(0)) //
122-
|| (idProperty.getType() == long.class && idValue.equals(0L));
123-
}
124101

125102
private void addConvertedPropertyValue(SqlIdentifierParameterSource parameterSource,
126103
RelationalPersistentProperty property, @Nullable Object value, SqlIdentifier name) {

spring-data-jdbc/src/test/java/org/springframework/data/jdbc/core/convert/DefaultDataAccessStrategyUnitTests.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ public void before() {
7171

7272
relationResolver.setDelegate(accessStrategy);
7373

74-
when(sqlParametersFactory.getInsert(any(), any(), any()))
74+
when(sqlParametersFactory.getInsert(any(), any(), any(), anyBoolean()))
7575
.thenReturn(new SqlIdentifierParameterSource(dialect.getIdentifierProcessing()));
7676
when(insertStrategyFactory.insertStrategy(anyBoolean(), any())).thenReturn(mock(InsertStrategy.class));
7777
when(insertStrategyFactory.batchInsertStrategy(anyBoolean(), any())).thenReturn(mock(BatchInsertStrategy.class));

spring-data-jdbc/src/test/java/org/springframework/data/jdbc/core/convert/SqlParametersFactoryTest.java

+4-4
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ void identifiersGetAddedAsParameters() {
7474
DummyEntity instance = new DummyEntity(id);
7575
long reference = 23L;
7676
SqlIdentifierParameterSource sqlParameterSource = sqlParametersFactory.getInsert(instance, DummyEntity.class,
77-
Identifier.of(SqlIdentifier.unquoted("reference"), reference, Long.class));
77+
Identifier.of(SqlIdentifier.unquoted("reference"), reference, Long.class), true);
7878

7979
assertThat(sqlParameterSource.getParameterNames()).hasSize(2);
8080
assertThat(sqlParameterSource.getValue("id")).isEqualTo(id);
@@ -87,7 +87,7 @@ void additionalIdentifierForIdDoesNotLeadToDuplicateParameters() {
8787
long id = 4711L;
8888
DummyEntity instance = new DummyEntity(id);
8989
SqlIdentifierParameterSource sqlParameterSource = sqlParametersFactory.getInsert(instance, DummyEntity.class,
90-
Identifier.of(SqlIdentifier.unquoted("id"), 23L, Long.class));
90+
Identifier.of(SqlIdentifier.unquoted("id"), 23L, Long.class), true);
9191

9292
assertThat(sqlParameterSource.getParameterNames()).hasSize(1);
9393
assertThat(sqlParameterSource.getValue("id")).isEqualTo(id);
@@ -101,7 +101,7 @@ void considersConfiguredWriteConverter() {
101101

102102
long id = 4711L;
103103
SqlIdentifierParameterSource sqlParameterSource = sqlParametersFactory.getInsert(new EntityWithBoolean(id, true),
104-
EntityWithBoolean.class, Identifier.empty());
104+
EntityWithBoolean.class, Identifier.empty(), true);
105105

106106
assertThat(sqlParameterSource.getValue("id")).isEqualTo(id);
107107
assertThat(sqlParameterSource.getValue("flag")).isEqualTo("T");
@@ -118,7 +118,7 @@ void considersConfiguredWriteConverterForIdValueObjects_onWrite() {
118118
entity.value = value;
119119

120120
SqlIdentifierParameterSource sqlParameterSource = sqlParametersFactory.getInsert(entity, WithValueObjectId.class,
121-
Identifier.empty());
121+
Identifier.empty(), true);
122122
assertThat(sqlParameterSource.getValue("id")).isEqualTo(rawId);
123123
assertThat(sqlParameterSource.getValue("value")).isEqualTo(value);
124124
}

spring-data-relational/src/main/java/org/springframework/data/relational/core/conversion/WritingContext.java

+17-4
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
import org.springframework.data.util.Pair;
2929
import org.springframework.lang.Nullable;
3030
import org.springframework.util.Assert;
31-
import org.springframework.util.comparator.BooleanComparator;
3231

3332
/**
3433
* Holds context information for the current save operation.
@@ -55,8 +54,7 @@ class WritingContext {
5554
this.root = root;
5655
this.entity = aggregateChange.getEntity();
5756
this.entityType = aggregateChange.getEntityType();
58-
this.rootIncludeId = context.getRequiredPersistentEntity(aggregateChange.getEntityType())
59-
.getIdentifierAccessor(root).getIdentifier() != null;
57+
this.rootIncludeId = idPropertyIsSet(root, context.getRequiredPersistentEntity(aggregateChange.getEntityType()));
6058
this.paths = context.findPersistentPropertyPaths(entityType, (p) -> p.isEntity() && !p.isEmbedded());
6159
}
6260

@@ -146,7 +144,7 @@ private List<? extends DbAction<?>> insertAll(PersistentPropertyPath<RelationalP
146144
} else {
147145
instance = node.getValue();
148146
}
149-
boolean includeId = persistentEntity.getIdentifierAccessor(instance).getIdentifier() != null;
147+
boolean includeId = idPropertyIsSet(instance, persistentEntity);
150148
DbAction.Insert<Object> insert = new DbAction.Insert<>(instance, path, parentAction, qualifiers, includeId);
151149
inserts.add(insert);
152150
previousActions.put(node, insert);
@@ -304,5 +302,20 @@ private List<PathNode> createNodes(PersistentPropertyPath<RelationalPersistentPr
304302

305303
return nodes;
306304
}
305+
306+
/**
307+
* Returns whether the id property is set. This is the case only when the value of the
308+
* id property is not null and when it is a primitive type, not zero.
309+
*/
310+
private static <S> boolean idPropertyIsSet(Object instance,
311+
RelationalPersistentEntity<S> persistentEntity) {
312+
313+
Object idValue = persistentEntity.getIdentifierAccessor(instance).getIdentifier();
314+
RelationalPersistentProperty idProperty = persistentEntity.getIdProperty();
315+
return idValue != null //
316+
&& idProperty != null //
317+
&& (idProperty.getType() != int.class || !idValue.equals(0)) //
318+
&& (idProperty.getType() != long.class || !idValue.equals(0L));
319+
}
307320

308321
}

spring-data-relational/src/test/java/org/springframework/data/relational/core/conversion/RelationalEntityWriterUnitTests.java

+115
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
import org.springframework.data.relational.core.mapping.RelationalPersistentProperty;
4444
import org.springframework.lang.Nullable;
4545

46+
import lombok.Data;
4647
import lombok.RequiredArgsConstructor;
4748

4849
/**
@@ -99,6 +100,48 @@ public void newEntityGetsConvertedToOneInsert() {
99100
);
100101
}
101102

103+
@Test
104+
void newEntityWithPrimitiveLongId_insertDoesNotIncludeId_whenIdValueIsZero() {
105+
PrimitiveLongIdEntity entity = new PrimitiveLongIdEntity();
106+
107+
MutableAggregateChange<PrimitiveLongIdEntity> aggregateChange = //
108+
new DefaultAggregateChange<>(AggregateChange.Kind.SAVE, PrimitiveLongIdEntity.class, entity);
109+
110+
converter.write(entity, aggregateChange);
111+
112+
assertThat(extractActions(aggregateChange)) //
113+
.extracting(DbAction::getClass, //
114+
DbAction::getEntityType, //
115+
DbActionTestSupport::extractPath, //
116+
DbActionTestSupport::actualEntityType, //
117+
DbActionTestSupport::isWithDependsOn, //
118+
DbActionTestSupport::insertIncludeId) //
119+
.containsExactly( //
120+
tuple(InsertRoot.class, PrimitiveLongIdEntity.class, "", PrimitiveLongIdEntity.class, false, false) //
121+
);
122+
}
123+
124+
@Test
125+
void newEntityWithPrimitiveIntId_insertDoesNotIncludeId_whenIdValueIsZero() {
126+
PrimitiveIntIdEntity entity = new PrimitiveIntIdEntity();
127+
128+
MutableAggregateChange<PrimitiveIntIdEntity> aggregateChange = //
129+
new DefaultAggregateChange<>(AggregateChange.Kind.SAVE, PrimitiveIntIdEntity.class, entity);
130+
131+
converter.write(entity, aggregateChange);
132+
133+
assertThat(extractActions(aggregateChange)) //
134+
.extracting(DbAction::getClass, //
135+
DbAction::getEntityType, //
136+
DbActionTestSupport::extractPath, //
137+
DbActionTestSupport::actualEntityType, //
138+
DbActionTestSupport::isWithDependsOn, //
139+
DbActionTestSupport::insertIncludeId) //
140+
.containsExactly( //
141+
tuple(InsertRoot.class, PrimitiveIntIdEntity.class, "", PrimitiveIntIdEntity.class, false, false) //
142+
);
143+
}
144+
102145
@Test // DATAJDBC-111
103146
public void newEntityGetsConvertedToOneInsertByEmbeddedEntities() {
104147

@@ -146,6 +189,32 @@ public void newEntityWithReferenceGetsConvertedToTwoInserts() {
146189
);
147190
}
148191

192+
@Test
193+
void newEntityWithReference_whenReferenceHasPrimitiveId_insertDoesNotIncludeId_whenIdValueIsZero() {
194+
195+
EntityWithReferencesToPrimitiveIdEntity entity = new EntityWithReferencesToPrimitiveIdEntity(null);
196+
entity.primitiveLongIdEntity = new PrimitiveLongIdEntity();
197+
entity.primitiveIntIdEntity = new PrimitiveIntIdEntity();
198+
199+
MutableAggregateChange<EntityWithReferencesToPrimitiveIdEntity> aggregateChange = //
200+
new DefaultAggregateChange<>(AggregateChange.Kind.SAVE, EntityWithReferencesToPrimitiveIdEntity.class, entity);
201+
202+
converter.write(entity, aggregateChange);
203+
204+
assertThat(extractActions(aggregateChange)) //
205+
.extracting(DbAction::getClass, //
206+
DbAction::getEntityType, //
207+
DbActionTestSupport::extractPath, //
208+
DbActionTestSupport::actualEntityType, //
209+
DbActionTestSupport::isWithDependsOn, //
210+
DbActionTestSupport::insertIncludeId) //
211+
.containsExactlyInAnyOrder( //
212+
tuple(InsertRoot.class, EntityWithReferencesToPrimitiveIdEntity.class, "", EntityWithReferencesToPrimitiveIdEntity.class, false, false), //
213+
tuple(Insert.class, PrimitiveLongIdEntity.class, "primitiveLongIdEntity", PrimitiveLongIdEntity.class, true, false), //
214+
tuple(Insert.class, PrimitiveIntIdEntity.class, "primitiveIntIdEntity", PrimitiveIntIdEntity.class, true, false) //
215+
);
216+
}
217+
149218
@Test // DATAJDBC-112
150219
public void existingEntityGetsConvertedToDeletePlusUpdate() {
151220

@@ -726,6 +795,32 @@ void newEntityWithCollectionWhereSomeElementsHaveIdSet_producesABatchInsertEachF
726795
);
727796
}
728797

798+
@Test
799+
void newEntityWithCollection_whenElementHasPrimitiveId_batchInsertDoesNotIncludeId_whenIdValueIsZero() {
800+
801+
EntityWithReferencesToPrimitiveIdEntity entity = new EntityWithReferencesToPrimitiveIdEntity(null);
802+
entity.primitiveLongIdEntities.add(new PrimitiveLongIdEntity());
803+
entity.primitiveIntIdEntities.add(new PrimitiveIntIdEntity());
804+
805+
MutableAggregateChange<EntityWithReferencesToPrimitiveIdEntity> aggregateChange = //
806+
new DefaultAggregateChange<>(AggregateChange.Kind.SAVE, EntityWithReferencesToPrimitiveIdEntity.class, entity);
807+
808+
converter.write(entity, aggregateChange);
809+
810+
List<DbAction<?>> actions = extractActions(aggregateChange);
811+
assertThat(actions).extracting(DbAction::getClass, //
812+
DbAction::getEntityType, //
813+
DbActionTestSupport::extractPath, //
814+
DbActionTestSupport::actualEntityType, //
815+
DbActionTestSupport::isWithDependsOn, //
816+
DbActionTestSupport::insertIncludeId) //
817+
.containsExactlyInAnyOrder( //
818+
tuple(InsertRoot.class, EntityWithReferencesToPrimitiveIdEntity.class, "", EntityWithReferencesToPrimitiveIdEntity.class, false, false), //
819+
tuple(Insert.class, PrimitiveLongIdEntity.class, "primitiveLongIdEntities", PrimitiveLongIdEntity.class, true, false), //
820+
tuple(Insert.class, PrimitiveIntIdEntity.class, "primitiveIntIdEntities", PrimitiveIntIdEntity.class, true, false) //
821+
);
822+
}
823+
729824
private List<DbAction<?>> extractActions(MutableAggregateChange<?> aggregateChange) {
730825

731826
List<DbAction<?>> actions = new ArrayList<>();
@@ -798,6 +893,26 @@ static PersistentPropertyPath<RelationalPersistentProperty> toPath(String path,
798893
return persistentPropertyPaths.filter(p -> p.toDotPath().equals(path)).stream().findFirst().orElse(null);
799894
}
800895

896+
@RequiredArgsConstructor
897+
@Data
898+
static class EntityWithReferencesToPrimitiveIdEntity {
899+
@Id final Long id;
900+
PrimitiveLongIdEntity primitiveLongIdEntity;
901+
List<PrimitiveLongIdEntity> primitiveLongIdEntities = new ArrayList<>();
902+
PrimitiveIntIdEntity primitiveIntIdEntity;
903+
List<PrimitiveIntIdEntity> primitiveIntIdEntities = new ArrayList<>();
904+
}
905+
906+
@Data
907+
static class PrimitiveLongIdEntity {
908+
@Id long id;
909+
}
910+
911+
@Data
912+
static class PrimitiveIntIdEntity {
913+
@Id int id;
914+
}
915+
801916
@RequiredArgsConstructor
802917
static class SingleReferenceEntity {
803918

0 commit comments

Comments
 (0)