40
40
import org .springframework .core .ReactiveAdapterRegistry ;
41
41
import org .springframework .core .ResolvableType ;
42
42
import org .springframework .core .task .SyncTaskExecutor ;
43
+ import org .springframework .core .task .TaskDecorator ;
43
44
import org .springframework .core .task .TaskExecutor ;
45
+ import org .springframework .core .task .support .ContextPropagatingTaskDecorator ;
44
46
import org .springframework .http .MediaType ;
45
47
import org .springframework .http .codec .ServerSentEvent ;
46
48
import org .springframework .http .server .ServerHttpResponse ;
@@ -91,18 +93,25 @@ class ReactiveTypeHandler {
91
93
92
94
private final ContentNegotiationManager contentNegotiationManager ;
93
95
96
+ private final ContextSnapshotFactory contextSnapshotFactory ;
97
+
94
98
95
99
public ReactiveTypeHandler () {
96
- this (ReactiveAdapterRegistry .getSharedInstance (), new SyncTaskExecutor (), new ContentNegotiationManager ());
100
+ this (ReactiveAdapterRegistry .getSharedInstance (), new SyncTaskExecutor (), new ContentNegotiationManager (), null );
97
101
}
98
102
99
- ReactiveTypeHandler (ReactiveAdapterRegistry registry , TaskExecutor executor , ContentNegotiationManager manager ) {
103
+ ReactiveTypeHandler (
104
+ ReactiveAdapterRegistry registry , TaskExecutor executor , ContentNegotiationManager manager ,
105
+ @ Nullable ContextSnapshotFactory contextSnapshotFactory ) {
106
+
100
107
Assert .notNull (registry , "ReactiveAdapterRegistry is required" );
101
108
Assert .notNull (executor , "TaskExecutor is required" );
102
109
Assert .notNull (manager , "ContentNegotiationManager is required" );
103
110
this .adapterRegistry = registry ;
104
111
this .taskExecutor = executor ;
105
112
this .contentNegotiationManager = manager ;
113
+ this .contextSnapshotFactory = (contextSnapshotFactory != null ?
114
+ contextSnapshotFactory : ContextSnapshotFactory .builder ().build ());
106
115
}
107
116
108
117
@@ -129,8 +138,10 @@ public ResponseBodyEmitter handleValue(Object returnValue, MethodParameter retur
129
138
ReactiveAdapter adapter = this .adapterRegistry .getAdapter (clazz );
130
139
Assert .state (adapter != null , () -> "Unexpected return value type: " + clazz );
131
140
141
+ TaskDecorator taskDecorator = null ;
132
142
if (isContextPropagationPresent ) {
133
- returnValue = ContextSnapshotHelper .writeReactorContext (returnValue );
143
+ returnValue = ContextSnapshotHelper .writeReactorContext (returnValue , this .contextSnapshotFactory );
144
+ taskDecorator = ContextSnapshotHelper .getTaskDecorator (this .contextSnapshotFactory );
134
145
}
135
146
136
147
ResolvableType elementType = ResolvableType .forMethodParameter (returnType ).getGeneric ();
@@ -143,7 +154,7 @@ public ResponseBodyEmitter handleValue(Object returnValue, MethodParameter retur
143
154
if (mediaTypes .stream ().anyMatch (MediaType .TEXT_EVENT_STREAM ::includes ) ||
144
155
ServerSentEvent .class .isAssignableFrom (elementClass )) {
145
156
SseEmitter emitter = new SseEmitter (STREAMING_TIMEOUT_VALUE );
146
- new SseEmitterSubscriber (emitter , this .taskExecutor ).connect (adapter , returnValue );
157
+ new SseEmitterSubscriber (emitter , this .taskExecutor , taskDecorator ).connect (adapter , returnValue );
147
158
return emitter ;
148
159
}
149
160
if (CharSequence .class .isAssignableFrom (elementClass )) {
@@ -247,9 +258,14 @@ private abstract static class AbstractEmitterSubscriber implements Subscriber<Ob
247
258
248
259
private volatile boolean done ;
249
260
250
- protected AbstractEmitterSubscriber (ResponseBodyEmitter emitter , TaskExecutor executor ) {
261
+ private final Runnable sendTask ;
262
+
263
+ protected AbstractEmitterSubscriber (
264
+ ResponseBodyEmitter emitter , TaskExecutor executor , @ Nullable TaskDecorator taskDecorator ) {
265
+
251
266
this .emitter = emitter ;
252
267
this .taskExecutor = executor ;
268
+ this .sendTask = (taskDecorator != null ? taskDecorator .decorate (this ) : this );
253
269
}
254
270
255
271
public void connect (ReactiveAdapter adapter , Object returnValue ) {
@@ -302,7 +318,7 @@ private void trySchedule() {
302
318
303
319
private void schedule () {
304
320
try {
305
- this .taskExecutor .execute (this );
321
+ this .taskExecutor .execute (this . sendTask );
306
322
}
307
323
catch (Throwable ex ) {
308
324
try {
@@ -380,8 +396,8 @@ private void terminate() {
380
396
381
397
private static class SseEmitterSubscriber extends AbstractEmitterSubscriber {
382
398
383
- SseEmitterSubscriber (SseEmitter sseEmitter , TaskExecutor executor ) {
384
- super (sseEmitter , executor );
399
+ SseEmitterSubscriber (SseEmitter sseEmitter , TaskExecutor executor , @ Nullable TaskDecorator taskDecorator ) {
400
+ super (sseEmitter , executor , taskDecorator );
385
401
}
386
402
387
403
@ Override
@@ -423,8 +439,10 @@ private SseEmitter.SseEventBuilder adapt(ServerSentEvent<?> sse) {
423
439
424
440
private static class JsonEmitterSubscriber extends AbstractEmitterSubscriber {
425
441
426
- JsonEmitterSubscriber (ResponseBodyEmitter emitter , TaskExecutor executor ) {
427
- super (emitter , executor );
442
+ JsonEmitterSubscriber (
443
+ ResponseBodyEmitter emitter , TaskExecutor executor ) {
444
+
445
+ super (emitter , executor , null );
428
446
}
429
447
430
448
@ Override
@@ -438,7 +456,7 @@ protected void send(Object element) throws IOException {
438
456
private static class TextEmitterSubscriber extends AbstractEmitterSubscriber {
439
457
440
458
TextEmitterSubscriber (ResponseBodyEmitter emitter , TaskExecutor executor ) {
441
- super (emitter , executor );
459
+ super (emitter , executor , null );
442
460
}
443
461
444
462
@ Override
@@ -518,22 +536,24 @@ public ResolvableType getReturnType() {
518
536
519
537
private static class ContextSnapshotHelper {
520
538
521
- private static final ContextSnapshotFactory factory = ContextSnapshotFactory .builder ().build ();
522
-
523
539
@ SuppressWarnings ("ReactiveStreamsUnusedPublisher" )
524
- public static Object writeReactorContext (Object returnValue ) {
540
+ public static Object writeReactorContext (Object returnValue , ContextSnapshotFactory snapshotFactory ) {
525
541
if (Mono .class .isAssignableFrom (returnValue .getClass ())) {
526
- ContextSnapshot snapshot = factory .captureAll ();
542
+ ContextSnapshot snapshot = snapshotFactory .captureAll ();
527
543
return ((Mono <?>) returnValue ).contextWrite (snapshot ::updateContext );
528
544
}
529
545
else if (Flux .class .isAssignableFrom (returnValue .getClass ())) {
530
- ContextSnapshot snapshot = factory .captureAll ();
546
+ ContextSnapshot snapshot = snapshotFactory .captureAll ();
531
547
return ((Flux <?>) returnValue ).contextWrite (snapshot ::updateContext );
532
548
}
533
549
else {
534
550
return returnValue ;
535
551
}
536
552
}
553
+
554
+ public static TaskDecorator getTaskDecorator (ContextSnapshotFactory snapshotFactory ) {
555
+ return new ContextPropagatingTaskDecorator (snapshotFactory );
556
+ }
537
557
}
538
558
539
559
}
0 commit comments