Skip to content

Commit fe827cb

Browse files
committed
Stop publishing after Content-Length bytes
Fixes #460
1 parent 899f873 commit fe827cb

File tree

4 files changed

+147
-10
lines changed

4 files changed

+147
-10
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
{
2+
"category": "Netty NIO Async HTTP Client",
3+
"type": "bugfix",
4+
"description": "Fix the Netty async client to stop publishing to the request stream once `Content-Length` is reached."
5+
}

http-clients/netty-nio-client/pom.xml

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,13 +120,22 @@
120120
<artifactId>assertj-core</artifactId>
121121
<scope>test</scope>
122122
</dependency>
123-
124123
<dependency>
125124
<groupId>org.reactivestreams</groupId>
126125
<artifactId>reactive-streams-tck</artifactId>
127126
<version>${reactive-streams.version}</version>
128127
<scope>test</scope>
129128
</dependency>
129+
<dependency>
130+
<groupId>org.slf4j</groupId>
131+
<artifactId>slf4j-log4j12</artifactId>
132+
<scope>test</scope>
133+
</dependency>
134+
<dependency>
135+
<groupId>log4j</groupId>
136+
<artifactId>log4j</artifactId>
137+
<scope>test</scope>
138+
</dependency>
130139
</dependencies>
131140

132141
<build>
@@ -145,4 +154,4 @@
145154
</plugin>
146155
</plugins>
147156
</build>
148-
</project>
157+
</project>

http-clients/netty-nio-client/src/main/java/software/amazon/awssdk/http/nio/netty/internal/RunnableRequest.java

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import io.netty.util.concurrent.Future;
3535
import java.net.URI;
3636
import java.nio.ByteBuffer;
37+
import java.util.Optional;
3738
import java.util.concurrent.TimeUnit;
3839
import java.util.concurrent.TimeoutException;
3940
import java.util.function.Supplier;
@@ -277,44 +278,96 @@ public String toString() {
277278
/**
278279
* Decorator around {@link StreamedHttpRequest} to adapt a publisher of {@link ByteBuffer} (i.e. {@link
279280
* software.amazon.awssdk.http.async.SdkHttpRequestProvider}) to a publisher of {@link HttpContent}.
281+
* <p />
282+
* This publisher also prevents the adapted publisher from publishing more content to the subscriber than
283+
* the specified 'Content-Length' of the request.
280284
*/
281285
private static class StreamedRequest extends DelegateHttpRequest implements StreamedHttpRequest {
282-
283286
private final Publisher<ByteBuffer> publisher;
284287
private final Channel channel;
288+
private final Optional<Long> requestContentLength;
289+
private long written = 0L;
290+
private boolean done;
291+
private Subscription subscription;
285292

286293
StreamedRequest(HttpRequest request, Publisher<ByteBuffer> publisher, Channel channel) {
287294
super(request);
288295
this.publisher = publisher;
289296
this.channel = channel;
297+
this.requestContentLength = contentLength(request);
290298
}
291299

292300
@Override
293301
public void subscribe(Subscriber<? super HttpContent> subscriber) {
294302
publisher.subscribe(new Subscriber<ByteBuffer>() {
295303
@Override
296304
public void onSubscribe(Subscription subscription) {
305+
StreamedRequest.this.subscription = subscription;
297306
subscriber.onSubscribe(subscription);
298307
}
299308

300309
@Override
301310
public void onNext(ByteBuffer byteBuffer) {
311+
if (done) {
312+
return;
313+
}
314+
315+
int newLimit = clampedBufferLimit(byteBuffer.remaining());
316+
byteBuffer.limit(newLimit);
302317
ByteBuf buffer = channel.alloc().buffer(byteBuffer.remaining());
303318
buffer.writeBytes(byteBuffer);
304319
HttpContent content = new DefaultHttpContent(buffer);
320+
305321
subscriber.onNext(content);
322+
written += newLimit;
323+
324+
if (!shouldContinuePublishing()) {
325+
done = true;
326+
subscription.cancel();
327+
subscriber.onComplete();
328+
}
306329
}
307330

308331
@Override
309332
public void onError(Throwable t) {
310-
subscriber.onError(t);
333+
if (!done) {
334+
done = true;
335+
subscriber.onError(t);
336+
337+
}
311338
}
312339

313340
@Override
314341
public void onComplete() {
315-
subscriber.onComplete();
342+
if (!done) {
343+
done = true;
344+
subscriber.onComplete();
345+
}
316346
}
317347
});
318348
}
349+
350+
private int clampedBufferLimit(int bufLen) {
351+
return requestContentLength.map(cl ->
352+
(int) Math.min(cl - written, bufLen)
353+
).orElse(bufLen);
354+
}
355+
356+
private boolean shouldContinuePublishing() {
357+
return requestContentLength.map(cl -> written < cl).orElse(true);
358+
}
359+
360+
private static Optional<Long> contentLength(HttpRequest request) {
361+
String value = request.headers().get("Content-Length");
362+
if (value != null) {
363+
try {
364+
return Optional.of(Long.parseLong(value));
365+
} catch (NumberFormatException e) {
366+
log.warn("Unable to parse 'Content-Length' header. Treating it as non existent.");
367+
}
368+
}
369+
return Optional.empty();
370+
}
371+
319372
}
320373
}

http-clients/netty-nio-client/src/test/java/software/amazon/awssdk/http/nio/netty/NettyNioAsyncHttpClientWireMockTest.java

Lines changed: 75 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,23 +39,31 @@
3939
import static org.mockito.Mockito.spy;
4040
import static org.mockito.Mockito.times;
4141

