Skip to content

Commit bc6b6a0

Browse files
committed
R2DBC support for @sequence annotation
1 parent fa9e15d commit bc6b6a0

File tree

4 files changed

+318
-9
lines changed

4 files changed

+318
-9
lines changed

Diff for: spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/config/AbstractR2dbcConfiguration.java

+14
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,14 @@
3939
import org.springframework.data.r2dbc.core.DefaultReactiveDataAccessStrategy;
4040
import org.springframework.data.r2dbc.core.R2dbcEntityTemplate;
4141
import org.springframework.data.r2dbc.core.ReactiveDataAccessStrategy;
42+
import org.springframework.data.r2dbc.core.mapping.IdGeneratingBeforeSaveCallback;
4243
import org.springframework.data.r2dbc.dialect.DialectResolver;
4344
import org.springframework.data.r2dbc.dialect.R2dbcDialect;
4445
import org.springframework.data.r2dbc.mapping.R2dbcMappingContext;
4546
import org.springframework.data.relational.RelationalManagedTypes;
4647
import org.springframework.data.relational.core.mapping.DefaultNamingStrategy;
4748
import org.springframework.data.relational.core.mapping.NamingStrategy;
49+
import org.springframework.data.relational.core.mapping.RelationalMappingContext;
4850
import org.springframework.data.relational.core.mapping.Table;
4951
import org.springframework.data.util.TypeScanner;
5052
import org.springframework.lang.Nullable;
@@ -182,6 +184,18 @@ public R2dbcMappingContext r2dbcMappingContext(Optional<NamingStrategy> namingSt
182184
return context;
183185
}
184186

