43
43
import org .springframework .data .r2dbc .query .Update ;
44
44
import org .springframework .data .relational .core .mapping .RelationalPersistentEntity ;
45
45
import org .springframework .data .relational .core .mapping .RelationalPersistentProperty ;
46
+ import org .springframework .data .relational .core .sql .Expression ;
46
47
import org .springframework .data .relational .core .sql .Functions ;
48
+ import org .springframework .data .relational .core .sql .SqlIdentifier ;
49
+ import org .springframework .data .relational .core .sql .Table ;
47
50
import org .springframework .data .util .ProxyUtils ;
48
51
import org .springframework .util .Assert ;
49
52
@@ -174,7 +177,7 @@ public Mono<Long> count(Query query, Class<?> entityClass) throws DataAccessExce
174
177
return doCount (query , entityClass , getTableName (entityClass ));
175
178
}
176
179
177
- Mono <Long > doCount (Query query , Class <?> entityClass , String tableName ) {
180
+ Mono <Long > doCount (Query query , Class <?> entityClass , SqlIdentifier tableName ) {
178
181
179
182
RelationalPersistentEntity <?> entity = getRequiredEntity (entityClass );
180
183
StatementMapper statementMapper = dataAccessStrategy .getStatementMapper ().forType (entityClass );
@@ -211,12 +214,13 @@ public Mono<Boolean> exists(Query query, Class<?> entityClass) throws DataAccess
211
214
return doExists (query , entityClass , getTableName (entityClass ));
212
215
}
213
216
214
- Mono <Boolean > doExists (Query query , Class <?> entityClass , String tableName ) {
217
+ Mono <Boolean > doExists (Query query , Class <?> entityClass , SqlIdentifier tableName ) {
215
218
216
219
RelationalPersistentEntity <?> entity = getRequiredEntity (entityClass );
217
220
StatementMapper statementMapper = dataAccessStrategy .getStatementMapper ().forType (entityClass );
218
221
219
- String columnName = entity .hasIdProperty () ? entity .getRequiredIdProperty ().getColumnName () : "*" ;
222
+ SqlIdentifier columnName = entity .hasIdProperty () ? entity .getRequiredIdProperty ().getColumnName ()
223
+ : SqlIdentifier .unquoted ("*" );
220
224
221
225
StatementMapper .SelectSpec selectSpec = statementMapper //
222
226
.createSelect (tableName ) //
@@ -248,14 +252,13 @@ public <T> Flux<T> select(Query query, Class<T> entityClass) throws DataAccessEx
248
252
return doSelect (query , entityClass , getTableName (entityClass ), entityClass ).all ();
249
253
}
250
254
251
- <T > RowsFetchSpec <T > doSelect (Query query , Class <?> entityClass , String tableName , Class <T > returnType ) {
255
+ <T > RowsFetchSpec <T > doSelect (Query query , Class <?> entityClass , SqlIdentifier tableName , Class <T > returnType ) {
252
256
253
- RelationalPersistentEntity <?> entity = getRequiredEntity (entityClass );
254
257
StatementMapper statementMapper = dataAccessStrategy .getStatementMapper ().forType (entityClass );
255
258
256
259
StatementMapper .SelectSpec selectSpec = statementMapper //
257
260
.createSelect (tableName ) //
258
- .withProjection (getSelectProjection (query , returnType ));
261
+ .doWithTable (( table , spec ) -> spec . withProjection (getSelectProjection (table , query , returnType ) ));
259
262
260
263
if (query .getLimit () > 0 ) {
261
264
selectSpec = selectSpec .limit (query .getLimit ());
@@ -310,7 +313,7 @@ public Mono<Integer> update(Query query, Update update, Class<?> entityClass) th
310
313
return doUpdate (query , update , entityClass , getTableName (entityClass ));
311
314
}
312
315
313
- Mono <Integer > doUpdate (Query query , Update update , Class <?> entityClass , String tableName ) {
316
+ Mono <Integer > doUpdate (Query query , Update update , Class <?> entityClass , SqlIdentifier tableName ) {
314
317
315
318
StatementMapper statementMapper = dataAccessStrategy .getStatementMapper ().forType (entityClass );
316
319
@@ -339,7 +342,7 @@ public Mono<Integer> delete(Query query, Class<?> entityClass) throws DataAccess
339
342
return doDelete (query , entityClass , getTableName (entityClass ));
340
343
}
341
344
342
- Mono <Integer > doDelete (Query query , Class <?> entityClass , String tableName ) {
345
+ Mono <Integer > doDelete (Query query , Class <?> entityClass , SqlIdentifier tableName ) {
343
346
344
347
StatementMapper statementMapper = dataAccessStrategy .getStatementMapper ().forType (entityClass );
345
348
@@ -371,7 +374,7 @@ public <T> Mono<T> insert(T entity) throws DataAccessException {
371
374
return doInsert (entity , getRequiredEntity (entity ).getTableName ());
372
375
}
373
376
374
- <T > Mono <T > doInsert (T entity , String tableName ) {
377
+ <T > Mono <T > doInsert (T entity , SqlIdentifier tableName ) {
375
378
376
379
RelationalPersistentEntity <T > persistentEntity = getRequiredEntity (entity );
377
380
@@ -434,7 +437,7 @@ private <T> Query getByIdQuery(T entity, RelationalPersistentEntity<?> persisten
434
437
return Query .query (Criteria .where (persistentEntity .getRequiredIdProperty ().getName ()).is (id ));
435
438
}
436
439
437
- String getTableName (Class <?> entityClass ) {
440
+ SqlIdentifier getTableName (Class <?> entityClass ) {
438
441
return getRequiredEntity (entityClass ).getTableName ();
439
442
}
440
443
@@ -447,7 +450,7 @@ private <T> RelationalPersistentEntity<T> getRequiredEntity(T entity) {
447
450
return (RelationalPersistentEntity ) getRequiredEntity (entityType );
448
451
}
449
452
450
- private <T > List <String > getSelectProjection (Query query , Class <T > returnType ) {
453
+ private <T > List <Expression > getSelectProjection (Table table , Query query , Class <T > returnType ) {
451
454
452
455
if (query .getColumns ().isEmpty ()) {
453
456
@@ -456,15 +459,15 @@ private <T> List<String> getSelectProjection(Query query, Class<T> returnType) {
456
459
ProjectionInformation projectionInformation = projectionFactory .getProjectionInformation (returnType );
457
460
458
461
if (projectionInformation .isClosed ()) {
459
- return projectionInformation .getInputProperties ().stream ().map (FeatureDescriptor ::getName )
462
+ return projectionInformation .getInputProperties ().stream ().map (FeatureDescriptor ::getName ). map ( table :: column )
460
463
.collect (Collectors .toList ());
461
464
}
462
465
}
463
466
464
- return Collections .singletonList ("*" );
467
+ return Collections .singletonList (table . asterisk () );
465
468
}
466
469
467
- return query .getColumns ();
470
+ return query .getColumns (). stream (). map ( table :: column ). collect ( Collectors . toList ()) ;
468
471
}
469
472
470
473
private static ReactiveDataAccessStrategy getDataAccessStrategy (DatabaseClient databaseClient ) {
0 commit comments