diff --git a/.changes/next-release/bugfix-NettyNIOAsyncHTTPClient-2a5d0da.json b/.changes/next-release/bugfix-NettyNIOAsyncHTTPClient-2a5d0da.json
new file mode 100644
index 000000000000..333e4f02b641
--- /dev/null
+++ b/.changes/next-release/bugfix-NettyNIOAsyncHTTPClient-2a5d0da.json
@@ -0,0 +1,5 @@
+{
+ "category": "Netty NIO Async HTTP Client",
+ "type": "bugfix",
+ "description": "Fix the Netty async client to stop publishing to the request stream once `Content-Length` is reached."
+}
diff --git a/http-clients/netty-nio-client/pom.xml b/http-clients/netty-nio-client/pom.xml
index f4c85b29fe29..79e9a3dd25db 100644
--- a/http-clients/netty-nio-client/pom.xml
+++ b/http-clients/netty-nio-client/pom.xml
@@ -120,13 +120,22 @@
assertj-core
test
-
org.reactivestreams
reactive-streams-tck
${reactive-streams.version}
test
+
+ org.slf4j
+ slf4j-log4j12
+ test
+
+
+ log4j
+ log4j
+ test
+
@@ -145,4 +154,4 @@
-
\ No newline at end of file
+
diff --git a/http-clients/netty-nio-client/src/main/java/software/amazon/awssdk/http/nio/netty/internal/RunnableRequest.java b/http-clients/netty-nio-client/src/main/java/software/amazon/awssdk/http/nio/netty/internal/RunnableRequest.java
index d2f70cd5d9bc..57073d10ee91 100644
--- a/http-clients/netty-nio-client/src/main/java/software/amazon/awssdk/http/nio/netty/internal/RunnableRequest.java
+++ b/http-clients/netty-nio-client/src/main/java/software/amazon/awssdk/http/nio/netty/internal/RunnableRequest.java
@@ -34,6 +34,7 @@
import io.netty.util.concurrent.Future;
import java.net.URI;
import java.nio.ByteBuffer;
+import java.util.Optional;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.function.Supplier;
@@ -277,16 +278,23 @@ public String toString() {
/**
* Decorator around {@link StreamedHttpRequest} to adapt a publisher of {@link ByteBuffer} (i.e. {@link
* software.amazon.awssdk.http.async.SdkHttpRequestProvider}) to a publisher of {@link HttpContent}.
+ *
+ * This publisher also prevents the adapted publisher from publishing more content to the subscriber than
+ * the specified 'Content-Length' of the request.
*/
private static class StreamedRequest extends DelegateHttpRequest implements StreamedHttpRequest {
-
private final Publisher publisher;
private final Channel channel;
+ private final Optional requestContentLength;
+ private long written = 0L;
+ private boolean done;
+ private Subscription subscription;
StreamedRequest(HttpRequest request, Publisher publisher, Channel channel) {
super(request);
this.publisher = publisher;
this.channel = channel;
+ this.requestContentLength = contentLength(request);
}
@Override
@@ -294,27 +302,72 @@ public void subscribe(Subscriber super HttpContent> subscriber) {
publisher.subscribe(new Subscriber() {
@Override
public void onSubscribe(Subscription subscription) {
+ StreamedRequest.this.subscription = subscription;
subscriber.onSubscribe(subscription);
}
@Override
public void onNext(ByteBuffer byteBuffer) {
+ if (done) {
+ return;
+ }
+
+ int newLimit = clampedBufferLimit(byteBuffer.remaining());
+ byteBuffer.limit(newLimit);
ByteBuf buffer = channel.alloc().buffer(byteBuffer.remaining());
buffer.writeBytes(byteBuffer);
HttpContent content = new DefaultHttpContent(buffer);
+
subscriber.onNext(content);
+ written += newLimit;
+
+ if (!shouldContinuePublishing()) {
+ done = true;
+ subscription.cancel();
+ subscriber.onComplete();
+ }
}
@Override
public void onError(Throwable t) {
- subscriber.onError(t);
+ if (!done) {
+ done = true;
+ subscriber.onError(t);
+
+ }
}
@Override
public void onComplete() {
- subscriber.onComplete();
+ if (!done) {
+ done = true;
+ subscriber.onComplete();
+ }
}
});
}
+
+ private int clampedBufferLimit(int bufLen) {
+ return requestContentLength.map(cl ->
+ (int) Math.min(cl - written, bufLen)
+ ).orElse(bufLen);
+ }
+
+ private boolean shouldContinuePublishing() {
+ return requestContentLength.map(cl -> written < cl).orElse(true);
+ }
+
+ private static Optional contentLength(HttpRequest request) {
+ String value = request.headers().get("Content-Length");
+ if (value != null) {
+ try {
+ return Optional.of(Long.parseLong(value));
+ } catch (NumberFormatException e) {
+ log.warn("Unable to parse 'Content-Length' header. Treating it as non existent.");
+ }
+ }
+ return Optional.empty();
+ }
+
}
}
diff --git a/http-clients/netty-nio-client/src/test/java/software/amazon/awssdk/http/nio/netty/NettyNioAsyncHttpClientWireMockTest.java b/http-clients/netty-nio-client/src/test/java/software/amazon/awssdk/http/nio/netty/NettyNioAsyncHttpClientWireMockTest.java
index 47a0d558b763..62159f1a9052 100644
--- a/http-clients/netty-nio-client/src/test/java/software/amazon/awssdk/http/nio/netty/NettyNioAsyncHttpClientWireMockTest.java
+++ b/http-clients/netty-nio-client/src/test/java/software/amazon/awssdk/http/nio/netty/NettyNioAsyncHttpClientWireMockTest.java
@@ -39,11 +39,16 @@
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
+import com.github.tomakehurst.wiremock.http.trafficlistener.WiremockNetworkTrafficListener;
import com.github.tomakehurst.wiremock.junit.WireMockRule;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
+
+import java.io.IOException;
+import java.net.Socket;
import java.net.URI;
import java.nio.ByteBuffer;
+import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collection;
@@ -51,11 +56,14 @@
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutionException;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
import java.util.stream.Stream;
import org.assertj.core.api.Condition;
import org.junit.AfterClass;
+import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
@@ -76,14 +84,23 @@
@RunWith(MockitoJUnitRunner.class)
public class NettyNioAsyncHttpClientWireMockTest {
+ private final RecordingNetworkTrafficListener wiremockTrafficListener = new RecordingNetworkTrafficListener();
+
@Rule
- public WireMockRule mockServer = new WireMockRule(wireMockConfig().dynamicPort().dynamicHttpsPort());
+ public WireMockRule mockServer = new WireMockRule(wireMockConfig()
+ .dynamicPort()
+ .dynamicHttpsPort()
+ .networkTrafficListener(wiremockTrafficListener));
@Mock
private SdkRequestContext requestContext;
- private static SdkAsyncHttpClient client = NettyNioAsyncHttpClient.builder()
- .buildWithDefaults(mapWithTrustAllCerts());
+ private static SdkAsyncHttpClient client = NettyNioAsyncHttpClient.builder().buildWithDefaults(mapWithTrustAllCerts());
+
+ @Before
+ public void methodSetup() {
+ wiremockTrafficListener.reset();
+ }
@AfterClass
public static void tearDown() throws Exception {
@@ -227,6 +244,30 @@ public void canSendContentAndGetThatContentBack() throws Exception {
assertThat(recorder.fullResponseAsString()).isEqualTo(reverse(body));
}
+ @Test
+ public void requestContentOnlyEqualToContentLengthHeaderFromProvider() throws InterruptedException, ExecutionException, TimeoutException, IOException {
+ final String content = randomAlphabetic(32);
+ final String streamContent = content + reverse(content);
+ stubFor(any(urlEqualTo("/echo?reversed=true"))
+ .withRequestBody(equalTo(content))
+ .willReturn(aResponse().withBody(reverse(content))));
+ URI uri = URI.create("http://localhost:" + mockServer.port());
+
+ SdkHttpFullRequest request = createRequest(uri, "/echo", streamContent, SdkHttpMethod.POST, singletonMap("reversed", "true"));
+ request = request.toBuilder().header("Content-Length", Integer.toString(content.length())).build();
+ RecordingResponseHandler recorder = new RecordingResponseHandler();
+
+
+ client.prepareRequest(request, requestContext, createProvider(streamContent), recorder).run();
+
+ recorder.completeFuture.get(5, TimeUnit.SECONDS);
+
+ // HTTP servers will stop processing the request as soon as it reads
+ // bytes equal to 'Content-Length' so we need to inspect the raw
+ // traffic to ensure that there wasn't anything after that.
+ assertThat(wiremockTrafficListener.requests.toString()).endsWith(content);
+ }
+
private void assertCanReceiveBasicRequest(URI uri, String body) throws Exception {
stubFor(any(urlPathEqualTo("/")).willReturn(aResponse().withHeader("Some-Header", "With Value").withBody(body)));
@@ -275,11 +316,11 @@ public void cancel() {
};
}
- private SdkHttpRequest createRequest(URI uri) {
+ private SdkHttpFullRequest createRequest(URI uri) {
return createRequest(uri, "/", null, SdkHttpMethod.GET, emptyMap());
}
- private SdkHttpRequest createRequest(URI uri,
+ private SdkHttpFullRequest createRequest(URI uri,
String resourcePath,
String body,
SdkHttpMethod method,
@@ -379,4 +420,33 @@ private static AttributeMap mapWithTrustAllCerts() {
.put(SdkHttpConfigurationOption.TRUST_ALL_CERTIFICATES, true)
.build();
}
+
+ private static class RecordingNetworkTrafficListener implements WiremockNetworkTrafficListener {
+ private final StringBuilder requests = new StringBuilder();
+
+
+ @Override
+ public void opened(Socket socket) {
+
+ }
+
+ @Override
+ public void incoming(Socket socket, ByteBuffer byteBuffer) {
+ requests.append(StandardCharsets.UTF_8.decode(byteBuffer));
+ }
+
+ @Override
+ public void outgoing(Socket socket, ByteBuffer byteBuffer) {
+
+ }
+
+ @Override
+ public void closed(Socket socket) {
+
+ }
+
+ public void reset() {
+ requests.setLength(0);
+ }
+ }
}