diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SplittingPublisher.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SplittingPublisher.java new file mode 100644 index 000000000000..095d69ac5e7d --- /dev/null +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SplittingPublisher.java @@ -0,0 +1,298 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.internal.async; + +import java.nio.ByteBuffer; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.core.async.SdkPublisher; +import software.amazon.awssdk.utils.Logger; +import software.amazon.awssdk.utils.Validate; +import software.amazon.awssdk.utils.async.SimplePublisher; + +/** + * Splits an {@link SdkPublisher} to multiple smaller {@link AsyncRequestBody}s, each of which publishes a specific portion of the + * original data. + * // TODO: create a default method in AsyncRequestBody for this + * // TODO: fix the case where content length is null + */ +@SdkInternalApi +public class SplittingPublisher implements SdkPublisher { + private static final Logger log = Logger.loggerFor(SplittingPublisher.class); + private final AsyncRequestBody upstreamPublisher; + private final SplittingSubscriber splittingSubscriber; + private final SimplePublisher downstreamPublisher = new SimplePublisher<>(); + private final long chunkSizeInBytes; + private final long maxMemoryUsageInBytes; + private final CompletableFuture future; + + private SplittingPublisher(Builder builder) { + this.upstreamPublisher = Validate.paramNotNull(builder.asyncRequestBody, "asyncRequestBody"); + this.chunkSizeInBytes = Validate.paramNotNull(builder.chunkSizeInBytes, "chunkSizeInBytes"); + this.splittingSubscriber = new SplittingSubscriber(upstreamPublisher.contentLength().orElse(null)); + this.maxMemoryUsageInBytes = builder.maxMemoryUsageInBytes == null ? Long.MAX_VALUE : builder.maxMemoryUsageInBytes; + this.future = builder.future; + + // We need to cancel upstream subscription if the future gets cancelled. + future.whenComplete((r, t) -> { + if (t != null) { + if (splittingSubscriber.upstreamSubscription != null) { + log.trace(() -> "Cancelling subscription because return future completed exceptionally ", t); + splittingSubscriber.upstreamSubscription.cancel(); + } + } + }); + } + + public static Builder builder() { + return new Builder(); + } + + @Override + public void subscribe(Subscriber downstreamSubscriber) { + downstreamPublisher.subscribe(downstreamSubscriber); + upstreamPublisher.subscribe(splittingSubscriber); + } + + private class SplittingSubscriber implements Subscriber { + private Subscription upstreamSubscription; + private final Long upstreamSize; + private final AtomicInteger chunkNumber = new AtomicInteger(0); + private volatile DownstreamBody currentBody; + private final AtomicBoolean hasOpenUpstreamDemand = new AtomicBoolean(false); + private final AtomicLong dataBuffered = new AtomicLong(0); + + /** + * A hint to determine whether we will exceed maxMemoryUsage by the next OnNext call. + */ + private int byteBufferSizeHint; + + SplittingSubscriber(Long upstreamSize) { + this.upstreamSize = upstreamSize; + } + + @Override + public void onSubscribe(Subscription s) { + this.upstreamSubscription = s; + this.currentBody = new DownstreamBody(calculateChunkSize(), chunkNumber.get()); + sendCurrentBody(); + // We need to request subscription *after* we set currentBody because onNext could be invoked right away. + upstreamSubscription.request(1); + } + + @Override + public void onNext(ByteBuffer byteBuffer) { + hasOpenUpstreamDemand.set(false); + byteBufferSizeHint = byteBuffer.remaining(); + + while (true) { + int amountRemainingInPart = amountRemainingInPart(); + int finalAmountRemainingInPart = amountRemainingInPart; + if (amountRemainingInPart == 0) { + currentBody.complete(); + int currentChunk = chunkNumber.incrementAndGet(); + Long partSize = calculateChunkSize(); + currentBody = new DownstreamBody(partSize, currentChunk); + sendCurrentBody(); + } + + amountRemainingInPart = amountRemainingInPart(); + if (amountRemainingInPart >= byteBuffer.remaining()) { + currentBody.send(byteBuffer.duplicate()); + break; + } + + ByteBuffer firstHalf = byteBuffer.duplicate(); + int newLimit = firstHalf.position() + amountRemainingInPart; + firstHalf.limit(newLimit); + byteBuffer.position(newLimit); + currentBody.send(firstHalf); + } + + maybeRequestMoreUpstreamData(); + } + + private int amountRemainingInPart() { + return Math.toIntExact(currentBody.totalLength - currentBody.transferredLength); + } + + @Override + public void onComplete() { + log.trace(() -> "Received onComplete()"); + downstreamPublisher.complete().thenRun(() -> future.complete(null)); + currentBody.complete(); + } + + @Override + public void onError(Throwable t) { + currentBody.error(t); + } + + private void sendCurrentBody() { + downstreamPublisher.send(currentBody).exceptionally(t -> { + downstreamPublisher.error(t); + return null; + }); + } + + private Long calculateChunkSize() { + Long dataRemaining = dataRemaining(); + if (dataRemaining == null) { + return null; + } + + return Math.min(chunkSizeInBytes, dataRemaining); + } + + private void maybeRequestMoreUpstreamData() { + long buffered = dataBuffered.get(); + if (shouldRequestMoreData(buffered) && + hasOpenUpstreamDemand.compareAndSet(false, true)) { + log.trace(() -> "Requesting more data, current data buffered: " + buffered); + upstreamSubscription.request(1); + } + } + + private boolean shouldRequestMoreData(long buffered) { + return buffered == 0 || buffered + byteBufferSizeHint < maxMemoryUsageInBytes; + } + + private Long dataRemaining() { + if (upstreamSize == null) { + return null; + } + return upstreamSize - (chunkNumber.get() * chunkSizeInBytes); + } + + private class DownstreamBody implements AsyncRequestBody { + private final Long totalLength; + private final SimplePublisher delegate = new SimplePublisher<>(); + private final int chunkNumber; + private volatile long transferredLength = 0; + + private DownstreamBody(Long totalLength, int chunkNumber) { + this.totalLength = totalLength; + this.chunkNumber = chunkNumber; + } + + @Override + public Optional contentLength() { + return Optional.ofNullable(totalLength); + } + + public void send(ByteBuffer data) { + log.trace(() -> "Sending bytebuffer " + data); + int length = data.remaining(); + transferredLength += length; + addDataBuffered(length); + delegate.send(data).whenComplete((r, t) -> { + addDataBuffered(-length); + if (t != null) { + error(t); + } + }); + } + + public void complete() { + log.debug(() -> "Received complete() for chunk number: " + chunkNumber); + delegate.complete(); + } + + public void error(Throwable error) { + delegate.error(error); + } + + @Override + public void subscribe(Subscriber s) { + delegate.subscribe(s); + } + + private void addDataBuffered(int length) { + dataBuffered.addAndGet(length); + if (length < 0) { + maybeRequestMoreUpstreamData(); + } + } + } + } + + public static final class Builder { + private AsyncRequestBody asyncRequestBody; + private Long chunkSizeInBytes; + private Long maxMemoryUsageInBytes; + private CompletableFuture future; + + /** + * Configures the asyncRequestBody to split + * + * @param asyncRequestBody The new asyncRequestBody value. + * @return This object for method chaining. + */ + public Builder asyncRequestBody(AsyncRequestBody asyncRequestBody) { + this.asyncRequestBody = asyncRequestBody; + return this; + } + + /** + * Configures the size of the chunk for each {@link AsyncRequestBody} to publish + * + * @param chunkSizeInBytes The new chunkSizeInBytes value. + * @return This object for method chaining. + */ + public Builder chunkSizeInBytes(Long chunkSizeInBytes) { + this.chunkSizeInBytes = chunkSizeInBytes; + return this; + } + + /** + * Sets the maximum memory usage in bytes. By default, it uses unlimited memory. + * + * @param maxMemoryUsageInBytes The new maxMemoryUsageInBytes value. + * @return This object for method chaining. + */ + // TODO: max memory usage might not be the best name, since we may technically go a little above this limit when we add + // on a new byte buffer. But we don't know for sure what the size of a buffer we request will be (we do use the size + // for the last byte buffer as a hint), so I don't think we can have a truly accurate max. Maybe we call it minimum + // buffer size instead? + public Builder maxMemoryUsageInBytes(Long maxMemoryUsageInBytes) { + this.maxMemoryUsageInBytes = maxMemoryUsageInBytes; + return this; + } + + /** + * Sets the result future. The future will be completed when all request bodies + * have been sent. + * + * @param future The new future value. + * @return This object for method chaining. + */ + public Builder resultFuture(CompletableFuture future) { + this.future = future; + return this; + } + + public SplittingPublisher build() { + return new SplittingPublisher(this); + } + } +} diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/SplittingPublisherTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/SplittingPublisherTest.java new file mode 100644 index 000000000000..df318190b92d --- /dev/null +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/SplittingPublisherTest.java @@ -0,0 +1,215 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.internal.async; + +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; +import static software.amazon.awssdk.utils.FunctionalUtils.invokeSafely; + +import java.io.ByteArrayOutputStream; +import java.io.FileInputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.charset.Charset; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; +import org.apache.commons.lang3.RandomStringUtils; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.testutils.RandomTempFile; +import software.amazon.awssdk.utils.BinaryUtils; + +public class SplittingPublisherTest { + private static final int CHUNK_SIZE = 5; + + private static final int CONTENT_SIZE = 101; + + private static final int NUM_OF_CHUNK = (int) Math.ceil(CONTENT_SIZE / (double) CHUNK_SIZE); + + private static RandomTempFile testFile; + + @BeforeAll + public static void beforeAll() throws IOException { + testFile = new RandomTempFile("testfile.dat", CONTENT_SIZE); + } + + @AfterAll + public static void afterAll() throws Exception { + testFile.delete(); + } + + @ParameterizedTest + @ValueSource(ints = {CHUNK_SIZE, CHUNK_SIZE * 2 - 1, CHUNK_SIZE * 2}) + void differentChunkSize_shouldSplitAsyncRequestBodyCorrectly(int upstreamByteBufferSize) throws Exception { + CompletableFuture future = new CompletableFuture<>(); + SplittingPublisher splittingPublisher = SplittingPublisher.builder() + .resultFuture(future) + .asyncRequestBody(FileAsyncRequestBody.builder() + .path(testFile.toPath()) + .chunkSizeInBytes(upstreamByteBufferSize) + .build()) + + .resultFuture(future) + .chunkSizeInBytes((long) CHUNK_SIZE) + .maxMemoryUsageInBytes((long) CHUNK_SIZE * 4) + .build(); + + List> futures = new ArrayList<>(); + + splittingPublisher.subscribe(requestBody -> { + CompletableFuture baosFuture = new CompletableFuture<>(); + BaosSubscriber subscriber = new BaosSubscriber(baosFuture); + futures.add(baosFuture); + requestBody.subscribe(subscriber); + }).get(5, TimeUnit.SECONDS); + + assertThat(futures.size()).isEqualTo(NUM_OF_CHUNK); + + for (int i = 0; i < futures.size(); i++) { + try (FileInputStream fileInputStream = new FileInputStream(testFile)) { + byte[] expected; + if (i == futures.size() - 1) { + expected = new byte[1]; + } else { + expected = new byte[5]; + } + fileInputStream.skip(i * 5); + fileInputStream.read(expected); + byte[] actualBytes = futures.get(i).join(); + assertThat(actualBytes).isEqualTo(expected); + }; + } + assertThat(future).isCompleted(); + } + + + @Test + void cancelFuture_shouldCancelUpstream() throws IOException { + CompletableFuture future = new CompletableFuture<>(); + TestAsyncRequestBody asyncRequestBody = new TestAsyncRequestBody(); + SplittingPublisher splittingPublisher = SplittingPublisher.builder() + .resultFuture(future) + .asyncRequestBody(asyncRequestBody) + .chunkSizeInBytes((long) CHUNK_SIZE) + .maxMemoryUsageInBytes(10L) + .build(); + + OnlyRequestOnceSubscriber downstreamSubscriber = new OnlyRequestOnceSubscriber(); + splittingPublisher.subscribe(downstreamSubscriber); + + future.completeExceptionally(new RuntimeException("test")); + assertThat(asyncRequestBody.cancelled).isTrue(); + assertThat(downstreamSubscriber.asyncRequestBodies.size()).isEqualTo(1); + } + + private static final class TestAsyncRequestBody implements AsyncRequestBody { + private static final byte[] CONTENT = RandomStringUtils.random(200).getBytes(Charset.defaultCharset()); + private boolean cancelled; + + @Override + public Optional contentLength() { + return Optional.of((long) CONTENT.length); + } + + @Override + public void subscribe(Subscriber s) { + s.onSubscribe(new Subscription() { + @Override + public void request(long n) { + s.onNext(ByteBuffer.wrap(CONTENT)); + s.onComplete(); + } + + @Override + public void cancel() { + cancelled = true; + } + }); + + } + } + + private static final class OnlyRequestOnceSubscriber implements Subscriber { + private List asyncRequestBodies = new ArrayList<>(); + + @Override + public void onSubscribe(Subscription s) { + s.request(1); + } + + @Override + public void onNext(AsyncRequestBody requestBody) { + asyncRequestBodies.add(requestBody); + } + + @Override + public void onError(Throwable t) { + + } + + @Override + public void onComplete() { + + } + } + + private static final class BaosSubscriber implements Subscriber { + private final CompletableFuture resultFuture; + + private ByteArrayOutputStream baos = new ByteArrayOutputStream(); + + private Subscription subscription; + + BaosSubscriber(CompletableFuture resultFuture) { + this.resultFuture = resultFuture; + } + + @Override + public void onSubscribe(Subscription s) { + if (this.subscription != null) { + s.cancel(); + return; + } + this.subscription = s; + subscription.request(Long.MAX_VALUE); + } + + @Override + public void onNext(ByteBuffer byteBuffer) { + invokeSafely(() -> baos.write(BinaryUtils.copyBytesFrom(byteBuffer))); + subscription.request(1); + } + + @Override + public void onError(Throwable throwable) { + baos = null; + resultFuture.completeExceptionally(throwable); + } + + @Override + public void onComplete() { + resultFuture.complete(baos.toByteArray()); + } + } +} diff --git a/services/s3/src/it/java/software/amazon/awssdk/services/s3/multipart/S3MultipartClientPutObjectIntegrationTest.java b/services/s3/src/it/java/software/amazon/awssdk/services/s3/multipart/S3MultipartClientPutObjectIntegrationTest.java new file mode 100644 index 000000000000..4174b87883dc --- /dev/null +++ b/services/s3/src/it/java/software/amazon/awssdk/services/s3/multipart/S3MultipartClientPutObjectIntegrationTest.java @@ -0,0 +1,78 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.multipart; + +import static java.util.concurrent.TimeUnit.MINUTES; +import static java.util.concurrent.TimeUnit.SECONDS; +import static org.assertj.core.api.Assertions.assertThat; +import static software.amazon.awssdk.testutils.service.S3BucketUtils.temporaryBucketName; + +import java.nio.file.Files; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.core.sync.ResponseTransformer; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.S3IntegrationTestBase; +import software.amazon.awssdk.services.s3.internal.multipart.MultipartS3AsyncClient; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.utils.ChecksumUtils; +import software.amazon.awssdk.testutils.RandomTempFile; + +public class S3MultipartClientPutObjectIntegrationTest extends S3IntegrationTestBase { + + private static final String TEST_BUCKET = temporaryBucketName(S3MultipartClientPutObjectIntegrationTest.class); + private static final String TEST_KEY = "testfile.dat"; + private static final int OBJ_SIZE = 19 * 1024 * 1024; + + private static RandomTempFile testFile; + private static S3AsyncClient mpuS3Client; + + @BeforeAll + public static void setup() throws Exception { + S3IntegrationTestBase.setUp(); + S3IntegrationTestBase.createBucket(TEST_BUCKET); + + testFile = new RandomTempFile(TEST_KEY, OBJ_SIZE); + mpuS3Client = new MultipartS3AsyncClient(s3Async); + } + + @AfterAll + public static void teardown() throws Exception { + mpuS3Client.close(); + testFile.delete(); + deleteBucketAndAllContents(TEST_BUCKET); + } + + @Test + @Timeout(value = 20, unit = SECONDS) + void putObject_fileRequestBody_objectSentCorrectly() throws Exception { + AsyncRequestBody body = AsyncRequestBody.fromFile(testFile.toPath()); + mpuS3Client.putObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY), body).join(); + + ResponseInputStream objContent = S3IntegrationTestBase.s3.getObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY), + ResponseTransformer.toInputStream()); + + assertThat(objContent.response().contentLength()).isEqualTo(testFile.length()); + byte[] expectedSum = ChecksumUtils.computeCheckSum(Files.newInputStream(testFile.toPath())); + assertThat(ChecksumUtils.computeCheckSum(objContent)).isEqualTo(expectedSum); + } + +} diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crt/CopyObjectHelper.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crt/CopyObjectHelper.java index e3e125c9d084..414262b7bffa 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crt/CopyObjectHelper.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crt/CopyObjectHelper.java @@ -19,15 +19,11 @@ import java.util.ArrayList; import java.util.List; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.CompletionException; import java.util.concurrent.atomic.AtomicReferenceArray; -import java.util.function.BiFunction; -import java.util.function.Supplier; import java.util.stream.IntStream; import software.amazon.awssdk.annotations.SdkInternalApi; -import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.services.s3.S3AsyncClient; -import software.amazon.awssdk.services.s3.model.AbortMultipartUploadRequest; +import software.amazon.awssdk.services.s3.internal.multipart.GenericMultipartHelper; import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadRequest; import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadResponse; import software.amazon.awssdk.services.s3.model.CompletedMultipartUpload; @@ -50,17 +46,16 @@ public final class CopyObjectHelper { private static final Logger log = Logger.loggerFor(S3AsyncClient.class); - /** - * The max number of parts on S3 side is 10,000 - */ - private static final long MAX_UPLOAD_PARTS = 10_000; - private final S3AsyncClient s3AsyncClient; private final long partSizeInBytes; + private final GenericMultipartHelper genericMultipartHelper; public CopyObjectHelper(S3AsyncClient s3AsyncClient, long partSizeInBytes) { this.s3AsyncClient = s3AsyncClient; this.partSizeInBytes = partSizeInBytes; + this.genericMultipartHelper = new GenericMultipartHelper<>(s3AsyncClient, + RequestConversionUtils::toAbortMultipartUploadRequest, + RequestConversionUtils::toCopyObjectResponse); } public CompletableFuture copyObject(CopyObjectRequest copyObjectRequest) { @@ -69,14 +64,15 @@ public CompletableFuture copyObject(CopyObjectRequest copyOb try { CompletableFuture headFuture = - s3AsyncClient.headObject(CopyRequestConversionUtils.toHeadObjectRequest(copyObjectRequest)); + s3AsyncClient.headObject(RequestConversionUtils.toHeadObjectRequest(copyObjectRequest)); // Ensure cancellations are forwarded to the head future CompletableFutureUtils.forwardExceptionTo(returnFuture, headFuture); headFuture.whenComplete((headObjectResponse, throwable) -> { if (throwable != null) { - handleException(returnFuture, () -> "Failed to retrieve metadata from the source object", throwable); + genericMultipartHelper.handleException(returnFuture, () -> "Failed to retrieve metadata from the source " + + "object", throwable); } else { doCopyObject(copyObjectRequest, returnFuture, headObjectResponse); } @@ -105,7 +101,7 @@ private void copyInParts(CopyObjectRequest copyObjectRequest, Long contentLength, CompletableFuture returnFuture) { - CreateMultipartUploadRequest request = CopyRequestConversionUtils.toCreateMultipartUploadRequest(copyObjectRequest); + CreateMultipartUploadRequest request = RequestConversionUtils.toCreateMultipartUploadRequest(copyObjectRequest); CompletableFuture createMultipartUploadFuture = s3AsyncClient.createMultipartUpload(request); @@ -114,7 +110,7 @@ private void copyInParts(CopyObjectRequest copyObjectRequest, createMultipartUploadFuture.whenComplete((createMultipartUploadResponse, throwable) -> { if (throwable != null) { - handleException(returnFuture, () -> "Failed to initiate multipart upload", throwable); + genericMultipartHelper.handleException(returnFuture, () -> "Failed to initiate multipart upload", throwable); } else { log.debug(() -> "Initiated new multipart upload, uploadId: " + createMultipartUploadResponse.uploadId()); doCopyInParts(copyObjectRequest, contentLength, returnFuture, createMultipartUploadResponse.uploadId()); @@ -122,17 +118,14 @@ private void copyInParts(CopyObjectRequest copyObjectRequest, }); } - private int determinePartCount(long contentLength, long partSize) { - return (int) Math.ceil(contentLength / (double) partSize); - } - private void doCopyInParts(CopyObjectRequest copyObjectRequest, Long contentLength, CompletableFuture returnFuture, String uploadId) { - long optimalPartSize = calculateOptimalPartSizeForCopy(contentLength); - int partCount = determinePartCount(contentLength, optimalPartSize); + long optimalPartSize = genericMultipartHelper.calculateOptimalPartSizeFor(contentLength, partSizeInBytes); + + int partCount = genericMultipartHelper.determinePartCount(contentLength, optimalPartSize); log.debug(() -> String.format("Starting multipart copy with partCount: %s, optimalPartSize: %s", partCount, optimalPartSize)); @@ -147,32 +140,15 @@ private void doCopyInParts(CopyObjectRequest copyObjectRequest, optimalPartSize); CompletableFutureUtils.allOfExceptionForwarded(futures.toArray(new CompletableFuture[0])) .thenCompose(ignore -> completeMultipartUpload(copyObjectRequest, uploadId, completedParts)) - .handle(handleExceptionOrResponse(copyObjectRequest, returnFuture, uploadId)) + .handle(genericMultipartHelper.handleExceptionOrResponse(copyObjectRequest, returnFuture, + uploadId)) .exceptionally(throwable -> { - handleException(returnFuture, () -> "Unexpected exception occurred", throwable); + genericMultipartHelper.handleException(returnFuture, () -> "Unexpected exception occurred", + throwable); return null; }); } - private BiFunction handleExceptionOrResponse( - CopyObjectRequest copyObjectRequest, - CompletableFuture returnFuture, - String uploadId) { - - return (completeMultipartUploadResponse, throwable) -> { - if (throwable != null) { - cleanUpParts(copyObjectRequest, uploadId); - handleException(returnFuture, () -> "Failed to send multipart copy requests.", - throwable); - } else { - returnFuture.complete(CopyRequestConversionUtils.toCopyObjectResponse( - completeMultipartUploadResponse)); - } - - return null; - }; - } - private CompletableFuture completeMultipartUpload( CopyObjectRequest copyObjectRequest, String uploadId, AtomicReferenceArray completedParts) { log.debug(() -> String.format("Sending completeMultipartUploadRequest, uploadId: %s", @@ -194,35 +170,6 @@ private CompletableFuture completeMultipartUplo return s3AsyncClient.completeMultipartUpload(completeMultipartUploadRequest); } - private void cleanUpParts(CopyObjectRequest copyObjectRequest, String uploadId) { - AbortMultipartUploadRequest abortMultipartUploadRequest = - CopyRequestConversionUtils.toAbortMultipartUploadRequest(copyObjectRequest, uploadId); - s3AsyncClient.abortMultipartUpload(abortMultipartUploadRequest) - .exceptionally(throwable -> { - log.warn(() -> String.format("Failed to abort previous multipart upload " - + "(id: %s)" - + ". You may need to call " - + "S3AsyncClient#abortMultiPartUpload to " - + "free all storage consumed by" - + " all parts. ", - uploadId), throwable); - return null; - }); - } - - private static void handleException(CompletableFuture returnFuture, - Supplier message, - Throwable throwable) { - Throwable cause = throwable instanceof CompletionException ? throwable.getCause() : throwable; - - if (cause instanceof Error) { - returnFuture.completeExceptionally(cause); - } else { - SdkClientException exception = SdkClientException.create(message.get(), cause); - returnFuture.completeExceptionally(exception); - } - } - private List> sendUploadPartCopyRequests(CopyObjectRequest copyObjectRequest, long contentLength, String uploadId, @@ -265,23 +212,13 @@ private static CompletedPart convertUploadPartCopyResponse(AtomicReferenceArray< UploadPartCopyResponse uploadPartCopyResponse) { CopyPartResult copyPartResult = uploadPartCopyResponse.copyPartResult(); CompletedPart completedPart = - CopyRequestConversionUtils.toCompletedPart(copyPartResult, - partNumber); + RequestConversionUtils.toCompletedPart(copyPartResult, + partNumber); completedParts.set(partNumber - 1, completedPart); return completedPart; } - /** - * Calculates the optimal part size of each part request if the copy operation is carried out as multipart copy. - */ - private long calculateOptimalPartSizeForCopy(long contentLengthOfSource) { - double optimalPartSize = contentLengthOfSource / (double) MAX_UPLOAD_PARTS; - - optimalPartSize = Math.ceil(optimalPartSize); - return (long) Math.max(optimalPartSize, partSizeInBytes); - } - private void copyInOneChunk(CopyObjectRequest copyObjectRequest, CompletableFuture returnFuture) { CompletableFuture copyObjectFuture = diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crt/CopyRequestConversionUtils.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crt/RequestConversionUtils.java similarity index 61% rename from services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crt/CopyRequestConversionUtils.java rename to services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crt/RequestConversionUtils.java index 2a464b10f499..f4a3aaf60d4a 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crt/CopyRequestConversionUtils.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crt/RequestConversionUtils.java @@ -24,15 +24,47 @@ import software.amazon.awssdk.services.s3.model.CopyPartResult; import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; import software.amazon.awssdk.services.s3.model.HeadObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectResponse; import software.amazon.awssdk.services.s3.model.UploadPartCopyRequest; +import software.amazon.awssdk.services.s3.model.UploadPartRequest; +import software.amazon.awssdk.services.s3.model.UploadPartResponse; /** - * Request conversion utility method for POJO classes associated with {@link S3CrtAsyncClient#copyObject(CopyObjectRequest)} + * Request conversion utility method for POJO classes associated with multipart feature. */ +//TODO: iterate over SDK fields to get the data @SdkInternalApi -public final class CopyRequestConversionUtils { +public final class RequestConversionUtils { - private CopyRequestConversionUtils() { + private RequestConversionUtils() { + } + + public static CreateMultipartUploadRequest toCreateMultipartUploadRequest(PutObjectRequest putObjectRequest) { + + return CreateMultipartUploadRequest.builder() + .bucket(putObjectRequest.bucket()) + .key(putObjectRequest.key()) + .sseCustomerAlgorithm(putObjectRequest.sseCustomerAlgorithm()) + .sseCustomerKey(putObjectRequest.sseCustomerKey()) + .sseCustomerKeyMD5(putObjectRequest.sseCustomerKeyMD5()) + .requestPayer(putObjectRequest.requestPayer()) + .acl(putObjectRequest.acl()) + .cacheControl(putObjectRequest.cacheControl()) + .metadata(putObjectRequest.metadata()) + .contentDisposition(putObjectRequest.contentDisposition()) + .contentEncoding(putObjectRequest.contentEncoding()) + .contentType(putObjectRequest.contentType()) + .contentLanguage(putObjectRequest.contentLanguage()) + .grantFullControl(putObjectRequest.grantFullControl()) + .expires(putObjectRequest.expires()) + .grantRead(putObjectRequest.grantRead()) + .grantFullControl(putObjectRequest.grantFullControl()) + .grantReadACP(putObjectRequest.grantReadACP()) + .grantWriteACP(putObjectRequest.grantWriteACP()) + //TODO filter out headers + //.overrideConfiguration(putObjectRequest.overrideConfiguration()) + .build(); } public static HeadObjectRequest toHeadObjectRequest(CopyObjectRequest copyObjectRequest) { @@ -63,6 +95,18 @@ public static CompletedPart toCompletedPart(CopyPartResult copyPartResult, int p .build(); } + public static CompletedPart toCompletedPart(UploadPartResponse partResponse, int partNumber) { + return CompletedPart.builder() + .partNumber(partNumber) + .eTag(partResponse.eTag()) + .checksumCRC32C(partResponse.checksumCRC32C()) + .checksumCRC32(partResponse.checksumCRC32()) + .checksumSHA1(partResponse.checksumSHA1()) + .checksumSHA256(partResponse.checksumSHA256()) + .eTag(partResponse.eTag()) + .build(); + } + public static CreateMultipartUploadRequest toCreateMultipartUploadRequest(CopyObjectRequest copyObjectRequest) { return CreateMultipartUploadRequest.builder() .bucket(copyObjectRequest.destinationBucket()) @@ -124,15 +168,20 @@ public static CopyObjectResponse toCopyObjectResponse(CompleteMultipartUploadRes return builder.build(); } - public static AbortMultipartUploadRequest toAbortMultipartUploadRequest(CopyObjectRequest copyObjectRequest, - String uploadId) { + public static AbortMultipartUploadRequest.Builder toAbortMultipartUploadRequest(CopyObjectRequest copyObjectRequest) { return AbortMultipartUploadRequest.builder() - .uploadId(uploadId) .bucket(copyObjectRequest.destinationBucket()) .key(copyObjectRequest.destinationKey()) .requestPayer(copyObjectRequest.requestPayerAsString()) - .expectedBucketOwner(copyObjectRequest.expectedBucketOwner()) - .build(); + .expectedBucketOwner(copyObjectRequest.expectedBucketOwner()); + } + + public static AbortMultipartUploadRequest.Builder toAbortMultipartUploadRequest(PutObjectRequest putObjectRequest) { + return AbortMultipartUploadRequest.builder() + .bucket(putObjectRequest.bucket()) + .key(putObjectRequest.key()) + .requestPayer(putObjectRequest.requestPayerAsString()) + .expectedBucketOwner(putObjectRequest.expectedBucketOwner()); } public static UploadPartCopyRequest toUploadPartCopyRequest(CopyObjectRequest copyObjectRequest, @@ -165,4 +214,47 @@ public static UploadPartCopyRequest toUploadPartCopyRequest(CopyObjectRequest co .build(); } + public static UploadPartRequest toUploadPartRequest(PutObjectRequest putObjectRequest, int partNumber, String uploadId) { + return UploadPartRequest.builder() + .bucket(putObjectRequest.bucket()) + .key(putObjectRequest.key()) + .uploadId(uploadId) + .partNumber(partNumber) + .sseCustomerAlgorithm(putObjectRequest.sseCustomerAlgorithm()) + .sseCustomerKeyMD5(putObjectRequest.sseCustomerKeyMD5()) + .sseCustomerKey(putObjectRequest.sseCustomerKey()) + .expectedBucketOwner(putObjectRequest.expectedBucketOwner()) + .requestPayer(putObjectRequest.requestPayerAsString()) + .sseCustomerKey(putObjectRequest.sseCustomerKey()) + .sseCustomerAlgorithm(putObjectRequest.sseCustomerAlgorithm()) + .sseCustomerKeyMD5(putObjectRequest.sseCustomerKeyMD5()) + .build(); + } + + public static PutObjectResponse toPutObjectResponse(CompleteMultipartUploadResponse response) { + PutObjectResponse.Builder builder = PutObjectResponse.builder() + .versionId(response.versionId()) + .checksumCRC32(response.checksumCRC32()) + .checksumSHA1(response.checksumSHA1()) + .checksumSHA256(response.checksumSHA256()) + .checksumCRC32C(response.checksumCRC32C()) + .eTag(response.eTag()) + .expiration(response.expiration()) + .bucketKeyEnabled(response.bucketKeyEnabled()) + .serverSideEncryption(response.serverSideEncryption()) + .ssekmsKeyId(response.ssekmsKeyId()) + .serverSideEncryption(response.serverSideEncryptionAsString()) + .requestCharged(response.requestChargedAsString()); + + // TODO: check why we have to do null check + if (response.responseMetadata() != null) { + builder.responseMetadata(response.responseMetadata()); + } + + if (response.sdkHttpResponse() != null) { + builder.sdkHttpResponse(response.sdkHttpResponse()); + } + + return builder.build(); + } } diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crt/UploadPartCopyRequestIterable.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crt/UploadPartCopyRequestIterable.java index 84d3c6ac5305..f929bc3fc8f4 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crt/UploadPartCopyRequestIterable.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crt/UploadPartCopyRequestIterable.java @@ -65,10 +65,10 @@ public UploadPartCopyRequest next() { long partSize = Math.min(optimalPartSize, remainingBytes); String range = range(partSize); UploadPartCopyRequest uploadPartCopyRequest = - CopyRequestConversionUtils.toUploadPartCopyRequest(copyObjectRequest, - partNumber, - uploadId, - range); + RequestConversionUtils.toUploadPartCopyRequest(copyObjectRequest, + partNumber, + uploadId, + range); partNumber++; offset += partSize; remainingBytes -= partSize; diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/GenericMultipartHelper.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/GenericMultipartHelper.java new file mode 100644 index 000000000000..4ab4b22a0e79 --- /dev/null +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/GenericMultipartHelper.java @@ -0,0 +1,134 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.internal.multipart; + +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.concurrent.atomic.AtomicReferenceArray; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.function.Supplier; +import java.util.stream.IntStream; +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.core.exception.SdkClientException; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.model.AbortMultipartUploadRequest; +import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadRequest; +import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadResponse; +import software.amazon.awssdk.services.s3.model.CompletedMultipartUpload; +import software.amazon.awssdk.services.s3.model.CompletedPart; +import software.amazon.awssdk.services.s3.model.S3Request; +import software.amazon.awssdk.services.s3.model.S3Response; +import software.amazon.awssdk.utils.Logger; + +@SdkInternalApi +public final class GenericMultipartHelper { + private static final Logger log = Logger.loggerFor(GenericMultipartHelper.class); + /** + * The max number of parts on S3 side is 10,000 + */ + private static final long MAX_UPLOAD_PARTS = 10_000; + + private final S3AsyncClient s3AsyncClient; + private final Function abortMultipartUploadRequestConverter; + private final Function responseConverter; + + public GenericMultipartHelper(S3AsyncClient s3AsyncClient, + Function abortMultipartUploadRequestConverter, + Function responseConverter) { + this.s3AsyncClient = s3AsyncClient; + this.abortMultipartUploadRequestConverter = abortMultipartUploadRequestConverter; + this.responseConverter = responseConverter; + } + + public void handleException(CompletableFuture returnFuture, + Supplier message, + Throwable throwable) { + Throwable cause = throwable instanceof CompletionException ? throwable.getCause() : throwable; + + if (cause instanceof Error) { + returnFuture.completeExceptionally(cause); + } else { + SdkClientException exception = SdkClientException.create(message.get(), cause); + returnFuture.completeExceptionally(exception); + } + } + + public long calculateOptimalPartSizeFor(long contentLengthOfSource, long partSizeInBytes) { + double optimalPartSize = contentLengthOfSource / (double) MAX_UPLOAD_PARTS; + + optimalPartSize = Math.ceil(optimalPartSize); + return (long) Math.max(optimalPartSize, partSizeInBytes); + } + + public int determinePartCount(long contentLength, long partSize) { + return (int) Math.ceil(contentLength / (double) partSize); + } + + public CompletableFuture completeMultipartUpload( + RequestT request, String uploadId, AtomicReferenceArray completedParts) { + log.debug(() -> String.format("Sending completeMultipartUploadRequest, uploadId: %s", + uploadId)); + CompletedPart[] parts = + IntStream.range(0, completedParts.length()) + .mapToObj(completedParts::get) + .toArray(CompletedPart[]::new); + CompleteMultipartUploadRequest completeMultipartUploadRequest = + CompleteMultipartUploadRequest.builder() + .bucket(request.getValueForField("Bucket", String.class).get()) + .key(request.getValueForField("Key", String.class).get()) + .uploadId(uploadId) + .multipartUpload(CompletedMultipartUpload.builder() + .parts(parts) + .build()) + .build(); + + return s3AsyncClient.completeMultipartUpload(completeMultipartUploadRequest); + } + + public BiFunction handleExceptionOrResponse( + RequestT request, + CompletableFuture returnFuture, + String uploadId) { + + return (completeMultipartUploadResponse, throwable) -> { + if (throwable != null) { + cleanUpParts(uploadId, abortMultipartUploadRequestConverter.apply(request)); + handleException(returnFuture, () -> "Failed to send multipart requests", + throwable); + } else { + returnFuture.complete(responseConverter.apply( + completeMultipartUploadResponse)); + } + + return null; + }; + } + + public void cleanUpParts(String uploadId, AbortMultipartUploadRequest.Builder abortMultipartUploadRequest) { + s3AsyncClient.abortMultipartUpload(abortMultipartUploadRequest.uploadId(uploadId).build()) + .exceptionally(throwable -> { + log.warn(() -> String.format("Failed to abort previous multipart upload " + + "(id: %s)" + + ". You may need to call " + + "S3AsyncClient#abortMultiPartUpload to " + + "free all storage consumed by" + + " all parts. ", + uploadId), throwable); + return null; + }); + } +} diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartS3AsyncClient.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartS3AsyncClient.java new file mode 100644 index 000000000000..f2895d65fcd2 --- /dev/null +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartS3AsyncClient.java @@ -0,0 +1,47 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.internal.multipart; + + +import java.util.concurrent.CompletableFuture; +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.services.s3.DelegatingS3AsyncClient; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectResponse; + +// This is just a temporary class for testing +//TODO: change this +@SdkInternalApi +public class MultipartS3AsyncClient extends DelegatingS3AsyncClient { + private static final long DEFAULT_PART_SIZE_IN_BYTES = 8L * 1024 * 1024; + private static final long DEFAULT_THRESHOLD = 8L * 1024 * 1024; + + private static final long DEFAULT_MAX_MEMORY = DEFAULT_PART_SIZE_IN_BYTES * 2; + private final MultipartUploadHelper mpuHelper; + + public MultipartS3AsyncClient(S3AsyncClient delegate) { + super(delegate); + // TODO: pass a config object to the upload helper instead + mpuHelper = new MultipartUploadHelper(delegate, DEFAULT_PART_SIZE_IN_BYTES, DEFAULT_THRESHOLD, DEFAULT_MAX_MEMORY); + } + + @Override + public CompletableFuture putObject(PutObjectRequest putObjectRequest, AsyncRequestBody requestBody) { + return mpuHelper.uploadObject(putObjectRequest, requestBody); + } +} diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelper.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelper.java new file mode 100644 index 000000000000..d043d88936c6 --- /dev/null +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelper.java @@ -0,0 +1,274 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.internal.multipart; + + +import static software.amazon.awssdk.services.s3.internal.crt.RequestConversionUtils.toAbortMultipartUploadRequest; + +import java.util.Collection; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.atomic.AtomicReferenceArray; +import java.util.function.Function; +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.core.internal.async.SplittingPublisher; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.internal.crt.RequestConversionUtils; +import software.amazon.awssdk.services.s3.model.CompletedPart; +import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; +import software.amazon.awssdk.services.s3.model.CreateMultipartUploadResponse; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectResponse; +import software.amazon.awssdk.services.s3.model.UploadPartRequest; +import software.amazon.awssdk.services.s3.model.UploadPartResponse; +import software.amazon.awssdk.utils.CompletableFutureUtils; +import software.amazon.awssdk.utils.Logger; +import software.amazon.awssdk.utils.Pair; + +/** + * An internal helper class that automatically uses multipart upload based on the size of the object. + */ +@SdkInternalApi +public final class MultipartUploadHelper { + private static final Logger log = Logger.loggerFor(MultipartUploadHelper.class); + + private final S3AsyncClient s3AsyncClient; + private final long partSizeInBytes; + private final GenericMultipartHelper genericMultipartHelper; + + private final long maxMemoryUsageInBytes; + private final long multipartUploadThresholdInBytes; + + public MultipartUploadHelper(S3AsyncClient s3AsyncClient, + long partSizeInBytes, + long multipartUploadThresholdInBytes, + long maxMemoryUsageInBytes) { + this.s3AsyncClient = s3AsyncClient; + this.partSizeInBytes = partSizeInBytes; + this.genericMultipartHelper = new GenericMultipartHelper<>(s3AsyncClient, + RequestConversionUtils::toAbortMultipartUploadRequest, + RequestConversionUtils::toPutObjectResponse); + this.maxMemoryUsageInBytes = maxMemoryUsageInBytes; + this.multipartUploadThresholdInBytes = multipartUploadThresholdInBytes; + } + + public CompletableFuture uploadObject(PutObjectRequest putObjectRequest, + AsyncRequestBody asyncRequestBody) { + Long contentLength = asyncRequestBody.contentLength().orElseGet(putObjectRequest::contentLength); + + // TODO: support null content length. Should be trivial to support it now + if (contentLength == null) { + throw new IllegalArgumentException("Content-length is required"); + } + + CompletableFuture returnFuture = new CompletableFuture<>(); + + try { + if (contentLength > multipartUploadThresholdInBytes && contentLength > partSizeInBytes) { + log.debug(() -> "Starting the upload as multipart upload request"); + uploadInParts(putObjectRequest, contentLength, asyncRequestBody, returnFuture); + } else { + log.debug(() -> "Starting the upload as a single upload part request"); + uploadInOneChunk(putObjectRequest, asyncRequestBody, returnFuture); + } + + } catch (Throwable throwable) { + returnFuture.completeExceptionally(throwable); + } + + return returnFuture; + } + + private void uploadInParts(PutObjectRequest putObjectRequest, long contentLength, AsyncRequestBody asyncRequestBody, + CompletableFuture returnFuture) { + + CreateMultipartUploadRequest request = RequestConversionUtils.toCreateMultipartUploadRequest(putObjectRequest); + CompletableFuture createMultipartUploadFuture = + s3AsyncClient.createMultipartUpload(request); + + // Ensure cancellations are forwarded to the createMultipartUploadFuture future + CompletableFutureUtils.forwardExceptionTo(returnFuture, createMultipartUploadFuture); + + createMultipartUploadFuture.whenComplete((createMultipartUploadResponse, throwable) -> { + if (throwable != null) { + genericMultipartHelper.handleException(returnFuture, () -> "Failed to initiate multipart upload", throwable); + } else { + log.debug(() -> "Initiated a new multipart upload, uploadId: " + createMultipartUploadResponse.uploadId()); + doUploadInParts(Pair.of(putObjectRequest, asyncRequestBody), contentLength, returnFuture, + createMultipartUploadResponse.uploadId()); + } + }); + } + + private void doUploadInParts(Pair request, + long contentLength, + CompletableFuture returnFuture, + String uploadId) { + + long optimalPartSize = genericMultipartHelper.calculateOptimalPartSizeFor(contentLength, partSizeInBytes); + int partCount = genericMultipartHelper.determinePartCount(contentLength, optimalPartSize); + + log.debug(() -> String.format("Starting multipart upload with partCount: %d, optimalPartSize: %d", partCount, + optimalPartSize)); + + // The list of completed parts must be sorted + AtomicReferenceArray completedParts = new AtomicReferenceArray<>(partCount); + + PutObjectRequest putObjectRequest = request.left(); + + Collection> futures = new ConcurrentLinkedQueue<>(); + + MpuRequestContext mpuRequestContext = new MpuRequestContext(request, contentLength, optimalPartSize, uploadId); + + CompletableFuture requestsFuture = sendUploadPartRequests(mpuRequestContext, + completedParts, + returnFuture, + futures); + requestsFuture.whenComplete((r, t) -> { + if (t != null) { + genericMultipartHelper.handleException(returnFuture, () -> "Failed to send multipart upload requests", t); + genericMultipartHelper.cleanUpParts(uploadId, toAbortMultipartUploadRequest(putObjectRequest)); + cancelingOtherOngoingRequests(futures, t); + return; + } + CompletableFutureUtils.allOfExceptionForwarded(futures.toArray(new CompletableFuture[0])) + .thenCompose(ignore -> genericMultipartHelper.completeMultipartUpload(putObjectRequest, + uploadId, + completedParts)) + .handle(genericMultipartHelper.handleExceptionOrResponse(putObjectRequest, returnFuture, + uploadId)) + .exceptionally(throwable -> { + genericMultipartHelper.handleException(returnFuture, () -> "Unexpected exception occurred", + throwable); + return null; + }); + }); + } + + private static void cancelingOtherOngoingRequests(Collection> futures, Throwable t) { + log.trace(() -> "cancelling other ongoing requests " + futures.size()); + futures.forEach(f -> f.completeExceptionally(t)); + } + + private CompletableFuture sendUploadPartRequests(MpuRequestContext mpuRequestContext, + AtomicReferenceArray completedParts, + CompletableFuture returnFuture, + Collection> futures) { + + CompletableFuture splittingPublisherFuture = new CompletableFuture<>(); + + AsyncRequestBody asyncRequestBody = mpuRequestContext.request.right(); + SplittingPublisher splittingPublisher = SplittingPublisher.builder() + .asyncRequestBody(asyncRequestBody) + .chunkSizeInBytes(mpuRequestContext.partSize) + .maxMemoryUsageInBytes(maxMemoryUsageInBytes) + .resultFuture(splittingPublisherFuture) + .build(); + + splittingPublisher.map(new BodyToRequestConverter(mpuRequestContext.request.left(), mpuRequestContext.uploadId)) + .subscribe(pair -> sendIndividualUploadPartRequest(mpuRequestContext.uploadId, + completedParts, + futures, + pair, + splittingPublisherFuture)) + .exceptionally(throwable -> { + returnFuture.completeExceptionally(throwable); + return null; + }); + return splittingPublisherFuture; + } + + private void sendIndividualUploadPartRequest(String uploadId, + AtomicReferenceArray completedParts, + Collection> futures, + Pair requestPair, + CompletableFuture sendUploadPartRequestsFuture) { + UploadPartRequest uploadPartRequest = requestPair.left(); + Integer partNumber = uploadPartRequest.partNumber(); + log.debug(() -> "Sending uploadPartRequest: " + uploadPartRequest.partNumber() + " uploadId: " + uploadId + " " + + "contentLength " + requestPair.right().contentLength()); + + CompletableFuture uploadPartFuture = s3AsyncClient.uploadPart(uploadPartRequest, requestPair.right()); + + CompletableFuture convertFuture = + uploadPartFuture.thenApply(uploadPartResponse -> convertUploadPartResponse(completedParts, partNumber, + uploadPartResponse)); + futures.add(convertFuture); + CompletableFutureUtils.forwardExceptionTo(convertFuture, uploadPartFuture); + CompletableFutureUtils.forwardExceptionTo(uploadPartFuture, sendUploadPartRequestsFuture); + } + + private static CompletedPart convertUploadPartResponse(AtomicReferenceArray completedParts, + Integer partNumber, + UploadPartResponse uploadPartResponse) { + CompletedPart completedPart = RequestConversionUtils.toCompletedPart(uploadPartResponse, partNumber); + + completedParts.set(partNumber - 1, completedPart); + return completedPart; + } + + private void uploadInOneChunk(PutObjectRequest putObjectRequest, + AsyncRequestBody asyncRequestBody, + CompletableFuture returnFuture) { + CompletableFuture putObjectResponseCompletableFuture = s3AsyncClient.putObject(putObjectRequest, + asyncRequestBody); + CompletableFutureUtils.forwardExceptionTo(returnFuture, putObjectResponseCompletableFuture); + CompletableFutureUtils.forwardResultTo(putObjectResponseCompletableFuture, returnFuture); + } + + private static final class BodyToRequestConverter implements Function> { + private int partNumber = 1; + private final PutObjectRequest putObjectRequest; + private final String uploadId; + + BodyToRequestConverter(PutObjectRequest putObjectRequest, String uploadId) { + this.putObjectRequest = putObjectRequest; + this.uploadId = uploadId; + } + + @Override + public Pair apply(AsyncRequestBody asyncRequestBody) { + log.trace(() -> "Generating uploadPartRequest for partNumber " + partNumber); + UploadPartRequest uploadRequest = + RequestConversionUtils.toUploadPartRequest(putObjectRequest, + partNumber, + uploadId); + ++partNumber; + return Pair.of(uploadRequest, asyncRequestBody); + } + } + + private static final class MpuRequestContext { + private final Pair request; + private final long contentLength; + private final long partSize; + + private final String uploadId; + + private MpuRequestContext(Pair request, + long contentLength, + long partSize, + String uploadId) { + this.request = request; + this.contentLength = contentLength; + this.partSize = partSize; + this.uploadId = uploadId; + } + } + +} diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crt/CopyObjectHelperTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crt/CopyObjectHelperTest.java index d3593570a6e6..ec78d7b15eb6 100644 --- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crt/CopyObjectHelperTest.java +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crt/CopyObjectHelperTest.java @@ -175,7 +175,7 @@ void multiPartCopy_onePartFailed_shouldFailOtherPartsAndAbort() { CompletableFuture future = copyHelper.copyObject(copyObjectRequest); - assertThatThrownBy(future::join).hasMessageContaining("Failed to send multipart copy requests").hasRootCause(exception); + assertThatThrownBy(future::join).hasMessageContaining("Failed to send multipart requests").hasRootCause(exception); verify(s3AsyncClient, never()).completeMultipartUpload(any(CompleteMultipartUploadRequest.class)); @@ -213,7 +213,7 @@ void multiPartCopy_completeMultipartFailed_shouldFailAndAbort() { CompletableFuture future = copyHelper.copyObject(copyObjectRequest); - assertThatThrownBy(future::join).hasMessageContaining("Failed to send multipart copy requests").hasRootCause(exception); + assertThatThrownBy(future::join).hasMessageContaining("Failed to send multipart requests").hasRootCause(exception); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(AbortMultipartUploadRequest.class); verify(s3AsyncClient).abortMultipartUpload(argumentCaptor.capture()); diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crt/CopyRequestConversionUtilsTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crt/CopyRequestConversionUtilsTest.java index 94071ad115fd..104d5f6e045f 100644 --- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crt/CopyRequestConversionUtilsTest.java +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crt/CopyRequestConversionUtilsTest.java @@ -48,14 +48,14 @@ import software.amazon.awssdk.utils.Logger; class CopyRequestConversionUtilsTest { - private static final Logger log = Logger.loggerFor(CopyRequestConversionUtils.class); + private static final Logger log = Logger.loggerFor(RequestConversionUtils.class); private static final Random RNG = new Random(); @Test void toHeadObject_shouldCopyProperties() { CopyObjectRequest randomCopyObject = randomCopyObjectRequest(); - HeadObjectRequest convertedToHeadObject = CopyRequestConversionUtils.toHeadObjectRequest(randomCopyObject); + HeadObjectRequest convertedToHeadObject = RequestConversionUtils.toHeadObjectRequest(randomCopyObject); Set fieldsToIgnore = new HashSet<>(Arrays.asList("ExpectedBucketOwner", "RequestPayer", "Bucket", @@ -74,7 +74,7 @@ void toCompletedPart_shouldCopyProperties() { setFieldsToRandomValues(fromObject.sdkFields(), fromObject); CopyPartResult result = fromObject.build(); - CompletedPart convertedCompletedPart = CopyRequestConversionUtils.toCompletedPart(result, 1); + CompletedPart convertedCompletedPart = RequestConversionUtils.toCompletedPart(result, 1); verifyFieldsAreCopied(result, convertedCompletedPart, new HashSet<>(), CopyPartResult.builder().sdkFields(), CompletedPart.builder().sdkFields()); @@ -84,7 +84,7 @@ void toCompletedPart_shouldCopyProperties() { @Test void toCreateMultipartUploadRequest_shouldCopyProperties() { CopyObjectRequest randomCopyObject = randomCopyObjectRequest(); - CreateMultipartUploadRequest convertedRequest = CopyRequestConversionUtils.toCreateMultipartUploadRequest(randomCopyObject); + CreateMultipartUploadRequest convertedRequest = RequestConversionUtils.toCreateMultipartUploadRequest(randomCopyObject); Set fieldsToIgnore = new HashSet<>(); verifyFieldsAreCopied(randomCopyObject, convertedRequest, fieldsToIgnore, CopyObjectRequest.builder().sdkFields(), @@ -100,7 +100,7 @@ void toCopyObjectResponse_shouldCopyProperties() { responseBuilder.responseMetadata(s3ResponseMetadata).sdkHttpResponse(sdkHttpFullResponse); CompleteMultipartUploadResponse result = responseBuilder.build(); - CopyObjectResponse convertedRequest = CopyRequestConversionUtils.toCopyObjectResponse(result); + CopyObjectResponse convertedRequest = RequestConversionUtils.toCopyObjectResponse(result); Set fieldsToIgnore = new HashSet<>(); verifyFieldsAreCopied(result, convertedRequest, fieldsToIgnore, CompleteMultipartUploadResponse.builder().sdkFields(), @@ -113,21 +113,20 @@ void toCopyObjectResponse_shouldCopyProperties() { @Test void toAbortMultipartUploadRequest_shouldCopyProperties() { CopyObjectRequest randomCopyObject = randomCopyObjectRequest(); - AbortMultipartUploadRequest convertedRequest = CopyRequestConversionUtils.toAbortMultipartUploadRequest(randomCopyObject, - "id"); + AbortMultipartUploadRequest convertedRequest = RequestConversionUtils.toAbortMultipartUploadRequest(randomCopyObject).build(); Set fieldsToIgnore = new HashSet<>(); verifyFieldsAreCopied(randomCopyObject, convertedRequest, fieldsToIgnore, CopyObjectRequest.builder().sdkFields(), AbortMultipartUploadRequest.builder().sdkFields()); - assertThat(convertedRequest.uploadId()).isEqualTo("id"); + //assertThat(convertedRequest.uploadId()).isEqualTo("id"); } @Test void toUploadPartCopyRequest_shouldCopyProperties() { CopyObjectRequest randomCopyObject = randomCopyObjectRequest(); - UploadPartCopyRequest convertedObject = CopyRequestConversionUtils.toUploadPartCopyRequest(randomCopyObject, 1, "id", - "bytes=0-1024"); + UploadPartCopyRequest convertedObject = RequestConversionUtils.toUploadPartCopyRequest(randomCopyObject, 1, "id", + "bytes=0-1024"); Set fieldsToIgnore = new HashSet<>(Collections.singletonList("CopySource")); verifyFieldsAreCopied(randomCopyObject, convertedObject, fieldsToIgnore, CopyObjectRequest.builder().sdkFields(), diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MpuTestUtils.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MpuTestUtils.java new file mode 100644 index 000000000000..435d5b406189 --- /dev/null +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MpuTestUtils.java @@ -0,0 +1,65 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.internal.multipart; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; + +import java.util.concurrent.CompletableFuture; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadRequest; +import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadResponse; +import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; +import software.amazon.awssdk.services.s3.model.CreateMultipartUploadResponse; +import software.amazon.awssdk.services.s3.model.HeadObjectRequest; +import software.amazon.awssdk.services.s3.model.HeadObjectResponse; + +public final class MpuTestUtils { + + private MpuTestUtils() { + } + + public static void stubSuccessfulHeadObjectCall(long contentLength, S3AsyncClient s3AsyncClient) { + CompletableFuture headFuture = + CompletableFuture.completedFuture(HeadObjectResponse.builder() + .contentLength(contentLength) + .build()); + + when(s3AsyncClient.headObject(any(HeadObjectRequest.class))) + .thenReturn(headFuture); + } + + public static void stubSuccessfulCreateMultipartCall(String mpuId, S3AsyncClient s3AsyncClient) { + CompletableFuture createMultipartUploadFuture = + CompletableFuture.completedFuture(CreateMultipartUploadResponse.builder() + .uploadId(mpuId) + .build()); + + when(s3AsyncClient.createMultipartUpload(any(CreateMultipartUploadRequest.class))) + .thenReturn(createMultipartUploadFuture); + } + + public static void stubSuccessfulCompleteMultipartCall(String bucket, String key, S3AsyncClient s3AsyncClient) { + CompletableFuture completeMultipartUploadFuture = + CompletableFuture.completedFuture(CompleteMultipartUploadResponse.builder() + .bucket(bucket) + .key(key) + .build()); + + when(s3AsyncClient.completeMultipartUpload(any(CompleteMultipartUploadRequest.class))) + .thenReturn(completeMultipartUploadFuture); + } +} diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelperTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelperTest.java new file mode 100644 index 000000000000..0db53c246e03 --- /dev/null +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelperTest.java @@ -0,0 +1,250 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.internal.multipart; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static software.amazon.awssdk.services.s3.internal.multipart.MpuTestUtils.stubSuccessfulCompleteMultipartCall; + +import java.io.IOException; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; +import org.mockito.stubbing.OngoingStubbing; +import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.core.exception.SdkClientException; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.model.AbortMultipartUploadRequest; +import software.amazon.awssdk.services.s3.model.AbortMultipartUploadResponse; +import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadRequest; +import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadResponse; +import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; +import software.amazon.awssdk.services.s3.model.CreateMultipartUploadResponse; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectResponse; +import software.amazon.awssdk.services.s3.model.UploadPartRequest; +import software.amazon.awssdk.services.s3.model.UploadPartResponse; +import software.amazon.awssdk.testutils.RandomTempFile; +import software.amazon.awssdk.utils.CompletableFutureUtils; + +public class MultipartUploadHelperTest { + + private static final String BUCKET = "bucket"; + private static final String KEY = "key"; + private static final long PART_SIZE = 8 * 1024; + + // Should contain four parts: [8KB, 8KB, 8KB, 1KB] + private static final long MPU_CONTENT_SIZE = 25 * 1024; + private static final long THRESHOLD = 10 * 1024; + private static final String UPLOAD_ID = "1234"; + + private static RandomTempFile testFile; + private MultipartUploadHelper uploadHelper; + private S3AsyncClient s3AsyncClient; + + @BeforeAll + public static void beforeAll() throws IOException { + testFile = new RandomTempFile("testfile.dat", MPU_CONTENT_SIZE); + } + + @AfterAll + public static void afterAll() throws Exception { + testFile.delete(); + } + + @BeforeEach + public void beforeEach() { + s3AsyncClient = Mockito.mock(S3AsyncClient.class); + uploadHelper = new MultipartUploadHelper(s3AsyncClient, PART_SIZE, THRESHOLD, PART_SIZE * 2); + } + + @ParameterizedTest + @ValueSource(longs = {THRESHOLD, PART_SIZE, THRESHOLD - 1, PART_SIZE - 1}) + public void uploadObject_doesNotExceedThresholdAndPartSize_shouldUploadInOneChunk(long contentLength) { + PutObjectRequest putObjectRequest = putObjectRequest(contentLength); + AsyncRequestBody asyncRequestBody = Mockito.mock(AsyncRequestBody.class); + + CompletableFuture completedFuture = + CompletableFuture.completedFuture(PutObjectResponse.builder().build()); + when(s3AsyncClient.putObject(putObjectRequest, asyncRequestBody)).thenReturn(completedFuture); + uploadHelper.uploadObject(putObjectRequest, asyncRequestBody).join(); + Mockito.verify(s3AsyncClient).putObject(putObjectRequest, asyncRequestBody); + } + + @Test + public void uploadObject_contentLengthExceedThresholdAndPartSize_shouldUseMPU() { + PutObjectRequest putObjectRequest = putObjectRequest(null); + + MpuTestUtils.stubSuccessfulCreateMultipartCall(UPLOAD_ID, s3AsyncClient); + stubSuccessfulUploadPartCalls(); + stubSuccessfulCompleteMultipartCall(BUCKET, KEY, s3AsyncClient); + + uploadHelper.uploadObject(putObjectRequest, AsyncRequestBody.fromFile(testFile)).join(); + ArgumentCaptor requestArgumentCaptor = ArgumentCaptor.forClass(UploadPartRequest.class); + ArgumentCaptor requestBodyArgumentCaptor = ArgumentCaptor.forClass(AsyncRequestBody.class); + verify(s3AsyncClient, times(4)).uploadPart(requestArgumentCaptor.capture(), + requestBodyArgumentCaptor.capture()); + + List actualRequests = requestArgumentCaptor.getAllValues(); + List actualRequestBodies = requestBodyArgumentCaptor.getAllValues(); + assertThat(actualRequestBodies).hasSize(4); + assertThat(actualRequests).hasSize(4); + + for (int i = 0; i < actualRequests.size(); i++) { + UploadPartRequest request = actualRequests.get(i); + AsyncRequestBody requestBody = actualRequestBodies.get(i); + assertThat(request.partNumber()).isEqualTo( i + 1); + assertThat(request.bucket()).isEqualTo(BUCKET); + assertThat(request.key()).isEqualTo(KEY); + + if (i == actualRequests.size() - 1) { + assertThat(requestBody.contentLength()).hasValue(1024L); + } else{ + assertThat(requestBody.contentLength()).hasValue(PART_SIZE); + } + } + } + + /** + * The second part failed, it should cancel ongoing part(first part). + */ + @Test + void mpu_onePartFailed_shouldFailOtherPartsAndAbort() { + PutObjectRequest putObjectRequest = putObjectRequest(MPU_CONTENT_SIZE); + + MpuTestUtils.stubSuccessfulCreateMultipartCall(UPLOAD_ID, s3AsyncClient); + CompletableFuture ongoingRequest = new CompletableFuture<>(); + + SdkClientException exception = SdkClientException.create("request failed"); + + OngoingStubbing> ongoingStubbing = + when(s3AsyncClient.uploadPart(any(UploadPartRequest.class), any(AsyncRequestBody.class))).thenReturn(ongoingRequest); + + stubFailedUploadPartCalls(ongoingStubbing, exception); + + when(s3AsyncClient.abortMultipartUpload(any(AbortMultipartUploadRequest.class))) + .thenReturn(CompletableFuture.completedFuture(AbortMultipartUploadResponse.builder().build())); + + CompletableFuture future = uploadHelper.uploadObject(putObjectRequest, + AsyncRequestBody.fromFile(testFile)); + + assertThatThrownBy(future::join).hasMessageContaining("Failed to send multipart upload requests").hasRootCause(exception); + + verify(s3AsyncClient, never()).completeMultipartUpload(any(CompleteMultipartUploadRequest.class)); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(AbortMultipartUploadRequest.class); + verify(s3AsyncClient).abortMultipartUpload(argumentCaptor.capture()); + AbortMultipartUploadRequest actualRequest = argumentCaptor.getValue(); + assertThat(actualRequest.uploadId()).isEqualTo(UPLOAD_ID); + + assertThat(ongoingRequest).isCompletedExceptionally(); + } + + @Test + void upload_cancelResponseFuture_shouldPropagate() { + PutObjectRequest putObjectRequest = putObjectRequest(null); + + CompletableFuture createMultipartFuture = new CompletableFuture<>(); + + when(s3AsyncClient.createMultipartUpload(any(CreateMultipartUploadRequest.class))) + .thenReturn(createMultipartFuture); + + CompletableFuture future = + uploadHelper.uploadObject(putObjectRequest, AsyncRequestBody.fromFile(testFile)); + + future.cancel(true); + + assertThat(createMultipartFuture).isCancelled(); + } + + @Test + public void uploadObject_completeMultipartFailed_shouldFailAndAbort() { + PutObjectRequest putObjectRequest = putObjectRequest(null); + + MpuTestUtils.stubSuccessfulCreateMultipartCall(UPLOAD_ID, s3AsyncClient); + stubSuccessfulUploadPartCalls(); + + SdkClientException exception = SdkClientException.create("CompleteMultipartUpload failed"); + + CompletableFuture completeMultipartUploadFuture = + CompletableFutureUtils.failedFuture(exception); + + when(s3AsyncClient.completeMultipartUpload(any(CompleteMultipartUploadRequest.class))) + .thenReturn(completeMultipartUploadFuture); + + when(s3AsyncClient.abortMultipartUpload(any(AbortMultipartUploadRequest.class))) + .thenReturn(CompletableFuture.completedFuture(AbortMultipartUploadResponse.builder().build())); + + CompletableFuture future = uploadHelper.uploadObject(putObjectRequest, AsyncRequestBody.fromFile(testFile)); + assertThatThrownBy(future::join).hasMessageContaining("Failed to send multipart requests").hasRootCause(exception); + } + + private static PutObjectRequest putObjectRequest(Long contentLength) { + return PutObjectRequest.builder() + .bucket(BUCKET) + .key(KEY) + .contentLength(contentLength) + .build(); + } + + private void stubSuccessfulUploadPartCalls() { + when(s3AsyncClient.uploadPart(any(UploadPartRequest.class), any(AsyncRequestBody.class))) + .thenAnswer(new Answer>() { + int numberOfCalls = 0; + + @Override + public CompletableFuture answer(InvocationOnMock invocationOnMock) { + AsyncRequestBody AsyncRequestBody = invocationOnMock.getArgument(1); + // Draining the request body + AsyncRequestBody.subscribe(b -> {}); + + numberOfCalls++; + return CompletableFuture.completedFuture(UploadPartResponse.builder() + .checksumCRC32("crc" + numberOfCalls) + .build()); + } + }); + } + + private OngoingStubbing> stubFailedUploadPartCalls(OngoingStubbing> stubbing, Exception exception) { + return stubbing.thenAnswer(new Answer>() { + + @Override + public CompletableFuture answer(InvocationOnMock invocationOnMock) { + AsyncRequestBody AsyncRequestBody = invocationOnMock.getArgument(1); + // Draining the request body + AsyncRequestBody.subscribe(b -> {}); + + return CompletableFutureUtils.failedFuture(exception); + } + }); + } + +} diff --git a/utils/src/main/java/software/amazon/awssdk/utils/async/SimplePublisher.java b/utils/src/main/java/software/amazon/awssdk/utils/async/SimplePublisher.java index 15bba8a0aaf1..11d029ee96c2 100644 --- a/utils/src/main/java/software/amazon/awssdk/utils/async/SimplePublisher.java +++ b/utils/src/main/java/software/amazon/awssdk/utils/async/SimplePublisher.java @@ -382,7 +382,7 @@ public void request(long n) { @Override public void cancel() { - log.trace(() -> "Received cancel()"); + log.trace(() -> "Received cancel() from " + subscriber); // Create exception here instead of in supplier to preserve a more-useful stack trace. highPriorityQueue.add(new CancelQueueEntry<>());