diff --git a/projects/RabbitMQ.Client/client/api/AmqpTcpEndpoint.cs b/projects/RabbitMQ.Client/client/api/AmqpTcpEndpoint.cs index 386ee3ba79..0cb31c3f7f 100644 --- a/projects/RabbitMQ.Client/client/api/AmqpTcpEndpoint.cs +++ b/projects/RabbitMQ.Client/client/api/AmqpTcpEndpoint.cs @@ -175,6 +175,11 @@ public IProtocol Protocol /// public SslOption Ssl { get; set; } + /// + /// Set the maximum size for a message in bytes. The default value is 0 (unlimited) + /// + public uint MaxMessageSize { get; set; } + /// /// Construct an instance from a protocol and an address in "hostname:port" format. /// diff --git a/projects/RabbitMQ.Client/client/impl/Frame.cs b/projects/RabbitMQ.Client/client/impl/Frame.cs index 3311a3ed4b..e83c39418e 100644 --- a/projects/RabbitMQ.Client/client/impl/Frame.cs +++ b/projects/RabbitMQ.Client/client/impl/Frame.cs @@ -205,7 +205,7 @@ private static void ProcessProtocolHeader(Stream reader) } } - internal static InboundFrame ReadFrom(Stream reader, byte[] frameHeaderBuffer, ArrayPool pool) + internal static InboundFrame ReadFrom(Stream reader, byte[] frameHeaderBuffer, ArrayPool pool, uint maxMessageSize) { int type = default; try @@ -239,7 +239,11 @@ internal static InboundFrame ReadFrom(Stream reader, byte[] frameHeaderBuffer, A reader.Read(frameHeaderBuffer, 0, frameHeaderBuffer.Length); int channel = NetworkOrderDeserializer.ReadUInt16(new ReadOnlySpan(frameHeaderBuffer)); - int payloadSize = NetworkOrderDeserializer.ReadInt32(new ReadOnlySpan(frameHeaderBuffer, 2, 4)); // FIXME - throw exn on unreasonable value + int payloadSize = NetworkOrderDeserializer.ReadInt32(new ReadOnlySpan(frameHeaderBuffer, 2, 4)); + if ((maxMessageSize > 0) && (payloadSize > maxMessageSize)) + { + throw new MalformedFrameException($"Frame payload size '{payloadSize}' exceeds maximum of '{maxMessageSize}' bytes"); + } const int EndMarkerLength = 1; // Is returned by InboundFrame.Dispose in Connection.MainLoopIteration diff --git a/projects/RabbitMQ.Client/client/impl/SocketFrameHandler.cs b/projects/RabbitMQ.Client/client/impl/SocketFrameHandler.cs index b9e965421a..141e02aee0 100644 --- a/projects/RabbitMQ.Client/client/impl/SocketFrameHandler.cs +++ b/projects/RabbitMQ.Client/client/impl/SocketFrameHandler.cs @@ -62,6 +62,7 @@ public static async Task TimeoutAfter(this Task task, TimeSpan timeout) class SocketFrameHandler : IFrameHandler { + private readonly AmqpTcpEndpoint _endpoint; // Socket poll timeout in ms. If the socket does not // become writeable in this amount of time, we throw // an exception. @@ -82,7 +83,7 @@ public SocketFrameHandler(AmqpTcpEndpoint endpoint, Func socketFactory, TimeSpan connectionTimeout, TimeSpan readTimeout, TimeSpan writeTimeout) { - Endpoint = endpoint; + _endpoint = endpoint; _frameHeaderBuffer = new byte[6]; var channel = Channel.CreateUnbounded>( new UnboundedChannelOptions @@ -135,7 +136,11 @@ public SocketFrameHandler(AmqpTcpEndpoint endpoint, WriteTimeout = writeTimeout; _writerTask = Task.Run(WriteLoop, CancellationToken.None); } - public AmqpTcpEndpoint Endpoint { get; set; } + + public AmqpTcpEndpoint Endpoint + { + get { return _endpoint; } + } internal ArrayPool MemoryPool { @@ -229,7 +234,7 @@ public void Close() public InboundFrame ReadFrame() { - return InboundFrame.ReadFrom(_reader, _frameHeaderBuffer, MemoryPool); + return InboundFrame.ReadFrom(_reader, _frameHeaderBuffer, MemoryPool, _endpoint.MaxMessageSize); } public void SendHeader() diff --git a/projects/Unit/APIApproval.Approve.verified.txt b/projects/Unit/APIApproval.Approve.verified.txt index 97ef85bc11..c13b7a270f 100644 --- a/projects/Unit/APIApproval.Approve.verified.txt +++ b/projects/Unit/APIApproval.Approve.verified.txt @@ -13,6 +13,7 @@ namespace RabbitMQ.Client public AmqpTcpEndpoint(string hostName, int portOrMinusOne, RabbitMQ.Client.SslOption ssl) { } public System.Net.Sockets.AddressFamily AddressFamily { get; set; } public string HostName { get; set; } + public uint MaxMessageSize { get; set; } public int Port { get; set; } public RabbitMQ.Client.IProtocol Protocol { get; } public RabbitMQ.Client.SslOption Ssl { get; set; }