Skip to content

Commit 5af9a8e

Browse files
committed
Ensure WebSocketHttpRequestHandler writes headers
Closes gh-23179
1 parent 6e79dcd commit 5af9a8e

File tree

3 files changed

+163
-34
lines changed

3 files changed

+163
-34
lines changed

spring-websocket/src/main/java/org/springframework/web/socket/server/support/WebSocketHttpRequestHandler.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2018 the original author or authors.
2+
* Copyright 2002-2019 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -166,7 +166,6 @@ public void handleRequest(HttpServletRequest servletRequest, HttpServletResponse
166166
}
167167
this.handshakeHandler.doHandshake(request, response, this.wsHandler, attributes);
168168
chain.applyAfterHandshake(request, response, null);
169-
response.close();
170169
}
171170
catch (HandshakeFailureException ex) {
172171
failure = ex;
@@ -177,8 +176,10 @@ public void handleRequest(HttpServletRequest servletRequest, HttpServletResponse
177176
finally {
178177
if (failure != null) {
179178
chain.applyAfterHandshake(request, response, failure);
179+
response.close();
180180
throw failure;
181181
}
182+
response.close();
182183
}
183184
}
184185

spring-websocket/src/test/java/org/springframework/web/socket/server/DefaultHandshakeHandlerTests.java

Lines changed: 19 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2018 the original author or authors.
2+
* Copyright 2002-2019 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -26,6 +26,7 @@
2626
import org.mockito.Mock;
2727
import org.mockito.MockitoAnnotations;
2828

29+
import org.springframework.http.HttpHeaders;
2930
import org.springframework.web.socket.AbstractHttpRequestTests;
3031
import org.springframework.web.socket.SubProtocolCapable;
3132
import org.springframework.web.socket.WebSocketExtension;
@@ -62,14 +63,9 @@ public void setup() {
6263
public void supportedSubProtocols() {
6364
this.handshakeHandler.setSupportedProtocols("stomp", "mqtt");
6465
given(this.upgradeStrategy.getSupportedVersions()).willReturn(new String[] {"13"});
65-
this.servletRequest.setMethod("GET");
6666

67-
WebSocketHttpHeaders headers = new WebSocketHttpHeaders(this.request.getHeaders());
68-
headers.setUpgrade("WebSocket");
69-
headers.setConnection("Upgrade");
70-
headers.setSecWebSocketVersion("13");
71-
headers.setSecWebSocketKey("82/ZS2YHjEnUN97HLL8tbw==");
72-
headers.setSecWebSocketProtocol("STOMP");
67+
this.servletRequest.setMethod("GET");
68+
initHeaders(this.request.getHeaders()).setSecWebSocketProtocol("STOMP");
7369

7470
WebSocketHandler handler = new TextWebSocketHandler();
7571
Map<String, Object> attributes = Collections.emptyMap();
@@ -88,16 +84,10 @@ public void supportedExtensions() {
8884
given(this.upgradeStrategy.getSupportedExtensions(this.request)).willReturn(Collections.singletonList(extension1));
8985

9086
this.servletRequest.setMethod("GET");
91-
92-
WebSocketHttpHeaders headers = new WebSocketHttpHeaders(this.request.getHeaders());
93-
headers.setUpgrade("WebSocket");
94-
headers.setConnection("Upgrade");
95-
headers.setSecWebSocketVersion("13");
96-
headers.setSecWebSocketKey("82/ZS2YHjEnUN97HLL8tbw==");
97-
headers.setSecWebSocketExtensions(Arrays.asList(extension1, extension2));
87+
initHeaders(this.request.getHeaders()).setSecWebSocketExtensions(Arrays.asList(extension1, extension2));
9888

9989
WebSocketHandler handler = new TextWebSocketHandler();
100-
Map<String, Object> attributes = Collections.<String, Object>emptyMap();
90+
Map<String, Object> attributes = Collections.emptyMap();
10191
this.handshakeHandler.doHandshake(this.request, this.response, handler, attributes);
10292

10393
verify(this.upgradeStrategy).upgrade(this.request, this.response, null,
@@ -109,16 +99,10 @@ public void subProtocolCapableHandler() {
10999
given(this.upgradeStrategy.getSupportedVersions()).willReturn(new String[] {"13"});
110100

111101
this.servletRequest.setMethod("GET");
112-
113-
WebSocketHttpHeaders headers = new WebSocketHttpHeaders(this.request.getHeaders());
114-
headers.setUpgrade("WebSocket");
115-
headers.setConnection("Upgrade");
116-
headers.setSecWebSocketVersion("13");
117-
headers.setSecWebSocketKey("82/ZS2YHjEnUN97HLL8tbw==");
118-
headers.setSecWebSocketProtocol("v11.stomp");
102+
initHeaders(this.request.getHeaders()).setSecWebSocketProtocol("v11.stomp");
119103

120104
WebSocketHandler handler = new SubProtocolCapableHandler("v12.stomp", "v11.stomp");
121-
Map<String, Object> attributes = Collections.<String, Object>emptyMap();
105+
Map<String, Object> attributes = Collections.emptyMap();
122106
this.handshakeHandler.doHandshake(this.request, this.response, handler, attributes);
123107

124108
verify(this.upgradeStrategy).upgrade(this.request, this.response, "v11.stomp",
@@ -130,22 +114,25 @@ public void subProtocolCapableHandlerNoMatch() {
130114
given(this.upgradeStrategy.getSupportedVersions()).willReturn(new String[] {"13"});
131115

132116
this.servletRequest.setMethod("GET");
133-
134-
WebSocketHttpHeaders headers = new WebSocketHttpHeaders(this.request.getHeaders());
135-
headers.setUpgrade("WebSocket");
136-
headers.setConnection("Upgrade");
137-
headers.setSecWebSocketVersion("13");
138-
headers.setSecWebSocketKey("82/ZS2YHjEnUN97HLL8tbw==");
139-
headers.setSecWebSocketProtocol("v10.stomp");
117+
initHeaders(this.request.getHeaders()).setSecWebSocketProtocol("v10.stomp");
140118

141119
WebSocketHandler handler = new SubProtocolCapableHandler("v12.stomp", "v11.stomp");
142-
Map<String, Object> attributes = Collections.<String, Object>emptyMap();
120+
Map<String, Object> attributes = Collections.emptyMap();
143121
this.handshakeHandler.doHandshake(this.request, this.response, handler, attributes);
144122

145123
verify(this.upgradeStrategy).upgrade(this.request, this.response, null,
146124
Collections.emptyList(), null, handler, attributes);
147125
}
148126

127+
private WebSocketHttpHeaders initHeaders(HttpHeaders httpHeaders) {
128+
WebSocketHttpHeaders headers = new WebSocketHttpHeaders(httpHeaders);
129+
headers.setUpgrade("WebSocket");
130+
headers.setConnection("Upgrade");
131+
headers.setSecWebSocketVersion("13");
132+
headers.setSecWebSocketKey("82/ZS2YHjEnUN97HLL8tbw==");
133+
return headers;
134+
}
135+
149136

150137
private static class SubProtocolCapableHandler extends TextWebSocketHandler implements SubProtocolCapable {
151138

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
/*
2+
* Copyright 2002-2019 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.web.socket.server.support;
17+
18+
import java.io.IOException;
19+
import java.util.Collections;
20+
import java.util.Map;
21+
import javax.servlet.ServletException;
22+
23+
import org.junit.Before;
24+
import org.junit.Test;
25+
26+
import org.springframework.http.server.ServerHttpRequest;
27+
import org.springframework.http.server.ServerHttpResponse;
28+
import org.springframework.mock.web.test.MockHttpServletRequest;
29+
import org.springframework.mock.web.test.MockHttpServletResponse;
30+
import org.springframework.web.socket.WebSocketHandler;
31+
import org.springframework.web.socket.server.HandshakeFailureException;
32+
import org.springframework.web.socket.server.HandshakeHandler;
33+
import org.springframework.web.socket.server.HandshakeInterceptor;
34+
35+
import static org.junit.Assert.assertEquals;
36+
import static org.junit.Assert.assertSame;
37+
import static org.junit.Assert.fail;
38+
import static org.mockito.ArgumentMatchers.any;
39+
import static org.mockito.Mockito.mock;
40+
import static org.mockito.Mockito.verify;
41+
import static org.mockito.Mockito.verifyNoMoreInteractions;
42+
import static org.mockito.Mockito.when;
43+
44+
/**
45+
* Unit tests for {@link WebSocketHttpRequestHandler}.
46+
* @author Rossen Stoyanchev
47+
* @since 5.1.9
48+
*/
49+
public class WebSocketHttpRequestHandlerTests {
50+
51+
private HandshakeHandler handshakeHandler;
52+
53+
private WebSocketHttpRequestHandler requestHandler;
54+
55+
private MockHttpServletResponse response;
56+
57+
58+
@Before
59+
public void setUp() {
60+
this.handshakeHandler = mock(HandshakeHandler.class);
61+
this.requestHandler = new WebSocketHttpRequestHandler(mock(WebSocketHandler.class), this.handshakeHandler);
62+
this.response = new MockHttpServletResponse();
63+
}
64+
65+
66+
@Test
67+
public void success() throws ServletException, IOException {
68+
TestInterceptor interceptor = new TestInterceptor(true);
69+
this.requestHandler.setHandshakeInterceptors(Collections.singletonList(interceptor));
70+
this.requestHandler.handleRequest(new MockHttpServletRequest(), this.response);
71+
72+
verify(this.handshakeHandler).doHandshake(any(), any(), any(), any());
73+
assertEquals("headerValue", this.response.getHeader("headerName"));
74+
}
75+
76+
@Test
77+
public void failure() throws ServletException, IOException {
78+
TestInterceptor interceptor = new TestInterceptor(true);
79+
this.requestHandler.setHandshakeInterceptors(Collections.singletonList(interceptor));
80+
81+
when(this.handshakeHandler.doHandshake(any(), any(), any(), any()))
82+
.thenThrow(new IllegalStateException("bad state"));
83+
84+
try {
85+
this.requestHandler.handleRequest(new MockHttpServletRequest(), this.response);
86+
fail();
87+
}
88+
catch (HandshakeFailureException ex) {
89+
assertSame(ex, interceptor.getException());
90+
assertEquals("headerValue", this.response.getHeader("headerName"));
91+
assertEquals("exceptionHeaderValue", this.response.getHeader("exceptionHeaderName"));
92+
}
93+
}
94+
95+
@Test // gh-23179
96+
public void handshakeNotAllowed() throws ServletException, IOException {
97+
TestInterceptor interceptor = new TestInterceptor(false);
98+
this.requestHandler.setHandshakeInterceptors(Collections.singletonList(interceptor));
99+
100+
this.requestHandler.handleRequest(new MockHttpServletRequest(), this.response);
101+
102+
verifyNoMoreInteractions(this.handshakeHandler);
103+
assertEquals("headerValue", this.response.getHeader("headerName"));
104+
}
105+
106+
107+
private static class TestInterceptor implements HandshakeInterceptor {
108+
109+
private final boolean allowHandshake;
110+
111+
private Exception exception;
112+
113+
114+
private TestInterceptor(boolean allowHandshake) {
115+
this.allowHandshake = allowHandshake;
116+
}
117+
118+
119+
public Exception getException() {
120+
return this.exception;
121+
}
122+
123+
124+
@Override
125+
public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response,
126+
WebSocketHandler wsHandler, Map<String, Object> attributes) {
127+
128+
response.getHeaders().add("headerName", "headerValue");
129+
return this.allowHandshake;
130+
}
131+
132+
@Override
133+
public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response,
134+
WebSocketHandler wsHandler, Exception exception) {
135+
136+
response.getHeaders().add("exceptionHeaderName", "exceptionHeaderValue");
137+
this.exception = exception;
138+
}
139+
}
140+
141+
}

0 commit comments

Comments
 (0)