187+
/**
188+
* Register a {@link IdGeneratingBeforeSaveCallback} using
189+
* {@link #r2dbcMappingContext(Optional, R2dbcCustomConversions, RelationalManagedTypes)} and
190+
* {@link #databaseClient()}
191+
*/
192+
@Bean
193+
public IdGeneratingBeforeSaveCallback idGeneratingBeforeSaveCallback(
194+
RelationalMappingContext relationalMappingContext, DatabaseClient databaseClient) {
195+
return new IdGeneratingBeforeSaveCallback(relationalMappingContext, getDialect(lookupConnectionFactory()),
196+
databaseClient);
197+
}
198+
185199
/**
186200
* Creates a {@link ReactiveDataAccessStrategy} using the configured
187201
* {@link #r2dbcConverter(R2dbcMappingContext, R2dbcCustomConversions) R2dbcConverter}.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
/*
2+
* Copyright 2020-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.data.r2dbc.core.mapping;
18+
19+
import org.apache.commons.logging.Log;
20+
import org.apache.commons.logging.LogFactory;
21+
import org.jetbrains.annotations.NotNull;
22+
import org.reactivestreams.Publisher;
23+
import org.springframework.data.r2dbc.dialect.R2dbcDialect;
24+
import org.springframework.data.r2dbc.mapping.OutboundRow;
25+
import org.springframework.data.r2dbc.mapping.event.BeforeSaveCallback;
26+
import org.springframework.data.relational.core.mapping.RelationalMappingContext;
27+
import org.springframework.data.relational.core.mapping.RelationalPersistentEntity;
28+
import org.springframework.data.relational.core.sql.SqlIdentifier;
29+
import org.springframework.r2dbc.core.DatabaseClient;
30+
import org.springframework.r2dbc.core.Parameter;
31+
import org.springframework.util.Assert;
32+
33+
import reactor.core.publisher.Mono;
34+
35+
/**
36+
* R2DBC Callback for generating ID via the database sequence.
37+
*
38+
* @author Mikhail Polivakha
39+
*/
40+
public class IdGeneratingBeforeSaveCallback implements BeforeSaveCallback<Object> {
41+
42+
private static final Log LOG = LogFactory.getLog(IdGeneratingBeforeSaveCallback.class);
43+
44+
private final RelationalMappingContext relationalMappingContext;
45+
private final R2dbcDialect dialect;
46+
47+
private final DatabaseClient databaseClient;
48+
49+
public IdGeneratingBeforeSaveCallback(RelationalMappingContext relationalMappingContext, R2dbcDialect dialect,
50+
DatabaseClient databaseClient) {
51+
this.relationalMappingContext = relationalMappingContext;
52+
this.dialect = dialect;
53+
this.databaseClient = databaseClient;
54+
}
55+
56+
@Override
57+
public Publisher<Object> onBeforeSave(Object entity, OutboundRow row, SqlIdentifier table) {
58+
Assert.notNull(entity, "The aggregate cannot be null at this point");
59+
60+
RelationalPersistentEntity<?> persistentEntity = relationalMappingContext.getPersistentEntity(entity.getClass());
61+
62+
if (!persistentEntity.hasIdProperty() || //
63+
!persistentEntity.getIdProperty().hasSequence() || //
64+
!persistentEntity.isNew(entity) //
65+
) {
66+
return Mono.just(entity);
67+
}
68+
69+
SqlIdentifier idSequence = persistentEntity.getIdProperty().getSequence();
70+
71+
if (dialect.getIdGeneration().sequencesSupported()) {
72+
return fetchIdFromSeq(entity, row, persistentEntity, idSequence);
73+
} else {
74+
illegalSequenceUsageWarning(entity);
75+
}
76+
77+
return Mono.just(entity);
78+
}
79+
80+
@NotNull
81+
private Mono<Object> fetchIdFromSeq(Object entity, OutboundRow row, RelationalPersistentEntity<?> persistentEntity,
82+
SqlIdentifier idSequence) {
83+
String sequenceQuery = dialect.getIdGeneration().createSequenceQuery(idSequence);
84+
85+
return databaseClient //
86+
.sql(sequenceQuery) //
87+
.map((r, rowMetadata) -> r.get(0)) //
88+
.one() //
89+
.doOnNext(fetchedId -> { //
90+
row.put(persistentEntity.getIdColumn().toSql(dialect.getIdentifierProcessing()), //
91+
Parameter.from(fetchedId) //
92+
);
93+
}).map(o -> entity);
94+
}
95+
96+
private static void illegalSequenceUsageWarning(Object entity) {
97+
LOG.warn("""
98+
It seems you're trying to insert an aggregate of type '%s' annotated with @Sequence, but the problem is RDBMS you're
99+
working with does not support sequences as such. Falling back to identity columns
100+
""".stripIndent().formatted(entity.getClass().getName())
101+
);
102+
}
103+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
/*
2+
* Copyright 2020-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.data.r2dbc.core.mapping;
18+
19+
import static org.assertj.core.api.Assertions.assertThat;
20+
import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
21+
import static org.mockito.Mockito.mock;
22+
import static org.mockito.Mockito.when;
23+
24+
import java.util.function.BiFunction;
25+
26+
import org.junit.jupiter.api.Test;
27+
import org.mockito.Mockito;
28+
import org.reactivestreams.Publisher;
29+
import org.springframework.data.annotation.Id;
30+
import org.springframework.data.r2dbc.dialect.MySqlDialect;
31+
import org.springframework.data.r2dbc.dialect.PostgresDialect;
32+
import org.springframework.data.r2dbc.mapping.OutboundRow;
33+
import org.springframework.data.r2dbc.mapping.R2dbcMappingContext;
34+
import org.springframework.data.relational.core.mapping.Sequence;
35+
import org.springframework.data.relational.core.sql.SqlIdentifier;
36+
import org.springframework.r2dbc.core.DatabaseClient;
37+
import org.springframework.r2dbc.core.Parameter;
38+
39+
import reactor.core.publisher.Mono;
40+
import reactor.test.StepVerifier;
41+
42+
/**
43+
* Unit tests for {@link IdGeneratingBeforeSaveCallback}.
44+
*
45+
* @author Mikhail Polivakha
46+
*/
47+
class IdGeneratingBeforeSaveCallbackTest {
48+
49+
@Test
50+
void testIdGenerationIsNotSupported() {
51+
R2dbcMappingContext r2dbcMappingContext = new R2dbcMappingContext();
52+
r2dbcMappingContext.getPersistentEntity(SimpleEntity.class);
53+
MySqlDialect dialect = MySqlDialect.INSTANCE;
54+
DatabaseClient databaseClient = mock(DatabaseClient.class);
55+
56+
IdGeneratingBeforeSaveCallback callback = new IdGeneratingBeforeSaveCallback(r2dbcMappingContext, dialect,
57+
databaseClient);
58+
59+
OutboundRow row = new OutboundRow("name", Parameter.from("my_name"));
60+
SimpleEntity entity = new SimpleEntity();
61+
Publisher<Object> publisher = callback.onBeforeSave(entity, row, SqlIdentifier.unquoted("simple_entity"));
62+
63+
StepVerifier.create(publisher).expectNext(entity).expectComplete().verify();
64+
assertThat(row).hasSize(1); // id is not added
65+
}
66+
67+
@Test
68+
void testEntityIsNotAnnotatedWithSequence() {
69+
R2dbcMappingContext r2dbcMappingContext = new R2dbcMappingContext();
70+
r2dbcMappingContext.getPersistentEntity(SimpleEntity.class);
71+
PostgresDialect dialect = PostgresDialect.INSTANCE;
72+
DatabaseClient databaseClient = mock(DatabaseClient.class);
73+
74+
IdGeneratingBeforeSaveCallback callback = new IdGeneratingBeforeSaveCallback(r2dbcMappingContext, dialect,
75+
databaseClient);
76+
77+
OutboundRow row = new OutboundRow("name", Parameter.from("my_name"));
78+
SimpleEntity entity = new SimpleEntity();
79+
Publisher<Object> publisher = callback.onBeforeSave(entity, row, SqlIdentifier.unquoted("simple_entity"));
80+
81+
StepVerifier.create(publisher).expectNext(entity).expectComplete().verify();
82+
assertThat(row).hasSize(1); // id is not added
83+
}
84+
85+
@Test
86+
void testIdGeneratedFromSequenceHappyPath() {
87+
R2dbcMappingContext r2dbcMappingContext = new R2dbcMappingContext();
88+
r2dbcMappingContext.getPersistentEntity(WithSequence.class);
89+
PostgresDialect dialect = PostgresDialect.INSTANCE;
90+
DatabaseClient databaseClient = mock(DatabaseClient.class, RETURNS_DEEP_STUBS);
91+
long generatedId = 1L;
92+
93+
when(databaseClient.sql(Mockito.anyString()).map(Mockito.any(BiFunction.class)).one()).thenReturn(
94+
Mono.just(generatedId));
95+
96+
IdGeneratingBeforeSaveCallback callback = new IdGeneratingBeforeSaveCallback(r2dbcMappingContext, dialect,
97+
databaseClient);
98+
99+
OutboundRow row = new OutboundRow("name", Parameter.from("my_name"));
100+
WithSequence entity = new WithSequence();
101+
Publisher<Object> publisher = callback.onBeforeSave(entity, row, SqlIdentifier.unquoted("simple_entity"));
102+
103+
StepVerifier.create(publisher).expectNext(entity).expectComplete().verify();
104+
assertThat(row).hasSize(2)
105+
.containsEntry(SqlIdentifier.unquoted("id"), Parameter.from(generatedId));
106+
}
107+
108+
static class SimpleEntity {
109+
110+
@Id
111+
private Long id;
112+
113+
private String name;
114+
}
115+
116+
static class WithSequence {
117+
118+
@Id
119+
@Sequence(sequence = "seq_name")
120+
private Long id;
121+
122+
private String name;
123+
}
124+
}

Diff for: spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/repository/PostgresR2dbcRepositoryIntegrationTests.java

+77-9
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,13 @@
1515
*/
1616
package org.springframework.data.r2dbc.repository;
1717

18-
import io.r2dbc.postgresql.codec.Json;
19-
import io.r2dbc.spi.ConnectionFactory;
18+
import static org.assertj.core.api.Assertions.assertThat;
19+
20+
import java.util.Collections;
21+
import java.util.Map;
22+
23+
import javax.sql.DataSource;
24+
2025
import org.junit.jupiter.api.Test;
2126
import org.junit.jupiter.api.extension.ExtendWith;
2227
import org.junit.jupiter.api.extension.RegisterExtension;
@@ -32,22 +37,20 @@
3237
import org.springframework.data.r2dbc.repository.support.R2dbcRepositoryFactory;
3338
import org.springframework.data.r2dbc.testing.ExternalDatabase;
3439
import org.springframework.data.r2dbc.testing.PostgresTestSupport;
40+
import org.springframework.data.relational.core.mapping.Sequence;
3541
import org.springframework.data.relational.core.mapping.Table;
3642
import org.springframework.data.repository.reactive.ReactiveCrudRepository;
3743
import org.springframework.jdbc.core.JdbcTemplate;
3844
import org.springframework.r2dbc.core.DatabaseClient;
3945
import org.springframework.test.context.ContextConfiguration;
4046
import org.springframework.test.context.junit.jupiter.SpringExtension;
47+
48+
import io.r2dbc.postgresql.codec.Json;
49+
import io.r2dbc.spi.ConnectionFactory;
4150
import reactor.core.publisher.Flux;
4251
import reactor.core.publisher.Mono;
4352
import reactor.test.StepVerifier;
4453

45-
import javax.sql.DataSource;
46-
import java.util.Collections;
47-
import java.util.Map;
48-
49-
import static org.assertj.core.api.Assertions.*;
50-
5154
/**
5255
* Integration tests for {@link LegoSetRepository} using {@link R2dbcRepositoryFactory} against Postgres.
5356
*
@@ -62,12 +65,14 @@ public class PostgresR2dbcRepositoryIntegrationTests extends AbstractR2dbcReposi
6265

6366
@Autowired WithJsonRepository withJsonRepository;
6467

68+
@Autowired WithIdFromSequenceRepository withIdFromSequenceRepository;
69+
6570
@Autowired WithHStoreRepository hstoreRepositoryWith;
6671

6772
@Configuration
6873
@EnableR2dbcRepositories(considerNestedRepositories = true,
6974
includeFilters = @Filter(
70-
classes = { PostgresLegoSetRepository.class, WithJsonRepository.class, WithHStoreRepository.class },
75+
classes = { PostgresLegoSetRepository.class, WithJsonRepository.class, WithHStoreRepository.class, WithIdFromSequenceRepository.class },
7176
type = FilterType.ASSIGNABLE_TYPE))
7277
static class IntegrationTestConfiguration extends AbstractR2dbcConfiguration {
7378

@@ -151,6 +156,51 @@ void shouldSaveAndLoadJson() {
151156
}).verifyComplete();
152157
}
153158

159+
@Test
160+
void shouldInsertWithAutoGeneratedId() {
161+
162+
JdbcTemplate template = new JdbcTemplate(createDataSource());
163+
164+
template.execute("DROP TABLE IF EXISTS with_id_from_sequence");
165+
template.execute("CREATE SEQUENCE IF NOT EXISTS target_sequence START WITH 15");
166+
template.execute("CREATE TABLE with_id_from_sequence(\n" //
167+
+ " id BIGINT PRIMARY KEY,\n" //
168+
+ " name TEXT NOT NULL" //
169+
+ ");");
170+
171+
WithIdFromSequence entity = new WithIdFromSequence(null, "Jordane");
172+
withIdFromSequenceRepository.save(entity).as(StepVerifier::create).expectNextCount(1).verifyComplete();
173+
174+
withIdFromSequenceRepository.findAll().as(StepVerifier::create).consumeNextWith(actual -> {
175+
176+
assertThat(actual.id).isNotNull().isEqualTo(15);
177+
assertThat(actual.name).isEqualTo("Jordane");
178+
}).verifyComplete();
179+
}
180+
181+
@Test
182+
void shouldUpdateNoIdGenerationHappens() {
183+
184+
JdbcTemplate template = new JdbcTemplate(createDataSource());
185+
186+
template.execute("DROP TABLE IF EXISTS with_id_from_sequence");
187+
template.execute("CREATE SEQUENCE IF NOT EXISTS target_sequence");
188+
template.execute("CREATE TABLE with_id_from_sequence(\n" //
189+
+ " id BIGINT PRIMARY KEY,\n" //
190+
+ " name TEXT NOT NULL" //
191+
+ ");");
192+
template.execute("INSERT INTO with_id_from_sequence VALUES(4, 'Alex');");
193+
194+
WithIdFromSequence entity = new WithIdFromSequence(4L, "NewName");
195+
withIdFromSequenceRepository.save(entity).as(StepVerifier::create).expectNextCount(1).verifyComplete();
196+
197+
withJsonRepository.findAll().as(StepVerifier::create).consumeNextWith(actual -> {
198+
199+
assertThat(actual.jsonValue).isNotNull().isEqualTo(4);
200+
assertThat(actual.jsonValue.asString()).isEqualTo("NewName");
201+
}).verifyComplete();
202+
}
203+
154204
@Test // gh-492
155205
void shouldSaveAndLoadHStore() {
156206

@@ -188,6 +238,24 @@ interface WithJsonRepository extends ReactiveCrudRepository<WithJson, Long> {
188238

189239
}
190240

241+
static class WithIdFromSequence {
242+
243+
@Id
244+
@Sequence(sequence = "target_sequence")
245+
Long id;
246+
247+
String name;
248+
249+
public WithIdFromSequence(Long id, String name) {
250+
this.id = id;
251+
this.name = name;
252+
}
253+
}
254+
255+
interface WithIdFromSequenceRepository extends ReactiveCrudRepository<WithIdFromSequence, Long> {
256+
257+
}
258+
191259
@Table("with_hstore")
192260
static class WithHStore {
193261

0 commit comments

Comments
 (0)