Skip to content

Commit 5f2af77

Browse files
dagnirzoewangg
authored andcommitted
Validate sent body len, remove buffering
This commit removes the buffering of `RequestBody.fromContentProvider(ContentStreamProvider,long,String)`, which was used to ensure that the provider always provided the expected number of bytes given by `length`. Instead of doing the buffering, this commit updates `SdkLengthAwareInputstream` to validate the number of bytes read is correct, and throws an error if EOF is reached prematurely. We use this stream to ensure that we sent the correct amount of bytes to the server when sending the HTTP request. Note that this a weaker check because we're only testing for the *number* of bytes sent rather than their contents but has the benefit of sidestepping buffering.
1 parent c121fca commit 5f2af77

File tree

6 files changed

+311
-60
lines changed

6 files changed

+311
-60
lines changed

core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/MakeHttpRequestStage.java

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,22 +16,26 @@
1616
package software.amazon.awssdk.core.internal.http.pipeline.stages;
1717

1818
import java.time.Duration;
19+
import java.util.Optional;
1920
import software.amazon.awssdk.annotations.SdkInternalApi;
2021
import software.amazon.awssdk.core.client.config.SdkClientOption;
2122
import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute;
2223
import software.amazon.awssdk.core.internal.http.HttpClientDependencies;
2324
import software.amazon.awssdk.core.internal.http.InterruptMonitor;
2425
import software.amazon.awssdk.core.internal.http.RequestExecutionContext;
2526
import software.amazon.awssdk.core.internal.http.pipeline.RequestPipeline;
27+
import software.amazon.awssdk.core.internal.io.SdkLengthAwareInputStream;
2628
import software.amazon.awssdk.core.internal.util.MetricUtils;
2729
import software.amazon.awssdk.core.metrics.CoreMetric;
30+
import software.amazon.awssdk.http.ContentStreamProvider;
2831
import software.amazon.awssdk.http.ExecutableHttpRequest;
2932
import software.amazon.awssdk.http.HttpExecuteRequest;
3033
import software.amazon.awssdk.http.HttpExecuteResponse;
3134
import software.amazon.awssdk.http.SdkHttpClient;
3235
import software.amazon.awssdk.http.SdkHttpFullRequest;
3336
import software.amazon.awssdk.http.SdkHttpFullResponse;
3437
import software.amazon.awssdk.metrics.MetricCollector;
38+
import software.amazon.awssdk.utils.Logger;
3539
import software.amazon.awssdk.utils.Pair;
3640

