diff --git a/spring-integration-websocket/src/main/java/org/springframework/integration/websocket/IntegrationWebSocketContainer.java b/spring-integration-websocket/src/main/java/org/springframework/integration/websocket/IntegrationWebSocketContainer.java
index dfe8d67e9c9..b386ff21983 100644
--- a/spring-integration-websocket/src/main/java/org/springframework/integration/websocket/IntegrationWebSocketContainer.java
+++ b/spring-integration-websocket/src/main/java/org/springframework/integration/websocket/IntegrationWebSocketContainer.java
@@ -28,6 +28,8 @@
import org.apache.commons.logging.LogFactory;
import org.springframework.beans.factory.DisposableBean;
+import org.springframework.lang.NonNull;
+import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.ReflectionUtils;
import org.springframework.web.socket.CloseStatus;
@@ -55,6 +57,7 @@
*
* @author Artem Bilan
* @author Gary Russell
+ * @author Julian Koch
*
* @since 4.1
*
@@ -83,6 +86,9 @@ public abstract class IntegrationWebSocketContainer implements DisposableBean {
private int sendBufferSizeLimit = DEFAULT_SEND_BUFFER_SIZE;
+ @Nullable
+ private ConcurrentWebSocketSessionDecorator.OverflowStrategy sendBufferOverflowStrategy;
+
public void setSendTimeLimit(int sendTimeLimit) {
this.sendTimeLimit = sendTimeLimit;
}
@@ -91,6 +97,19 @@ public void setSendBufferSizeLimit(int sendBufferSizeLimit) {
this.sendBufferSizeLimit = sendBufferSizeLimit;
}
+ /**
+ * Set the send buffer overflow strategy.
+ *
Concurrently generated outbound messages are buffered if sending is slow.
+ * This strategy determines the behavior when the buffer has reached the limit
+ * configured with {@link #setSendBufferSizeLimit}.
+ * @param overflowStrategy The overflow strategy to use (see {@link ConcurrentWebSocketSessionDecorator.OverflowStrategy}),
+ * or {@code null} to use the default as specified by {@link ConcurrentWebSocketSessionDecorator}.
+ * @see ConcurrentWebSocketSessionDecorator
+ */
+ public void setSendBufferOverflowStrategy(@Nullable ConcurrentWebSocketSessionDecorator.OverflowStrategy overflowStrategy) {
+ this.sendBufferOverflowStrategy = overflowStrategy;
+ }
+
public void setMessageListener(WebSocketListener messageListener) {
Assert.state(this.messageListener == null || this.messageListener.equals(messageListener),
"'messageListener' is already configured");
@@ -187,10 +206,7 @@ public List getSubProtocols() {
public void afterConnectionEstablished(WebSocketSession sessionToDecorate)
throws Exception { // NOSONAR
- WebSocketSession session =
- new ConcurrentWebSocketSessionDecorator(sessionToDecorate,
- IntegrationWebSocketContainer.this.sendTimeLimit,
- IntegrationWebSocketContainer.this.sendBufferSizeLimit);
+ WebSocketSession session = decorateSession(sessionToDecorate);
IntegrationWebSocketContainer.this.sessions.put(session.getId(), session);
if (IntegrationWebSocketContainer.this.logger.isDebugEnabled()) {
@@ -240,6 +256,16 @@ public boolean supportsPartialMessages() {
return false;
}
+ private WebSocketSession decorateSession(@NonNull WebSocketSession sessionToDecorate) {
+ return (IntegrationWebSocketContainer.this.sendBufferOverflowStrategy == null
+ ? new ConcurrentWebSocketSessionDecorator(sessionToDecorate,
+ IntegrationWebSocketContainer.this.sendTimeLimit,
+ IntegrationWebSocketContainer.this.sendBufferSizeLimit)
+ : new ConcurrentWebSocketSessionDecorator(sessionToDecorate,
+ IntegrationWebSocketContainer.this.sendTimeLimit,
+ IntegrationWebSocketContainer.this.sendBufferSizeLimit,
+ IntegrationWebSocketContainer.this.sendBufferOverflowStrategy));
+ }
}
}
diff --git a/spring-integration-websocket/src/test/java/org/springframework/integration/websocket/ClientWebSocketContainerTests.java b/spring-integration-websocket/src/test/java/org/springframework/integration/websocket/ClientWebSocketContainerTests.java
index 211abc09185..05d0c298333 100644
--- a/spring-integration-websocket/src/test/java/org/springframework/integration/websocket/ClientWebSocketContainerTests.java
+++ b/spring-integration-websocket/src/test/java/org/springframework/integration/websocket/ClientWebSocketContainerTests.java
@@ -34,6 +34,8 @@
import org.junit.jupiter.api.Test;
import org.springframework.http.HttpHeaders;
+import org.springframework.integration.test.util.TestUtils;
+import org.springframework.lang.Nullable;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.PingMessage;
import org.springframework.web.socket.PongMessage;
@@ -42,6 +44,7 @@
import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.client.standard.StandardWebSocketClient;
+import org.springframework.web.socket.handler.ConcurrentWebSocketSessionDecorator;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalStateException;
@@ -138,16 +141,60 @@ protected CompletableFuture executeInternal(WebSocketHandler w
assertThat(session.isOpen()).isTrue();
}
+ @Test
+ public void testWebSocketContainerOverflowStrategyPropagation() throws Exception {
+ StandardWebSocketClient webSocketClient = new StandardWebSocketClient();
+
+ Map userProperties = new HashMap<>();
+ userProperties.put(Constants.IO_TIMEOUT_MS_PROPERTY, "" + (Constants.IO_TIMEOUT_MS_DEFAULT * 6));
+ webSocketClient.setUserProperties(userProperties);
+
+ ClientWebSocketContainer container =
+ new ClientWebSocketContainer(webSocketClient, new URI(server.getWsBaseUrl() + "/ws/websocket"));
+
+ // We expect the options we set here to be propagated to the concrete ConcurrentWebSocketSessionDecorator
+ container.setSendTimeLimit(10_000);
+ container.setSendBufferSizeLimit(12345);
+ container.setSendBufferOverflowStrategy(ConcurrentWebSocketSessionDecorator.OverflowStrategy.DROP);
+
+ TestWebSocketListener messageListener = new TestWebSocketListener();
+ container.setMessageListener(messageListener);
+ container.setConnectionTimeout(30);
+
+ container.start();
+
+ // We must wait at least until the session has been started before we can check the propagated options
+ assertThat(messageListener.sessionStartedLatch.await(10, TimeUnit.SECONDS)).isTrue();
+
+ assertThat(messageListener.optionsPropagatedToSession).isNotNull();
+ assertThat(messageListener.optionsPropagatedToSession.sendTimeLimit).isEqualTo(10_000);
+ assertThat(messageListener.optionsPropagatedToSession.sendBufferSizeLimit).isEqualTo(12345);
+ assertThat(messageListener.optionsPropagatedToSession.sendBufferOverflowStrategy)
+ .isEqualTo(ConcurrentWebSocketSessionDecorator.OverflowStrategy.DROP);
+ }
+
+ private record OptionsPropagatedToSession(
+ int sendTimeLimit,
+ int sendBufferSizeLimit,
+ ConcurrentWebSocketSessionDecorator.OverflowStrategy sendBufferOverflowStrategy
+ ) {
+ }
+
private static class TestWebSocketListener implements WebSocketListener {
public final CountDownLatch messageLatch = new CountDownLatch(1);
+ public final CountDownLatch sessionStartedLatch = new CountDownLatch(1);
+
public final CountDownLatch sessionEndedLatch = new CountDownLatch(1);
public WebSocketMessage> message;
public boolean started;
+ @Nullable
+ public OptionsPropagatedToSession optionsPropagatedToSession;
+
TestWebSocketListener() {
}
@@ -160,6 +207,15 @@ public void onMessage(WebSocketSession session, WebSocketMessage> message) {
@Override
public void afterSessionStarted(WebSocketSession session) {
this.started = true;
+
+ var sessionDecorator = (ConcurrentWebSocketSessionDecorator) session;
+ this.optionsPropagatedToSession = new OptionsPropagatedToSession(
+ sessionDecorator.getSendTimeLimit(),
+ sessionDecorator.getBufferSizeLimit(),
+ TestUtils.getPropertyValue(sessionDecorator, "overflowStrategy", ConcurrentWebSocketSessionDecorator.OverflowStrategy.class)
+ );
+
+ this.sessionStartedLatch.countDown();
}
@Override