37
37
import org .springframework .data .mapping .context .MappingContext ;
38
38
import org .springframework .data .projection .ProjectionInformation ;
39
39
import org .springframework .data .projection .SpelAwareProxyProjectionFactory ;
40
- import org .springframework .data .r2dbc .mapping .R2dbcMappingContext ;
41
40
import org .springframework .data .r2dbc .query .Criteria ;
42
41
import org .springframework .data .r2dbc .query .Query ;
43
42
import org .springframework .data .r2dbc .query .Update ;
44
43
import org .springframework .data .relational .core .mapping .RelationalPersistentEntity ;
45
44
import org .springframework .data .relational .core .mapping .RelationalPersistentProperty ;
45
+ import org .springframework .data .relational .core .sql .Expression ;
46
46
import org .springframework .data .relational .core .sql .Functions ;
47
+ import org .springframework .data .relational .core .sql .SqlIdentifier ;
48
+ import org .springframework .data .relational .core .sql .Table ;
47
49
import org .springframework .data .util .ProxyUtils ;
48
50
import org .springframework .util .Assert ;
49
51
@@ -74,13 +76,7 @@ public class R2dbcEntityTemplate implements R2dbcEntityOperations, BeanFactoryAw
74
76
* @param databaseClient
75
77
*/
76
78
public R2dbcEntityTemplate (DatabaseClient databaseClient ) {
77
-
78
- Assert .notNull (databaseClient , "DatabaseClient must not be null" );
79
-
80
- this .databaseClient = databaseClient ;
81
- this .dataAccessStrategy = getDataAccessStrategy (databaseClient );
82
- this .mappingContext = getMappingContext (this .dataAccessStrategy );
83
- this .projectionFactory = new SpelAwareProxyProjectionFactory ();
79
+ this (databaseClient , getDataAccessStrategy (databaseClient ));
84
80
}
85
81
86
82
/**
@@ -174,7 +170,7 @@ public Mono<Long> count(Query query, Class<?> entityClass) throws DataAccessExce
174
170
return doCount (query , entityClass , getTableName (entityClass ));
175
171
}
176
172
177
- Mono <Long > doCount (Query query , Class <?> entityClass , String tableName ) {
173
+ Mono <Long > doCount (Query query , Class <?> entityClass , SqlIdentifier tableName ) {
178
174
179
175
RelationalPersistentEntity <?> entity = getRequiredEntity (entityClass );
180
176
StatementMapper statementMapper = dataAccessStrategy .getStatementMapper ().forType (entityClass );
@@ -211,16 +207,18 @@ public Mono<Boolean> exists(Query query, Class<?> entityClass) throws DataAccess
211
207
return doExists (query , entityClass , getTableName (entityClass ));
212
208
}
213
209
214
- Mono <Boolean > doExists (Query query , Class <?> entityClass , String tableName ) {
210
+ Mono <Boolean > doExists (Query query , Class <?> entityClass , SqlIdentifier tableName ) {
215
211
216
212
RelationalPersistentEntity <?> entity = getRequiredEntity (entityClass );
217
213
StatementMapper statementMapper = dataAccessStrategy .getStatementMapper ().forType (entityClass );
218
214
219
- String columnName = entity .hasIdProperty () ? entity .getRequiredIdProperty ().getColumnName () : "*" ;
215
+ SqlIdentifier columnName = entity .hasIdProperty () ? entity .getRequiredIdProperty ().getColumnName ()
216
+ : SqlIdentifier .unquoted ("*" );
220
217
221
218
StatementMapper .SelectSpec selectSpec = statementMapper //
222
219
.createSelect (tableName ) //
223
- .withProjection (columnName );
220
+ .withProjection (columnName ) //
221
+ .limit (1 );
224
222
225
223
Optional <Criteria > criteria = query .getCriteria ();
226
224
if (criteria .isPresent ()) {
@@ -248,14 +246,13 @@ public <T> Flux<T> select(Query query, Class<T> entityClass) throws DataAccessEx
248
246
return doSelect (query , entityClass , getTableName (entityClass ), entityClass ).all ();
249
247
}
250
248
251
- <T > RowsFetchSpec <T > doSelect (Query query , Class <?> entityClass , String tableName , Class <T > returnType ) {
249
+ <T > RowsFetchSpec <T > doSelect (Query query , Class <?> entityClass , SqlIdentifier tableName , Class <T > returnType ) {
252
250
253
- RelationalPersistentEntity <?> entity = getRequiredEntity (entityClass );
254
251
StatementMapper statementMapper = dataAccessStrategy .getStatementMapper ().forType (entityClass );
255
252
256
253
StatementMapper .SelectSpec selectSpec = statementMapper //
257
254
.createSelect (tableName ) //
258
- .withProjection (getSelectProjection (query , returnType ));
255
+ .doWithTable (( table , spec ) -> spec . withProjection (getSelectProjection (table , query , returnType ) ));
259
256
260
257
if (query .getLimit () > 0 ) {
261
258
selectSpec = selectSpec .limit (query .getLimit ());
@@ -310,7 +307,7 @@ public Mono<Integer> update(Query query, Update update, Class<?> entityClass) th
310
307
return doUpdate (query , update , entityClass , getTableName (entityClass ));
311
308
}
312
309
313
- Mono <Integer > doUpdate (Query query , Update update , Class <?> entityClass , String tableName ) {
310
+ Mono <Integer > doUpdate (Query query , Update update , Class <?> entityClass , SqlIdentifier tableName ) {
314
311
315
312
StatementMapper statementMapper = dataAccessStrategy .getStatementMapper ().forType (entityClass );
316
313
@@ -339,7 +336,7 @@ public Mono<Integer> delete(Query query, Class<?> entityClass) throws DataAccess
339
336
return doDelete (query , entityClass , getTableName (entityClass ));
340
337
}
341
338
342
- Mono <Integer > doDelete (Query query , Class <?> entityClass , String tableName ) {
339
+ Mono <Integer > doDelete (Query query , Class <?> entityClass , SqlIdentifier tableName ) {
343
340
344
341
StatementMapper statementMapper = dataAccessStrategy .getStatementMapper ().forType (entityClass );
345
342
@@ -371,7 +368,7 @@ public <T> Mono<T> insert(T entity) throws DataAccessException {
371
368
return doInsert (entity , getRequiredEntity (entity ).getTableName ());
372
369
}
373
370
374
- <T > Mono <T > doInsert (T entity , String tableName ) {
371
+ <T > Mono <T > doInsert (T entity , SqlIdentifier tableName ) {
375
372
376
373
RelationalPersistentEntity <T > persistentEntity = getRequiredEntity (entity );
377
374
@@ -434,7 +431,7 @@ private <T> Query getByIdQuery(T entity, RelationalPersistentEntity<?> persisten
434
431
return Query .query (Criteria .where (persistentEntity .getRequiredIdProperty ().getName ()).is (id ));
435
432
}
436
433
437
- String getTableName (Class <?> entityClass ) {
434
+ SqlIdentifier getTableName (Class <?> entityClass ) {
438
435
return getRequiredEntity (entityClass ).getTableName ();
439
436
}
440
437
@@ -447,7 +444,7 @@ private <T> RelationalPersistentEntity<T> getRequiredEntity(T entity) {
447
444
return (RelationalPersistentEntity ) getRequiredEntity (entityType );
448
445
}
449
446
450
- private <T > List <String > getSelectProjection (Query query , Class <T > returnType ) {
447
+ private <T > List <Expression > getSelectProjection (Table table , Query query , Class <T > returnType ) {
451
448
452
449
if (query .getColumns ().isEmpty ()) {
453
450
@@ -456,19 +453,21 @@ private <T> List<String> getSelectProjection(Query query, Class<T> returnType) {
456
453
ProjectionInformation projectionInformation = projectionFactory .getProjectionInformation (returnType );
457
454
458
455
if (projectionInformation .isClosed ()) {
459
- return projectionInformation .getInputProperties ().stream ().map (FeatureDescriptor ::getName )
456
+ return projectionInformation .getInputProperties ().stream ().map (FeatureDescriptor ::getName ). map ( table :: column )
460
457
.collect (Collectors .toList ());
461
458
}
462
459
}
463
460
464
- return Collections .singletonList ("*" );
461
+ return Collections .singletonList (table . asterisk () );
465
462
}
466
463
467
- return query .getColumns ();
464
+ return query .getColumns (). stream (). map ( table :: column ). collect ( Collectors . toList ()) ;
468
465
}
469
466
470
467
private static ReactiveDataAccessStrategy getDataAccessStrategy (DatabaseClient databaseClient ) {
471
468
469
+ Assert .notNull (databaseClient , "DatabaseClient must not be null" );
470
+
472
471
if (databaseClient instanceof DefaultDatabaseClient ) {
473
472
474
473
DefaultDatabaseClient client = (DefaultDatabaseClient ) databaseClient ;
@@ -478,14 +477,4 @@ private static ReactiveDataAccessStrategy getDataAccessStrategy(DatabaseClient d
478
477
throw new IllegalStateException ("Cannot obtain ReactiveDataAccessStrategy" );
479
478
}
480
479
481
- private static MappingContext <? extends RelationalPersistentEntity <?>, ? extends RelationalPersistentProperty > getMappingContext (
482
- ReactiveDataAccessStrategy strategy ) {
483
-
484
- if (strategy instanceof DefaultReactiveDataAccessStrategy ) {
485
- DefaultReactiveDataAccessStrategy strategy1 = (DefaultReactiveDataAccessStrategy ) strategy ;
486
- return strategy1 .getMappingContext ();
487
- }
488
- return new R2dbcMappingContext ();
489
- }
490
-
491
480
}
0 commit comments