Skip to content

Commit 44ca9b0

Browse files
committed
SSE handlers support keep-alive
Closes gh-1048
1 parent a5ec819 commit 44ca9b0

File tree

4 files changed

+235
-39
lines changed

4 files changed

+235
-39
lines changed

spring-graphql/src/main/java/org/springframework/graphql/server/webflux/GraphQlSseHandler.java

Lines changed: 66 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2020-2024 the original author or authors.
2+
* Copyright 2020-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -51,31 +51,54 @@
5151
*/
5252
public class GraphQlSseHandler extends AbstractGraphQlHttpHandler {
5353

54-
private static final Mono<ServerSentEvent<Map<String, Object>>> COMPLETE_EVENT = Mono.just(
55-
ServerSentEvent.<Map<String, Object>>builder(Collections.emptyMap()).event("complete").build());
54+
private static final ServerSentEvent<Map<String, Object>> HEARTBEAT_EVENT =
55+
ServerSentEvent.<Map<String, Object>>builder().comment("").build();
56+
57+
private static final Mono<ServerSentEvent<Map<String, Object>>> COMPLETE_EVENT_MONO =
58+
Mono.just(ServerSentEvent.<Map<String, Object>>builder(Collections.emptyMap()).event("complete").build());
5659

5760
@Nullable
5861
private final Duration timeout;
5962

63+
@Nullable
64+
private final Duration keepAliveDuration;
65+
6066

6167
/**
62-
* Constructor with the handler to delegate to, and no timeout by default,
68+
* Basic constructor with the handler to delegate to, and no timeout by default,
6369
* which results in never timing out.
6470
* @param graphQlHandler the handler to delegate to
6571
*/
6672
public GraphQlSseHandler(WebGraphQlHandler graphQlHandler) {
67-
this(graphQlHandler, null);
73+
this(graphQlHandler, null, null);
6874
}
6975

7076
/**
71-
* Variant constructor with a timeout to use for SSE subscriptions.
77+
* Constructor with a timeout on how long to wait for the application to return
78+
* the {@link ServerResponse} that will start the stream.
7279
* @param graphQlHandler the handler to delegate to
73-
* @param timeout the timeout value to use or {@code null} to never time out
80+
* @param timeout the timeout value, or {@code null} to never time out
7481
* @since 1.3.3
7582
*/
7683
public GraphQlSseHandler(WebGraphQlHandler graphQlHandler, @Nullable Duration timeout) {
84+
this(graphQlHandler, null, null);
85+
}
86+
87+
/**
88+
* Constructor with a keep-alive duration that determines how frequently to
89+
* heartbeats during periods of inactivity.
90+
* @param graphQlHandler the handler to delegate to
91+
* @param timeout the timeout value to use or {@code null} to never time out
92+
* @param keepAliveDuration how frequently to send empty comment messages
93+
* when no other messages are sent
94+
* @since 1.4.0
95+
*/
96+
public GraphQlSseHandler(
97+
WebGraphQlHandler graphQlHandler, @Nullable Duration timeout, @Nullable Duration keepAliveDuration) {
98+
7799
super(graphQlHandler, null);
78100
this.timeout = timeout;
101+
this.keepAliveDuration = keepAliveDuration;
79102
}
80103

81104

@@ -104,7 +127,12 @@ protected Mono<ServerResponse> prepareResponse(ServerRequest request, WebGraphQl
104127

105128
Flux<ServerSentEvent<Map<String, Object>>> sseFlux =
106129
resultFlux.map((event) -> ServerSentEvent.builder(event).event("next").build())
107-
.concatWith(COMPLETE_EVENT);
130+
.concatWith(COMPLETE_EVENT_MONO);
131+
132+
if (this.keepAliveDuration != null) {
133+
KeepAliveHandler handler = new KeepAliveHandler(this.keepAliveDuration);
134+
sseFlux = handler.compose(sseFlux);
135+
}
108136

109137
Mono<ServerResponse> responseMono = ServerResponse.ok()
110138
.contentType(MediaType.TEXT_EVENT_STREAM)
@@ -124,4 +152,34 @@ private Mono<Map<String, Object>> exceptionToResultMap(Throwable ex) {
124152
.toSpecification());
125153
}
126154

155+
156+
private static final class KeepAliveHandler {
157+
158+
private final Duration keepAliveDuration;
159+
160+
private boolean eventSent;
161+
162+
KeepAliveHandler(Duration keepAliveDuration) {
163+
this.keepAliveDuration = keepAliveDuration;
164+
}
165+
166+
public Flux<ServerSentEvent<Map<String, Object>>> compose(Flux<ServerSentEvent<Map<String, Object>>> flux) {
167+
return flux.doOnNext((event) -> this.eventSent = true)
168+
.mergeWith(getKeepAliveFlux())
169+
.takeUntil((sse) -> "complete".equals(sse.event()));
170+
}
171+
172+
private Flux<ServerSentEvent<Map<String, Object>>> getKeepAliveFlux() {
173+
return Flux.interval(this.keepAliveDuration, this.keepAliveDuration)
174+
.filter((aLong) -> !checkEventSentAndClear())
175+
.map((aLong) -> HEARTBEAT_EVENT);
176+
}
177+
178+
private boolean checkEventSentAndClear() {
179+
boolean result = this.eventSent;
180+
this.eventSent = false;
181+
return result;
182+
}
183+
}
184+
127185
}

