diff --git a/spring-integration-websocket/src/main/java/org/springframework/integration/websocket/IntegrationWebSocketContainer.java b/spring-integration-websocket/src/main/java/org/springframework/integration/websocket/IntegrationWebSocketContainer.java index dfe8d67e9c9..b386ff21983 100644 --- a/spring-integration-websocket/src/main/java/org/springframework/integration/websocket/IntegrationWebSocketContainer.java +++ b/spring-integration-websocket/src/main/java/org/springframework/integration/websocket/IntegrationWebSocketContainer.java @@ -28,6 +28,8 @@ import org.apache.commons.logging.LogFactory; import org.springframework.beans.factory.DisposableBean; +import org.springframework.lang.NonNull; +import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.ReflectionUtils; import org.springframework.web.socket.CloseStatus; @@ -55,6 +57,7 @@ * * @author Artem Bilan * @author Gary Russell + * @author Julian Koch * * @since 4.1 * @@ -83,6 +86,9 @@ public abstract class IntegrationWebSocketContainer implements DisposableBean { private int sendBufferSizeLimit = DEFAULT_SEND_BUFFER_SIZE; + @Nullable + private ConcurrentWebSocketSessionDecorator.OverflowStrategy sendBufferOverflowStrategy; + public void setSendTimeLimit(int sendTimeLimit) { this.sendTimeLimit = sendTimeLimit; } @@ -91,6 +97,19 @@ public void setSendBufferSizeLimit(int sendBufferSizeLimit) { this.sendBufferSizeLimit = sendBufferSizeLimit; } + /** + * Set the send buffer overflow strategy. + *

Concurrently generated outbound messages are buffered if sending is slow. + * This strategy determines the behavior when the buffer has reached the limit + * configured with {@link #setSendBufferSizeLimit}. + * @param overflowStrategy The overflow strategy to use (see {@link ConcurrentWebSocketSessionDecorator.OverflowStrategy}), + * or {@code null} to use the default as specified by {@link ConcurrentWebSocketSessionDecorator}. + * @see ConcurrentWebSocketSessionDecorator + */ + public void setSendBufferOverflowStrategy(@Nullable ConcurrentWebSocketSessionDecorator.OverflowStrategy overflowStrategy) { + this.sendBufferOverflowStrategy = overflowStrategy; + } + public void setMessageListener(WebSocketListener messageListener) { Assert.state(this.messageListener == null || this.messageListener.equals(messageListener), "'messageListener' is already configured"); @@ -187,10 +206,7 @@ public List getSubProtocols() { public void afterConnectionEstablished(WebSocketSession sessionToDecorate) throws Exception { // NOSONAR - WebSocketSession session = - new ConcurrentWebSocketSessionDecorator(sessionToDecorate, - IntegrationWebSocketContainer.this.sendTimeLimit, - IntegrationWebSocketContainer.this.sendBufferSizeLimit); + WebSocketSession session = decorateSession(sessionToDecorate); IntegrationWebSocketContainer.this.sessions.put(session.getId(), session); if (IntegrationWebSocketContainer.this.logger.isDebugEnabled()) { @@ -240,6 +256,16 @@ public boolean supportsPartialMessages() { return false; } + private WebSocketSession decorateSession(@NonNull WebSocketSession sessionToDecorate) { + return (IntegrationWebSocketContainer.this.sendBufferOverflowStrategy == null + ? new ConcurrentWebSocketSessionDecorator(sessionToDecorate, + IntegrationWebSocketContainer.this.sendTimeLimit, + IntegrationWebSocketContainer.this.sendBufferSizeLimit) + : new ConcurrentWebSocketSessionDecorator(sessionToDecorate, + IntegrationWebSocketContainer.this.sendTimeLimit, + IntegrationWebSocketContainer.this.sendBufferSizeLimit, + IntegrationWebSocketContainer.this.sendBufferOverflowStrategy)); + } } } diff --git a/spring-integration-websocket/src/test/java/org/springframework/integration/websocket/ClientWebSocketContainerTests.java b/spring-integration-websocket/src/test/java/org/springframework/integration/websocket/ClientWebSocketContainerTests.java index 211abc09185..05d0c298333 100644 --- a/spring-integration-websocket/src/test/java/org/springframework/integration/websocket/ClientWebSocketContainerTests.java +++ b/spring-integration-websocket/src/test/java/org/springframework/integration/websocket/ClientWebSocketContainerTests.java @@ -34,6 +34,8 @@ import org.junit.jupiter.api.Test; import org.springframework.http.HttpHeaders; +import org.springframework.integration.test.util.TestUtils; +import org.springframework.lang.Nullable; import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.PingMessage; import org.springframework.web.socket.PongMessage; @@ -42,6 +44,7 @@ import org.springframework.web.socket.WebSocketMessage; import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.client.standard.StandardWebSocketClient; +import org.springframework.web.socket.handler.ConcurrentWebSocketSessionDecorator; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalStateException; @@ -138,16 +141,60 @@ protected CompletableFuture executeInternal(WebSocketHandler w assertThat(session.isOpen()).isTrue(); } + @Test + public void testWebSocketContainerOverflowStrategyPropagation() throws Exception { + StandardWebSocketClient webSocketClient = new StandardWebSocketClient(); + + Map userProperties = new HashMap<>(); + userProperties.put(Constants.IO_TIMEOUT_MS_PROPERTY, "" + (Constants.IO_TIMEOUT_MS_DEFAULT * 6)); + webSocketClient.setUserProperties(userProperties); + + ClientWebSocketContainer container = + new ClientWebSocketContainer(webSocketClient, new URI(server.getWsBaseUrl() + "/ws/websocket")); + + // We expect the options we set here to be propagated to the concrete ConcurrentWebSocketSessionDecorator + container.setSendTimeLimit(10_000); + container.setSendBufferSizeLimit(12345); + container.setSendBufferOverflowStrategy(ConcurrentWebSocketSessionDecorator.OverflowStrategy.DROP); + + TestWebSocketListener messageListener = new TestWebSocketListener(); + container.setMessageListener(messageListener); + container.setConnectionTimeout(30); + + container.start(); + + // We must wait at least until the session has been started before we can check the propagated options + assertThat(messageListener.sessionStartedLatch.await(10, TimeUnit.SECONDS)).isTrue(); + + assertThat(messageListener.optionsPropagatedToSession).isNotNull(); + assertThat(messageListener.optionsPropagatedToSession.sendTimeLimit).isEqualTo(10_000); + assertThat(messageListener.optionsPropagatedToSession.sendBufferSizeLimit).isEqualTo(12345); + assertThat(messageListener.optionsPropagatedToSession.sendBufferOverflowStrategy) + .isEqualTo(ConcurrentWebSocketSessionDecorator.OverflowStrategy.DROP); + } + + private record OptionsPropagatedToSession( + int sendTimeLimit, + int sendBufferSizeLimit, + ConcurrentWebSocketSessionDecorator.OverflowStrategy sendBufferOverflowStrategy + ) { + } + private static class TestWebSocketListener implements WebSocketListener { public final CountDownLatch messageLatch = new CountDownLatch(1); + public final CountDownLatch sessionStartedLatch = new CountDownLatch(1); + public final CountDownLatch sessionEndedLatch = new CountDownLatch(1); public WebSocketMessage message; public boolean started; + @Nullable + public OptionsPropagatedToSession optionsPropagatedToSession; + TestWebSocketListener() { } @@ -160,6 +207,15 @@ public void onMessage(WebSocketSession session, WebSocketMessage message) { @Override public void afterSessionStarted(WebSocketSession session) { this.started = true; + + var sessionDecorator = (ConcurrentWebSocketSessionDecorator) session; + this.optionsPropagatedToSession = new OptionsPropagatedToSession( + sessionDecorator.getSendTimeLimit(), + sessionDecorator.getBufferSizeLimit(), + TestUtils.getPropertyValue(sessionDecorator, "overflowStrategy", ConcurrentWebSocketSessionDecorator.OverflowStrategy.class) + ); + + this.sessionStartedLatch.countDown(); } @Override