3741
/**
@@ -40,6 +44,7 @@
4044
@SdkInternalApi
4145
public class MakeHttpRequestStage
4246
implements RequestPipeline<SdkHttpFullRequest, Pair<SdkHttpFullRequest, SdkHttpFullResponse>> {
47+
private static final Logger LOG = Logger.loggerFor(MakeHttpRequestStage.class);
4348

4449
private final SdkHttpClient sdkHttpClient;
4550

@@ -65,6 +70,8 @@ private HttpExecuteResponse executeHttpRequest(SdkHttpFullRequest request, Reque
6570

6671
MetricCollector httpMetricCollector = MetricUtils.createHttpMetricsCollector(context);
6772

73+
request = enforceContentLengthIfPresent(request);
74+
6875
ExecutableHttpRequest requestCallable = sdkHttpClient
6976
.prepareRequest(HttpExecuteRequest.builder()
7077
.request(request)
@@ -94,4 +101,39 @@ private static long updateMetricCollectionAttributes(RequestExecutionContext con
94101
now);
95102
return now;
96103
}
104+
105+
private static SdkHttpFullRequest enforceContentLengthIfPresent(SdkHttpFullRequest request) {
106+
Optional<ContentStreamProvider> requestContentStreamProviderOptional = request.contentStreamProvider();
107+
108+
if (!requestContentStreamProviderOptional.isPresent()) {
109+
return request;
110+
}
111+
112+
Optional<Long> contentLength = contentLength(request);
113+
if (!contentLength.isPresent()) {
114+
LOG.warn(() -> String.format("Request contains a body but does not have a Content-Length header. Not validating "
115+
+ "the amount of data sent to the service: %s", request));
116+
return request;
117+
}
118+
119+
ContentStreamProvider requestContentProvider = requestContentStreamProviderOptional.get();
120+
ContentStreamProvider lengthVerifyingProvider = () -> new SdkLengthAwareInputStream(requestContentProvider.newStream(),
121+
contentLength.get());
122+
return request.toBuilder()
123+
.contentStreamProvider(lengthVerifyingProvider)
124+
.build();
125+
}
126+
127+
private static Optional<Long> contentLength(SdkHttpFullRequest request) {
128+
Optional<String> contentLengthHeader = request.firstMatchingHeader("Content-Length");
129+
130+
if (contentLengthHeader.isPresent()) {
131+
try {
132+
return Optional.of(Long.parseLong(contentLengthHeader.get()));
133+
} catch (NumberFormatException e) {
134+
LOG.warn(() -> "Unable to parse 'Content-Length' header. Treating it as non existent.");
135+
}
136+
}
137+
return Optional.empty();
138+
}
97139
}

core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/io/SdkLengthAwareInputStream.java

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,9 @@
2525
import software.amazon.awssdk.utils.Validate;
2626

2727
/**
28-
* An {@code InputStream} that is aware of its length. The main purpose of this class is to support truncating streams to a
29-
* length that is shorter than the total length of the stream.
28+
* An {@code InputStream} that is aware of its length. This class enforces that we sent exactly the number of bytes equal to
29+
* the input length. If the wrapped stream has more bytes than the expected length, it will be truncated to length. If the stream
30+
* has less bytes (i.e. reaches EOF) before the expected length is reached, it will throw {@code IOException}.
3031
*/
3132
@SdkInternalApi
3233
public class SdkLengthAwareInputStream extends FilterInputStream {
@@ -48,8 +49,11 @@ public int read() throws IOException {
4849
}
4950

5051
int read = super.read();
52+
5153
if (read != -1) {
5254
remaining--;
55+
} else if (remaining != 0) { // EOF, ensure we've read the number of expected bytes
56+
throw new IOException("Reached EOF before reading entire expected content");
5357
}
5458
return read;
5559
}
@@ -61,12 +65,18 @@ public int read(byte[] b, int off, int len) throws IOException {
6165
return -1;
6266
}
6367

64-
len = Math.min(len, saturatedCast(remaining));
65-
int read = super.read(b, off, len);
66-
if (read > 0) {
68+
int readLen = Math.min(len, saturatedCast(remaining));
69+
70+
int read = super.read(b, off, readLen);
71+
if (read != -1) {
6772
remaining -= read;
6873
}
6974

75+
// EOF, ensure we've read the number of expected bytes
76+
if (read == -1 && remaining != 0) {
77+
throw new IOException("Reached EOF before reading entire expected content");
78+
}
79+
7080
return read;
7181
}
7282

