Skip to content

Commit 170bc73

Browse files
authored
Honor content-length set on request body (#3123)
This commit ensures that when a RequestBody is created with an InputStream, the content-length also set on the RequestBody is honored, regardless of how much extra data is available in the stream. Note that if less data is actually available in the wrapped stream, this implmenetation will return EOF as well. Fixes #2908
1 parent 8866b9e commit 170bc73

File tree

6 files changed

+462
-4
lines changed

6 files changed

+462
-4
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
{
2+
"category": "AWS SDK for Java v2",
3+
"contributor": "",
4+
"type": "bugfix",
5+
"description": "Fix issue where the `contentLength` specified on the `RequestBody` is not honored. Fixes [#2908](https://github.com/aws/aws-sdk-java-v2/issues/2908)."
6+
}

core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/handler/BaseClientHandler.java

+16-4
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import software.amazon.awssdk.core.interceptor.InterceptorContext;
3535
import software.amazon.awssdk.core.interceptor.SdkExecutionAttribute;
3636
import software.amazon.awssdk.core.internal.InternalCoreExecutionAttribute;
37+
import software.amazon.awssdk.core.internal.io.SdkLengthAwareInputStream;
3738
import software.amazon.awssdk.core.internal.util.MetricUtils;
3839
import software.amazon.awssdk.core.metrics.CoreMetric;
3940
import software.amazon.awssdk.core.sync.RequestBody;
@@ -119,11 +120,22 @@ private static void addHttpRequest(ExecutionContext executionContext, SdkHttpFul
119120
}
120121

121122
private static RequestBody getBody(SdkHttpFullRequest request) {
122-
Optional<ContentStreamProvider> contentStreamProvider = request.contentStreamProvider();
123-
if (contentStreamProvider.isPresent()) {
124-
long contentLength = Long.parseLong(request.firstMatchingHeader("Content-Length").orElse("0"));
123+
Optional<ContentStreamProvider> contentStreamProviderOptional = request.contentStreamProvider();
124+
if (contentStreamProviderOptional.isPresent()) {
125+
Optional<String> contentLengthOptional = request.firstMatchingHeader("Content-Length");
126+
long contentLength = Long.parseLong(contentLengthOptional.orElse("0"));
125127
String contentType = request.firstMatchingHeader("Content-Type").orElse("");
126-
return RequestBody.fromContentProvider(contentStreamProvider.get(), contentLength, contentType);
128+
129+
// Enforce the content length specified only if it was present on the request (and not the default).
130+
ContentStreamProvider streamProvider = contentStreamProviderOptional.get();
131+
if (contentLengthOptional.isPresent()) {
132+
ContentStreamProvider toWrap = contentStreamProviderOptional.get();
133+
streamProvider = () -> new SdkLengthAwareInputStream(toWrap.newStream(), contentLength);
134+
}
135+
136+
return RequestBody.fromContentProvider(streamProvider,
137+
contentLength,
138+
contentType);
127139
}
128140

129141
return null;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License").
5+
* You may not use this file except in compliance with the License.
6+
* A copy of the License is located at
7+
*
8+
* http://aws.amazon.com/apache2.0
9+
*
10+
* or in the "license" file accompanying this file. This file is distributed
11+
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
12+
* express or implied. See the License for the specific language governing
13+
* permissions and limitations under the License.
14+
*/
15+
16+
package software.amazon.awssdk.core.internal.io;
17+
18+
import static software.amazon.awssdk.utils.NumericUtils.saturatedCast;
19+
20+
import java.io.FilterInputStream;
21+
import java.io.IOException;
22+
import java.io.InputStream;
23+
import software.amazon.awssdk.annotations.SdkInternalApi;
24+
import software.amazon.awssdk.utils.Logger;
25+
import software.amazon.awssdk.utils.Validate;
26+
27+
/**
28+
* An {@code InputStream} that is aware of its length. The main purpose of this class is to support truncating streams to a
29+
* length that is shorter than the total length of the stream.
30+
*/
31+
@SdkInternalApi
32+
public class SdkLengthAwareInputStream extends FilterInputStream {
33+
private static final Logger LOG = Logger.loggerFor(SdkLengthAwareInputStream.class);
34+
private long length;
35+
private long remaining;
36+
37+
public SdkLengthAwareInputStream(InputStream in, long length) {
38+
super(in);
39+
this.length = Validate.isNotNegative(length, "length");
40+
this.remaining = this.length;
41+
}
42+
43+
@Override
44+
public int read() throws IOException {
45+
if (!hasMoreBytes()) {
46+
LOG.debug(() -> String.format("Specified InputStream length of %d has been reached. Returning EOF.", length));
47+
return -1;
48+
}
49+
50+
int read = super.read();
51+
if (read != -1) {
52+
remaining--;
53+
}
54+
return read;
55+
}
56+
57+
@Override
58+
public int read(byte[] b, int off, int len) throws IOException {
59+
if (!hasMoreBytes()) {
60+
LOG.debug(() -> String.format("Specified InputStream length of %d has been reached. Returning EOF.", length));
61+
return -1;
62+
}
63+
64+
len = Math.min(len, saturatedCast(remaining));
65+
int read = super.read(b, off, len);
66+
if (read > 0) {
67+
remaining -= read;
68+
}
69+
70+
return read;
71+
}
72+
73+
@Override
74+
public long skip(long requestedBytesToSkip) throws IOException {
75+
requestedBytesToSkip = Math.min(requestedBytesToSkip, remaining);
76+
long skippedActual = super.skip(requestedBytesToSkip);
77+
remaining -= skippedActual;
78+
return skippedActual;
79+
}
80+
81+
@Override
82+
public int available() throws IOException {
83+
int streamAvailable = super.available();
84+
return Math.min(streamAvailable, saturatedCast(remaining));
85+
}
86+
87+
@Override
88+
public void mark(int readlimit) {
89+
super.mark(readlimit);
90+
// mark() causes reset() to change the stream's position back to the current position. Therefore, when reset() is called,
91+
// the new length of the stream will be equal to the current value of 'remaining'.
92+
length = remaining;
93+
}
94+
95+
@Override
96+
public void reset() throws IOException {
97+
super.reset();
98+
remaining = length;
99+
}
100+
101+
public long remaining() {
102+
return remaining;
103+
}
104+
105+
private boolean hasMoreBytes() {
106+
return remaining > 0;
107+
}
108+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License").
5+
* You may not use this file except in compliance with the License.
6+
* A copy of the License is located at
7+
*
8+
* http://aws.amazon.com/apache2.0
9+
*
10+
* or in the "license" file accompanying this file. This file is distributed
11+
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
12+
* express or implied. See the License for the specific language governing
13+
* permissions and limitations under the License.
14+
*/
15+
16+
package software.amazon.awssdk.core.internal.io;
17+
18+
import static org.assertj.core.api.Assertions.assertThat;
19+
import static org.mockito.ArgumentMatchers.any;
20+
import static org.mockito.Mockito.mock;
21+
import static org.mockito.Mockito.verify;
22+
import static org.mockito.Mockito.when;
23+
24+
import java.io.ByteArrayInputStream;
25+
import java.io.IOException;
26+
import java.io.InputStream;
27+
import org.junit.jupiter.api.BeforeEach;
28+
import org.junit.jupiter.api.Test;
29+
30+
class SdkLengthAwareInputStreamTest {
31+
private InputStream delegateStream;
32+
33+
@BeforeEach
34+
void setup() {
35+
delegateStream = mock(InputStream.class);
36+
}
37+
38+
@Test
39+
void read_lengthIs0_returnsEof() throws IOException {
40+
when(delegateStream.available()).thenReturn(Integer.MAX_VALUE);
41+
42+
SdkLengthAwareInputStream is = new SdkLengthAwareInputStream(delegateStream, 0);
43+
44+
assertThat(is.read()).isEqualTo(-1);
45+
assertThat(is.read(new byte[16], 0, 16)).isEqualTo(-1);
46+
}
47+
48+
@Test
49+
void read_lengthNonZero_delegateEof_returnsEof() throws IOException {
50+
when(delegateStream.read()).thenReturn(-1);
51+
when(delegateStream.read(any(byte[].class), any(int.class), any(int.class))).thenReturn(-1);
52+
53+
SdkLengthAwareInputStream is = new SdkLengthAwareInputStream(delegateStream, 0);
54+
55+
assertThat(is.read()).isEqualTo(-1);
56+
assertThat(is.read(new byte[16], 0, 16)).isEqualTo(-1);
57+
}
58+
59+
@Test
60+
void readByte_lengthNonZero_delegateHasAvailable_returnsDelegateData() throws IOException {
61+
when(delegateStream.read()).thenReturn(42);
62+
63+
SdkLengthAwareInputStream is = new SdkLengthAwareInputStream(delegateStream, 16);
64+
65+
assertThat(is.read()).isEqualTo(42);
66+
}
67+
68+
@Test
69+
void readArray_lengthNonZero_delegateHasAvailable_returnsDelegateData() throws IOException {
70+
when(delegateStream.read(any(byte[].class), any(int.class), any(int.class))).thenReturn(8);
71+
72+
SdkLengthAwareInputStream is = new SdkLengthAwareInputStream(delegateStream, 16);
73+
74+
assertThat(is.read(new byte[16], 0, 16)).isEqualTo(8);
75+
}
76+
77+
@Test
78+
void readArray_lengthNonZero_propagatesCallToDelegate() throws IOException {
79+
when(delegateStream.read(any(byte[].class), any(int.class), any(int.class))).thenReturn(8);
80+
81+
SdkLengthAwareInputStream is = new SdkLengthAwareInputStream(delegateStream, 16);
82+
byte[] buff = new byte[16];
83+
is.read(buff, 0, 16);
84+
85+
verify(delegateStream).read(buff, 0, 16);
86+
}
87+
88+
@Test
89+
void read_markAndReset_availableReflectsNewLength() throws IOException {
90+
delegateStream = new ByteArrayInputStream(new byte[32]);
91+
92+
SdkLengthAwareInputStream is = new SdkLengthAwareInputStream(delegateStream, 16);
93+
94+
for (int i = 0; i < 4; ++i) {
95+
is.read();
96+
}
97+
assertThat(is.available()).isEqualTo(12);
98+
99+
is.mark(16);
100+
101+
for (int i = 0; i < 4; ++i) {
102+
is.read();
103+
}
104+
assertThat(is.available()).isEqualTo(8);
105+
106+
is.reset();
107+
108+
assertThat(is.available()).isEqualTo(12);
109+
}
110+
111+
@Test
112+
void skip_markAndReset_availableReflectsNewLength() throws IOException {
113+
delegateStream = new ByteArrayInputStream(new byte[32]);
114+
115+
SdkLengthAwareInputStream is = new SdkLengthAwareInputStream(delegateStream, 16);
116+
117+
is.skip(4);
118+
119+
assertThat(is.remaining()).isEqualTo(12);
120+
121+
is.mark(16);
122+
123+
for (int i = 0; i < 4; ++i) {
124+
is.read();
125+
}
126+
127+
assertThat(is.remaining()).isEqualTo(8);
128+
129+
is.reset();
130+
131+
assertThat(is.remaining()).isEqualTo(12);
132+
}
133+
134+
@Test
135+
void skip_delegateSkipsLessThanRequested_availableUpdatedCorrectly() throws IOException {
136+
when(delegateStream.skip(any(long.class))).thenAnswer(i -> {
137+
Long n = i.getArgument(0, Long.class);
138+
return n / 2;
139+
});
140+
141+
when(delegateStream.read(any(byte[].class), any(int.class), any(int.class))).thenReturn(1);
142+
143+
SdkLengthAwareInputStream is = new SdkLengthAwareInputStream(delegateStream, 16);
144+
145+
long skipped = is.skip(4);
146+
147+
assertThat(skipped).isEqualTo(2);
148+
assertThat(is.remaining()).isEqualTo(14);
149+
}
150+
151+
@Test
152+
void readArray_delegateReadsLessThanRequested_availableUpdatedCorrectly() throws IOException {
153+
when(delegateStream.read(any(byte[].class), any(int.class), any(int.class))).thenAnswer(i -> {
154+
Integer n = i.getArgument(2, Integer.class);
155+
return n / 2;
156+
});
157+
158+
SdkLengthAwareInputStream is = new SdkLengthAwareInputStream(delegateStream, 16);
159+
160+
long read = is.read(new byte[16], 0, 8);
161+
162+
assertThat(read).isEqualTo(4);
163+
assertThat(is.remaining()).isEqualTo(12);
164+
}
165+
166+
@Test
167+
void read_delegateAtEof_returnsEof() throws IOException {
168+
when(delegateStream.read()).thenReturn(-1);
169+
170+
SdkLengthAwareInputStream is = new SdkLengthAwareInputStream(delegateStream, 16);
171+
172+
assertThat(is.read()).isEqualTo(-1);
173+
}
174+
175+
@Test
176+
void readArray_delegateAtEof_returnsEof() throws IOException {
177+
when(delegateStream.read(any(byte[].class), any(int.class), any(int.class))).thenReturn(-1);
178+
179+
SdkLengthAwareInputStream is = new SdkLengthAwareInputStream(delegateStream, 16);
180+
181+
assertThat(is.read(new byte[8], 0, 8)).isEqualTo(-1);
182+
}
183+
}

test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/SyncHttpChecksumInTrailerTest.java

+31
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,12 @@
3333
import com.github.tomakehurst.wiremock.junit.WireMockRule;
3434
import com.github.tomakehurst.wiremock.stubbing.Scenario;
3535
import com.github.tomakehurst.wiremock.verification.LoggedRequest;
36+
import java.io.ByteArrayInputStream;
37+
import java.io.ByteArrayOutputStream;
3638
import java.net.URI;
39+
import java.nio.ByteBuffer;
40+
import java.nio.charset.StandardCharsets;
41+
import java.util.Base64;
3742
import java.util.List;
3843
import org.junit.Before;
3944
import org.junit.Rule;
@@ -42,8 +47,11 @@
4247
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
4348
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
4449
import software.amazon.awssdk.core.HttpChecksumConstant;
50+
import software.amazon.awssdk.core.checksums.Algorithm;
51+
import software.amazon.awssdk.core.checksums.SdkChecksum;
4552
import software.amazon.awssdk.core.sync.RequestBody;
4653
import software.amazon.awssdk.core.sync.ResponseTransformer;
54+
import software.amazon.awssdk.http.ContentStreamProvider;
4755
import software.amazon.awssdk.regions.Region;
4856
import software.amazon.awssdk.services.protocolrestjson.ProtocolRestJsonClient;
4957
import software.amazon.awssdk.services.protocolrestjson.model.ChecksumAlgorithm;
@@ -86,6 +94,29 @@ public void sync_streaming_NoSigner_appends_trailer_checksum() {
8694
+ "x-amz-checksum-crc32:i9aeUg==" + CRLF + CRLF)));
8795
}
8896

97+
@Test
98+
public void sync_streaming_specifiedLengthIsLess_NoSigner_appends_trailer_checksum() {
99+
stubResponseWithHeaders();
100+
101+
ContentStreamProvider provider = () -> new ByteArrayInputStream("Hello world".getBytes(StandardCharsets.UTF_8));
102+
// length of 5 truncates to "Hello"
103+
RequestBody requestBody = RequestBody.fromContentProvider(provider, 5, "text/plain");
104+
client.putOperationWithChecksum(r -> r.checksumAlgorithm(ChecksumAlgorithm.CRC32),
105+
requestBody,
106+
ResponseTransformer.toBytes());
107+
verify(putRequestedFor(anyUrl()).withHeader(CONTENT_LENGTH, equalTo("46")));
108+
verify(putRequestedFor(anyUrl()).withHeader(HttpChecksumConstant.HEADER_FOR_TRAILER_REFERENCE, equalTo("x-amz-checksum-crc32")));
109+
verify(putRequestedFor(anyUrl()).withHeader("x-amz-content-sha256", equalTo("STREAMING-UNSIGNED-PAYLOAD-TRAILER")));
110+
verify(putRequestedFor(anyUrl()).withHeader("x-amz-decoded-content-length", equalTo("5")));
111+
verify(putRequestedFor(anyUrl()).withHeader("Content-Encoding", equalTo("aws-chunked")));
112+
verify(putRequestedFor(anyUrl()).withRequestBody(
113+
containing(
114+
"5" + CRLF + "Hello" + CRLF
115+
+ "0" + CRLF
116+
// 99GJgg== is the base64 encoded CRC32 of "Hello"
117+
+ "x-amz-checksum-crc32:99GJgg==" + CRLF + CRLF)));
118+
}
119+
89120
@Test
90121
public void syncStreaming_withRetry_NoSigner_shouldContainChecksum_fromInterceptors() {
91122
stubForFailureThenSuccess(500, "500");

0 commit comments

Comments
 (0)