Skip to content

Commit d6121cb

Browse files
mipo256mp911de
authored andcommitted
R2DBC @Sequence annotation support.
Signed-off-by: mipo256 <[email protected]> See #1955 Original pull request: #2028
1 parent 30dfdce commit d6121cb

File tree

10 files changed

+382
-28
lines changed

10 files changed

+382
-28
lines changed

spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/IdGeneratingEntityCallback.java

+3-3
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,14 @@ public Object onBeforeSave(Object aggregate, MutableAggregateChange<Object> aggr
5656
return aggregate;
5757
}
5858

59-
RelationalPersistentProperty property = entity.getRequiredIdProperty();
59+
RelationalPersistentProperty idProperty = entity.getRequiredIdProperty();
6060
PersistentPropertyAccessor<Object> accessor = entity.getPropertyAccessor(aggregate);
6161

62-
if (!entity.isNew(aggregate) || delegate.hasValue(property, accessor) || !property.hasSequence()) {
62+
if (!entity.isNew(aggregate) || delegate.hasValue(idProperty, accessor) || !idProperty.hasSequence()) {
6363
return aggregate;
6464
}
6565

66-
delegate.generateSequenceValue(property, accessor);
66+
delegate.generateSequenceValue(idProperty, accessor);
6767

6868
return accessor.getBean();
6969
}

spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/config/AbstractR2dbcConfiguration.java

+14
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,14 @@
3939
import org.springframework.data.r2dbc.core.DefaultReactiveDataAccessStrategy;
4040
import org.springframework.data.r2dbc.core.R2dbcEntityTemplate;
4141
import org.springframework.data.r2dbc.core.ReactiveDataAccessStrategy;
42+
import org.springframework.data.r2dbc.core.mapping.IdGeneratingBeforeSaveCallback;
4243
import org.springframework.data.r2dbc.dialect.DialectResolver;
4344
import org.springframework.data.r2dbc.dialect.R2dbcDialect;
4445
import org.springframework.data.r2dbc.mapping.R2dbcMappingContext;
4546
import org.springframework.data.relational.RelationalManagedTypes;
4647
import org.springframework.data.relational.core.mapping.DefaultNamingStrategy;
4748
import org.springframework.data.relational.core.mapping.NamingStrategy;
49+
import org.springframework.data.relational.core.mapping.RelationalMappingContext;
4850
import org.springframework.data.relational.core.mapping.Table;
4951
import org.springframework.data.util.TypeScanner;
5052
import org.springframework.lang.Nullable;
@@ -182,6 +184,18 @@ public R2dbcMappingContext r2dbcMappingContext(Optional<NamingStrategy> namingSt
182184
return context;
183185
}
184186

