Skip to content

Commit b0a1a11

Browse files
committed
Merge branch '6.2.x'
2 parents e49d2da + 7a0fe7d commit b0a1a11

File tree

8 files changed

+111
-43
lines changed

8 files changed

+111
-43
lines changed

spring-messaging/src/main/java/org/springframework/messaging/simp/user/UserDestinationMessageHandler.java

+5-8
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2024 the original author or authors.
2+
* Copyright 2002-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -20,7 +20,6 @@
2020
import java.util.Iterator;
2121
import java.util.List;
2222
import java.util.Map;
23-
import java.util.Set;
2423
import java.util.concurrent.ConcurrentHashMap;
2524

2625
import org.apache.commons.logging.Log;
@@ -275,12 +274,10 @@ public MessageSendingOperations<String> getMessagingTemplate() {
275274
return this.messagingTemplate;
276275
}
277276

278-
public void send(UserDestinationResult destinationResult, Message<?> message) throws MessagingException {
279-
Set<String> sessionIds = destinationResult.getSessionIds();
280-
Iterator<String> itr = (sessionIds != null ? sessionIds.iterator() : null);
281-
282-
for (String target : destinationResult.getTargetDestinations()) {
283-
String sessionId = (itr != null ? itr.next() : null);
277+
public void send(UserDestinationResult result, Message<?> message) throws MessagingException {
278+
Iterator<String> itr = result.getSessionIds().iterator();
279+
for (String target : result.getTargetDestinations()) {
280+
String sessionId = (itr.hasNext() ? itr.next() : null);
284281
getTemplateToUse(sessionId).send(target, message);
285282
}
286283
}

spring-messaging/src/main/java/org/springframework/messaging/simp/user/UserDestinationResult.java

+7-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2023 the original author or authors.
2+
* Copyright 2002-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -44,7 +44,11 @@ public class UserDestinationResult {
4444
private final Set<String> sessionIds;
4545

4646

47-
public UserDestinationResult(String sourceDestination, Set<String> targetDestinations,
47+
/**
48+
* Main constructor.
49+
*/
50+
public UserDestinationResult(
51+
String sourceDestination, Set<String> targetDestinations,
4852
String subscribeDestination, @Nullable String user) {
4953

5054
this(sourceDestination, targetDestinations, subscribeDestination, user, null);
@@ -113,7 +117,7 @@ public String getSubscribeDestination() {
113117
/**
114118
* Return the session id for the targetDestination.
115119
*/
116-
public @Nullable Set<String> getSessionIds() {
120+
public Set<String> getSessionIds() {
117121
return this.sessionIds;
118122
}
119123

spring-messaging/src/test/java/org/springframework/messaging/simp/user/UserDestinationMessageHandlerTests.java

+22-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2024 the original author or authors.
2+
* Copyright 2002-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -17,6 +17,7 @@
1717
package org.springframework.messaging.simp.user;
1818

1919
import java.nio.charset.StandardCharsets;
20+
import java.util.Set;
2021

2122
import org.jspecify.annotations.Nullable;
2223
import org.junit.jupiter.api.Test;
@@ -98,6 +99,26 @@ void handleMessage() {
9899
assertThat(accessor.getFirstNativeHeader(ORIGINAL_DESTINATION)).isEqualTo("/user/queue/foo");
99100
}
100101

102+
@Test
103+
@SuppressWarnings("rawtypes")
104+
void handleMessageWithoutSessionIds() {
105+
UserDestinationResolver resolver = mock();
106+
Message message = createWith(SimpMessageType.MESSAGE, "joe", null, "/user/joe/queue/foo");
107+
UserDestinationResult result = new UserDestinationResult("/queue/foo-user123", Set.of("/queue/foo-user123"), "/user/queue/foo", "joe");
108+
given(resolver.resolveDestination(message)).willReturn(result);
109+
110+
given(this.brokerChannel.send(Mockito.any(Message.class))).willReturn(true);
111+
UserDestinationMessageHandler handler = new UserDestinationMessageHandler(new StubMessageChannel(), this.brokerChannel, resolver);
112+
handler.handleMessage(message);
113+
114+
ArgumentCaptor<Message> captor = ArgumentCaptor.forClass(Message.class);
115+
Mockito.verify(this.brokerChannel).send(captor.capture());
116+
117+
SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.wrap(captor.getValue());
118+
assertThat(accessor.getDestination()).isEqualTo("/queue/foo-user123");
119+
assertThat(accessor.getFirstNativeHeader(ORIGINAL_DESTINATION)).isEqualTo("/user/queue/foo");
120+
}
121+
101122
@Test
102123
@SuppressWarnings("rawtypes")
103124
void handleMessageWithoutActiveSession() {

spring-web/src/main/java/org/springframework/web/context/request/async/WebAsyncManager.java

+10-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2024 the original author or authors.
2+
* Copyright 2002-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -34,6 +34,7 @@
3434
import org.springframework.util.Assert;
3535
import org.springframework.web.context.request.RequestAttributes;
3636
import org.springframework.web.context.request.async.DeferredResult.DeferredResultHandler;
37+
import org.springframework.web.util.DisconnectedClientHelper;
3738

3839
/**
3940
* The central class for managing asynchronous request processing, mainly intended
@@ -342,6 +343,10 @@ public void startCallableProcessing(final WebAsyncTask<?> webAsyncTask, Object..
342343
if (logger.isDebugEnabled()) {
343344
logger.debug("Servlet container error notification for " + formatUri(this.asyncWebRequest) + ": " + ex);
344345
}
346+
if (DisconnectedClientHelper.isClientDisconnectedException(ex)) {
347+
ex = new AsyncRequestNotUsableException(
348+
"Servlet container error notification for disconnected client", ex);
349+
}
345350
Object result = interceptorChain.triggerAfterError(this.asyncWebRequest, callable, ex);
346351
result = (result != CallableProcessingInterceptor.RESULT_NONE ? result : ex);
347352
setConcurrentResultAndDispatch(result);
@@ -434,6 +439,10 @@ public void startDeferredResultProcessing(
434439
if (logger.isDebugEnabled()) {
435440
logger.debug("Servlet container error notification for " + formatUri(this.asyncWebRequest));
436441
}
442+
if (DisconnectedClientHelper.isClientDisconnectedException(ex)) {
443+
ex = new AsyncRequestNotUsableException(
444+
"Servlet container error notification for disconnected client", ex);
445+
}
437446
try {
438447
interceptorChain.triggerAfterError(this.asyncWebRequest, deferredResult, ex);
439448
synchronized (WebAsyncManager.this) {

spring-web/src/test/java/org/springframework/web/context/request/async/WebAsyncManagerErrorTests.java

+33-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2024 the original author or authors.
2+
* Copyright 2002-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -16,6 +16,7 @@
1616

1717
package org.springframework.web.context.request.async;
1818

19+
import java.io.IOException;
1920
import java.util.concurrent.Callable;
2021

2122
import jakarta.servlet.AsyncEvent;
@@ -152,6 +153,22 @@ void startCallableProcessingAfterException() throws Exception {
152153
verify(interceptor).beforeConcurrentHandling(this.asyncWebRequest, callable);
153154
}
154155

156+
@Test // gh-34363
157+
void startCallableProcessingDisconnectedClient() throws Exception {
158+
StubCallable callable = new StubCallable();
159+
this.asyncManager.startCallableProcessing(callable);
160+
161+
IOException ex = new IOException("broken pipe");
162+
AsyncEvent event = new AsyncEvent(new MockAsyncContext(this.servletRequest, this.servletResponse), ex);
163+
this.asyncWebRequest.onError(event);
164+
165+
MockAsyncContext asyncContext = (MockAsyncContext) this.servletRequest.getAsyncContext();
166+
assertThat(this.asyncManager.hasConcurrentResult()).isTrue();
167+
assertThat(this.asyncManager.getConcurrentResult())
168+
.as("Disconnected client error not wrapped AsyncRequestNotUsableException")
169+
.isOfAnyClassIn(AsyncRequestNotUsableException.class);
170+
}
171+
155172
@Test
156173
void startDeferredResultProcessingErrorAndComplete() throws Exception {
157174

@@ -259,6 +276,21 @@ public <T> boolean handleError(NativeWebRequest request, DeferredResult<T> resul
259276
assertThat(((MockAsyncContext) this.servletRequest.getAsyncContext()).getDispatchedPath()).isEqualTo("/test");
260277
}
261278

279+
@Test // gh-34363
280+
void startDeferredResultProcessingDisconnectedClient() throws Exception {
281+
DeferredResult<Object> deferredResult = new DeferredResult<>();
282+
this.asyncManager.startDeferredResultProcessing(deferredResult);
283+
284+
IOException ex = new IOException("broken pipe");
285+
AsyncEvent event = new AsyncEvent(new MockAsyncContext(this.servletRequest, this.servletResponse), ex);
286+
this.asyncWebRequest.onError(event);
287+
288+
assertThat(this.asyncManager.hasConcurrentResult()).isTrue();
289+
assertThat(deferredResult.getResult())
290+
.as("Disconnected client error not wrapped AsyncRequestNotUsableException")
291+
.isOfAnyClassIn(AsyncRequestNotUsableException.class);
292+
}
293+
262294

263295
private static final class StubCallable implements Callable<Object> {
264296
@Override

spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/support/HandshakeWebSocketService.java

+13-11
Original file line numberDiff line numberDiff line change
@@ -198,23 +198,25 @@ public Mono<Void> handleRequest(ServerWebExchange exchange, WebSocketHandler han
198198
HttpMethod method = request.getMethod();
199199
HttpHeaders headers = request.getHeaders();
200200

201-
if (HttpMethod.GET != method && CONNECT_METHOD != method) {
201+
if (HttpMethod.GET != method && !CONNECT_METHOD.equals(method)) {
202202
return Mono.error(new MethodNotAllowedException(
203203
request.getMethod(), Set.of(HttpMethod.GET, CONNECT_METHOD)));
204204
}
205205

206-
if (!"WebSocket".equalsIgnoreCase(headers.getUpgrade())) {
207-
return handleBadRequest(exchange, "Invalid 'Upgrade' header: " + headers);
208-
}
206+
if (HttpMethod.GET == method) {
207+
if (!"WebSocket".equalsIgnoreCase(headers.getUpgrade())) {
208+
return handleBadRequest(exchange, "Invalid 'Upgrade' header: " + headers);
209+
}
209210

210-
List<String> connectionValue = headers.getConnection();
211-
if (!connectionValue.contains("Upgrade") && !connectionValue.contains("upgrade")) {
212-
return handleBadRequest(exchange, "Invalid 'Connection' header: " + headers);
213-
}
211+
List<String> connectionValue = headers.getConnection();
212+
if (!connectionValue.contains("Upgrade") && !connectionValue.contains("upgrade")) {
213+
return handleBadRequest(exchange, "Invalid 'Connection' header: " + headers);
214+
}
214215

215-
String key = headers.getFirst(SEC_WEBSOCKET_KEY);
216-
if (key == null) {
217-
return handleBadRequest(exchange, "Missing \"Sec-WebSocket-Key\" header");
216+
String key = headers.getFirst(SEC_WEBSOCKET_KEY);
217+
if (key == null) {
218+
return handleBadRequest(exchange, "Missing \"Sec-WebSocket-Key\" header");
219+
}
218220
}
219221

220222
String protocol = selectProtocol(headers, handler);

spring-websocket/src/main/java/org/springframework/web/socket/WebSocketHttpHeaders.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2023 the original author or authors.
2+
* Copyright 2002-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -149,7 +149,7 @@ public void setSecWebSocketProtocol(List<String> secWebSocketProtocols) {
149149
}
150150

151151
/**
152-
* Returns the value of the {@code Sec-WebSocket-Key} header.
152+
* Returns the value of the {@code Sec-WebSocket-Protocol} header.
153153
* @return the value of the header
154154
*/
155155
public List<String> getSecWebSocketProtocol() {

spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractHandshakeHandler.java

+19-16
Original file line numberDiff line numberDiff line change
@@ -175,21 +175,32 @@ public final boolean doHandshake(ServerHttpRequest request, ServerHttpResponse r
175175
}
176176
try {
177177
HttpMethod httpMethod = request.getMethod();
178-
if (HttpMethod.GET != httpMethod && CONNECT_METHOD != httpMethod) {
178+
if (HttpMethod.GET != httpMethod && !CONNECT_METHOD.equals(httpMethod)) {
179179
response.setStatusCode(HttpStatus.METHOD_NOT_ALLOWED);
180180
response.getHeaders().setAllow(Set.of(HttpMethod.GET, CONNECT_METHOD));
181181
if (logger.isErrorEnabled()) {
182182
logger.error("Handshake failed due to unexpected HTTP method: " + httpMethod);
183183
}
184184
return false;
185185
}
186-
if (!"WebSocket".equalsIgnoreCase(headers.getUpgrade())) {
187-
handleInvalidUpgradeHeader(request, response);
188-
return false;
189-
}
190-
if (!headers.getConnection().contains("Upgrade") && !headers.getConnection().contains("upgrade")) {
191-
handleInvalidConnectHeader(request, response);
192-
return false;
186+
if (HttpMethod.GET == httpMethod) {
187+
if (!"WebSocket".equalsIgnoreCase(headers.getUpgrade())) {
188+
handleInvalidUpgradeHeader(request, response);
189+
return false;
190+
}
191+
List<String> connectionValue = headers.getConnection();
192+
if (!connectionValue.contains("Upgrade") && !connectionValue.contains("upgrade")) {
193+
handleInvalidConnectHeader(request, response);
194+
return false;
195+
}
196+
String key = headers.getSecWebSocketKey();
197+
if (key == null) {
198+
if (logger.isErrorEnabled()) {
199+
logger.error("Missing \"Sec-WebSocket-Key\" header");
200+
}
201+
response.setStatusCode(HttpStatus.BAD_REQUEST);
202+
return false;
203+
}
193204
}
194205
if (!isWebSocketVersionSupported(headers)) {
195206
handleWebSocketVersionNotSupported(request, response);
@@ -199,14 +210,6 @@ public final boolean doHandshake(ServerHttpRequest request, ServerHttpResponse r
199210
response.setStatusCode(HttpStatus.FORBIDDEN);
200211
return false;
201212
}
202-
String wsKey = headers.getSecWebSocketKey();
203-
if (wsKey == null) {
204-
if (logger.isErrorEnabled()) {
205-
logger.error("Missing \"Sec-WebSocket-Key\" header");
206-
}
207-
response.setStatusCode(HttpStatus.BAD_REQUEST);
208-
return false;
209-
}
210213
}
211214
catch (IOException ex) {
212215
throw new HandshakeFailureException(

0 commit comments

Comments
 (0)