17
17
18
18
import io .r2dbc .spi .Row ;
19
19
import io .r2dbc .spi .RowMetadata ;
20
- import org .springframework .dao .OptimisticLockingFailureException ;
21
20
import reactor .core .publisher .Flux ;
22
21
import reactor .core .publisher .Mono ;
23
22
33
32
import org .springframework .beans .factory .BeanFactoryAware ;
34
33
import org .springframework .core .convert .ConversionService ;
35
34
import org .springframework .dao .DataAccessException ;
35
+ import org .springframework .dao .OptimisticLockingFailureException ;
36
36
import org .springframework .dao .TransientDataAccessResourceException ;
37
37
import org .springframework .data .mapping .IdentifierAccessor ;
38
38
import org .springframework .data .mapping .MappingException ;
@@ -377,7 +377,7 @@ <T> Mono<T> doInsert(T entity, SqlIdentifier tableName) {
377
377
378
378
RelationalPersistentEntity <T > persistentEntity = getRequiredEntity (entity );
379
379
380
- setVersionIfNecessary (persistentEntity , entity );
380
+ setVersionIfNecessary (persistentEntity , entity );
381
381
382
382
return this .databaseClient .insert () //
383
383
.into (persistentEntity .getType ()) //
@@ -388,6 +388,7 @@ <T> Mono<T> doInsert(T entity, SqlIdentifier tableName) {
388
388
}
389
389
390
390
private <T > void setVersionIfNecessary (RelationalPersistentEntity <T > persistentEntity , T entity ) {
391
+
391
392
RelationalPersistentProperty versionProperty = persistentEntity .getVersionProperty ();
392
393
if (versionProperty == null ) {
393
394
return ;
@@ -418,45 +419,37 @@ public <T> Mono<T> update(T entity) throws DataAccessException {
418
419
419
420
DatabaseClient .UpdateSpec updateSpec = updateMatchingSpec ;
420
421
if (persistentEntity .hasVersionProperty ()) {
422
+
421
423
updateSpec = updateMatchingSpec .matching (createMatchingVersionCriteria (entity , persistentEntity ));
422
424
incrementVersion (entity , persistentEntity );
423
425
}
424
426
425
427
return updateSpec .fetch () //
426
428
.rowsUpdated () //
427
- .flatMap (rowsUpdated -> rowsUpdated == 0
428
- ? handleMissingUpdate (entity , persistentEntity ) : Mono .just (entity ));
429
+ .flatMap (rowsUpdated -> rowsUpdated == 0 ? handleMissingUpdate (entity , persistentEntity ) : Mono .just (entity ));
429
430
}
430
431
431
432
private <T > Mono <? extends T > handleMissingUpdate (T entity , RelationalPersistentEntity <T > persistentEntity ) {
432
- if (!persistentEntity .hasVersionProperty ()) {
433
- return Mono .error (new TransientDataAccessResourceException (
434
- formatTransientEntityExceptionMessage (entity , persistentEntity )));
435
- }
436
433
437
- return doCount (getByIdQuery (entity , persistentEntity ), entity .getClass (), persistentEntity .getTableName ())
438
- .map (count -> {
439
- if (count == 0 ) {
440
- throw new TransientDataAccessResourceException (
441
- formatTransientEntityExceptionMessage (entity , persistentEntity ));
442
- } else {
443
- throw new OptimisticLockingFailureException (
444
- formatOptimisticLockingExceptionMessage (entity , persistentEntity ));
445
- }
446
- });
434
+ return Mono .error (persistentEntity .hasVersionProperty ()
435
+ ? new OptimisticLockingFailureException (formatOptimisticLockingExceptionMessage (entity , persistentEntity ))
436
+ : new TransientDataAccessResourceException (formatTransientEntityExceptionMessage (entity , persistentEntity )));
447
437
}
448
438
449
439
private <T > String formatOptimisticLockingExceptionMessage (T entity , RelationalPersistentEntity <T > persistentEntity ) {
440
+
450
441
return String .format ("Failed to update table [%s]. Version does not match for row with Id [%s]." ,
451
442
persistentEntity .getTableName (), persistentEntity .getIdentifierAccessor (entity ).getIdentifier ());
452
443
}
453
444
454
445
private <T > String formatTransientEntityExceptionMessage (T entity , RelationalPersistentEntity <T > persistentEntity ) {
446
+
455
447
return String .format ("Failed to update table [%s]. Row with Id [%s] does not exist." ,
456
448
persistentEntity .getTableName (), persistentEntity .getIdentifierAccessor (entity ).getIdentifier ());
457
449
}
458
450
459
451
private <T > void incrementVersion (T entity , RelationalPersistentEntity <T > persistentEntity ) {
452
+
460
453
PersistentPropertyAccessor <?> propertyAccessor = persistentEntity .getPropertyAccessor (entity );
461
454
RelationalPersistentProperty versionProperty = persistentEntity .getVersionProperty ();
462
455
@@ -471,6 +464,7 @@ private <T> void incrementVersion(T entity, RelationalPersistentEntity<T> persis
471
464
}
472
465
473
466
private <T > Criteria createMatchingVersionCriteria (T entity , RelationalPersistentEntity <T > persistentEntity ) {
467
+
474
468
PersistentPropertyAccessor <?> propertyAccessor = persistentEntity .getPropertyAccessor (entity );
475
469
RelationalPersistentProperty versionProperty = persistentEntity .getVersionProperty ();
476
470
0 commit comments