Skip to content

Commit 6c2f602

Browse files
committed
Propagate context to send for SSE Flux
Closes gh-32813
1 parent c1250b1 commit 6c2f602

File tree

3 files changed

+135
-20
lines changed

3 files changed

+135
-20
lines changed

spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ReactiveTypeHandler.java

+36-16
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,9 @@
4040
import org.springframework.core.ReactiveAdapterRegistry;
4141
import org.springframework.core.ResolvableType;
4242
import org.springframework.core.task.SyncTaskExecutor;
43+
import org.springframework.core.task.TaskDecorator;
4344
import org.springframework.core.task.TaskExecutor;
45+
import org.springframework.core.task.support.ContextPropagatingTaskDecorator;
4446
import org.springframework.http.MediaType;
4547
import org.springframework.http.codec.ServerSentEvent;
4648
import org.springframework.http.server.ServerHttpResponse;
@@ -91,18 +93,25 @@ class ReactiveTypeHandler {
9193

9294
private final ContentNegotiationManager contentNegotiationManager;
9395

96+
private final ContextSnapshotFactory contextSnapshotFactory;
97+
9498

9599
public ReactiveTypeHandler() {
96-
this(ReactiveAdapterRegistry.getSharedInstance(), new SyncTaskExecutor(), new ContentNegotiationManager());
100+
this(ReactiveAdapterRegistry.getSharedInstance(), new SyncTaskExecutor(), new ContentNegotiationManager(), null);
97101
}
98102

99-
ReactiveTypeHandler(ReactiveAdapterRegistry registry, TaskExecutor executor, ContentNegotiationManager manager) {
103+
ReactiveTypeHandler(
104+
ReactiveAdapterRegistry registry, TaskExecutor executor, ContentNegotiationManager manager,
105+
@Nullable ContextSnapshotFactory contextSnapshotFactory) {
106+
100107
Assert.notNull(registry, "ReactiveAdapterRegistry is required");
101108
Assert.notNull(executor, "TaskExecutor is required");
102109
Assert.notNull(manager, "ContentNegotiationManager is required");
103110
this.adapterRegistry = registry;
104111
this.taskExecutor = executor;
105112
this.contentNegotiationManager = manager;
113+
this.contextSnapshotFactory = (contextSnapshotFactory != null ?
114+
contextSnapshotFactory : ContextSnapshotFactory.builder().build());
106115
}
107116

108117

@@ -129,8 +138,10 @@ public ResponseBodyEmitter handleValue(Object returnValue, MethodParameter retur
129138
ReactiveAdapter adapter = this.adapterRegistry.getAdapter(clazz);
130139
Assert.state(adapter != null, () -> "Unexpected return value type: " + clazz);
131140

141+
TaskDecorator taskDecorator = null;
132142
if (isContextPropagationPresent) {
133-
returnValue = ContextSnapshotHelper.writeReactorContext(returnValue);
143+
returnValue = ContextSnapshotHelper.writeReactorContext(returnValue, this.contextSnapshotFactory);
144+
taskDecorator = ContextSnapshotHelper.getTaskDecorator(this.contextSnapshotFactory);
134145
}
135146

136147
ResolvableType elementType = ResolvableType.forMethodParameter(returnType).getGeneric();
@@ -143,7 +154,7 @@ public ResponseBodyEmitter handleValue(Object returnValue, MethodParameter retur
143154
if (mediaTypes.stream().anyMatch(MediaType.TEXT_EVENT_STREAM::includes) ||
144155
ServerSentEvent.class.isAssignableFrom(elementClass)) {
145156
SseEmitter emitter = new SseEmitter(STREAMING_TIMEOUT_VALUE);
146-
new SseEmitterSubscriber(emitter, this.taskExecutor).connect(adapter, returnValue);
157+
new SseEmitterSubscriber(emitter, this.taskExecutor, taskDecorator).connect(adapter, returnValue);
147158
return emitter;
148159
}
149160
if (CharSequence.class.isAssignableFrom(elementClass)) {
@@ -247,9 +258,14 @@ private abstract static class AbstractEmitterSubscriber implements Subscriber<Ob
247258

248259
private volatile boolean done;
249260

250-
protected AbstractEmitterSubscriber(ResponseBodyEmitter emitter, TaskExecutor executor) {
261+
private final Runnable sendTask;
262+
263+
protected AbstractEmitterSubscriber(
264+
ResponseBodyEmitter emitter, TaskExecutor executor, @Nullable TaskDecorator taskDecorator) {
265+
251266
this.emitter = emitter;
252267
this.taskExecutor = executor;
268+
this.sendTask = (taskDecorator != null ? taskDecorator.decorate(this) : this);
253269
}
254270

255271
public void connect(ReactiveAdapter adapter, Object returnValue) {
@@ -302,7 +318,7 @@ private void trySchedule() {
302318

303319
private void schedule() {
304320
try {
305-
this.taskExecutor.execute(this);
321+
this.taskExecutor.execute(this.sendTask);
306322
}
307323
catch (Throwable ex) {
308324
try {
@@ -380,8 +396,8 @@ private void terminate() {
380396

381397
private static class SseEmitterSubscriber extends AbstractEmitterSubscriber {
382398

383-
SseEmitterSubscriber(SseEmitter sseEmitter, TaskExecutor executor) {
384-
super(sseEmitter, executor);
399+
SseEmitterSubscriber(SseEmitter sseEmitter, TaskExecutor executor, @Nullable TaskDecorator taskDecorator) {
400+
super(sseEmitter, executor, taskDecorator);
385401
}
386402

387403
@Override
@@ -423,8 +439,10 @@ private SseEmitter.SseEventBuilder adapt(ServerSentEvent<?> sse) {
423439

424440
private static class JsonEmitterSubscriber extends AbstractEmitterSubscriber {
425441

426-
JsonEmitterSubscriber(ResponseBodyEmitter emitter, TaskExecutor executor) {
427-
super(emitter, executor);
442+
JsonEmitterSubscriber(
443+
ResponseBodyEmitter emitter, TaskExecutor executor) {
444+
445+
super(emitter, executor, null);
428446
}
429447

430448
@Override
@@ -438,7 +456,7 @@ protected void send(Object element) throws IOException {
438456
private static class TextEmitterSubscriber extends AbstractEmitterSubscriber {
439457

440458
TextEmitterSubscriber(ResponseBodyEmitter emitter, TaskExecutor executor) {
441-
super(emitter, executor);
459+
super(emitter, executor, null);
442460
}
443461

444462
@Override
@@ -518,22 +536,24 @@ public ResolvableType getReturnType() {
518536

519537
private static class ContextSnapshotHelper {
520538

521-
private static final ContextSnapshotFactory factory = ContextSnapshotFactory.builder().build();
522-
523539
@SuppressWarnings("ReactiveStreamsUnusedPublisher")
524-
public static Object writeReactorContext(Object returnValue) {
540+
public static Object writeReactorContext(Object returnValue, ContextSnapshotFactory snapshotFactory) {
525541
if (Mono.class.isAssignableFrom(returnValue.getClass())) {
526-
ContextSnapshot snapshot = factory.captureAll();
542+
ContextSnapshot snapshot = snapshotFactory.captureAll();
527543
return ((Mono<?>) returnValue).contextWrite(snapshot::updateContext);
528544
}
529545
else if (Flux.class.isAssignableFrom(returnValue.getClass())) {
530-
ContextSnapshot snapshot = factory.captureAll();
546+
ContextSnapshot snapshot = snapshotFactory.captureAll();
531547
return ((Flux<?>) returnValue).contextWrite(snapshot::updateContext);
532548
}
533549
else {
534550
return returnValue;
535551
}
536552
}
553+
554+
public static TaskDecorator getTaskDecorator(ContextSnapshotFactory snapshotFactory) {
555+
return new ContextPropagatingTaskDecorator(snapshotFactory);
556+
}
537557
}
538558

539559
}

spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitterReturnValueHandler.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2023 the original author or authors.
2+
* Copyright 2002-2024 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -91,7 +91,7 @@ public ResponseBodyEmitterReturnValueHandler(List<HttpMessageConverter<?>> messa
9191

9292
Assert.notEmpty(messageConverters, "HttpMessageConverter List must not be empty");
9393
this.sseMessageConverters = initSseConverters(messageConverters);
94-
this.reactiveHandler = new ReactiveTypeHandler(registry, executor, manager);
94+
this.reactiveHandler = new ReactiveTypeHandler(registry, executor, manager, null);
9595
}
9696

9797
private static List<HttpMessageConverter<?>> initSseConverters(List<HttpMessageConverter<?>> converters) {

spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ReactiveTypeHandlerTests.java

+97-2
Original file line numberDiff line numberDiff line change
@@ -22,28 +22,40 @@
2222
import java.util.Collections;
2323
import java.util.List;
2424
import java.util.Set;
25+
import java.util.concurrent.atomic.AtomicInteger;
2526
import java.util.concurrent.atomic.AtomicReference;
2627
import java.util.function.Consumer;
2728
import java.util.stream.Collectors;
2829

30+
import io.micrometer.context.ContextRegistry;
31+
import io.micrometer.context.ContextSnapshotFactory;
2932
import io.reactivex.rxjava3.core.Single;
3033
import io.reactivex.rxjava3.core.SingleEmitter;
34+
import jakarta.servlet.http.HttpServletRequest;
3135
import org.junit.jupiter.api.BeforeEach;
3236
import org.junit.jupiter.api.Test;
3337
import reactor.core.publisher.Flux;
3438
import reactor.core.publisher.Mono;
3539
import reactor.core.publisher.Sinks;
40+
import reactor.util.context.ReactorContextAccessor;
3641

3742
import org.springframework.core.MethodParameter;
3843
import org.springframework.core.ReactiveAdapterRegistry;
3944
import org.springframework.core.ResolvableType;
45+
import org.springframework.core.task.SimpleAsyncTaskExecutor;
4046
import org.springframework.core.task.SyncTaskExecutor;
47+
import org.springframework.core.task.TaskExecutor;
4148
import org.springframework.http.MediaType;
4249
import org.springframework.http.codec.ServerSentEvent;
4350
import org.springframework.http.server.ServletServerHttpResponse;
51+
import org.springframework.lang.Nullable;
4452
import org.springframework.web.accept.ContentNegotiationManager;
4553
import org.springframework.web.accept.ContentNegotiationManagerFactoryBean;
4654
import org.springframework.web.context.request.NativeWebRequest;
55+
import org.springframework.web.context.request.RequestAttributes;
56+
import org.springframework.web.context.request.RequestAttributesThreadLocalAccessor;
57+
import org.springframework.web.context.request.RequestContextHolder;
58+
import org.springframework.web.context.request.ServletRequestAttributes;
4759
import org.springframework.web.context.request.ServletWebRequest;
4860
import org.springframework.web.context.request.async.AsyncWebRequest;
4961
import org.springframework.web.context.request.async.StandardServletAsyncWebRequest;
@@ -75,12 +87,18 @@ class ReactiveTypeHandlerTests {
7587

7688
@BeforeEach
7789
void setup() throws Exception {
90+
this.handler = initHandler(new SyncTaskExecutor(), null);
91+
resetRequest();
92+
}
93+
94+
private static ReactiveTypeHandler initHandler(
95+
TaskExecutor taskExecutor, @Nullable ContextSnapshotFactory snapshotFactory) {
96+
7897
ContentNegotiationManagerFactoryBean factoryBean = new ContentNegotiationManagerFactoryBean();
7998
factoryBean.afterPropertiesSet();
8099
ContentNegotiationManager manager = factoryBean.getObject();
81100
ReactiveAdapterRegistry adapterRegistry = ReactiveAdapterRegistry.getSharedInstance();
82-
this.handler = new ReactiveTypeHandler(adapterRegistry, new SyncTaskExecutor(), manager);
83-
resetRequest();
101+
return new ReactiveTypeHandler(adapterRegistry, taskExecutor, manager, snapshotFactory);
84102
}
85103

86104
private void resetRequest() {
@@ -414,6 +432,42 @@ void writeFluxOfString() throws Exception {
414432
testEmitterContentType("application/json");
415433
}
416434

435+
@Test
436+
void contextPropagation() throws Exception {
437+
438+
ContextRegistry registry = new ContextRegistry();
439+
registry.registerThreadLocalAccessor(new RequestAttributesThreadLocalAccessor());
440+
registry.registerContextAccessor(new ReactorContextAccessor());
441+
ContextSnapshotFactory snapshotFactory = ContextSnapshotFactory.builder().contextRegistry(registry).build();
442+
443+
ModelAndViewContainer mavContainer = new ModelAndViewContainer();
444+
MethodParameter returnType = on(TestController.class).resolveReturnType(Flux.class, forClass(String.class));
445+
ReactiveTypeHandler handler = initHandler(new SimpleAsyncTaskExecutor(), snapshotFactory);
446+
447+
this.servletRequest.addHeader("Accept", MediaType.TEXT_EVENT_STREAM_VALUE);
448+
this.servletRequest.setAttribute("key", "context value");
449+
RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(this.servletRequest));
450+
451+
try {
452+
Sinks.Many<String> sink = Sinks.many().unicast().onBackpressureBuffer();
453+
ResponseBodyEmitter emitter = handler.handleValue(sink.asFlux(), returnType, mavContainer, this.webRequest);
454+
455+
ContextEmitterHandler emitterHandler = new ContextEmitterHandler();
456+
emitter.initialize(emitterHandler);
457+
458+
sink.tryEmitNext("emitted value");
459+
emitterHandler.awaitMessageCount(1);
460+
461+
sink.tryEmitComplete();
462+
463+
assertThat(emitterHandler.getValuesAsText()).isEqualTo("data:emitted value\n\n");
464+
assertThat(emitterHandler.getSavedRequest()).isSameAs(this.servletRequest);
465+
}
466+
finally {
467+
RequestContextHolder.resetRequestAttributes();
468+
}
469+
}
470+
417471
private void testEmitterContentType(String expected) throws Exception {
418472
ServletServerHttpResponse message = new ServletServerHttpResponse(this.servletResponse);
419473
ResponseBodyEmitter emitter = handleValue(Flux.empty(), Flux.class, forClass(String.class));
@@ -541,6 +595,47 @@ public void send(Set<ResponseBodyEmitter.DataWithMediaType> items) throws IOExce
541595
}
542596
}
543597

598+
599+
private static class ContextEmitterHandler extends EmitterHandler {
600+
601+
private final AtomicInteger count = new AtomicInteger();
602+
603+
private HttpServletRequest savedRequest;
604+
605+
public HttpServletRequest getSavedRequest() {
606+
return this.savedRequest;
607+
}
608+
609+
@Override
610+
public void send(Object data, MediaType mediaType) throws IOException {
611+
saveRequest();
612+
super.send(data, mediaType);
613+
this.count.addAndGet(1);
614+
}
615+
616+
@Override
617+
public void send(Set<ResponseBodyEmitter.DataWithMediaType> items) throws IOException {
618+
saveRequest();
619+
for (ResponseBodyEmitter.DataWithMediaType item : items) {
620+
super.send(item.getData(), item.getMediaType());
621+
}
622+
this.count.addAndGet(1);
623+
}
624+
625+
private void saveRequest() {
626+
RequestAttributes attributes = RequestContextHolder.currentRequestAttributes();
627+
this.savedRequest = ((ServletRequestAttributes) attributes).getRequest();
628+
}
629+
630+
public void awaitMessageCount(int count) throws InterruptedException {
631+
for (int i = 0; i < 10 && this.count.get() < count; i++) {
632+
Thread.sleep(10);
633+
}
634+
assertThat(this.count.get()).isGreaterThanOrEqualTo(count);
635+
}
636+
}
637+
638+
544639
private static class Bar {
545640

546641
private final String value;

0 commit comments

Comments
 (0)