diff --git a/spring-integration-websocket/src/main/java/org/springframework/integration/websocket/ClientWebSocketContainer.java b/spring-integration-websocket/src/main/java/org/springframework/integration/websocket/ClientWebSocketContainer.java
index 2ec61043c00..ac7f9d3b4ac 100644
--- a/spring-integration-websocket/src/main/java/org/springframework/integration/websocket/ClientWebSocketContainer.java
+++ b/spring-integration-websocket/src/main/java/org/springframework/integration/websocket/ClientWebSocketContainer.java
@@ -208,7 +208,7 @@ public void stop(Runnable callback) {
*
* Opened {@link WebSocketSession} is populated to the wrapping {@link ClientWebSocketContainer}.
*
- * The {@link #webSocketHandler} is used to handle {@link WebSocketSession} events.
+ * The {@link #getWebSocketHandler()} is used to handle {@link WebSocketSession} events.
*/
private final class IntegrationWebSocketConnectionManager extends ConnectionManagerSupport {
@@ -258,8 +258,7 @@ protected void openConnection() {
}
ClientWebSocketContainer.this.headers.setSecWebSocketProtocol(getSubProtocols());
CompletableFuture future =
- this.client.execute(ClientWebSocketContainer.this.webSocketHandler,
- ClientWebSocketContainer.this.headers, getUri());
+ this.client.execute(getWebSocketHandler(), ClientWebSocketContainer.this.headers, getUri());
future.whenComplete((session, throwable) -> {
if (throwable == null) {
diff --git a/spring-integration-websocket/src/main/java/org/springframework/integration/websocket/IntegrationWebSocketContainer.java b/spring-integration-websocket/src/main/java/org/springframework/integration/websocket/IntegrationWebSocketContainer.java
index 95d436ddafc..6436a386b22 100644
--- a/spring-integration-websocket/src/main/java/org/springframework/integration/websocket/IntegrationWebSocketContainer.java
+++ b/spring-integration-websocket/src/main/java/org/springframework/integration/websocket/IntegrationWebSocketContainer.java
@@ -1,5 +1,5 @@
/*
- * Copyright 2014-2022 the original author or authors.
+ * Copyright 2014-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -67,8 +67,8 @@ public abstract class IntegrationWebSocketContainer implements DisposableBean {
protected final Log logger = LogFactory.getLog(getClass()); // NOSONAR
- protected final WebSocketHandler webSocketHandler = new IntegrationWebSocketHandler(); // NOSONAR
-
+ private WebSocketHandler webSocketHandler = new IntegrationWebSocketHandler();
+
protected final Map sessions = new ConcurrentHashMap<>(); // NOSONAR
private final List supportedProtocols = new ArrayList<>();
@@ -104,6 +104,15 @@ public void addSupportedProtocols(String... protocols) {
}
}
+ /**
+ * Replace the default {@link WebSocketHandler} with the one provided here, e.g. via decoration factories.
+ * @param handler the actual {@link WebSocketHandler} to replace.
+ * @since 5.5.18
+ */
+ protected void setWebSocketHandler(WebSocketHandler handler) {
+ this.webSocketHandler = handler;
+ }
+
public WebSocketHandler getWebSocketHandler() {
return this.webSocketHandler;
}
diff --git a/spring-integration-websocket/src/main/java/org/springframework/integration/websocket/ServerWebSocketContainer.java b/spring-integration-websocket/src/main/java/org/springframework/integration/websocket/ServerWebSocketContainer.java
index de7b648cfb6..792e36b4207 100644
--- a/spring-integration-websocket/src/main/java/org/springframework/integration/websocket/ServerWebSocketContainer.java
+++ b/spring-integration-websocket/src/main/java/org/springframework/integration/websocket/ServerWebSocketContainer.java
@@ -1,5 +1,5 @@
/*
- * Copyright 2014-2022 the original author or authors.
+ * Copyright 2014-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -150,11 +150,12 @@ public TaskScheduler getSockJsTaskScheduler() {
@Override
public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
- WebSocketHandler webSocketHandler = this.webSocketHandler;
+ WebSocketHandler webSocketHandler = getWebSocketHandler();
if (this.decoratorFactories != null) {
for (WebSocketHandlerDecoratorFactory factory : this.decoratorFactories) {
webSocketHandler = factory.decorate(webSocketHandler);
+ setWebSocketHandler(webSocketHandler);
}
}
diff --git a/spring-integration-websocket/src/test/java/org/springframework/integration/websocket/dsl/WebSocketDslTests.java b/spring-integration-websocket/src/test/java/org/springframework/integration/websocket/dsl/WebSocketDslTests.java
index 2994acf0721..c445769808c 100644
--- a/spring-integration-websocket/src/test/java/org/springframework/integration/websocket/dsl/WebSocketDslTests.java
+++ b/spring-integration-websocket/src/test/java/org/springframework/integration/websocket/dsl/WebSocketDslTests.java
@@ -1,5 +1,5 @@
/*
- * Copyright 2021-2022 the original author or authors.
+ * Copyright 2021-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -16,6 +16,8 @@
package org.springframework.integration.websocket.dsl;
+import java.util.concurrent.atomic.AtomicReference;
+
import jakarta.websocket.DeploymentException;
import org.junit.jupiter.api.Test;
@@ -37,6 +39,7 @@
import org.springframework.test.annotation.DirtiesContext;
import org.springframework.test.context.junit.jupiter.SpringJUnitConfig;
import org.springframework.web.context.support.AnnotationConfigWebApplicationContext;
+import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.client.WebSocketClient;
import org.springframework.web.socket.client.standard.StandardWebSocketClient;
import org.springframework.web.socket.server.HandshakeHandler;
@@ -46,6 +49,9 @@
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.awaitility.Awaitility.await;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.verify;
@SpringJUnitConfig(classes = WebSocketDslTests.ClientConfig.class)
@DirtiesContext
@@ -61,13 +67,19 @@ public class WebSocketDslTests {
IntegrationFlowContext integrationFlowContext;
@Test
- void testDynamicServerEndpointRegistration() {
+ void testDynamicServerEndpointRegistration() throws Exception {
// Dynamic server flow
AnnotationConfigWebApplicationContext serverContext = this.server.getServerContext();
IntegrationFlowContext serverIntegrationFlowContext = serverContext.getBean(IntegrationFlowContext.class);
+ AtomicReference decoratedHandler = new AtomicReference<>();
ServerWebSocketContainer serverWebSocketContainer =
new ServerWebSocketContainer("/dynamic")
.setHandshakeHandler(serverContext.getBean(HandshakeHandler.class))
+ .setDecoratorFactories(handler -> {
+ WebSocketHandler spy = spy(handler);
+ decoratedHandler.set(spy);
+ return spy;
+ })
.withSockJs();
WebSocketInboundChannelAdapter webSocketInboundChannelAdapter =
@@ -106,6 +118,8 @@ void testDynamicServerEndpointRegistration() {
.extracting(Message::getPayload)
.isEqualTo("dynamic test");
+ verify(decoratedHandler.get()).handleMessage(any(), any());
+
dynamicServerFlow.destroy();
await() // Looks like endpoint is removed on the server side somewhat async