From e5427fec1a0d6f91d139ad6c41081a02753967f7 Mon Sep 17 00:00:00 2001 From: Zoe Wang <33073555+zoewangg@users.noreply.github.com> Date: Mon, 10 Jul 2023 16:56:07 -0700 Subject: [PATCH] Fix null content length in SplittingPublisher --- .../internal/async/SplittingPublisher.java | 103 +++++++++++++----- .../async/SplittingPublisherTest.java | 65 ++++++++++- .../multipart/MultipartUploadHelper.java | 3 +- 3 files changed, 138 insertions(+), 33 deletions(-) diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SplittingPublisher.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SplittingPublisher.java index 095d69ac5e7d..8152e13980a6 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SplittingPublisher.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SplittingPublisher.java @@ -33,8 +33,11 @@ /** * Splits an {@link SdkPublisher} to multiple smaller {@link AsyncRequestBody}s, each of which publishes a specific portion of the * original data. + * + *

If content length is known, each {@link AsyncRequestBody} is sent to the subscriber right after it's initialized. + * Otherwise, it is sent after the entire content for that chunk is buffered. This is required to get content length. + * * // TODO: create a default method in AsyncRequestBody for this - * // TODO: fix the case where content length is null */ @SdkInternalApi public class SplittingPublisher implements SdkPublisher { @@ -86,6 +89,7 @@ private class SplittingSubscriber implements Subscriber { * A hint to determine whether we will exceed maxMemoryUsage by the next OnNext call. */ private int byteBufferSizeHint; + private volatile boolean upstreamComplete; SplittingSubscriber(Long upstreamSize) { this.upstreamSize = upstreamSize; @@ -94,36 +98,49 @@ private class SplittingSubscriber implements Subscriber { @Override public void onSubscribe(Subscription s) { this.upstreamSubscription = s; - this.currentBody = new DownstreamBody(calculateChunkSize(), chunkNumber.get()); - sendCurrentBody(); + this.currentBody = + initializeNextDownstreamBody(upstreamSize != null, calculateChunkSize(upstreamSize), + chunkNumber.get()); // We need to request subscription *after* we set currentBody because onNext could be invoked right away. upstreamSubscription.request(1); } + private DownstreamBody initializeNextDownstreamBody(boolean contentLengthKnown, long chunkSize, int chunkNumber) { + DownstreamBody body = new DownstreamBody(contentLengthKnown, chunkSize, chunkNumber); + if (contentLengthKnown) { + sendCurrentBody(body); + } + return body; + } + @Override public void onNext(ByteBuffer byteBuffer) { hasOpenUpstreamDemand.set(false); byteBufferSizeHint = byteBuffer.remaining(); while (true) { - int amountRemainingInPart = amountRemainingInPart(); - int finalAmountRemainingInPart = amountRemainingInPart; - if (amountRemainingInPart == 0) { - currentBody.complete(); - int currentChunk = chunkNumber.incrementAndGet(); - Long partSize = calculateChunkSize(); - currentBody = new DownstreamBody(partSize, currentChunk); - sendCurrentBody(); + int amountRemainingInChunk = amountRemainingInChunk(); + + // If we have fulfilled this chunk, + // we should create a new DownstreamBody if needed + if (amountRemainingInChunk == 0) { + completeCurrentBody(); + + if (shouldCreateNewDownstreamRequestBody(byteBuffer)) { + int currentChunk = chunkNumber.incrementAndGet(); + long chunkSize = calculateChunkSize(totalDataRemaining()); + currentBody = initializeNextDownstreamBody(upstreamSize != null, chunkSize, currentChunk); + } } - amountRemainingInPart = amountRemainingInPart(); - if (amountRemainingInPart >= byteBuffer.remaining()) { + amountRemainingInChunk = amountRemainingInChunk(); + if (amountRemainingInChunk >= byteBuffer.remaining()) { currentBody.send(byteBuffer.duplicate()); break; } ByteBuffer firstHalf = byteBuffer.duplicate(); - int newLimit = firstHalf.position() + amountRemainingInPart; + int newLimit = firstHalf.position() + amountRemainingInChunk; firstHalf.limit(newLimit); byteBuffer.position(newLimit); currentBody.send(firstHalf); @@ -132,15 +149,32 @@ public void onNext(ByteBuffer byteBuffer) { maybeRequestMoreUpstreamData(); } - private int amountRemainingInPart() { - return Math.toIntExact(currentBody.totalLength - currentBody.transferredLength); + + /** + * If content length is known, we should create new DownstreamRequestBody if there's remaining data. + * If content length is unknown, we should create new DownstreamRequestBody if upstream is not completed yet. + */ + private boolean shouldCreateNewDownstreamRequestBody(ByteBuffer byteBuffer) { + return !upstreamComplete || byteBuffer.remaining() > 0; + } + + private int amountRemainingInChunk() { + return Math.toIntExact(currentBody.maxLength - currentBody.transferredLength); + } + + private void completeCurrentBody() { + currentBody.complete(); + if (upstreamSize == null) { + sendCurrentBody(currentBody); + } } @Override public void onComplete() { + upstreamComplete = true; log.trace(() -> "Received onComplete()"); + completeCurrentBody(); downstreamPublisher.complete().thenRun(() -> future.complete(null)); - currentBody.complete(); } @Override @@ -148,17 +182,17 @@ public void onError(Throwable t) { currentBody.error(t); } - private void sendCurrentBody() { - downstreamPublisher.send(currentBody).exceptionally(t -> { + private void sendCurrentBody(AsyncRequestBody body) { + downstreamPublisher.send(body).exceptionally(t -> { downstreamPublisher.error(t); return null; }); } - private Long calculateChunkSize() { - Long dataRemaining = dataRemaining(); + private long calculateChunkSize(Long dataRemaining) { + // Use default chunk size if the content length is unknown if (dataRemaining == null) { - return null; + return chunkSizeInBytes; } return Math.min(chunkSizeInBytes, dataRemaining); @@ -177,27 +211,34 @@ private boolean shouldRequestMoreData(long buffered) { return buffered == 0 || buffered + byteBufferSizeHint < maxMemoryUsageInBytes; } - private Long dataRemaining() { + private Long totalDataRemaining() { if (upstreamSize == null) { return null; } return upstreamSize - (chunkNumber.get() * chunkSizeInBytes); } - private class DownstreamBody implements AsyncRequestBody { + private final class DownstreamBody implements AsyncRequestBody { + + /** + * The maximum length of the content this AsyncRequestBody can hold. + * If the upstream content length is known, this is the same as totalLength + */ + private final long maxLength; private final Long totalLength; private final SimplePublisher delegate = new SimplePublisher<>(); private final int chunkNumber; private volatile long transferredLength = 0; - private DownstreamBody(Long totalLength, int chunkNumber) { - this.totalLength = totalLength; + private DownstreamBody(boolean contentLengthKnown, long maxLength, int chunkNumber) { + this.totalLength = contentLengthKnown ? maxLength : null; + this.maxLength = maxLength; this.chunkNumber = chunkNumber; } @Override public Optional contentLength() { - return Optional.ofNullable(totalLength); + return totalLength != null ? Optional.of(totalLength) : Optional.of(transferredLength); } public void send(ByteBuffer data) { @@ -214,8 +255,12 @@ public void send(ByteBuffer data) { } public void complete() { - log.debug(() -> "Received complete() for chunk number: " + chunkNumber); - delegate.complete(); + log.debug(() -> "Received complete() for chunk number: " + chunkNumber + " length " + transferredLength); + delegate.complete().whenComplete((r, t) -> { + if (t != null) { + error(t); + } + }); } public void error(Throwable error) { diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/SplittingPublisherTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/SplittingPublisherTest.java index df318190b92d..45938ea684c8 100644 --- a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/SplittingPublisherTest.java +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/SplittingPublisherTest.java @@ -18,6 +18,7 @@ import static org.assertj.core.api.AssertionsForClassTypes.assertThat; import static software.amazon.awssdk.utils.FunctionalUtils.invokeSafely; +import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.FileInputStream; import java.io.IOException; @@ -28,6 +29,7 @@ import java.util.Optional; import java.util.concurrent.CompletableFuture; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; import org.apache.commons.lang3.RandomStringUtils; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; @@ -44,6 +46,8 @@ public class SplittingPublisherTest { private static final int CHUNK_SIZE = 5; private static final int CONTENT_SIZE = 101; + private static final byte[] CONTENT = + RandomStringUtils.randomAscii(CONTENT_SIZE).getBytes(Charset.defaultCharset()); private static final int NUM_OF_CHUNK = (int) Math.ceil(CONTENT_SIZE / (double) CHUNK_SIZE); @@ -123,9 +127,59 @@ void cancelFuture_shouldCancelUpstream() throws IOException { assertThat(downstreamSubscriber.asyncRequestBodies.size()).isEqualTo(1); } - private static final class TestAsyncRequestBody implements AsyncRequestBody { - private static final byte[] CONTENT = RandomStringUtils.random(200).getBytes(Charset.defaultCharset()); - private boolean cancelled; + @Test + void contentLengthNotPresent_shouldHandle() throws Exception { + CompletableFuture future = new CompletableFuture<>(); + TestAsyncRequestBody asyncRequestBody = new TestAsyncRequestBody() { + @Override + public Optional contentLength() { + return Optional.empty(); + } + }; + SplittingPublisher splittingPublisher = SplittingPublisher.builder() + .resultFuture(future) + .asyncRequestBody(asyncRequestBody) + .chunkSizeInBytes((long) CHUNK_SIZE) + .maxMemoryUsageInBytes(10L) + .build(); + + + List> futures = new ArrayList<>(); + AtomicInteger index = new AtomicInteger(0); + + splittingPublisher.subscribe(requestBody -> { + CompletableFuture baosFuture = new CompletableFuture<>(); + BaosSubscriber subscriber = new BaosSubscriber(baosFuture); + futures.add(baosFuture); + requestBody.subscribe(subscriber); + if (index.incrementAndGet() == NUM_OF_CHUNK) { + assertThat(requestBody.contentLength()).hasValue(1L); + } else { + assertThat(requestBody.contentLength()).hasValue((long) CHUNK_SIZE); + } + }).get(5, TimeUnit.SECONDS); + assertThat(futures.size()).isEqualTo(NUM_OF_CHUNK); + + for (int i = 0; i < futures.size(); i++) { + try (ByteArrayInputStream inputStream = new ByteArrayInputStream(CONTENT)) { + byte[] expected; + if (i == futures.size() - 1) { + expected = new byte[1]; + } else { + expected = new byte[CHUNK_SIZE]; + } + inputStream.skip(i * CHUNK_SIZE); + inputStream.read(expected); + byte[] actualBytes = futures.get(i).join(); + assertThat(actualBytes).isEqualTo(expected); + }; + } + + } + + private static class TestAsyncRequestBody implements AsyncRequestBody { + private volatile boolean cancelled; + private volatile boolean isDone; @Override public Optional contentLength() { @@ -137,8 +191,13 @@ public void subscribe(Subscriber s) { s.onSubscribe(new Subscription() { @Override public void request(long n) { + if (isDone) { + return; + } + isDone = true; s.onNext(ByteBuffer.wrap(CONTENT)); s.onComplete(); + } @Override diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelper.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelper.java index d043d88936c6..7502dd1a9743 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelper.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelper.java @@ -70,7 +70,8 @@ public CompletableFuture uploadObject(PutObjectRequest putObj AsyncRequestBody asyncRequestBody) { Long contentLength = asyncRequestBody.contentLength().orElseGet(putObjectRequest::contentLength); - // TODO: support null content length. Should be trivial to support it now + // TODO: support null content length. Need to determine whether to use single object or MPU based on the first + // AsyncRequestBody if (contentLength == null) { throw new IllegalArgumentException("Content-length is required"); }