spring-graphql/src/main/java/org/springframework/graphql/server/webmvc/GraphQlSseHandler.java

Lines changed: 87 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2020-2024 the original author or authors.
2+
* Copyright 2020-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -18,6 +18,7 @@
1818

1919
import java.io.IOException;
2020
import java.time.Duration;
21+
import java.util.LinkedHashMap;
2122
import java.util.Map;
2223
import java.util.function.Consumer;
2324

@@ -29,6 +30,7 @@
2930
import reactor.core.publisher.BaseSubscriber;
3031
import reactor.core.publisher.Flux;
3132
import reactor.core.publisher.Mono;
33+
import reactor.core.publisher.Sinks;
3234

3335
import org.springframework.graphql.execution.SubscriptionPublisherException;
3436
import org.springframework.graphql.server.WebGraphQlHandler;
@@ -50,17 +52,23 @@
5052
*/
5153
public class GraphQlSseHandler extends AbstractGraphQlHttpHandler {
5254

55+
private static final Map<String, Object> HEARTBEAT_MAP = new LinkedHashMap<>(0);
56+
57+
5358
@Nullable
5459
private final Duration timeout;
5560

61+
@Nullable
62+
private final Duration keepAliveDuration;
63+
5664

5765
/**
5866
* Constructor with the handler to delegate to, and no timeout,
5967
* i.e. relying on underlying Server async request timeout.
6068
* @param graphQlHandler the handler to delegate to
6169
*/
6270
public GraphQlSseHandler(WebGraphQlHandler graphQlHandler) {
63-
this(graphQlHandler, null);
71+
this(graphQlHandler, null, null);
6472
}
6573

6674
/**
@@ -71,8 +79,24 @@ public GraphQlSseHandler(WebGraphQlHandler graphQlHandler) {
7179
* @since 1.3.3
7280
*/
7381
public GraphQlSseHandler(WebGraphQlHandler graphQlHandler, @Nullable Duration timeout) {
82+
this(graphQlHandler, timeout, null);
83+
}
84+
85+
/**
86+
* Variant constructor with a timeout to use for SSE subscriptions.
87+
* @param graphQlHandler the handler to delegate to
88+
* @param timeout the timeout value to set on
89+
* @param keepAliveDuration how frequently to send empty comment messages
90+
* when no other messages are sent
91+
* {@link org.springframework.web.context.request.async.AsyncWebRequest#setTimeout(Long)}
92+
* @since 1.4.0
93+
*/
94+
public GraphQlSseHandler(
95+
WebGraphQlHandler graphQlHandler, @Nullable Duration timeout, @Nullable Duration keepAliveDuration) {
96+
7497
super(graphQlHandler, null);
7598
this.timeout = timeout;
99+
this.keepAliveDuration = keepAliveDuration;
76100
}
77101

78102

@@ -101,8 +125,8 @@ protected ServerResponse prepareResponse(
101125
});
102126

103127
return ((this.timeout != null) ?
104-
ServerResponse.sse(SseSubscriber.connect(resultFlux), this.timeout) :
105-
ServerResponse.sse(SseSubscriber.connect(resultFlux)));
128+
ServerResponse.sse(SseSubscriber.connect(resultFlux, this.keepAliveDuration), this.timeout) :
129+
ServerResponse.sse(SseSubscriber.connect(resultFlux, this.keepAliveDuration)));
106130
}
107131

