Skip to content

Commit 650a023

Browse files
committed
Make ChecksumValidatingInputStream implement close(). Preserve abort semantics through interceptor modification.
1 parent d51b3e3 commit 650a023

File tree

5 files changed

+70
-7
lines changed

5 files changed

+70
-7
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
{
2+
"category": "AWS SDK for Java v2",
3+
"type": "bugfix",
4+
"description": "Fixed an issue where close() and abort() weren't being honored for streaming responses in all cases."
5+
}

http-client-spi/src/main/java/software/amazon/awssdk/http/AbortableInputStream.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ public static AbortableInputStream create(InputStream delegate, Abortable aborta
5454
* @return a new instance of AbortableInputStream
5555
*/
5656
public static AbortableInputStream create(InputStream delegate) {
57+
if (delegate instanceof Abortable) {
58+
return new AbortableInputStream(delegate, (Abortable) delegate);
59+
}
5760
return new AbortableInputStream(delegate, () -> { });
5861
}
5962

services/s3/src/main/java/software/amazon/awssdk/services/s3/checksums/ChecksumValidatingInputStream.java

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@
2121
import software.amazon.awssdk.annotations.SdkInternalApi;
2222
import software.amazon.awssdk.core.checksums.SdkChecksum;
2323
import software.amazon.awssdk.core.exception.SdkClientException;
24+
import software.amazon.awssdk.http.Abortable;
2425

2526
@SdkInternalApi
26-
public class ChecksumValidatingInputStream extends InputStream {
27+
public class ChecksumValidatingInputStream extends InputStream implements Abortable {
2728
private static final int CHECKSUM_SIZE = 16;
2829

2930
private final SdkChecksum checkSum;
@@ -146,6 +147,18 @@ public synchronized void reset() throws IOException {
146147
}
147148
}
148149

150+
@Override
151+
public void abort() {
152+
if (inputStream instanceof Abortable) {
153+
((Abortable) inputStream).abort();
154+
}
155+
}
156+
157+
@Override
158+
public void close() throws IOException {
159+
inputStream.close();
160+
}
161+
149162
/**
150163
* Gets the stream's checksum as an integer.
151164
*
@@ -165,4 +178,5 @@ private void validateAndThrow() {
165178
computedChecksumInt, streamChecksumInt)).build();
166179
}
167180
}
181+
168182
}

services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/handlers/SyncChecksumValidationInterceptor.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,14 @@
3333
import software.amazon.awssdk.core.interceptor.ExecutionInterceptor;
3434
import software.amazon.awssdk.core.interceptor.SdkExecutionAttribute;
3535
import software.amazon.awssdk.core.sync.RequestBody;
36-
import software.amazon.awssdk.http.SdkHttpResponse;
3736
import software.amazon.awssdk.services.s3.S3Configuration;
3837
import software.amazon.awssdk.services.s3.checksums.ChecksumCalculatingInputStream;
3938
import software.amazon.awssdk.services.s3.checksums.ChecksumValidatingInputStream;
4039
import software.amazon.awssdk.services.s3.model.GetObjectRequest;
4140
import software.amazon.awssdk.services.s3.model.PutObjectRequest;
4241
import software.amazon.awssdk.services.s3.model.PutObjectResponse;
4342
import software.amazon.awssdk.utils.BinaryUtils;
43+
import software.amazon.awssdk.utils.internal.Base16;
4444
import software.amazon.awssdk.utils.internal.Base16Lower;
4545

4646
@SdkInternalApi
@@ -76,7 +76,6 @@ public Optional<InputStream> modifyHttpResponseContent(Context.ModifyHttpRespons
7676
ExecutionAttributes executionAttributes) {
7777

7878
if (context.request() instanceof GetObjectRequest && checksumValidationEnabled(executionAttributes)) {
79-
SdkHttpResponse originalResponse = context.httpResponse();
8079
SdkChecksum checksum = new Md5Checksum();
8180

8281
int contentLength = Integer.valueOf(context.httpResponse().firstMatchingHeader(CONTENT_LENGTH_HEADER).orElse("0"));
@@ -103,7 +102,8 @@ public void afterUnmarshalling(Context.AfterUnmarshalling context, ExecutionAttr
103102

104103
if (!Arrays.equals(digest, ssHash)) {
105104
throw SdkClientException.create(String.format("Data read has a different checksum than expected. " +
106-
"Was %d, but expected %d", digest, ssHash));
105+
"Was 0x%s, but expected 0x%s",
106+
Base16.encodeAsString(digest), Base16.encodeAsString(ssHash)));
107107
}
108108
}
109109
}

test/protocol-tests/src/test/java/software/amazon/awssdk/protocol/tests/ResponseTransformerTest.java

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat;
2525

2626
import com.github.tomakehurst.wiremock.junit.WireMockRule;
27+
import com.github.tomakehurst.wiremock.stubbing.StubMapping;
2728
import java.io.ByteArrayOutputStream;
2829
import java.io.File;
2930
import java.io.IOException;
@@ -43,6 +44,7 @@
4344
import software.amazon.awssdk.http.apache.ApacheHttpClient;
4445
import software.amazon.awssdk.regions.Region;
4546
import software.amazon.awssdk.services.protocolrestjson.ProtocolRestJsonClient;
47+
import software.amazon.awssdk.services.protocolrestjson.ProtocolRestJsonClientBuilder;
4648
import software.amazon.awssdk.services.protocolrestjson.model.StreamingOutputOperationRequest;
4749
import software.amazon.awssdk.services.protocolrestjson.model.StreamingOutputOperationResponse;
4850
import software.amazon.awssdk.utils.BinaryUtils;
@@ -58,7 +60,7 @@ public class ResponseTransformerTest {
5860

5961
@Test
6062
public void bytesMethodConvertsCorrectly() {
61-
stubFor(post(urlPathEqualTo(STREAMING_OUTPUT_PATH)).willReturn(aResponse().withStatus(200).withBody("test \uD83D\uDE02")));
63+
stubForSuccess();
6264

6365
ResponseBytes<StreamingOutputOperationResponse> response =
6466
testClient().streamingOutputOperationAsBytes(StreamingOutputOperationRequest.builder().build());
@@ -120,6 +122,38 @@ public void downloadToOutputStreamDoesNotRetry() throws IOException {
120122
.isInstanceOf(SdkClientException.class);
121123
}
122124

125+
@Test
126+
public void streamingCloseActuallyCloses() throws IOException {
127+
stubForSuccess();
128+
129+
ProtocolRestJsonClient client = testClientBuilder()
130+
.httpClientBuilder(ApacheHttpClient.builder()
131+
.connectionAcquisitionTimeout(Duration.ofSeconds(1))
132+
.maxConnections(1))
133+
.build();
134+
135+
136+
// Two successful requests with a max of one connection means that closing the connection worked.
137+
client.streamingOutputOperation(StreamingOutputOperationRequest.builder().build()).close();
138+
client.streamingOutputOperation(StreamingOutputOperationRequest.builder().build()).close();
139+
}
140+
141+
@Test
142+
public void streamingAbortActuallyAborts() {
143+
stubForSuccess();
144+
145+
ProtocolRestJsonClient client = testClientBuilder()
146+
.httpClientBuilder(ApacheHttpClient.builder()
147+
.connectionAcquisitionTimeout(Duration.ofSeconds(1))
148+
.maxConnections(1))
149+
.build();
150+
151+
152+
// Two successful requests with a max of one connection means that closing the connection worked.
153+
client.streamingOutputOperation(StreamingOutputOperationRequest.builder().build()).abort();
154+
client.streamingOutputOperation(StreamingOutputOperationRequest.builder().build()).abort();
155+
}
156+
123157
private void stubForRetriesTimeoutReadingFromStreams() {
124158
stubFor(post(urlPathEqualTo(STREAMING_OUTPUT_PATH)).inScenario("retries")
125159
.whenScenarioStateIs(STARTED)
@@ -133,11 +167,18 @@ private void stubForRetriesTimeoutReadingFromStreams() {
133167
}
134168

135169
private ProtocolRestJsonClient testClient() {
170+
return testClientBuilder().build();
171+
}
172+
173+
private ProtocolRestJsonClientBuilder testClientBuilder() {
136174
return ProtocolRestJsonClient.builder()
137175
.region(Region.US_WEST_1)
138176
.endpointOverride(URI.create("http://localhost:" + wireMock.port()))
139177
.credentialsProvider(() -> AwsBasicCredentials.create("akid", "skid"))
140-
.httpClientBuilder(ApacheHttpClient.builder().socketTimeout(Duration.ofSeconds(1)))
141-
.build();
178+
.httpClientBuilder(ApacheHttpClient.builder().socketTimeout(Duration.ofSeconds(1)));
179+
}
180+
181+
private StubMapping stubForSuccess() {
182+
return stubFor(post(urlPathEqualTo(STREAMING_OUTPUT_PATH)).willReturn(aResponse().withStatus(200).withBody("test \uD83D\uDE02")));
142183
}
143184
}

0 commit comments

Comments
 (0)