Skip to content

Commit 3d913b8

Browse files
committed
Merge branch '5.1.x'
2 parents 03a3423 + 4e6e47b commit 3d913b8

File tree

7 files changed

+242
-48
lines changed

7 files changed

+242
-48
lines changed

spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessageHeaderAccessor.java

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2017 the original author or authors.
2+
* Copyright 2002-2019 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -19,6 +19,7 @@
1919
import java.security.Principal;
2020
import java.util.List;
2121
import java.util.Map;
22+
import java.util.function.Consumer;
2223

2324
import org.springframework.lang.Nullable;
2425
import org.springframework.messaging.Message;
@@ -84,6 +85,10 @@ public class SimpMessageHeaderAccessor extends NativeMessageHeaderAccessor {
8485
public static final String IGNORE_ERROR = "simpIgnoreError";
8586

8687

88+
@Nullable
89+
private Consumer<Principal> userCallback;
90+
91+
8792
/**
8893
* A constructor for creating new message headers.
8994
* This constructor is protected. See factory methods in this and sub-classes.
@@ -171,6 +176,9 @@ public Map<String, Object> getSessionAttributes() {
171176

172177
public void setUser(@Nullable Principal principal) {
173178
setHeader(USER_HEADER, principal);
179+
if (this.userCallback != null) {
180+
this.userCallback.accept(principal);
181+
}
174182
}
175183

176184
/**
@@ -181,6 +189,18 @@ public Principal getUser() {
181189
return (Principal) getHeader(USER_HEADER);
182190
}
183191

192+
/**
193+
* Provide a callback to be invoked if and when {@link #setUser(Principal)}
194+
* is called. This is used internally on the inbound channel to detect
195+
* token-based authentications through an interceptor.
196+
* @param callback the callback to invoke
197+
* @since 5.1.9
198+
*/
199+
public void setUserChangeCallback(Consumer<Principal> callback) {
200+
Assert.notNull(callback, "'callback' is required");
201+
this.userCallback = this.userCallback != null ? this.userCallback.andThen(callback) : callback;
202+
}
203+
184204
@Override
185205
public String getShortLogMessage(Object payload) {
186206
if (getMessageType() == null) {

spring-messaging/src/test/java/org/springframework/messaging/simp/SimpMessageHeaderAccessorTests.java

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,14 @@
1616

1717
package org.springframework.messaging.simp;
1818

19+
import java.security.Principal;
1920
import java.util.Collections;
21+
import java.util.function.Consumer;
2022

2123
import org.junit.Test;
2224

2325
import static org.assertj.core.api.Assertions.assertThat;
26+
import static org.mockito.Mockito.mock;
2427

2528
/**
2629
* Unit tests for SimpMessageHeaderAccessor.
@@ -32,7 +35,8 @@ public class SimpMessageHeaderAccessorTests {
3235

3336
@Test
3437
public void getShortLogMessage() {
35-
assertThat(SimpMessageHeaderAccessor.create().getShortLogMessage("p")).isEqualTo("MESSAGE session=null payload=p");
38+
assertThat(SimpMessageHeaderAccessor.create().getShortLogMessage("p"))
39+
.isEqualTo("MESSAGE session=null payload=p");
3640
}
3741

3842
@Test
@@ -44,8 +48,9 @@ public void getLogMessageWithValuesSet() {
4448
accessor.setUser(new TestPrincipal("user"));
4549
accessor.setSessionAttributes(Collections.<String, Object>singletonMap("key", "value"));
4650

47-
assertThat(accessor.getShortLogMessage("p")).isEqualTo(("MESSAGE destination=/destination subscriptionId=subscription " +
48-
"session=session user=user attributes[1] payload=p"));
51+
assertThat(accessor.getShortLogMessage("p"))
52+
.isEqualTo(("MESSAGE destination=/destination subscriptionId=subscription " +
53+
"session=session user=user attributes[1] payload=p"));
4954
}
5055

5156
@Test
@@ -58,9 +63,41 @@ public void getDetailedLogMessageWithValuesSet() {
5863
accessor.setSessionAttributes(Collections.<String, Object>singletonMap("key", "value"));
5964
accessor.setNativeHeader("nativeKey", "nativeValue");
6065

61-
assertThat(accessor.getDetailedLogMessage("p")).isEqualTo(("MESSAGE destination=/destination subscriptionId=subscription " +
62-
"session=session user=user attributes={key=value} nativeHeaders=" +
63-
"{nativeKey=[nativeValue]} payload=p"));
66+
assertThat(accessor.getDetailedLogMessage("p"))
67+
.isEqualTo(("MESSAGE destination=/destination subscriptionId=subscription " +
68+
"session=session user=user attributes={key=value} nativeHeaders=" +
69+
"{nativeKey=[nativeValue]} payload=p"));
70+
}
71+
72+
@Test
73+
public void userChangeCallback() {
74+
UserCallback userCallback = new UserCallback();
75+
SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.create();
76+
accessor.setUserChangeCallback(userCallback);
77+
78+
Principal user1 = mock(Principal.class);
79+
accessor.setUser(user1);
80+
assertThat(userCallback.getUser()).isEqualTo(user1);
81+
82+
Principal user2 = mock(Principal.class);
83+
accessor.setUser(user2);
84+
assertThat(userCallback.getUser()).isEqualTo(user2);
85+
}
86+
87+
88+
private static class UserCallback implements Consumer<Principal> {
89+
90+
private Principal user;
91+
92+
93+
public Principal getUser() {
94+
return this.user;
95+
}
96+
97+
@Override
98+
public void accept(Principal principal) {
99+
this.user = principal;
100+
}
64101
}
65102

66103
}

spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -267,9 +267,19 @@ else if (webSocketMessage instanceof BinaryMessage) {
267267
MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
268268
Assert.state(headerAccessor != null, "No StompHeaderAccessor");
269269

270+
StompCommand command = headerAccessor.getCommand();
271+
boolean isConnect = StompCommand.CONNECT.equals(command) || StompCommand.STOMP.equals(command);
272+
270273
headerAccessor.setSessionId(session.getId());
271274
headerAccessor.setSessionAttributes(session.getAttributes());
272275
headerAccessor.setUser(getUser(session));
276+
if (isConnect) {
277+
headerAccessor.setUserChangeCallback(user -> {
278+
if (user != null && user != session.getPrincipal()) {
279+
this.stompAuthentications.put(session.getId(), user);
280+
}
281+
});
282+
}
273283
headerAccessor.setHeader(SimpMessageHeaderAccessor.HEART_BEAT_HEADER, headerAccessor.getHeartbeat());
274284
if (!detectImmutableMessageInterceptor(outputChannel)) {
275285
headerAccessor.setImmutable();
@@ -279,8 +289,6 @@ else if (webSocketMessage instanceof BinaryMessage) {
279289
logger.trace("From client: " + headerAccessor.getShortLogMessage(message.getPayload()));
280290
}
281291

282-
StompCommand command = headerAccessor.getCommand();
283-
boolean isConnect = StompCommand.CONNECT.equals(command) || StompCommand.STOMP.equals(command);
284292
if (isConnect) {
285293
this.stats.incrementConnectCount();
286294
}
@@ -293,12 +301,6 @@ else if (StompCommand.DISCONNECT.equals(command)) {
293301
boolean sent = outputChannel.send(message);
294302

295303
if (sent) {
296-
if (isConnect) {
297-
Principal user = headerAccessor.getUser();
298-
if (user != null && user != session.getPrincipal()) {
299-
this.stompAuthentications.put(session.getId(), user);
300-
}
301-
}
302304
if (this.eventPublisher != null) {
303305
Principal user = getUser(session);
304306
if (isConnect) {

spring-websocket/src/main/java/org/springframework/web/socket/server/support/WebSocketHttpRequestHandler.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2018 the original author or authors.
2+
* Copyright 2002-2019 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -166,7 +166,6 @@ public void handleRequest(HttpServletRequest servletRequest, HttpServletResponse
166166
}
167167
this.handshakeHandler.doHandshake(request, response, this.wsHandler, attributes);
168168
chain.applyAfterHandshake(request, response, null);
169-
response.close();
170169
}
171170
catch (HandshakeFailureException ex) {
172171
failure = ex;
@@ -177,8 +176,10 @@ public void handleRequest(HttpServletRequest servletRequest, HttpServletResponse
177176
finally {
178177
if (failure != null) {
179178
chain.applyAfterHandshake(request, response, failure);
179+
response.close();
180180
throw failure;
181181
}
182+
response.close();
182183
}
183184
}
184185

spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,15 @@ public void handleMessageFromClientWithTokenAuthentication() {
383383
Principal user = SimpMessageHeaderAccessor.getUser(message.getHeaders());
384384
assertThat(user).isNotNull();
385385
assertThat(user.getName()).isEqualTo("[email protected]");
386+
387+
StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.CONNECTED);
388+
message = MessageBuilder.createMessage(EMPTY_PAYLOAD, accessor.getMessageHeaders());
389+
handler.handleMessageToClient(this.session, message);
390+
391+
assertThat(this.session.getSentMessages()).hasSize(1);
392+
WebSocketMessage<?> textMessage = this.session.getSentMessages().get(0);
393+
assertThat(textMessage.getPayload())
394+
.isEqualTo("CONNECTED\n" + "user-name:[email protected]\n" + "\n" + "\u0000");
386395
}
387396

388397
@Test

spring-websocket/src/test/java/org/springframework/web/socket/server/DefaultHandshakeHandlerTests.java

Lines changed: 18 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
import org.junit.Test;
2525

26+
import org.springframework.http.HttpHeaders;
2627
import org.springframework.web.socket.AbstractHttpRequestTests;
2728
import org.springframework.web.socket.SubProtocolCapable;
2829
import org.springframework.web.socket.WebSocketExtension;
@@ -51,14 +52,9 @@ public class DefaultHandshakeHandlerTests extends AbstractHttpRequestTests {
5152
public void supportedSubProtocols() {
5253
this.handshakeHandler.setSupportedProtocols("stomp", "mqtt");
5354
given(this.upgradeStrategy.getSupportedVersions()).willReturn(new String[] {"13"});
54-
this.servletRequest.setMethod("GET");
5555

56-
WebSocketHttpHeaders headers = new WebSocketHttpHeaders(this.request.getHeaders());
57-
headers.setUpgrade("WebSocket");
58-
headers.setConnection("Upgrade");
59-
headers.setSecWebSocketVersion("13");
60-
headers.setSecWebSocketKey("82/ZS2YHjEnUN97HLL8tbw==");
61-
headers.setSecWebSocketProtocol("STOMP");
56+
this.servletRequest.setMethod("GET");
57+
initHeaders(this.request.getHeaders()).setSecWebSocketProtocol("STOMP");
6258

6359
WebSocketHandler handler = new TextWebSocketHandler();
6460
Map<String, Object> attributes = Collections.emptyMap();
@@ -77,16 +73,10 @@ public void supportedExtensions() {
7773
given(this.upgradeStrategy.getSupportedExtensions(this.request)).willReturn(Collections.singletonList(extension1));
7874

7975
this.servletRequest.setMethod("GET");
80-
81-
WebSocketHttpHeaders headers = new WebSocketHttpHeaders(this.request.getHeaders());
82-
headers.setUpgrade("WebSocket");
83-
headers.setConnection("Upgrade");
84-
headers.setSecWebSocketVersion("13");
85-
headers.setSecWebSocketKey("82/ZS2YHjEnUN97HLL8tbw==");
86-
headers.setSecWebSocketExtensions(Arrays.asList(extension1, extension2));
76+
initHeaders(this.request.getHeaders()).setSecWebSocketExtensions(Arrays.asList(extension1, extension2));
8777

8878
WebSocketHandler handler = new TextWebSocketHandler();
89-
Map<String, Object> attributes = Collections.<String, Object>emptyMap();
79+
Map<String, Object> attributes = Collections.emptyMap();
9080
this.handshakeHandler.doHandshake(this.request, this.response, handler, attributes);
9181

9282
verify(this.upgradeStrategy).upgrade(this.request, this.response, null,
@@ -98,16 +88,10 @@ public void subProtocolCapableHandler() {
9888
given(this.upgradeStrategy.getSupportedVersions()).willReturn(new String[] {"13"});
9989

10090
this.servletRequest.setMethod("GET");
101-
102-
WebSocketHttpHeaders headers = new WebSocketHttpHeaders(this.request.getHeaders());
103-
headers.setUpgrade("WebSocket");
104-
headers.setConnection("Upgrade");
105-
headers.setSecWebSocketVersion("13");
106-
headers.setSecWebSocketKey("82/ZS2YHjEnUN97HLL8tbw==");
107-
headers.setSecWebSocketProtocol("v11.stomp");
91+
initHeaders(this.request.getHeaders()).setSecWebSocketProtocol("v11.stomp");
10892

10993
WebSocketHandler handler = new SubProtocolCapableHandler("v12.stomp", "v11.stomp");
110-
Map<String, Object> attributes = Collections.<String, Object>emptyMap();
94+
Map<String, Object> attributes = Collections.emptyMap();
11195
this.handshakeHandler.doHandshake(this.request, this.response, handler, attributes);
11296

11397
verify(this.upgradeStrategy).upgrade(this.request, this.response, "v11.stomp",
@@ -119,22 +103,25 @@ public void subProtocolCapableHandlerNoMatch() {
119103
given(this.upgradeStrategy.getSupportedVersions()).willReturn(new String[] {"13"});
120104

121105
this.servletRequest.setMethod("GET");
122-
123-
WebSocketHttpHeaders headers = new WebSocketHttpHeaders(this.request.getHeaders());
124-
headers.setUpgrade("WebSocket");
125-
headers.setConnection("Upgrade");
126-
headers.setSecWebSocketVersion("13");
127-
headers.setSecWebSocketKey("82/ZS2YHjEnUN97HLL8tbw==");
128-
headers.setSecWebSocketProtocol("v10.stomp");
106+
initHeaders(this.request.getHeaders()).setSecWebSocketProtocol("v10.stomp");
129107

130108
WebSocketHandler handler = new SubProtocolCapableHandler("v12.stomp", "v11.stomp");
131-
Map<String, Object> attributes = Collections.<String, Object>emptyMap();
109+
Map<String, Object> attributes = Collections.emptyMap();
132110
this.handshakeHandler.doHandshake(this.request, this.response, handler, attributes);
133111

134112
verify(this.upgradeStrategy).upgrade(this.request, this.response, null,
135113
Collections.emptyList(), null, handler, attributes);
136114
}
137115

116+
private WebSocketHttpHeaders initHeaders(HttpHeaders httpHeaders) {
117+
WebSocketHttpHeaders headers = new WebSocketHttpHeaders(httpHeaders);
118+
headers.setUpgrade("WebSocket");
119+
headers.setConnection("Upgrade");
120+
headers.setSecWebSocketVersion("13");
121+
headers.setSecWebSocketKey("82/ZS2YHjEnUN97HLL8tbw==");
122+
return headers;
123+
}
124+
138125

139126
private static class SubProtocolCapableHandler extends TextWebSocketHandler implements SubProtocolCapable {
140127

0 commit comments

Comments
 (0)