diff --git a/spring-integration-websocket/src/main/java/org/springframework/integration/websocket/ClientWebSocketContainer.java b/spring-integration-websocket/src/main/java/org/springframework/integration/websocket/ClientWebSocketContainer.java index 2ec61043c00..ac7f9d3b4ac 100644 --- a/spring-integration-websocket/src/main/java/org/springframework/integration/websocket/ClientWebSocketContainer.java +++ b/spring-integration-websocket/src/main/java/org/springframework/integration/websocket/ClientWebSocketContainer.java @@ -208,7 +208,7 @@ public void stop(Runnable callback) { *

* Opened {@link WebSocketSession} is populated to the wrapping {@link ClientWebSocketContainer}. *

- * The {@link #webSocketHandler} is used to handle {@link WebSocketSession} events. + * The {@link #getWebSocketHandler()} is used to handle {@link WebSocketSession} events. */ private final class IntegrationWebSocketConnectionManager extends ConnectionManagerSupport { @@ -258,8 +258,7 @@ protected void openConnection() { } ClientWebSocketContainer.this.headers.setSecWebSocketProtocol(getSubProtocols()); CompletableFuture future = - this.client.execute(ClientWebSocketContainer.this.webSocketHandler, - ClientWebSocketContainer.this.headers, getUri()); + this.client.execute(getWebSocketHandler(), ClientWebSocketContainer.this.headers, getUri()); future.whenComplete((session, throwable) -> { if (throwable == null) { diff --git a/spring-integration-websocket/src/main/java/org/springframework/integration/websocket/IntegrationWebSocketContainer.java b/spring-integration-websocket/src/main/java/org/springframework/integration/websocket/IntegrationWebSocketContainer.java index 95d436ddafc..6436a386b22 100644 --- a/spring-integration-websocket/src/main/java/org/springframework/integration/websocket/IntegrationWebSocketContainer.java +++ b/spring-integration-websocket/src/main/java/org/springframework/integration/websocket/IntegrationWebSocketContainer.java @@ -1,5 +1,5 @@ /* - * Copyright 2014-2022 the original author or authors. + * Copyright 2014-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -67,8 +67,8 @@ public abstract class IntegrationWebSocketContainer implements DisposableBean { protected final Log logger = LogFactory.getLog(getClass()); // NOSONAR - protected final WebSocketHandler webSocketHandler = new IntegrationWebSocketHandler(); // NOSONAR - + private WebSocketHandler webSocketHandler = new IntegrationWebSocketHandler(); + protected final Map sessions = new ConcurrentHashMap<>(); // NOSONAR private final List supportedProtocols = new ArrayList<>(); @@ -104,6 +104,15 @@ public void addSupportedProtocols(String... protocols) { } } + /** + * Replace the default {@link WebSocketHandler} with the one provided here, e.g. via decoration factories. + * @param handler the actual {@link WebSocketHandler} to replace. + * @since 5.5.18 + */ + protected void setWebSocketHandler(WebSocketHandler handler) { + this.webSocketHandler = handler; + } + public WebSocketHandler getWebSocketHandler() { return this.webSocketHandler; } diff --git a/spring-integration-websocket/src/main/java/org/springframework/integration/websocket/ServerWebSocketContainer.java b/spring-integration-websocket/src/main/java/org/springframework/integration/websocket/ServerWebSocketContainer.java index de7b648cfb6..792e36b4207 100644 --- a/spring-integration-websocket/src/main/java/org/springframework/integration/websocket/ServerWebSocketContainer.java +++ b/spring-integration-websocket/src/main/java/org/springframework/integration/websocket/ServerWebSocketContainer.java @@ -1,5 +1,5 @@ /* - * Copyright 2014-2022 the original author or authors. + * Copyright 2014-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -150,11 +150,12 @@ public TaskScheduler getSockJsTaskScheduler() { @Override public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) { - WebSocketHandler webSocketHandler = this.webSocketHandler; + WebSocketHandler webSocketHandler = getWebSocketHandler(); if (this.decoratorFactories != null) { for (WebSocketHandlerDecoratorFactory factory : this.decoratorFactories) { webSocketHandler = factory.decorate(webSocketHandler); + setWebSocketHandler(webSocketHandler); } } diff --git a/spring-integration-websocket/src/test/java/org/springframework/integration/websocket/dsl/WebSocketDslTests.java b/spring-integration-websocket/src/test/java/org/springframework/integration/websocket/dsl/WebSocketDslTests.java index 2994acf0721..c445769808c 100644 --- a/spring-integration-websocket/src/test/java/org/springframework/integration/websocket/dsl/WebSocketDslTests.java +++ b/spring-integration-websocket/src/test/java/org/springframework/integration/websocket/dsl/WebSocketDslTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2021-2022 the original author or authors. + * Copyright 2021-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,8 @@ package org.springframework.integration.websocket.dsl; +import java.util.concurrent.atomic.AtomicReference; + import jakarta.websocket.DeploymentException; import org.junit.jupiter.api.Test; @@ -37,6 +39,7 @@ import org.springframework.test.annotation.DirtiesContext; import org.springframework.test.context.junit.jupiter.SpringJUnitConfig; import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; +import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.client.WebSocketClient; import org.springframework.web.socket.client.standard.StandardWebSocketClient; import org.springframework.web.socket.server.HandshakeHandler; @@ -46,6 +49,9 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.awaitility.Awaitility.await; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; @SpringJUnitConfig(classes = WebSocketDslTests.ClientConfig.class) @DirtiesContext @@ -61,13 +67,19 @@ public class WebSocketDslTests { IntegrationFlowContext integrationFlowContext; @Test - void testDynamicServerEndpointRegistration() { + void testDynamicServerEndpointRegistration() throws Exception { // Dynamic server flow AnnotationConfigWebApplicationContext serverContext = this.server.getServerContext(); IntegrationFlowContext serverIntegrationFlowContext = serverContext.getBean(IntegrationFlowContext.class); + AtomicReference decoratedHandler = new AtomicReference<>(); ServerWebSocketContainer serverWebSocketContainer = new ServerWebSocketContainer("/dynamic") .setHandshakeHandler(serverContext.getBean(HandshakeHandler.class)) + .setDecoratorFactories(handler -> { + WebSocketHandler spy = spy(handler); + decoratedHandler.set(spy); + return spy; + }) .withSockJs(); WebSocketInboundChannelAdapter webSocketInboundChannelAdapter = @@ -106,6 +118,8 @@ void testDynamicServerEndpointRegistration() { .extracting(Message::getPayload) .isEqualTo("dynamic test"); + verify(decoratedHandler.get()).handleMessage(any(), any()); + dynamicServerFlow.destroy(); await() // Looks like endpoint is removed on the server side somewhat async