diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/interceptor/HttpChecksumInHeaderInterceptor.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/interceptor/HttpChecksumInHeaderInterceptor.java index 0ddf70959cae..f3c92a254bec 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/interceptor/HttpChecksumInHeaderInterceptor.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/interceptor/HttpChecksumInHeaderInterceptor.java @@ -15,7 +15,6 @@ package software.amazon.awssdk.core.internal.interceptor; -import static software.amazon.awssdk.core.HttpChecksumConstant.HTTP_CHECKSUM_VALUE; import static software.amazon.awssdk.core.HttpChecksumConstant.SIGNING_METHOD; import static software.amazon.awssdk.core.internal.util.HttpChecksumResolver.getResolvedChecksumSpecs; @@ -23,7 +22,6 @@ import java.io.UncheckedIOException; import java.util.Optional; import software.amazon.awssdk.annotations.SdkInternalApi; -import software.amazon.awssdk.core.checksums.Algorithm; import software.amazon.awssdk.core.checksums.ChecksumSpecs; import software.amazon.awssdk.core.interceptor.Context; import software.amazon.awssdk.core.interceptor.ExecutionAttributes; @@ -47,49 +45,27 @@ @SdkInternalApi public class HttpChecksumInHeaderInterceptor implements ExecutionInterceptor { - @Override - public void afterMarshalling(Context.AfterMarshalling context, ExecutionAttributes executionAttributes) { - ChecksumSpecs headerChecksumSpecs = HttpChecksumUtils.checksumSpecWithRequestAlgorithm(executionAttributes).orElse(null); - - if (shouldSkipHttpChecksumInHeader(context, executionAttributes, headerChecksumSpecs)) { - return; - } - Optional syncContent = context.requestBody(); - syncContent.ifPresent( - requestBody -> saveContentChecksum(requestBody, executionAttributes, headerChecksumSpecs.algorithm())); - } - - @Override - public SdkHttpRequest modifyHttpRequest(Context.ModifyHttpRequest context, ExecutionAttributes executionAttributes) { - ChecksumSpecs checksumSpecs = getResolvedChecksumSpecs(executionAttributes); - - if (shouldSkipHttpChecksumInHeader(context, executionAttributes, checksumSpecs)) { - return context.httpRequest(); - } - - String httpChecksumValue = executionAttributes.getAttribute(HTTP_CHECKSUM_VALUE); - if (httpChecksumValue != null) { - return context.httpRequest().copy(r -> r.putHeader(checksumSpecs.headerName(), httpChecksumValue)); - } - return context.httpRequest(); - - } - /** - * Calculates the checksumSpecs of the provided request (and base64 encodes it), storing the result in - * executionAttribute "HttpChecksumValue". + * Calculates the checksum of the provided request (and base64 encodes it), and adds the header to the request. * *

Note: This assumes that the content stream provider can create multiple new streams. If it only supports one (e.g. with * an input stream that doesn't support mark/reset), we could consider buffering the content in memory here and updating the * request body to use that buffered content. We obviously don't want to do that for giant streams, so we haven't opted to do * that yet. */ - private static void saveContentChecksum(RequestBody requestBody, ExecutionAttributes executionAttributes, - Algorithm algorithm) { + @Override + public SdkHttpRequest modifyHttpRequest(Context.ModifyHttpRequest context, ExecutionAttributes executionAttributes) { + ChecksumSpecs checksumSpecs = getResolvedChecksumSpecs(executionAttributes); + Optional syncContent = context.requestBody(); + + if (shouldSkipHttpChecksumInHeader(context, executionAttributes, checksumSpecs) || !syncContent.isPresent()) { + return context.httpRequest(); + } + try { String payloadChecksum = BinaryUtils.toBase64(HttpChecksumUtils.computeChecksum( - requestBody.contentStreamProvider().newStream(), algorithm)); - executionAttributes.putAttribute(HTTP_CHECKSUM_VALUE, payloadChecksum); + syncContent.get().contentStreamProvider().newStream(), checksumSpecs.algorithm())); + return context.httpRequest().copy(r -> r.putHeader(checksumSpecs.headerName(), payloadChecksum)); } catch (IOException e) { throw new UncheckedIOException(e); } diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/interceptor/HttpChecksumRequiredInterceptor.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/interceptor/HttpChecksumRequiredInterceptor.java index c98cde397f0c..9729cd2076d7 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/interceptor/HttpChecksumRequiredInterceptor.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/interceptor/HttpChecksumRequiredInterceptor.java @@ -21,7 +21,6 @@ import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.core.async.AsyncRequestBody; import software.amazon.awssdk.core.interceptor.Context; -import software.amazon.awssdk.core.interceptor.ExecutionAttribute; import software.amazon.awssdk.core.interceptor.ExecutionAttributes; import software.amazon.awssdk.core.interceptor.ExecutionInterceptor; import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; @@ -41,10 +40,17 @@ */ @SdkInternalApi public class HttpChecksumRequiredInterceptor implements ExecutionInterceptor { - private static final ExecutionAttribute CONTENT_MD5_VALUE = new ExecutionAttribute<>("ContentMd5"); + /** + * Calculates the MD5 checksum of the provided request (and base64 encodes it), and adds the header to the request. + * + *

Note: This assumes that the content stream provider can create multiple new streams. If it only supports one (e.g. with + * an input stream that doesn't support mark/reset), we could consider buffering the content in memory here and updating the + * request body to use that buffered content. We obviously don't want to do that for giant streams, so we haven't opted to do + * that yet. + */ @Override - public void afterMarshalling(Context.AfterMarshalling context, ExecutionAttributes executionAttributes) { + public SdkHttpRequest modifyHttpRequest(Context.ModifyHttpRequest context, ExecutionAttributes executionAttributes) { boolean isHttpChecksumRequired = isHttpChecksumRequired(executionAttributes); boolean requestAlreadyHasMd5 = context.httpRequest().firstMatchingHeader(Header.CONTENT_MD5).isPresent(); @@ -52,7 +58,7 @@ public void afterMarshalling(Context.AfterMarshalling context, ExecutionAttribut Optional asyncContent = context.asyncRequestBody(); if (!isHttpChecksumRequired || requestAlreadyHasMd5) { - return; + return context.httpRequest(); } if (asyncContent.isPresent()) { @@ -60,14 +66,13 @@ public void afterMarshalling(Context.AfterMarshalling context, ExecutionAttribut + "for non-blocking content."); } - syncContent.ifPresent(requestBody -> saveContentMd5(requestBody, executionAttributes)); - } - - @Override - public SdkHttpRequest modifyHttpRequest(Context.ModifyHttpRequest context, ExecutionAttributes executionAttributes) { - String contentMd5 = executionAttributes.getAttribute(CONTENT_MD5_VALUE); - if (contentMd5 != null) { - return context.httpRequest().copy(r -> r.putHeader(Header.CONTENT_MD5, contentMd5)); + if (syncContent.isPresent()) { + try { + String payloadMd5 = Md5Utils.md5AsBase64(syncContent.get().contentStreamProvider().newStream()); + return context.httpRequest().copy(r -> r.putHeader(Header.CONTENT_MD5, payloadMd5)); + } catch (IOException e) { + throw new UncheckedIOException(e); + } } return context.httpRequest(); } @@ -76,22 +81,4 @@ private boolean isHttpChecksumRequired(ExecutionAttributes executionAttributes) return executionAttributes.getAttribute(SdkInternalExecutionAttribute.HTTP_CHECKSUM_REQUIRED) != null || HttpChecksumUtils.isMd5ChecksumRequired(executionAttributes); } - - /** - * Calculates the MD5 checksum of the provided request (and base64 encodes it), storing the result in - * {@link #CONTENT_MD5_VALUE}. - * - *

Note: This assumes that the content stream provider can create multiple new streams. If it only supports one (e.g. with - * an input stream that doesn't support mark/reset), we could consider buffering the content in memory here and updating the - * request body to use that buffered content. We obviously don't want to do that for giant streams, so we haven't opted to do - * that yet. - */ - private void saveContentMd5(RequestBody requestBody, ExecutionAttributes executionAttributes) { - try { - String payloadMd5 = Md5Utils.md5AsBase64(requestBody.contentStreamProvider().newStream()); - executionAttributes.putAttribute(CONTENT_MD5_VALUE, payloadMd5); - } catch (IOException e) { - throw new UncheckedIOException(e); - } - } } diff --git a/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/HttpChecksumInHeaderTest.java b/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/HttpChecksumInHeaderTest.java index 029a19447047..8ba47dee79cf 100644 --- a/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/HttpChecksumInHeaderTest.java +++ b/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/HttpChecksumInHeaderTest.java @@ -17,7 +17,6 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; -import static software.amazon.awssdk.core.HttpChecksumConstant.HTTP_CHECKSUM_VALUE; import io.reactivex.Flowable; import java.io.IOException; @@ -28,7 +27,6 @@ import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.stream.Collectors; -import org.junit.After; import org.junit.Before; import org.junit.Test; import org.mockito.ArgumentCaptor; @@ -38,9 +36,6 @@ import software.amazon.awssdk.awscore.client.builder.AwsClientBuilder; import software.amazon.awssdk.awscore.client.builder.AwsSyncClientBuilder; import software.amazon.awssdk.core.checksums.Algorithm; -import software.amazon.awssdk.core.interceptor.Context; -import software.amazon.awssdk.core.interceptor.ExecutionAttributes; -import software.amazon.awssdk.core.interceptor.ExecutionInterceptor; import software.amazon.awssdk.http.ExecutableHttpRequest; import software.amazon.awssdk.http.HttpExecuteRequest; import software.amazon.awssdk.http.HttpExecuteResponse; @@ -103,11 +98,6 @@ public void setup() throws IOException { }); } - @After - public void clear() { - CaptureChecksumValueInterceptor.reset(); - } - @Test public void sync_json_nonStreaming_unsignedPayload_with_Sha1_in_header() { // jsonClient.flexibleCheckSumOperationWithShaChecksum(r -> r.stringMember("Hello world")); @@ -118,9 +108,6 @@ public void sync_json_nonStreaming_unsignedPayload_with_Sha1_in_header() { assertThat(getSyncRequest().firstMatchingHeader("x-amz-checksum-sha1")).hasValue("M68rRwFal7o7B3KEMt3m0w39TaA="); // Assertion to make sure signer was not executed assertThat(getSyncRequest().firstMatchingHeader("x-amz-content-sha256")).isNotPresent(); - - assertThat(CaptureChecksumValueInterceptor.interceptorComputedChecksum).isEqualTo("M68rRwFal7o7B3KEMt3m0w39TaA="); - } @Test @@ -133,9 +120,6 @@ public void aync_json_nonStreaming_unsignedPayload_with_Sha1_in_header() { assertThat(getAsyncRequest().firstMatchingHeader("x-amz-checksum-sha1")).hasValue("M68rRwFal7o7B3KEMt3m0w39TaA="); // Assertion to make sure signer was not executed assertThat(getAsyncRequest().firstMatchingHeader("x-amz-content-sha256")).isNotPresent(); - assertThat(CaptureChecksumValueInterceptor.interceptorComputedChecksum).isEqualTo("M68rRwFal7o7B3KEMt3m0w39TaA="); - - } @Test @@ -148,9 +132,6 @@ public void sync_xml_nonStreaming_unsignedPayload_with_Sha1_in_header() { assertThat(getSyncRequest().firstMatchingHeader("x-amz-checksum-sha1")).hasValue("FB/utBbwFLbIIt5ul3Ojuy5dKgU="); // Assertion to make sure signer was not executed assertThat(getSyncRequest().firstMatchingHeader("x-amz-content-sha256")).isNotPresent(); - - assertThat(CaptureChecksumValueInterceptor.interceptorComputedChecksum).isEqualTo("FB/utBbwFLbIIt5ul3Ojuy5dKgU="); - } @Test @@ -169,9 +150,6 @@ public void sync_xml_nonStreaming_unsignedEmptyPayload_with_Sha1_in_header() { // Assertion to make sure signer was not executed assertThat(getSyncRequest().firstMatchingHeader("x-amz-content-sha256")).isNotPresent(); - - assertThat(CaptureChecksumValueInterceptor.interceptorComputedChecksum).isNull(); - } @Test @@ -185,8 +163,6 @@ public void aync_xml_nonStreaming_unsignedPayload_with_Sha1_in_header() { assertThat(getAsyncRequest().firstMatchingHeader("x-amz-checksum-sha1")).hasValue("FB/utBbwFLbIIt5ul3Ojuy5dKgU="); // Assertion to make sure signer was not executed assertThat(getAsyncRequest().firstMatchingHeader("x-amz-content-sha256")).isNotPresent(); - assertThat(CaptureChecksumValueInterceptor.interceptorComputedChecksum).isEqualTo("FB/utBbwFLbIIt5ul3Ojuy5dKgU="); - } @Test @@ -206,8 +182,6 @@ public void aync_xml_nonStreaming_unsignedEmptyPayload_with_Sha1_in_header() { assertThat(getAsyncRequest().firstMatchingHeader("x-amz-checksum-sha1")).isNotPresent(); // Assertion to make sure signer was not executed assertThat(getAsyncRequest().firstMatchingHeader("x-amz-content-sha256")).isNotPresent(); - assertThat(CaptureChecksumValueInterceptor.interceptorComputedChecksum).isNull(); - } private SdkHttpRequest getSyncRequest() { @@ -224,32 +198,15 @@ private SdkHttpRequest getAsyncRequest() { private & AwsClientBuilder> T initializeSync(T syncClientBuilder) { - return initialize(syncClientBuilder.httpClient(httpClient) - .overrideConfiguration(o -> o.addExecutionInterceptor(new CaptureChecksumValueInterceptor()))); + return initialize(syncClientBuilder.httpClient(httpClient)); } private & AwsClientBuilder> T initializeAsync(T asyncClientBuilder) { - return initialize(asyncClientBuilder.httpClient(httpAsyncClient) - .overrideConfiguration(o -> o.addExecutionInterceptor(new CaptureChecksumValueInterceptor()))); + return initialize(asyncClientBuilder.httpClient(httpAsyncClient)); } private > T initialize(T clientBuilder) { return clientBuilder.credentialsProvider(AnonymousCredentialsProvider.create()) .region(Region.US_WEST_2); } - - - private static class CaptureChecksumValueInterceptor implements ExecutionInterceptor { - private static String interceptorComputedChecksum; - - private static void reset() { - interceptorComputedChecksum = null; - } - - @Override - public void beforeTransmission(Context.BeforeTransmission context, ExecutionAttributes executionAttributes) { - interceptorComputedChecksum = executionAttributes.getAttribute(HTTP_CHECKSUM_VALUE); - - } - } } \ No newline at end of file