42+
import com.github.tomakehurst.wiremock.http.trafficlistener.WiremockNetworkTrafficListener;
4243
import com.github.tomakehurst.wiremock.junit.WireMockRule;
4344
import io.netty.channel.EventLoopGroup;
4445
import io.netty.channel.nio.NioEventLoopGroup;
46+
47+
import java.io.IOException;
48+
import java.net.Socket;
4549
import java.net.URI;
4650
import java.nio.ByteBuffer;
51+
import java.nio.charset.StandardCharsets;
4752
import java.time.Duration;
4853
import java.util.ArrayList;
4954
import java.util.Collection;
5055
import java.util.Collections;
5156
import java.util.List;
5257
import java.util.Map;
5358
import java.util.concurrent.CompletableFuture;
59+
import java.util.concurrent.ExecutionException;
5460
import java.util.concurrent.ThreadFactory;
5561
import java.util.concurrent.TimeUnit;
62+
import java.util.concurrent.TimeoutException;
5663
import java.util.stream.Stream;
5764
import org.assertj.core.api.Condition;
5865
import org.junit.AfterClass;
66+
import org.junit.Before;
5967
import org.junit.Rule;
6068
import org.junit.Test;
6169
import org.junit.runner.RunWith;
@@ -76,14 +84,23 @@
7684
@RunWith(MockitoJUnitRunner.class)
7785
public class NettyNioAsyncHttpClientWireMockTest {
7886

87+
private final RecordingNetworkTrafficListener wiremockTrafficListener = new RecordingNetworkTrafficListener();
88+
7989
@Rule
80-
public WireMockRule mockServer = new WireMockRule(wireMockConfig().dynamicPort().dynamicHttpsPort());
90+
public WireMockRule mockServer = new WireMockRule(wireMockConfig()
91+
.dynamicPort()
92+
.dynamicHttpsPort()
93+
.networkTrafficListener(wiremockTrafficListener));
8194

8295
@Mock
8396
private SdkRequestContext requestContext;
8497

85-
private static SdkAsyncHttpClient client = NettyNioAsyncHttpClient.builder()
86-
.buildWithDefaults(mapWithTrustAllCerts());
98+
private static SdkAsyncHttpClient client = NettyNioAsyncHttpClient.builder().buildWithDefaults(mapWithTrustAllCerts());
99+
100+
@Before
101+
public void methodSetup() {
102+
wiremockTrafficListener.reset();
103+
}
87104

88105
@AfterClass
89106
public static void tearDown() throws Exception {
@@ -227,6 +244,30 @@ public void canSendContentAndGetThatContentBack() throws Exception {
227244
assertThat(recorder.fullResponseAsString()).isEqualTo(reverse(body));
228245
}
229246

247+
@Test
248+
public void requestContentOnlyEqualToContentLengthHeaderFromProvider() throws InterruptedException, ExecutionException, TimeoutException, IOException {
249+
final String content = randomAlphabetic(32);
250+
final String streamContent = content + reverse(content);
251+
stubFor(any(urlEqualTo("/echo?reversed=true"))
252+
.withRequestBody(equalTo(content))
253+
.willReturn(aResponse().withBody(reverse(content))));
254+
URI uri = URI.create("http://localhost:" + mockServer.port());
255+
256+
SdkHttpFullRequest request = createRequest(uri, "/echo", streamContent, SdkHttpMethod.POST, singletonMap("reversed", "true"));
257+
request = request.toBuilder().header("Content-Length", Integer.toString(content.length())).build();
258+
RecordingResponseHandler recorder = new RecordingResponseHandler();
259+
260+
261+
client.prepareRequest(request, requestContext, createProvider(streamContent), recorder).run();
262+
263+
recorder.completeFuture.get(5, TimeUnit.SECONDS);
264+
265+
// HTTP servers will stop processing the request as soon as it reads
266+
// bytes equal to 'Content-Length' so we need to inspect the raw
267+
// traffic to ensure that there wasn't anything after that.
268+
assertThat(wiremockTrafficListener.requests.toString()).endsWith(content);
269+
}
270+
230271
private void assertCanReceiveBasicRequest(URI uri, String body) throws Exception {
231272
stubFor(any(urlPathEqualTo("/")).willReturn(aResponse().withHeader("Some-Header", "With Value").withBody(body)));
232273

@@ -275,11 +316,11 @@ public void cancel() {
275316
};
276317
}
277318

278-
private SdkHttpRequest createRequest(URI uri) {
319+
private SdkHttpFullRequest createRequest(URI uri) {
279320
return createRequest(uri, "/", null, SdkHttpMethod.GET, emptyMap());
280321
}
281322

282-
private SdkHttpRequest createRequest(URI uri,
323+
private SdkHttpFullRequest createRequest(URI uri,
283324
String resourcePath,
284325
String body,
285326
SdkHttpMethod method,
@@ -379,4 +420,33 @@ private static AttributeMap mapWithTrustAllCerts() {
379420
.put(SdkHttpConfigurationOption.TRUST_ALL_CERTIFICATES, true)
380421
.build();
381422
}
423+
424+
private static class RecordingNetworkTrafficListener implements WiremockNetworkTrafficListener {
425+
private final StringBuilder requests = new StringBuilder();
426+
427+
428+
@Override
429+
public void opened(Socket socket) {
430+
431+
}
432+
433+
@Override
434+
public void incoming(Socket socket, ByteBuffer byteBuffer) {
435+
requests.append(StandardCharsets.UTF_8.decode(byteBuffer));
436+
}
437+
438+
@Override
439+
public void outgoing(Socket socket, ByteBuffer byteBuffer) {
440+
441+
}
442+
443+
@Override
444+
public void closed(Socket socket) {
445+
446+
}
447+
448+
public void reset() {
449+
requests.setLength(0);
450+
}
451+
}
382452
}

0 commit comments

Comments
 (0)