17
17
18
18
import io .r2dbc .spi .Connection ;
19
19
import io .r2dbc .spi .Statement ;
20
+ import kotlin .internal .LowPriorityInOverloadResolution ;
20
21
import lombok .Getter ;
21
22
import lombok .RequiredArgsConstructor ;
22
23
29
30
import java .util .function .BiConsumer ;
30
31
import java .util .function .BiFunction ;
31
32
import java .util .function .Consumer ;
33
+ import java .util .function .Function ;
32
34
33
35
import org .springframework .dao .InvalidDataAccessApiUsageException ;
34
36
import org .springframework .data .r2dbc .dialect .BindMarker ;
@@ -362,6 +364,7 @@ private static Condition toCondition(BindMarkers bindMarkers, Column column, Set
362
364
}
363
365
}
364
366
367
+
365
368
/**
366
369
* Value object holding value and {@code NULL} bindings.
367
370
*
@@ -412,13 +415,75 @@ void apply(Statement to) {
412
415
}
413
416
}
414
417
418
+
419
+ static abstract class PreparedOperationSupport <T > implements PreparedOperation <T > {
420
+
421
+ private Function <String , String > sqlFilter = s -> s ;
422
+ private Function <Binding , Binding > bindingFilter = b -> b ;
423
+
424
+
425
+ abstract protected String createBaseSql ();
426
+
427
+ protected abstract Binding getBaseBinding ();
428
+
429
+ /*
430
+ * (non-Javadoc)
431
+ * @see org.springframework.data.r2dbc.function.QueryOperation#toQuery()
432
+ */
433
+ @ Override
434
+ public String toQuery () {
435
+
436
+ return sqlFilter .apply (createBaseSql ());
437
+ }
438
+ /*
439
+ * (non-Javadoc)
440
+ * @see org.springframework.data.r2dbc.function.PreparedOperation#bind(io.r2dbc.spi.Statement)
441
+ */
442
+ protected Statement bind (Statement to ) {
443
+
444
+ bindingFilter .apply (getBaseBinding ()).apply (to );
445
+ return to ;
446
+ }
447
+
448
+
449
+ @ Override
450
+ public Statement createBoundStatement (Connection connection ) {
451
+
452
+ // TODO add back logging
453
+ // if (logger.isDebugEnabled()) {
454
+ // logger.debug("Executing SQL statement [" + sql + "]");
455
+ // }
456
+
457
+ return bind (connection .createStatement (toQuery ()));
458
+ }
459
+
460
+ @ Override
461
+ public void addSqlFilter (Function <String , String > filter ) {
462
+
463
+ Assert .notNull (filter , "Filter must not be null." );
464
+
465
+ sqlFilter = filter ;
466
+
467
+ }
468
+
469
+ @ Override
470
+ public void addBindingFilter (Function <Binding , Binding > filter ) {
471
+
472
+ Assert .notNull (filter , "Filter must not be null." );
473
+
474
+ bindingFilter = filter ;
475
+ }
476
+
477
+ }
478
+
479
+
415
480
/**
416
481
* Default implementation of {@link PreparedOperation}.
417
482
*
418
483
* @param <T>
419
484
*/
420
485
@ RequiredArgsConstructor
421
- static class DefaultPreparedOperation <T > implements PreparedOperation <T > {
486
+ static class DefaultPreparedOperation <T > extends PreparedOperationSupport <T > {
422
487
423
488
private final T source ;
424
489
private final RenderContext renderContext ;
@@ -433,13 +498,8 @@ public T getSource() {
433
498
return this .source ;
434
499
}
435
500
436
- /*
437
- * (non-Javadoc)
438
- * @see org.springframework.data.r2dbc.function.QueryOperation#toQuery()
439
- */
440
501
@ Override
441
- public String toQuery () {
442
-
502
+ protected String createBaseSql () {
443
503
SqlRenderer sqlRenderer = SqlRenderer .create (renderContext );
444
504
445
505
if (this .source instanceof Select ) {
@@ -461,25 +521,9 @@ public String toQuery() {
461
521
throw new IllegalStateException ("Cannot render " + this .getSource ());
462
522
}
463
523
464
- /*
465
- * (non-Javadoc)
466
- * @see org.springframework.data.r2dbc.function.PreparedOperation#bind(io.r2dbc.spi.Statement)
467
- */
468
- protected Statement bind (Statement to ) {
469
-
470
- binding .apply (to );
471
- return to ;
472
- }
473
-
474
524
@ Override
475
- public Statement createBoundStatement (Connection connection ) {
476
-
477
- // TODO add back logging
478
- // if (logger.isDebugEnabled()) {
479
- // logger.debug("Executing SQL statement [" + sql + "]");
480
- // }
481
-
482
- return bind (connection .createStatement (toQuery ()));
525
+ protected Binding getBaseBinding () {
526
+ return binding ;
483
527
}
484
528
}
485
529
}
0 commit comments