108132

@@ -120,6 +144,10 @@ private SseSubscriber(ServerResponse.SseBuilder sseBuilder) {
120144

121145
@Override
122146
protected void hookOnNext(Map<String, Object> value) {
147+
if (value == HEARTBEAT_MAP) {
148+
sendHeartbeat();
149+
return;
150+
}
123151
sendNext(value);
124152
}
125153

@@ -133,6 +161,18 @@ private void sendNext(Map<String, Object> value) {
133161
}
134162
}
135163

164+
private void sendHeartbeat() {
165+
try {
166+
// Currently, comment cannot be empty:
167+
// https://github.com/spring-projects/spring-framework/issues/34608
168+
this.sseBuilder.comment(" ");
169+
this.sseBuilder.send();
170+
}
171+
catch (IOException exception) {
172+
cancelWithError(exception);
173+
}
174+
}
175+
136176
private void cancelWithError(Throwable ex) {
137177
this.cancel();
138178
this.sseBuilder.error(ex);
@@ -169,12 +209,53 @@ protected void hookOnComplete() {
169209
sendComplete();
170210
}
171211

172-
static Consumer<ServerResponse.SseBuilder> connect(Flux<Map<String, Object>> resultFlux) {
212+
static Consumer<ServerResponse.SseBuilder> connect(
213+
Flux<Map<String, Object>> resultFlux, @Nullable Duration keepAliveDuration) {
214+
173215
return (sseBuilder) -> {
174216
SseSubscriber subscriber = new SseSubscriber(sseBuilder);
175-
resultFlux.subscribe(subscriber);
217+
if (keepAliveDuration != null) {
218+
KeepAliveHandler handler = new KeepAliveHandler(keepAliveDuration);
219+
handler.compose(resultFlux).subscribe(subscriber);
220+
}
221+
else {
222+
resultFlux.subscribe(subscriber);
223+
}
176224
};
177225
}
178226
}
179227

228+
229+
private static final class KeepAliveHandler {
230+
231+
private final Duration keepAliveDuration;
232+
233+
private boolean eventSent;
234+
235+
private final Sinks.Empty<Void> completionSink = Sinks.empty();
236+
237+
KeepAliveHandler(Duration keepAliveDuration) {
238+
this.keepAliveDuration = keepAliveDuration;
239+
}
240+
241+
public Flux<Map<String, Object>> compose(Flux<Map<String, Object>> flux) {
242+
return flux.doOnNext((event) -> this.eventSent = true)
243+
.doOnComplete(this.completionSink::tryEmitEmpty)
244+
.mergeWith(getKeepAliveFlux())
245+
.takeUntilOther(this.completionSink.asMono());
246+
}
247+
248+
private Flux<Map<String, Object>> getKeepAliveFlux() {
249+
return Flux.interval(this.keepAliveDuration, this.keepAliveDuration)
250+
.filter((aLong) -> !checkEventSentAndClear())
251+
.map((aLong) -> HEARTBEAT_MAP);
252+
}
253+
254+
private boolean checkEventSentAndClear() {
255+
boolean result = this.eventSent;
256+
this.eventSent = false;
257+
return result;
258+
}
259+
}
260+
180261
}

spring-graphql/src/test/java/org/springframework/graphql/server/webflux/GraphQlSseHandlerTests.java

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2020-2024 the original author or authors.
2+
* Copyright 2020-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -17,6 +17,7 @@
1717
package org.springframework.graphql.server.webflux;
1818

1919

20+
import java.time.Duration;
2021
import java.util.Collections;
2122
import java.util.List;
2223