core/sdk-core/src/main/java/software/amazon/awssdk/core/sync/RequestBody.java

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -136,21 +136,14 @@ public static RequestBody fromFile(File file) {
136136
* @return RequestBody instance.
137137
*/
138138
public static RequestBody fromInputStream(InputStream inputStream, long contentLength) {
139-
// NOTE: does not have an effect if mark not supported
140139
IoUtils.markStreamWithMaxReadLimit(inputStream);
141140
InputStream nonCloseable = nonCloseableInputStream(inputStream);
142-
ContentStreamProvider provider;
143-
if (nonCloseable.markSupported()) {
144-
// stream supports mark + reset
145-
provider = () -> {
141+
return fromContentProvider(() -> {
142+
if (nonCloseable.markSupported()) {
146143
invokeSafely(nonCloseable::reset);
147-
return nonCloseable;
148-
};
149-
} else {
150-
// stream doesn't support mark + reset, make sure to buffer it
151-
provider = new BufferingContentStreamProvider(() -> nonCloseable, contentLength);
152-
}
153-
return new RequestBody(provider, contentLength, Mimetype.MIMETYPE_OCTET_STREAM);
144+
}
145+
return nonCloseable;
146+
}, contentLength, Mimetype.MIMETYPE_OCTET_STREAM);
154147
}
155148

156149
/**
@@ -224,9 +217,6 @@ public static RequestBody empty() {
224217
/**
225218
* Creates a {@link RequestBody} from the given {@link ContentStreamProvider}.
226219
* <p>
227-
* Important: Be aware that this implementation requires buffering the contents for {@code ContentStreamProvider}, which can
228-
* cause increased memory usage.
229-
* <p>
230220
* If you are using this in conjunction with S3 and want to upload a stream with an unknown content length, you can refer
231221
* S3's documentation for
232222
* <a href="https://docs.aws.amazon.com/AmazonS3/latest/API/s3_example_s3_Scenario_UploadStream_section.html">alternative
@@ -239,7 +229,7 @@ public static RequestBody empty() {
239229
* @return The created {@code RequestBody}.
240230
*/
241231
public static RequestBody fromContentProvider(ContentStreamProvider provider, long contentLength, String mimeType) {
242-
return new RequestBody(new BufferingContentStreamProvider(provider, contentLength), contentLength, mimeType);
232+
return new RequestBody(provider, contentLength, mimeType);
243233
}
244234

245235
/**

core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/pipeline/stages/MakeHttpRequestStageTest.java

Lines changed: 102 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,42 +16,48 @@
1616
package software.amazon.awssdk.core.internal.http.pipeline.stages;
1717

1818
import static org.assertj.core.api.Assertions.assertThat;
19+
import static org.assertj.core.api.Assertions.assertThatThrownBy;
1920
import static org.mockito.ArgumentMatchers.any;
2021
import static org.mockito.ArgumentMatchers.eq;
2122
import static org.mockito.Mockito.mock;
2223
import static org.mockito.Mockito.verify;
2324
import static org.mockito.Mockito.when;
2425
import static software.amazon.awssdk.core.client.config.SdkClientOption.SYNC_HTTP_CLIENT;
26+
27+
import java.io.ByteArrayInputStream;
2528
import java.io.IOException;
26-
import org.junit.Before;
27-
import org.junit.Test;
28-
import org.junit.runner.RunWith;
29+
import java.io.InputStream;
30+
import java.util.stream.Stream;
31+
import org.junit.jupiter.api.BeforeEach;
32+
import org.junit.jupiter.api.Test;
33+
import org.junit.jupiter.params.ParameterizedTest;
34+
import org.junit.jupiter.params.provider.Arguments;
35+
import org.junit.jupiter.params.provider.MethodSource;
2936
import org.mockito.ArgumentCaptor;
30-
import org.mockito.Mock;
31-
import org.mockito.junit.MockitoJUnitRunner;
3237
import software.amazon.awssdk.core.client.config.SdkClientConfiguration;
3338
import software.amazon.awssdk.core.http.ExecutionContext;
3439
import software.amazon.awssdk.core.interceptor.ExecutionAttributes;
3540
import software.amazon.awssdk.core.internal.http.HttpClientDependencies;
3641
import software.amazon.awssdk.core.internal.http.RequestExecutionContext;
3742
import software.amazon.awssdk.core.internal.http.timers.TimeoutTracker;
43+
import software.amazon.awssdk.core.internal.io.SdkLengthAwareInputStream;
44+
import software.amazon.awssdk.http.ContentStreamProvider;
3845
import software.amazon.awssdk.http.HttpExecuteRequest;
3946
import software.amazon.awssdk.http.SdkHttpClient;
4047
import software.amazon.awssdk.http.SdkHttpFullRequest;
4148
import software.amazon.awssdk.http.SdkHttpMethod;
4249
import software.amazon.awssdk.metrics.MetricCollector;
4350
import utils.ValidSdkObjects;
4451

45-
@RunWith(MockitoJUnitRunner.class)
4652
public class MakeHttpRequestStageTest {
4753

48-
@Mock
4954
private SdkHttpClient mockClient;
5055

5156
private MakeHttpRequestStage stage;
5257

53-
@Before
58+
@BeforeEach
5459
public void setup() throws IOException {
60+
mockClient = mock(SdkHttpClient.class);
5561
SdkClientConfiguration config = SdkClientConfiguration.builder().option(SYNC_HTTP_CLIENT, mockClient).build();
5662
stage = new MakeHttpRequestStage(HttpClientDependencies.builder().clientConfiguration(config).build());
5763
}
@@ -94,4 +100,92 @@ public void testExecute_contextContainsMetricCollector_addsChildToExecuteRequest
94100
assertThat(httpRequestCaptor.getValue().metricCollector()).contains(childCollector);
95101
}
96102
}
103+
104+
@ParameterizedTest
105+
@MethodSource("contentLengthVerificationInputs")
106+
public void execute_testLengthChecking(String description,
107+
ContentStreamProvider provider,
108+
Long contentLength,
109+
boolean expectLengthAware) {
110+
SdkHttpFullRequest.Builder requestBuilder = SdkHttpFullRequest.builder()
111+
.method(SdkHttpMethod.PUT)
112+
.host("mybucket.s3.us-west-2.amazonaws.com")
113+
.protocol("https");
114+
115+
if (provider != null) {
116+
requestBuilder.contentStreamProvider(provider);
117+
}
118+
119+
if (contentLength != null) {
120+
requestBuilder.putHeader("Content-Length", String.valueOf(contentLength));
121+
}
122+
123+
when(mockClient.prepareRequest(any()))
124+
.thenThrow(new RuntimeException("BOOM"));
125+
126+
assertThatThrownBy(() -> stage.execute(requestBuilder.build(), createContext())).hasMessage("BOOM");
127+
128+
ArgumentCaptor<HttpExecuteRequest> requestCaptor = ArgumentCaptor.forClass(HttpExecuteRequest.class);
129+
130+
verify(mockClient).prepareRequest(requestCaptor.capture());
131+
132+
HttpExecuteRequest capturedRequest = requestCaptor.getValue();
133+
134+
if (provider != null) {
135+
InputStream requestContentStream = capturedRequest.contentStreamProvider().get().newStream();
136+
137+
if (expectLengthAware) {
138+
assertThat(requestContentStream).isInstanceOf(SdkLengthAwareInputStream.class);
139+
} else {
140+
assertThat(requestContentStream).isNotInstanceOf(SdkLengthAwareInputStream.class);
141+
}
142+
} else {
143+
assertThat(capturedRequest.contentStreamProvider()).isEmpty();
144+
}
145+
}
146+
147+
private static Stream<Arguments> contentLengthVerificationInputs() {
148+
return Stream.of(
149+
Arguments.of(
150+
"Provider present, ContentLength present",
151+
(ContentStreamProvider) () -> new ByteArrayInputStream(new byte[16]),
152+
16L,
153+
true
154+
),
155+
Arguments.of(
156+
"Provider present, ContentLength not present",
157+
(ContentStreamProvider) () -> new ByteArrayInputStream(new byte[16]),
158+
null,
159+
false
160+
),
161+
Arguments.of(
162+
"Provider not present, ContentLength present",
163+
null,
164+
16L,
165+
false
166+
),
167+
Arguments.of(
168+
"Provider not present, ContentLength not present",
169+
null,
170+
null,
171+
false
172+
)
173+
);
174+
}
175+
176+
private static RequestExecutionContext createContext() {
177+
ExecutionContext executionContext = ExecutionContext.builder()
178+
.executionAttributes(new ExecutionAttributes())
179+
.build();
180+
181+
RequestExecutionContext context = RequestExecutionContext.builder()
182+
.originalRequest(ValidSdkObjects.sdkRequest())
183+
.executionContext(executionContext)
184+
.build();
185+
186+
context.apiCallAttemptTimeoutTracker(mock(TimeoutTracker.class));
187+
context.apiCallTimeoutTracker(mock(TimeoutTracker.class));
188+
189+
return context;
190+
}
97191
}

0 commit comments

Comments
 (0)