diff --git a/src/main/java/com/rabbitmq/client/Channel.java b/src/main/java/com/rabbitmq/client/Channel.java index cdff101fd4..f6c1c240d4 100644 --- a/src/main/java/com/rabbitmq/client/Channel.java +++ b/src/main/java/com/rabbitmq/client/Channel.java @@ -193,42 +193,49 @@ public interface Channel extends ShutdownNotifier, AutoCloseable { /** * Request specific "quality of service" settings. - * + *

* These settings impose limits on the amount of data the server * will deliver to consumers before requiring acknowledgements. * Thus they provide a means of consumer-initiated flow control. - * @see com.rabbitmq.client.AMQP.Basic.Qos - * @param prefetchSize maximum amount of content (measured in - * octets) that the server will deliver, 0 if unlimited + *

+ * Note the prefetch count must be between 0 and 65535 (unsigned short in AMQP 0-9-1). + * + * @param prefetchSize maximum amount of content (measured in + * octets) that the server will deliver, 0 if unlimited * @param prefetchCount maximum number of messages that the server - * will deliver, 0 if unlimited - * @param global true if the settings should be applied to the - * entire channel rather than each consumer + * will deliver, 0 if unlimited + * @param global true if the settings should be applied to the + * entire channel rather than each consumer * @throws java.io.IOException if an error is encountered + * @see com.rabbitmq.client.AMQP.Basic.Qos */ void basicQos(int prefetchSize, int prefetchCount, boolean global) throws IOException; /** * Request a specific prefetchCount "quality of service" settings * for this channel. + *

+ * Note the prefetch count must be between 0 and 65535 (unsigned short in AMQP 0-9-1). * - * @see #basicQos(int, int, boolean) * @param prefetchCount maximum number of messages that the server - * will deliver, 0 if unlimited - * @param global true if the settings should be applied to the - * entire channel rather than each consumer + * will deliver, 0 if unlimited + * @param global true if the settings should be applied to the + * entire channel rather than each consumer * @throws java.io.IOException if an error is encountered + * @see #basicQos(int, int, boolean) */ void basicQos(int prefetchCount, boolean global) throws IOException; /** * Request a specific prefetchCount "quality of service" settings * for this channel. + *

+ * Note the prefetch count must be between 0 and 65535 (unsigned short in AMQP 0-9-1). * - * @see #basicQos(int, int, boolean) * @param prefetchCount maximum number of messages that the server - * will deliver, 0 if unlimited + * will deliver, 0 if unlimited * @throws java.io.IOException if an error is encountered + * @see #basicQos(int, int, boolean) */ void basicQos(int prefetchCount) throws IOException; diff --git a/src/main/java/com/rabbitmq/client/ConnectionFactory.java b/src/main/java/com/rabbitmq/client/ConnectionFactory.java index 9c1dfa3fe0..22d468432e 100644 --- a/src/main/java/com/rabbitmq/client/ConnectionFactory.java +++ b/src/main/java/com/rabbitmq/client/ConnectionFactory.java @@ -47,6 +47,8 @@ */ public class ConnectionFactory implements Cloneable { + private static final int MAX_UNSIGNED_SHORT = 65535; + /** Default user name */ public static final String DEFAULT_USER = "guest"; /** Default password */ @@ -384,10 +386,16 @@ public int getRequestedChannelMax() { } /** - * Set the requested maximum channel number + * Set the requested maximum channel number. + *

+ * Note the value must be between 0 and 65535 (unsigned short in AMQP 0-9-1). + * * @param requestedChannelMax initially requested maximum channel number; zero for unlimited */ public void setRequestedChannelMax(int requestedChannelMax) { + if (requestedChannelMax < 0 || requestedChannelMax > MAX_UNSIGNED_SHORT) { + throw new IllegalArgumentException("Requested channel max must be between 0 and " + MAX_UNSIGNED_SHORT); + } this.requestedChannelMax = requestedChannelMax; } @@ -477,10 +485,16 @@ public int getShutdownTimeout() { * Set the requested heartbeat timeout. Heartbeat frames will be sent at about 1/2 the timeout interval. * If server heartbeat timeout is configured to a non-zero value, this method can only be used * to lower the value; otherwise any value provided by the client will be used. + *

+ * Note the value must be between 0 and 65535 (unsigned short in AMQP 0-9-1). + * * @param requestedHeartbeat the initially requested heartbeat timeout, in seconds; zero for none * @see RabbitMQ Heartbeats Guide */ public void setRequestedHeartbeat(int requestedHeartbeat) { + if (requestedHeartbeat < 0 || requestedHeartbeat > MAX_UNSIGNED_SHORT) { + throw new IllegalArgumentException("Requested heartbeat must be between 0 and " + MAX_UNSIGNED_SHORT); + } this.requestedHeartbeat = requestedHeartbeat; } diff --git a/src/main/java/com/rabbitmq/client/impl/AMQConnection.java b/src/main/java/com/rabbitmq/client/impl/AMQConnection.java index 28c03a847d..7c09e6900c 100644 --- a/src/main/java/com/rabbitmq/client/impl/AMQConnection.java +++ b/src/main/java/com/rabbitmq/client/impl/AMQConnection.java @@ -15,13 +15,12 @@ package com.rabbitmq.client.impl; -import com.rabbitmq.client.*; import com.rabbitmq.client.Method; +import com.rabbitmq.client.*; import com.rabbitmq.client.impl.AMQChannel.BlockingRpcContinuation; import com.rabbitmq.client.impl.recovery.RecoveryCanBeginListener; import com.rabbitmq.utility.BlockingCell; import com.rabbitmq.utility.Utility; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -47,6 +46,8 @@ final class Copyright { */ public class AMQConnection extends ShutdownNotifierComponent implements Connection, NetworkConnection { + private static final int MAX_UNSIGNED_SHORT = 65535; + private static final Logger LOGGER = LoggerFactory.getLogger(AMQConnection.class); // we want socket write and channel shutdown timeouts to kick in after // the heartbeat one, so we use a value of 105% of the effective heartbeat timeout @@ -399,6 +400,11 @@ public void start() int channelMax = negotiateChannelMax(this.requestedChannelMax, connTune.getChannelMax()); + + if (!checkUnsignedShort(channelMax)) { + throw new IllegalArgumentException("Negotiated channel max must be between 0 and " + MAX_UNSIGNED_SHORT + ": " + channelMax); + } + _channelManager = instantiateChannelManager(channelMax, threadFactory); int frameMax = @@ -410,6 +416,10 @@ public void start() negotiatedMaxValue(this.requestedHeartbeat, connTune.getHeartbeat()); + if (!checkUnsignedShort(heartbeat)) { + throw new IllegalArgumentException("Negotiated heartbeat must be between 0 and " + MAX_UNSIGNED_SHORT + ": " + heartbeat); + } + setHeartbeat(heartbeat); _channel0.transmit(new AMQP.Connection.TuneOk.Builder() @@ -626,6 +636,10 @@ private static int negotiatedMaxValue(int clientValue, int serverValue) { Math.min(clientValue, serverValue); } + private static boolean checkUnsignedShort(int value) { + return value >= 0 && value <= MAX_UNSIGNED_SHORT; + } + private class MainLoop implements Runnable { /** diff --git a/src/main/java/com/rabbitmq/client/impl/ChannelN.java b/src/main/java/com/rabbitmq/client/impl/ChannelN.java index a3f7f5f794..db4d9b86e3 100644 --- a/src/main/java/com/rabbitmq/client/impl/ChannelN.java +++ b/src/main/java/com/rabbitmq/client/impl/ChannelN.java @@ -15,30 +15,24 @@ package com.rabbitmq.client.impl; -import java.io.IOException; -import java.util.Collection; -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; -import java.util.SortedSet; -import java.util.TreeSet; -import java.util.concurrent.*; - -import com.rabbitmq.client.ConfirmCallback; import com.rabbitmq.client.*; -import com.rabbitmq.client.AMQP.BasicProperties; +import com.rabbitmq.client.Connection; import com.rabbitmq.client.Method; -import com.rabbitmq.client.impl.AMQImpl.Basic; +import com.rabbitmq.client.AMQP.BasicProperties; import com.rabbitmq.client.impl.AMQImpl.Channel; -import com.rabbitmq.client.impl.AMQImpl.Confirm; -import com.rabbitmq.client.impl.AMQImpl.Exchange; import com.rabbitmq.client.impl.AMQImpl.Queue; -import com.rabbitmq.client.impl.AMQImpl.Tx; +import com.rabbitmq.client.impl.AMQImpl.*; import com.rabbitmq.utility.Utility; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.io.IOException; +import java.util.*; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeoutException; + /** * Main interface to AMQP protocol functionality. Public API - * Implementation of all AMQChannels except channel zero. @@ -50,6 +44,7 @@ * */ public class ChannelN extends AMQChannel implements com.rabbitmq.client.Channel { + private static final int MAX_UNSIGNED_SHORT = 65535; private static final String UNSPECIFIED_OUT_OF_BAND = ""; private static final Logger LOGGER = LoggerFactory.getLogger(ChannelN.class); @@ -647,7 +642,10 @@ public AMQCommand transformReply(AMQCommand command) { public void basicQos(int prefetchSize, int prefetchCount, boolean global) throws IOException { - exnWrappingRpc(new Basic.Qos(prefetchSize, prefetchCount, global)); + if (prefetchCount < 0 || prefetchCount > MAX_UNSIGNED_SHORT) { + throw new IllegalArgumentException("Prefetch count must be between 0 and " + MAX_UNSIGNED_SHORT); + } + exnWrappingRpc(new Basic.Qos(prefetchSize, prefetchCount, global)); } /** Public API - {@inheritDoc} */ diff --git a/src/test/java/com/rabbitmq/client/test/ChannelNTest.java b/src/test/java/com/rabbitmq/client/test/ChannelNTest.java index 34346366c8..80c7902be4 100644 --- a/src/test/java/com/rabbitmq/client/test/ChannelNTest.java +++ b/src/test/java/com/rabbitmq/client/test/ChannelNTest.java @@ -24,6 +24,9 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.stream.Stream; + +import static org.assertj.core.api.Assertions.assertThatThrownBy; public class ChannelNTest { @@ -57,4 +60,32 @@ public void callingBasicCancelForUnknownConsumerDoesNotThrowException() throws E channel.basicCancel("does-not-exist"); } + @Test + public void qosShouldBeUnsignedShort() { + AMQConnection connection = Mockito.mock(AMQConnection.class); + ChannelN channel = new ChannelN(connection, 1, consumerWorkService); + class TestConfig { + int value; + Consumer call; + + public TestConfig(int value, Consumer call) { + this.value = value; + this.call = call; + } + } + Consumer qos = value -> channel.basicQos(value); + Consumer qosGlobal = value -> channel.basicQos(value, true); + Consumer qosPrefetchSize = value -> channel.basicQos(10, value, true); + Stream.of( + new TestConfig(-1, qos), new TestConfig(65536, qos) + ).flatMap(config -> Stream.of(config, new TestConfig(config.value, qosGlobal), new TestConfig(config.value, qosPrefetchSize))) + .forEach(config -> assertThatThrownBy(() -> config.call.apply(config.value)).isInstanceOf(IllegalArgumentException.class)); + } + + interface Consumer { + + void apply(int value) throws Exception; + + } + } diff --git a/src/test/java/com/rabbitmq/client/test/ClientTests.java b/src/test/java/com/rabbitmq/client/test/ClientTests.java index 02a0aebadb..77e2a75f83 100644 --- a/src/test/java/com/rabbitmq/client/test/ClientTests.java +++ b/src/test/java/com/rabbitmq/client/test/ClientTests.java @@ -52,7 +52,6 @@ ConnectionFactoryTest.class, RecoveryAwareAMQConnectionFactoryTest.class, RpcTest.class, - SslContextFactoryTest.class, LambdaCallbackTest.class, ChannelAsyncCompletableFutureTest.class, RecoveryDelayHandlerTest.class, diff --git a/src/test/java/com/rabbitmq/client/test/ConnectionFactoryTest.java b/src/test/java/com/rabbitmq/client/test/ConnectionFactoryTest.java index f77cee0a7e..7a9dd3d320 100644 --- a/src/test/java/com/rabbitmq/client/test/ConnectionFactoryTest.java +++ b/src/test/java/com/rabbitmq/client/test/ConnectionFactoryTest.java @@ -15,19 +15,8 @@ package com.rabbitmq.client.test; -import com.rabbitmq.client.Address; -import com.rabbitmq.client.AddressResolver; -import com.rabbitmq.client.Connection; -import com.rabbitmq.client.ConnectionFactory; -import com.rabbitmq.client.DnsRecordIpAddressResolver; -import com.rabbitmq.client.ListAddressResolver; -import com.rabbitmq.client.MetricsCollector; -import com.rabbitmq.client.impl.AMQConnection; -import com.rabbitmq.client.impl.ConnectionParams; -import com.rabbitmq.client.impl.CredentialsProvider; -import com.rabbitmq.client.impl.FrameHandler; -import com.rabbitmq.client.impl.FrameHandlerFactory; -import org.junit.AfterClass; +import com.rabbitmq.client.*; +import com.rabbitmq.client.impl.*; import org.junit.Test; import java.io.IOException; @@ -37,17 +26,18 @@ import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import java.util.stream.Stream; -import static org.hamcrest.Matchers.allOf; -import static org.hamcrest.Matchers.instanceOf; -import static org.hamcrest.Matchers.notNullValue; -import static org.junit.Assert.*; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.Mockito.*; public class ConnectionFactoryTest { // see https://github.com/rabbitmq/rabbitmq-java-client/issues/262 - @Test public void tryNextAddressIfTimeoutExceptionNoAutoRecovery() throws IOException, TimeoutException { + @Test + public void tryNextAddressIfTimeoutExceptionNoAutoRecovery() throws IOException, TimeoutException { final AMQConnection connectionThatThrowsTimeout = mock(AMQConnection.class); final AMQConnection connectionThatSucceeds = mock(AMQConnection.class); final Queue connections = new ArrayBlockingQueue(10); @@ -69,22 +59,23 @@ protected synchronized FrameHandlerFactory createFrameHandlerFactory() { doThrow(TimeoutException.class).when(connectionThatThrowsTimeout).start(); doNothing().when(connectionThatSucceeds).start(); Connection returnedConnection = connectionFactory.newConnection( - new Address[] { new Address("host1"), new Address("host2") } + new Address[]{new Address("host1"), new Address("host2")} ); - assertSame(connectionThatSucceeds, returnedConnection); + assertThat(returnedConnection).isSameAs(connectionThatSucceeds); } - + // see https://github.com/rabbitmq/rabbitmq-java-client/pull/350 - @Test public void customizeCredentialsProvider() throws Exception { + @Test + public void customizeCredentialsProvider() throws Exception { final CredentialsProvider provider = mock(CredentialsProvider.class); final AMQConnection connection = mock(AMQConnection.class); final AtomicBoolean createCalled = new AtomicBoolean(false); - + ConnectionFactory connectionFactory = new ConnectionFactory() { @Override protected AMQConnection createConnection(ConnectionParams params, FrameHandler frameHandler, - MetricsCollector metricsCollector) { - assertSame(provider, params.getCredentialsProvider()); + MetricsCollector metricsCollector) { + assertThat(provider).isSameAs(params.getCredentialsProvider()); createCalled.set(true); return connection; } @@ -96,22 +87,23 @@ protected synchronized FrameHandlerFactory createFrameHandlerFactory() { }; connectionFactory.setCredentialsProvider(provider); connectionFactory.setAutomaticRecoveryEnabled(false); - + doNothing().when(connection).start(); - + Connection returnedConnection = connectionFactory.newConnection(); - assertSame(returnedConnection, connection); - assertTrue(createCalled.get()); + assertThat(returnedConnection).isSameAs(connection); + assertThat(createCalled).isTrue(); } - @Test public void shouldNotUseDnsResolutionWhenOneAddressAndNoTls() throws Exception { + @Test + public void shouldNotUseDnsResolutionWhenOneAddressAndNoTls() throws Exception { AMQConnection connection = mock(AMQConnection.class); AtomicReference addressResolver = new AtomicReference<>(); ConnectionFactory connectionFactory = new ConnectionFactory() { @Override protected AMQConnection createConnection(ConnectionParams params, FrameHandler frameHandler, - MetricsCollector metricsCollector) { + MetricsCollector metricsCollector) { return connection; } @@ -131,18 +123,18 @@ protected synchronized FrameHandlerFactory createFrameHandlerFactory() { doNothing().when(connection).start(); connectionFactory.newConnection(); - - assertThat(addressResolver.get(), allOf(notNullValue(), instanceOf(ListAddressResolver.class))); + assertThat(addressResolver.get()).isNotNull().isInstanceOf(ListAddressResolver.class); } - @Test public void shouldNotUseDnsResolutionWhenOneAddressAndTls() throws Exception { + @Test + public void shouldNotUseDnsResolutionWhenOneAddressAndTls() throws Exception { AMQConnection connection = mock(AMQConnection.class); AtomicReference addressResolver = new AtomicReference<>(); ConnectionFactory connectionFactory = new ConnectionFactory() { @Override protected AMQConnection createConnection(ConnectionParams params, FrameHandler frameHandler, - MetricsCollector metricsCollector) { + MetricsCollector metricsCollector) { return connection; } @@ -164,7 +156,42 @@ protected synchronized FrameHandlerFactory createFrameHandlerFactory() { connectionFactory.useSslProtocol(); connectionFactory.newConnection(); - assertThat(addressResolver.get(), allOf(notNullValue(), instanceOf(ListAddressResolver.class))); + assertThat(addressResolver.get()).isNotNull().isInstanceOf(ListAddressResolver.class); + } + + @Test + public void heartbeatAndChannelMaxMustBeUnsignedShorts() { + class TestConfig { + int value; + Consumer call; + boolean expectException; + + public TestConfig(int value, Consumer call, boolean expectException) { + this.value = value; + this.call = call; + this.expectException = expectException; + } + } + + ConnectionFactory cf = new ConnectionFactory(); + Consumer setHeartbeat = cf::setRequestedHeartbeat; + Consumer setChannelMax = cf::setRequestedChannelMax; + + Stream.of( + new TestConfig(0, setHeartbeat, false), + new TestConfig(10, setHeartbeat, false), + new TestConfig(65535, setHeartbeat, false), + new TestConfig(-1, setHeartbeat, true), + new TestConfig(65536, setHeartbeat, true)) + .flatMap(config -> Stream.of(config, new TestConfig(config.value, setChannelMax, config.expectException))) + .forEach(config -> { + if (config.expectException) { + assertThatThrownBy(() -> config.call.accept(config.value)).isInstanceOf(IllegalArgumentException.class); + } else { + config.call.accept(config.value); + } + }); + } } diff --git a/src/test/java/com/rabbitmq/client/test/ssl/SSLTests.java b/src/test/java/com/rabbitmq/client/test/ssl/SSLTests.java index 0dbb808584..1dddf62e38 100644 --- a/src/test/java/com/rabbitmq/client/test/ssl/SSLTests.java +++ b/src/test/java/com/rabbitmq/client/test/ssl/SSLTests.java @@ -17,6 +17,7 @@ package com.rabbitmq.client.test.ssl; import com.rabbitmq.client.test.AbstractRMQTestSuite; +import com.rabbitmq.client.test.SslContextFactoryTest; import org.junit.runner.RunWith; import org.junit.runner.Runner; import org.junit.runners.Suite; @@ -34,7 +35,8 @@ ConnectionFactoryDefaultTlsVersion.class, NioTlsUnverifiedConnection.class, HostnameVerification.class, - TlsConnectionLogging.class + TlsConnectionLogging.class, + SslContextFactoryTest.class }) public class SSLTests {