37
37
import java .util .LinkedHashSet ;
38
38
import java .util .List ;
39
39
import java .util .Map ;
40
- import java .util .Set ;
41
40
import java .util .concurrent .atomic .AtomicBoolean ;
42
41
import java .util .function .BiFunction ;
43
42
import java .util .function .Function ;
49
48
import org .reactivestreams .Publisher ;
50
49
51
50
import org .springframework .dao .DataAccessException ;
51
+ import org .springframework .dao .InvalidDataAccessApiUsageException ;
52
52
import org .springframework .data .domain .Pageable ;
53
53
import org .springframework .data .domain .Sort ;
54
54
import org .springframework .data .r2dbc .UncategorizedR2dbcException ;
57
57
import org .springframework .data .r2dbc .function .connectionfactory .ConnectionProxy ;
58
58
import org .springframework .data .r2dbc .function .convert .ColumnMapRowMapper ;
59
59
import org .springframework .data .r2dbc .support .R2dbcExceptionTranslator ;
60
+ import org .springframework .data .relational .core .sql .Insert ;
60
61
import org .springframework .jdbc .core .SqlProvider ;
61
62
import org .springframework .lang .Nullable ;
62
63
import org .springframework .util .Assert ;
@@ -337,9 +338,17 @@ <T> FetchSpec<T> exchange(String sql, BiFunction<Row, RowMetadata, T> mappingFun
337
338
logger .debug ("Executing SQL statement [" + sql + "]" );
338
339
}
339
340
341
+ if (sqlSupplier instanceof PreparedOperation <?>) {
342
+ return ((PreparedOperation <?>) sqlSupplier ).bind (it .createStatement (sql ));
343
+ }
344
+
340
345
BindableOperation operation = namedParameters .expand (sql , dataAccessStrategy .getBindMarkersFactory (),
341
346
new MapBindParameterSource (byName ));
342
347
348
+ if (logger .isTraceEnabled ()) {
349
+ logger .trace ("Expanded SQL [" + operation .toQuery () + "]" );
350
+ }
351
+
343
352
Statement statement = it .createStatement (operation .toQuery ());
344
353
345
354
byName .forEach ((name , o ) -> {
@@ -367,6 +376,7 @@ <T> FetchSpec<T> exchange(String sql, BiFunction<Row, RowMetadata, T> mappingFun
367
376
368
377
public ExecuteSpecSupport bind (int index , Object value ) {
369
378
379
+ assertNotPreparedOperation ();
370
380
Assert .notNull (value , () -> String .format ("Value at index %d must not be null. Use bindNull(…) instead." , index ));
371
381
372
382
Map <Integer , SettableValue > byIndex = new LinkedHashMap <>(this .byIndex );
@@ -377,6 +387,8 @@ public ExecuteSpecSupport bind(int index, Object value) {
377
387
378
388
public ExecuteSpecSupport bindNull (int index , Class <?> type ) {
379
389
390
+ assertNotPreparedOperation ();
391
+
380
392
Map <Integer , SettableValue > byIndex = new LinkedHashMap <>(this .byIndex );
381
393
byIndex .put (index , SettableValue .empty (type ));
382
394
@@ -385,6 +397,8 @@ public ExecuteSpecSupport bindNull(int index, Class<?> type) {
385
397
386
398
public ExecuteSpecSupport bind (String name , Object value ) {
387
399
400
+ assertNotPreparedOperation ();
401
+
388
402
Assert .hasText (name , "Parameter name must not be null or empty!" );
389
403
Assert .notNull (value ,
390
404
() -> String .format ("Value for parameter %s must not be null. Use bindNull(…) instead." , name ));
@@ -397,6 +411,7 @@ public ExecuteSpecSupport bind(String name, Object value) {
397
411
398
412
public ExecuteSpecSupport bindNull (String name , Class <?> type ) {
399
413
414
+ assertNotPreparedOperation ();
400
415
Assert .hasText (name , "Parameter name must not be null or empty!" );
401
416
402
417
Map <String , SettableValue > byName = new LinkedHashMap <>(this .byName );
@@ -405,6 +420,12 @@ public ExecuteSpecSupport bindNull(String name, Class<?> type) {
405
420
return createInstance (this .byIndex , byName , this .sqlSupplier );
406
421
}
407
422
423
+ private void assertNotPreparedOperation () {
424
+ if (sqlSupplier instanceof PreparedOperation <?>) {
425
+ throw new InvalidDataAccessApiUsageException ("Cannot add bindings to a PreparedOperation" );
426
+ }
427
+ }
428
+
408
429
protected ExecuteSpecSupport createInstance (Map <Integer , SettableValue > byIndex , Map <String , SettableValue > byName ,
409
430
Supplier <String > sqlSupplier ) {
410
431
return new ExecuteSpecSupport (byIndex , byName , sqlSupplier );
@@ -882,20 +903,19 @@ private <R> FetchSpec<R> exchange(BiFunction<Row, RowMetadata, R> mappingFunctio
882
903
throw new IllegalStateException ("Insert fields is empty!" );
883
904
}
884
905
885
- BindableOperation bindableInsert = dataAccessStrategy .insertAndReturnGeneratedKeys (table , byName .keySet ());
906
+ PreparedOperation <Insert > operation = dataAccessStrategy .getStatements ().insert (table , Collections .emptyList (),
907
+ it -> {
908
+ byName .forEach (it ::bind );
909
+ });
886
910
887
- String sql = bindableInsert .toQuery ();
911
+ String sql = operation .toQuery ();
888
912
Function <Connection , Statement > insertFunction = it -> {
889
913
890
914
if (logger .isDebugEnabled ()) {
891
915
logger .debug ("Executing SQL statement [" + sql + "]" );
892
916
}
893
917
894
- Statement statement = it .createStatement (sql ).returnGeneratedValues ();
895
-
896
- byName .forEach ((k , v ) -> bindableInsert .bind (statement , k , v ));
897
-
898
- return statement ;
918
+ return operation .bind (it .createStatement (sql ));
899
919
};
900
920
901
921
Function <Connection , Flux <Result >> resultFunction = it -> Flux .from (insertFunction .apply (it ).execute ());
@@ -999,34 +1019,25 @@ private <MR> FetchSpec<MR> exchange(Object toInsert, BiFunction<Row, RowMetadata
999
1019
1000
1020
OutboundRow outboundRow = dataAccessStrategy .getOutboundRow (toInsert );
1001
1021
1002
- Set <String > columns = new LinkedHashSet <>();
1003
-
1004
- outboundRow .forEach ((k , v ) -> {
1005
-
1006
- if (v .hasValue ()) {
1007
- columns .add (k );
1008
- }
1009
- });
1022
+ PreparedOperation <Insert > operation = dataAccessStrategy .getStatements ().insert (table , Collections .emptyList (),
1023
+ it -> {
1024
+ outboundRow .forEach ((k , v ) -> {
1010
1025
1011
- BindableOperation bindableInsert = dataAccessStrategy .insertAndReturnGeneratedKeys (table , columns );
1026
+ if (v .hasValue ()) {
1027
+ it .bind (k , v );
1028
+ }
1029
+ });
1030
+ });
1012
1031
1013
- String sql = bindableInsert .toQuery ();
1032
+ String sql = operation .toQuery ();
1014
1033
1015
1034
Function <Connection , Statement > insertFunction = it -> {
1016
1035
1017
1036
if (logger .isDebugEnabled ()) {
1018
1037
logger .debug ("Executing SQL statement [" + sql + "]" );
1019
1038
}
1020
1039
1021
- Statement statement = it .createStatement (sql ).returnGeneratedValues ();
1022
-
1023
- outboundRow .forEach ((k , v ) -> {
1024
- if (v .hasValue ()) {
1025
- bindableInsert .bind (statement , k , v );
1026
- }
1027
- });
1028
-
1029
- return statement ;
1040
+ return operation .bind (it .createStatement (sql ));
1030
1041
};
1031
1042
1032
1043
Function <Connection , Flux <Result >> resultFunction = it -> Flux .from (insertFunction .apply (it ).execute ());
0 commit comments