diff --git a/.changes/next-release/bugfix-NettyNIOAsyncHTTPClient-2a5d0da.json b/.changes/next-release/bugfix-NettyNIOAsyncHTTPClient-2a5d0da.json new file mode 100644 index 000000000000..333e4f02b641 --- /dev/null +++ b/.changes/next-release/bugfix-NettyNIOAsyncHTTPClient-2a5d0da.json @@ -0,0 +1,5 @@ +{ + "category": "Netty NIO Async HTTP Client", + "type": "bugfix", + "description": "Fix the Netty async client to stop publishing to the request stream once `Content-Length` is reached." +} diff --git a/http-clients/netty-nio-client/pom.xml b/http-clients/netty-nio-client/pom.xml index f4c85b29fe29..79e9a3dd25db 100644 --- a/http-clients/netty-nio-client/pom.xml +++ b/http-clients/netty-nio-client/pom.xml @@ -120,13 +120,22 @@ assertj-core test - org.reactivestreams reactive-streams-tck ${reactive-streams.version} test + + org.slf4j + slf4j-log4j12 + test + + + log4j + log4j + test + @@ -145,4 +154,4 @@ - \ No newline at end of file + diff --git a/http-clients/netty-nio-client/src/main/java/software/amazon/awssdk/http/nio/netty/internal/RunnableRequest.java b/http-clients/netty-nio-client/src/main/java/software/amazon/awssdk/http/nio/netty/internal/RunnableRequest.java index d2f70cd5d9bc..57073d10ee91 100644 --- a/http-clients/netty-nio-client/src/main/java/software/amazon/awssdk/http/nio/netty/internal/RunnableRequest.java +++ b/http-clients/netty-nio-client/src/main/java/software/amazon/awssdk/http/nio/netty/internal/RunnableRequest.java @@ -34,6 +34,7 @@ import io.netty.util.concurrent.Future; import java.net.URI; import java.nio.ByteBuffer; +import java.util.Optional; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.function.Supplier; @@ -277,16 +278,23 @@ public String toString() { /** * Decorator around {@link StreamedHttpRequest} to adapt a publisher of {@link ByteBuffer} (i.e. {@link * software.amazon.awssdk.http.async.SdkHttpRequestProvider}) to a publisher of {@link HttpContent}. + *

+ * This publisher also prevents the adapted publisher from publishing more content to the subscriber than + * the specified 'Content-Length' of the request. */ private static class StreamedRequest extends DelegateHttpRequest implements StreamedHttpRequest { - private final Publisher publisher; private final Channel channel; + private final Optional requestContentLength; + private long written = 0L; + private boolean done; + private Subscription subscription; StreamedRequest(HttpRequest request, Publisher publisher, Channel channel) { super(request); this.publisher = publisher; this.channel = channel; + this.requestContentLength = contentLength(request); } @Override @@ -294,27 +302,72 @@ public void subscribe(Subscriber subscriber) { publisher.subscribe(new Subscriber() { @Override public void onSubscribe(Subscription subscription) { + StreamedRequest.this.subscription = subscription; subscriber.onSubscribe(subscription); } @Override public void onNext(ByteBuffer byteBuffer) { + if (done) { + return; + } + + int newLimit = clampedBufferLimit(byteBuffer.remaining()); + byteBuffer.limit(newLimit); ByteBuf buffer = channel.alloc().buffer(byteBuffer.remaining()); buffer.writeBytes(byteBuffer); HttpContent content = new DefaultHttpContent(buffer); + subscriber.onNext(content); + written += newLimit; + + if (!shouldContinuePublishing()) { + done = true; + subscription.cancel(); + subscriber.onComplete(); + } } @Override public void onError(Throwable t) { - subscriber.onError(t); + if (!done) { + done = true; + subscriber.onError(t); + + } } @Override public void onComplete() { - subscriber.onComplete(); + if (!done) { + done = true; + subscriber.onComplete(); + } } }); } + + private int clampedBufferLimit(int bufLen) { + return requestContentLength.map(cl -> + (int) Math.min(cl - written, bufLen) + ).orElse(bufLen); + } + + private boolean shouldContinuePublishing() { + return requestContentLength.map(cl -> written < cl).orElse(true); + } + + private static Optional contentLength(HttpRequest request) { + String value = request.headers().get("Content-Length"); + if (value != null) { + try { + return Optional.of(Long.parseLong(value)); + } catch (NumberFormatException e) { + log.warn("Unable to parse 'Content-Length' header. Treating it as non existent."); + } + } + return Optional.empty(); + } + } } diff --git a/http-clients/netty-nio-client/src/test/java/software/amazon/awssdk/http/nio/netty/NettyNioAsyncHttpClientWireMockTest.java b/http-clients/netty-nio-client/src/test/java/software/amazon/awssdk/http/nio/netty/NettyNioAsyncHttpClientWireMockTest.java index 47a0d558b763..62159f1a9052 100644 --- a/http-clients/netty-nio-client/src/test/java/software/amazon/awssdk/http/nio/netty/NettyNioAsyncHttpClientWireMockTest.java +++ b/http-clients/netty-nio-client/src/test/java/software/amazon/awssdk/http/nio/netty/NettyNioAsyncHttpClientWireMockTest.java @@ -39,11 +39,16 @@ import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; +import com.github.tomakehurst.wiremock.http.trafficlistener.WiremockNetworkTrafficListener; import com.github.tomakehurst.wiremock.junit.WireMockRule; import io.netty.channel.EventLoopGroup; import io.netty.channel.nio.NioEventLoopGroup; + +import java.io.IOException; +import java.net.Socket; import java.net.URI; import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; import java.time.Duration; import java.util.ArrayList; import java.util.Collection; @@ -51,11 +56,14 @@ import java.util.List; import java.util.Map; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; import java.util.concurrent.ThreadFactory; import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import java.util.stream.Stream; import org.assertj.core.api.Condition; import org.junit.AfterClass; +import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; @@ -76,14 +84,23 @@ @RunWith(MockitoJUnitRunner.class) public class NettyNioAsyncHttpClientWireMockTest { + private final RecordingNetworkTrafficListener wiremockTrafficListener = new RecordingNetworkTrafficListener(); + @Rule - public WireMockRule mockServer = new WireMockRule(wireMockConfig().dynamicPort().dynamicHttpsPort()); + public WireMockRule mockServer = new WireMockRule(wireMockConfig() + .dynamicPort() + .dynamicHttpsPort() + .networkTrafficListener(wiremockTrafficListener)); @Mock private SdkRequestContext requestContext; - private static SdkAsyncHttpClient client = NettyNioAsyncHttpClient.builder() - .buildWithDefaults(mapWithTrustAllCerts()); + private static SdkAsyncHttpClient client = NettyNioAsyncHttpClient.builder().buildWithDefaults(mapWithTrustAllCerts()); + + @Before + public void methodSetup() { + wiremockTrafficListener.reset(); + } @AfterClass public static void tearDown() throws Exception { @@ -227,6 +244,30 @@ public void canSendContentAndGetThatContentBack() throws Exception { assertThat(recorder.fullResponseAsString()).isEqualTo(reverse(body)); } + @Test + public void requestContentOnlyEqualToContentLengthHeaderFromProvider() throws InterruptedException, ExecutionException, TimeoutException, IOException { + final String content = randomAlphabetic(32); + final String streamContent = content + reverse(content); + stubFor(any(urlEqualTo("/echo?reversed=true")) + .withRequestBody(equalTo(content)) + .willReturn(aResponse().withBody(reverse(content)))); + URI uri = URI.create("http://localhost:" + mockServer.port()); + + SdkHttpFullRequest request = createRequest(uri, "/echo", streamContent, SdkHttpMethod.POST, singletonMap("reversed", "true")); + request = request.toBuilder().header("Content-Length", Integer.toString(content.length())).build(); + RecordingResponseHandler recorder = new RecordingResponseHandler(); + + + client.prepareRequest(request, requestContext, createProvider(streamContent), recorder).run(); + + recorder.completeFuture.get(5, TimeUnit.SECONDS); + + // HTTP servers will stop processing the request as soon as it reads + // bytes equal to 'Content-Length' so we need to inspect the raw + // traffic to ensure that there wasn't anything after that. + assertThat(wiremockTrafficListener.requests.toString()).endsWith(content); + } + private void assertCanReceiveBasicRequest(URI uri, String body) throws Exception { stubFor(any(urlPathEqualTo("/")).willReturn(aResponse().withHeader("Some-Header", "With Value").withBody(body))); @@ -275,11 +316,11 @@ public void cancel() { }; } - private SdkHttpRequest createRequest(URI uri) { + private SdkHttpFullRequest createRequest(URI uri) { return createRequest(uri, "/", null, SdkHttpMethod.GET, emptyMap()); } - private SdkHttpRequest createRequest(URI uri, + private SdkHttpFullRequest createRequest(URI uri, String resourcePath, String body, SdkHttpMethod method, @@ -379,4 +420,33 @@ private static AttributeMap mapWithTrustAllCerts() { .put(SdkHttpConfigurationOption.TRUST_ALL_CERTIFICATES, true) .build(); } + + private static class RecordingNetworkTrafficListener implements WiremockNetworkTrafficListener { + private final StringBuilder requests = new StringBuilder(); + + + @Override + public void opened(Socket socket) { + + } + + @Override + public void incoming(Socket socket, ByteBuffer byteBuffer) { + requests.append(StandardCharsets.UTF_8.decode(byteBuffer)); + } + + @Override + public void outgoing(Socket socket, ByteBuffer byteBuffer) { + + } + + @Override + public void closed(Socket socket) { + + } + + public void reset() { + requests.setLength(0); + } + } }