@@ -28,6 +29,7 @@
2829
import org.springframework.graphql.BookSource;
2930
import org.springframework.graphql.GraphQlRequest;
3031
import org.springframework.graphql.GraphQlSetup;
32+
import org.springframework.graphql.server.WebGraphQlHandler;
3133
import org.springframework.graphql.server.support.SerializableGraphQlRequest;
3234
import org.springframework.http.MediaType;
3335
import org.springframework.http.codec.HttpMessageWriter;
@@ -67,7 +69,7 @@ class GraphQlSseHandlerTests {
6769
@Test
6870
void shouldRejectQueryOperations() {
6971
SerializableGraphQlRequest request = initRequest("{ bookById(id: 42) {name} }");
70-
GraphQlSseHandler handler = createHandler(SEARCH_DATA_FETCHER);
72+
GraphQlSseHandler handler = createSseHandler(SEARCH_DATA_FETCHER);
7173
MockServerHttpResponse response = handleRequest(this.httpRequest, handler, request);
7274

7375
assertThat(response.getHeaders().getContentType().isCompatibleWith(MediaType.TEXT_EVENT_STREAM)).isTrue();
@@ -87,7 +89,7 @@ void shouldWriteMultipleEventsForSubscription() {
8789
SerializableGraphQlRequest request = initRequest(
8890
"subscription TestSubscription { bookSearch(author:\"Orwell\") { id name } }");
8991

90-
GraphQlSseHandler handler = createHandler(SEARCH_DATA_FETCHER);
92+
GraphQlSseHandler handler = createSseHandler(SEARCH_DATA_FETCHER);
9193
MockServerHttpResponse response = handleRequest(this.httpRequest, handler, request);
9294

9395
assertThat(response.getHeaders().getContentType().isCompatibleWith(MediaType.TEXT_EVENT_STREAM)).isTrue();
@@ -113,7 +115,7 @@ void shouldWriteEventsAndTerminalError() {
113115
DataFetcher<?> errorDataFetcher = env ->
114116
Flux.just(BookSource.getBook(1L)).concatWith(Flux.error(new IllegalStateException("test error")));
115117

116-
GraphQlSseHandler handler = createHandler(errorDataFetcher);
118+
GraphQlSseHandler handler = createSseHandler(errorDataFetcher);
117119
MockServerHttpResponse response = handleRequest(this.httpRequest, handler, request);
118120

119121
assertThat(response.getHeaders().getContentType().isCompatibleWith(MediaType.TEXT_EVENT_STREAM)).isTrue();
@@ -130,12 +132,40 @@ void shouldWriteEventsAndTerminalError() {
130132
""");
131133
}
132134

133-
private GraphQlSseHandler createHandler(DataFetcher<?> subscriptionDataFetcher) {
134-
return new GraphQlSseHandler(
135-
GraphQlSetup.schemaResource(BookSource.schema)
136-
.queryFetcher("bookById", (env) -> BookSource.getBookWithoutAuthor(1L))
137-
.subscriptionFetcher("bookSearch", subscriptionDataFetcher)
138-
.toWebGraphQlHandler());
135+
@Test
136+
void shouldSendKeepAlivePings() {
137+
SerializableGraphQlRequest request = initRequest(
138+
"subscription TestSubscription { bookSearch(author:\"Orwell\") { id name } }");
139+
140+
WebGraphQlHandler webGraphQlHandler = createWebGraphQlHandler(env -> Mono.delay(Duration.ofMillis(50)).then());
141+
GraphQlSseHandler handler = new GraphQlSseHandler(webGraphQlHandler, null, Duration.ofMillis(10));
142+
143+
assertThat(handleRequest(this.httpRequest, handler, request).getBodyAsString().block())
144+
.startsWith("""
145+
:
146+
147+
:
148+
149+
""")
150+
.endsWith("""
151+
:
152+
153+
event:complete
154+
data:{}
155+
156+
""");
157+
}
158+
159+
private GraphQlSseHandler createSseHandler(DataFetcher<?> subscriptionDataFetcher) {
160+
WebGraphQlHandler webGraphQlHandler = createWebGraphQlHandler(subscriptionDataFetcher);
161+
return new GraphQlSseHandler(webGraphQlHandler);
162+
}
163+
164+
private static WebGraphQlHandler createWebGraphQlHandler(DataFetcher<?> subscriptionDataFetcher) {
165+
return GraphQlSetup.schemaResource(BookSource.schema)
166+
.queryFetcher("bookById", (env) -> BookSource.getBookWithoutAuthor(1L))
167+
.subscriptionFetcher("bookSearch", subscriptionDataFetcher)
168+
.toWebGraphQlHandler();
139169
}
140170

141171
private static SerializableGraphQlRequest initRequest(String document) {

0 commit comments

Comments
 (0)