187+
/**
188+
* Register a {@link IdGeneratingBeforeSaveCallback} using
189+
* {@link #r2dbcMappingContext(Optional, R2dbcCustomConversions, RelationalManagedTypes)} and
190+
* {@link #databaseClient()}
191+
*/
192+
@Bean
193+
public IdGeneratingBeforeSaveCallback idGeneratingBeforeSaveCallback(
194+
RelationalMappingContext relationalMappingContext, DatabaseClient databaseClient) {
195+
return new IdGeneratingBeforeSaveCallback(relationalMappingContext, getDialect(lookupConnectionFactory()),
196+
databaseClient);
197+
}
198+
185199
/**
186200
* Creates a {@link ReactiveDataAccessStrategy} using the configured
187201
* {@link #r2dbcConverter(R2dbcMappingContext, R2dbcCustomConversions) R2dbcConverter}.

spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/convert/MappingR2dbcConverter.java

+7-10
Original file line numberDiff line numberDiff line change
@@ -186,11 +186,11 @@ private void writeInternal(Object source, OutboundRow sink, Class<?> userClass)
186186
RelationalPersistentEntity<?> entity = getRequiredPersistentEntity(userClass);
187187
PersistentPropertyAccessor<?> propertyAccessor = entity.getPropertyAccessor(source);
188188

189-
writeProperties(sink, entity, propertyAccessor, entity.isNew(source));
189+
writeProperties(sink, entity, propertyAccessor);
190190
}
191191

192192
private void writeProperties(OutboundRow sink, RelationalPersistentEntity<?> entity,
193-
PersistentPropertyAccessor<?> accessor, boolean isNew) {
193+
PersistentPropertyAccessor<?> accessor) {
194194

195195
for (RelationalPersistentProperty property : entity) {
196196

@@ -213,24 +213,22 @@ private void writeProperties(OutboundRow sink, RelationalPersistentEntity<?> ent
213213
}
214214

215215
if (getConversions().isSimpleType(value.getClass())) {
216-
writeSimpleInternal(sink, value, isNew, property);
216+
writeSimpleInternal(sink, value, property);
217217
} else {
218-
writePropertyInternal(sink, value, isNew, property);
218+
writePropertyInternal(sink, value, property);
219219
}
220220
}
221221
}
222222

223-
private void writeSimpleInternal(OutboundRow sink, Object value, boolean isNew,
224-
RelationalPersistentProperty property) {
223+
private void writeSimpleInternal(OutboundRow sink, Object value, RelationalPersistentProperty property) {
225224

226225
Object result = getPotentiallyConvertedSimpleWrite(value);
227226

228227
sink.put(property.getColumnName(),
229228
Parameter.fromOrEmpty(result, getPotentiallyConvertedSimpleNullType(property.getType())));
230229
}
231230

232-
private void writePropertyInternal(OutboundRow sink, Object value, boolean isNew,
233-
RelationalPersistentProperty property) {
231+
private void writePropertyInternal(OutboundRow sink, Object value, RelationalPersistentProperty property) {
234232

235233
TypeInformation<?> valueType = TypeInformation.of(value.getClass());
236234

@@ -239,7 +237,7 @@ private void writePropertyInternal(OutboundRow sink, Object value, boolean isNew
239237
if (valueType.getActualType() != null && valueType.getRequiredActualType().isCollectionLike()) {
240238

241239
// pass-thru nested collections
242-
writeSimpleInternal(sink, value, isNew, property);
240+
writeSimpleInternal(sink, value, property);
243241
return;
244242
}
245243

@@ -310,7 +308,6 @@ private Class<?> getPotentiallyConvertedSimpleNullType(Class<?> type) {
310308

311309
if (customTarget.isPresent()) {
312310
return customTarget.get();
313-
314311
}
315312

316313
if (type.isEnum()) {

spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/core/R2dbcEntityTemplate.java

+5-5
Original file line numberDiff line numberDiff line change
@@ -521,15 +521,15 @@ private void potentiallyRemoveId(RelationalPersistentEntity<?> persistentEntity,
521521
return;
522522
}
523523

524-
SqlIdentifier columnName = idProperty.getColumnName();
525-
Parameter parameter = outboundRow.get(columnName);
524+
SqlIdentifier idColumnName = idProperty.getColumnName();
525+
Parameter parameter = outboundRow.get(idColumnName);
526526

527-
if (shouldSkipIdValue(parameter, idProperty)) {
528-
outboundRow.remove(columnName);
527+
if (shouldSkipIdValue(parameter)) {
528+
outboundRow.remove(idColumnName);
529529
}
530530
}
531531

532-
private boolean shouldSkipIdValue(@Nullable Parameter value, RelationalPersistentProperty property) {
532+
private boolean shouldSkipIdValue(@Nullable Parameter value) {
533533

534534
if (value == null || value.getValue() == null) {
535535
return true;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
/*
2+
* Copyright 2020-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.data.r2dbc.core.mapping;
18+
19+
import org.apache.commons.logging.Log;
20+
import org.apache.commons.logging.LogFactory;
21+
import org.reactivestreams.Publisher;
22+
import org.springframework.data.r2dbc.dialect.R2dbcDialect;
23+
import org.springframework.data.r2dbc.mapping.OutboundRow;
24+
import org.springframework.data.r2dbc.mapping.event.BeforeSaveCallback;
25+
import org.springframework.data.relational.core.mapping.RelationalMappingContext;
26+
import org.springframework.data.relational.core.mapping.RelationalPersistentEntity;
27+
import org.springframework.data.relational.core.mapping.RelationalPersistentProperty;
28+
import org.springframework.data.relational.core.sql.SqlIdentifier;
29+
import org.springframework.r2dbc.core.DatabaseClient;
30+
import org.springframework.r2dbc.core.Parameter;
31+
import org.springframework.util.Assert;
32+
33+
import reactor.core.publisher.Mono;
34+
35+
/**
36+
* R2DBC Callback for generating ID via the database sequence.
37+
*
38+
* @author Mikhail Polivakha
39+
*/
40+
public class IdGeneratingBeforeSaveCallback implements BeforeSaveCallback<Object> {
41+
42+
private static final Log LOG = LogFactory.getLog(IdGeneratingBeforeSaveCallback.class);
43+
44+
private final RelationalMappingContext relationalMappingContext;
45+
private final R2dbcDialect dialect;
46+
47+
private final DatabaseClient databaseClient;
48+
49+
public IdGeneratingBeforeSaveCallback(RelationalMappingContext relationalMappingContext, R2dbcDialect dialect,
50+
DatabaseClient databaseClient) {
51+
this.relationalMappingContext = relationalMappingContext;
52+
this.dialect = dialect;
53+
this.databaseClient = databaseClient;
54+
}
55+
56+
@Override
57+
public Publisher<Object> onBeforeSave(Object entity, OutboundRow row, SqlIdentifier table) {
58+
Assert.notNull(entity, "The aggregate cannot be null at this point");
59+
60+
RelationalPersistentEntity<?> persistentEntity = relationalMappingContext.getPersistentEntity(entity.getClass());
61+
62+
if (!persistentEntity.hasIdProperty() || //
63+
!persistentEntity.getIdProperty().hasSequence() || //
64+
!persistentEntity.isNew(entity) //
65+
) {
66+
return Mono.just(entity);
67+
}
68+
69+
RelationalPersistentProperty property = persistentEntity.getIdProperty();
70+
SqlIdentifier idSequence = property.getSequence();
71+
72+
if (dialect.getIdGeneration().sequencesSupported()) {
73+
return fetchIdFromSeq(entity, row, persistentEntity, idSequence);
74+
} else {
75+
illegalSequenceUsageWarning(entity);
76+
}
77+
78+
return Mono.just(entity);
79+
}
80+
81+
private Mono<Object> fetchIdFromSeq(Object entity, OutboundRow row, RelationalPersistentEntity<?> persistentEntity,
82+
SqlIdentifier idSequence) {
83+
String sequenceQuery = dialect.getIdGeneration().createSequenceQuery(idSequence);
84+
85+
return databaseClient //
86+
.sql(sequenceQuery) //
87+
.map((r, rowMetadata) -> r.get(0)) //
88+
.one() //
89+
.map(fetchedId -> { //
90+
row.put( //
91+
persistentEntity.getIdColumn().toSql(dialect.getIdentifierProcessing()), //
92+
Parameter.from(fetchedId) //
93+
);
94+
return entity;
95+
});
96+
}
97+
98+
private static void illegalSequenceUsageWarning(Object entity) {
99+
LOG.warn("""
100+
It seems you're trying to insert an aggregate of type '%s' annotated with @Sequence, but the problem is RDBMS you're
101+
working with does not support sequences as such. Falling back to identity columns
102+
""".stripIndent().formatted(entity.getClass().getName()));
103+
}
104+
}

spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/query/UpdateMapper.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ public BoundAssignments getMappedObject(BindMarkers markers, Update update, Tabl
8383
* @param entity related {@link RelationalPersistentEntity}, can be {@literal null}.
8484
* @return the mapped {@link BoundAssignments}.
8585
*/
86-
public BoundAssignments getMappedObject(BindMarkers markers, Map<SqlIdentifier, ? extends Object> assignments,
86+
public BoundAssignments getMappedObject(BindMarkers markers, Map<SqlIdentifier, ?> assignments,
8787
Table table, @Nullable RelationalPersistentEntity<?> entity) {
8888

8989
Assert.notNull(markers, "BindMarkers must not be null");
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
/*
2+
* Copyright 2020-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.data.r2dbc.core.mapping;
18+
19+
import static org.assertj.core.api.Assertions.assertThat;
20+
import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
21+
import static org.mockito.Mockito.mock;
22+
import static org.mockito.Mockito.when;
23+
24+
import java.util.function.BiFunction;
25+
26+
import org.junit.jupiter.api.Test;
27+
import org.mockito.Mockito;
28+
import org.reactivestreams.Publisher;
29+
import org.springframework.data.annotation.Id;
30+
import org.springframework.data.r2dbc.dialect.MySqlDialect;
31+
import org.springframework.data.r2dbc.dialect.PostgresDialect;
32+
import org.springframework.data.r2dbc.mapping.OutboundRow;
33+
import org.springframework.data.r2dbc.mapping.R2dbcMappingContext;
34+
import org.springframework.data.relational.core.mapping.Sequence;
35+
import org.springframework.data.relational.core.sql.SqlIdentifier;
36+
import org.springframework.r2dbc.core.DatabaseClient;
37+
import org.springframework.r2dbc.core.Parameter;
38+
39+
import reactor.core.publisher.Mono;
40+
import reactor.test.StepVerifier;
41+
42+
/**
43+
* Unit tests for {@link IdGeneratingBeforeSaveCallback}.
44+
*
45+
* @author Mikhail Polivakha
46+
*/
47+
class IdGeneratingBeforeSaveCallbackTest {
48+
49+
@Test
50+
void testIdGenerationIsNotSupported() {
51+
R2dbcMappingContext r2dbcMappingContext = new R2dbcMappingContext();
52+
r2dbcMappingContext.getPersistentEntity(SimpleEntity.class);
53+
MySqlDialect dialect = MySqlDialect.INSTANCE;
54+
DatabaseClient databaseClient = mock(DatabaseClient.class);
55+
56+
IdGeneratingBeforeSaveCallback callback = new IdGeneratingBeforeSaveCallback(r2dbcMappingContext, dialect,
57+
databaseClient);
58+
59+
OutboundRow row = new OutboundRow("name", Parameter.from("my_name"));
60+
SimpleEntity entity = new SimpleEntity();
61+
Publisher<Object> publisher = callback.onBeforeSave(entity, row, SqlIdentifier.unquoted("simple_entity"));
62+
63+
StepVerifier.create(publisher).expectNext(entity).expectComplete().verify();
64+
assertThat(row).hasSize(1); // id is not added
65+
}
66+
67+
@Test
68+
void testEntityIsNotAnnotatedWithSequence() {
69+
R2dbcMappingContext r2dbcMappingContext = new R2dbcMappingContext();
70+
r2dbcMappingContext.getPersistentEntity(SimpleEntity.class);
71+
PostgresDialect dialect = PostgresDialect.INSTANCE;
72+
DatabaseClient databaseClient = mock(DatabaseClient.class);
73+
74+
IdGeneratingBeforeSaveCallback callback = new IdGeneratingBeforeSaveCallback(r2dbcMappingContext, dialect,
75+
databaseClient);
76+
77+
OutboundRow row = new OutboundRow("name", Parameter.from("my_name"));
78+
SimpleEntity entity = new SimpleEntity();
79+
Publisher<Object> publisher = callback.onBeforeSave(entity, row, SqlIdentifier.unquoted("simple_entity"));
80+
81+
StepVerifier.create(publisher).expectNext(entity).expectComplete().verify();
82+
assertThat(row).hasSize(1); // id is not added
83+
}
84+
85+
@Test
86+
void testIdGeneratedFromSequenceHappyPath() {
87+
R2dbcMappingContext r2dbcMappingContext = new R2dbcMappingContext();
88+
r2dbcMappingContext.getPersistentEntity(WithSequence.class);
89+
PostgresDialect dialect = PostgresDialect.INSTANCE;
90+
DatabaseClient databaseClient = mock(DatabaseClient.class, RETURNS_DEEP_STUBS);
91+
long generatedId = 1L;
92+
93+
when(databaseClient.sql(Mockito.anyString()).map(Mockito.any(BiFunction.class)).one()).thenReturn(
94+
Mono.just(generatedId));
95+
96+
IdGeneratingBeforeSaveCallback callback = new IdGeneratingBeforeSaveCallback(r2dbcMappingContext, dialect,
97+
databaseClient);
98+
99+
OutboundRow row = new OutboundRow("name", Parameter.from("my_name"));
100+
WithSequence entity = new WithSequence();
101+
Publisher<Object> publisher = callback.onBeforeSave(entity, row, SqlIdentifier.unquoted("simple_entity"));
102+
103+
StepVerifier.create(publisher).expectNext(entity).expectComplete().verify();
104+
assertThat(row).hasSize(2)
105+
.containsEntry(SqlIdentifier.unquoted("id"), Parameter.from(generatedId));
106+
}
107+
108+
static class SimpleEntity {
109+
110+
@Id
111+
private Long id;
112+
113+
private String name;
114+
}
115+
116+
static class WithSequence {
117+
118+
@Id
119+
@Sequence(sequence = "seq_name")
120+
private Long id;
121+
122+
private String name;
123+
}
124+
}

0 commit comments

Comments
 (0)