From e5427fec1a0d6f91d139ad6c41081a02753967f7 Mon Sep 17 00:00:00 2001
From: Zoe Wang <33073555+zoewangg@users.noreply.github.com>
Date: Mon, 10 Jul 2023 16:56:07 -0700
Subject: [PATCH] Fix null content length in SplittingPublisher
---
.../internal/async/SplittingPublisher.java | 103 +++++++++++++-----
.../async/SplittingPublisherTest.java | 65 ++++++++++-
.../multipart/MultipartUploadHelper.java | 3 +-
3 files changed, 138 insertions(+), 33 deletions(-)
diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SplittingPublisher.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SplittingPublisher.java
index 095d69ac5e7d..8152e13980a6 100644
--- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SplittingPublisher.java
+++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SplittingPublisher.java
@@ -33,8 +33,11 @@
/**
* Splits an {@link SdkPublisher} to multiple smaller {@link AsyncRequestBody}s, each of which publishes a specific portion of the
* original data.
+ *
+ *
If content length is known, each {@link AsyncRequestBody} is sent to the subscriber right after it's initialized.
+ * Otherwise, it is sent after the entire content for that chunk is buffered. This is required to get content length.
+ *
* // TODO: create a default method in AsyncRequestBody for this
- * // TODO: fix the case where content length is null
*/
@SdkInternalApi
public class SplittingPublisher implements SdkPublisher {
@@ -86,6 +89,7 @@ private class SplittingSubscriber implements Subscriber {
* A hint to determine whether we will exceed maxMemoryUsage by the next OnNext call.
*/
private int byteBufferSizeHint;
+ private volatile boolean upstreamComplete;
SplittingSubscriber(Long upstreamSize) {
this.upstreamSize = upstreamSize;
@@ -94,36 +98,49 @@ private class SplittingSubscriber implements Subscriber {
@Override
public void onSubscribe(Subscription s) {
this.upstreamSubscription = s;
- this.currentBody = new DownstreamBody(calculateChunkSize(), chunkNumber.get());
- sendCurrentBody();
+ this.currentBody =
+ initializeNextDownstreamBody(upstreamSize != null, calculateChunkSize(upstreamSize),
+ chunkNumber.get());
// We need to request subscription *after* we set currentBody because onNext could be invoked right away.
upstreamSubscription.request(1);
}
+ private DownstreamBody initializeNextDownstreamBody(boolean contentLengthKnown, long chunkSize, int chunkNumber) {
+ DownstreamBody body = new DownstreamBody(contentLengthKnown, chunkSize, chunkNumber);
+ if (contentLengthKnown) {
+ sendCurrentBody(body);
+ }
+ return body;
+ }
+
@Override
public void onNext(ByteBuffer byteBuffer) {
hasOpenUpstreamDemand.set(false);
byteBufferSizeHint = byteBuffer.remaining();
while (true) {
- int amountRemainingInPart = amountRemainingInPart();
- int finalAmountRemainingInPart = amountRemainingInPart;
- if (amountRemainingInPart == 0) {
- currentBody.complete();
- int currentChunk = chunkNumber.incrementAndGet();
- Long partSize = calculateChunkSize();
- currentBody = new DownstreamBody(partSize, currentChunk);
- sendCurrentBody();
+ int amountRemainingInChunk = amountRemainingInChunk();
+
+ // If we have fulfilled this chunk,
+ // we should create a new DownstreamBody if needed
+ if (amountRemainingInChunk == 0) {
+ completeCurrentBody();
+
+ if (shouldCreateNewDownstreamRequestBody(byteBuffer)) {
+ int currentChunk = chunkNumber.incrementAndGet();
+ long chunkSize = calculateChunkSize(totalDataRemaining());
+ currentBody = initializeNextDownstreamBody(upstreamSize != null, chunkSize, currentChunk);
+ }
}
- amountRemainingInPart = amountRemainingInPart();
- if (amountRemainingInPart >= byteBuffer.remaining()) {
+ amountRemainingInChunk = amountRemainingInChunk();
+ if (amountRemainingInChunk >= byteBuffer.remaining()) {
currentBody.send(byteBuffer.duplicate());
break;
}
ByteBuffer firstHalf = byteBuffer.duplicate();
- int newLimit = firstHalf.position() + amountRemainingInPart;
+ int newLimit = firstHalf.position() + amountRemainingInChunk;
firstHalf.limit(newLimit);
byteBuffer.position(newLimit);
currentBody.send(firstHalf);
@@ -132,15 +149,32 @@ public void onNext(ByteBuffer byteBuffer) {
maybeRequestMoreUpstreamData();
}
- private int amountRemainingInPart() {
- return Math.toIntExact(currentBody.totalLength - currentBody.transferredLength);
+
+ /**
+ * If content length is known, we should create new DownstreamRequestBody if there's remaining data.
+ * If content length is unknown, we should create new DownstreamRequestBody if upstream is not completed yet.
+ */
+ private boolean shouldCreateNewDownstreamRequestBody(ByteBuffer byteBuffer) {
+ return !upstreamComplete || byteBuffer.remaining() > 0;
+ }
+
+ private int amountRemainingInChunk() {
+ return Math.toIntExact(currentBody.maxLength - currentBody.transferredLength);
+ }
+
+ private void completeCurrentBody() {
+ currentBody.complete();
+ if (upstreamSize == null) {
+ sendCurrentBody(currentBody);
+ }
}
@Override
public void onComplete() {
+ upstreamComplete = true;
log.trace(() -> "Received onComplete()");
+ completeCurrentBody();
downstreamPublisher.complete().thenRun(() -> future.complete(null));
- currentBody.complete();
}
@Override
@@ -148,17 +182,17 @@ public void onError(Throwable t) {
currentBody.error(t);
}
- private void sendCurrentBody() {
- downstreamPublisher.send(currentBody).exceptionally(t -> {
+ private void sendCurrentBody(AsyncRequestBody body) {
+ downstreamPublisher.send(body).exceptionally(t -> {
downstreamPublisher.error(t);
return null;
});
}
- private Long calculateChunkSize() {
- Long dataRemaining = dataRemaining();
+ private long calculateChunkSize(Long dataRemaining) {
+ // Use default chunk size if the content length is unknown
if (dataRemaining == null) {
- return null;
+ return chunkSizeInBytes;
}
return Math.min(chunkSizeInBytes, dataRemaining);
@@ -177,27 +211,34 @@ private boolean shouldRequestMoreData(long buffered) {
return buffered == 0 || buffered + byteBufferSizeHint < maxMemoryUsageInBytes;
}
- private Long dataRemaining() {
+ private Long totalDataRemaining() {
if (upstreamSize == null) {
return null;
}
return upstreamSize - (chunkNumber.get() * chunkSizeInBytes);
}
- private class DownstreamBody implements AsyncRequestBody {
+ private final class DownstreamBody implements AsyncRequestBody {
+
+ /**
+ * The maximum length of the content this AsyncRequestBody can hold.
+ * If the upstream content length is known, this is the same as totalLength
+ */
+ private final long maxLength;
private final Long totalLength;
private final SimplePublisher delegate = new SimplePublisher<>();
private final int chunkNumber;
private volatile long transferredLength = 0;
- private DownstreamBody(Long totalLength, int chunkNumber) {
- this.totalLength = totalLength;
+ private DownstreamBody(boolean contentLengthKnown, long maxLength, int chunkNumber) {
+ this.totalLength = contentLengthKnown ? maxLength : null;
+ this.maxLength = maxLength;
this.chunkNumber = chunkNumber;
}
@Override
public Optional contentLength() {
- return Optional.ofNullable(totalLength);
+ return totalLength != null ? Optional.of(totalLength) : Optional.of(transferredLength);
}
public void send(ByteBuffer data) {
@@ -214,8 +255,12 @@ public void send(ByteBuffer data) {
}
public void complete() {
- log.debug(() -> "Received complete() for chunk number: " + chunkNumber);
- delegate.complete();
+ log.debug(() -> "Received complete() for chunk number: " + chunkNumber + " length " + transferredLength);
+ delegate.complete().whenComplete((r, t) -> {
+ if (t != null) {
+ error(t);
+ }
+ });
}
public void error(Throwable error) {
diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/SplittingPublisherTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/SplittingPublisherTest.java
index df318190b92d..45938ea684c8 100644
--- a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/SplittingPublisherTest.java
+++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/SplittingPublisherTest.java
@@ -18,6 +18,7 @@
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
import static software.amazon.awssdk.utils.FunctionalUtils.invokeSafely;
+import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.FileInputStream;
import java.io.IOException;
@@ -28,6 +29,7 @@
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
import org.apache.commons.lang3.RandomStringUtils;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
@@ -44,6 +46,8 @@ public class SplittingPublisherTest {
private static final int CHUNK_SIZE = 5;
private static final int CONTENT_SIZE = 101;
+ private static final byte[] CONTENT =
+ RandomStringUtils.randomAscii(CONTENT_SIZE).getBytes(Charset.defaultCharset());
private static final int NUM_OF_CHUNK = (int) Math.ceil(CONTENT_SIZE / (double) CHUNK_SIZE);
@@ -123,9 +127,59 @@ void cancelFuture_shouldCancelUpstream() throws IOException {
assertThat(downstreamSubscriber.asyncRequestBodies.size()).isEqualTo(1);
}
- private static final class TestAsyncRequestBody implements AsyncRequestBody {
- private static final byte[] CONTENT = RandomStringUtils.random(200).getBytes(Charset.defaultCharset());
- private boolean cancelled;
+ @Test
+ void contentLengthNotPresent_shouldHandle() throws Exception {
+ CompletableFuture future = new CompletableFuture<>();
+ TestAsyncRequestBody asyncRequestBody = new TestAsyncRequestBody() {
+ @Override
+ public Optional contentLength() {
+ return Optional.empty();
+ }
+ };
+ SplittingPublisher splittingPublisher = SplittingPublisher.builder()
+ .resultFuture(future)
+ .asyncRequestBody(asyncRequestBody)
+ .chunkSizeInBytes((long) CHUNK_SIZE)
+ .maxMemoryUsageInBytes(10L)
+ .build();
+
+
+ List> futures = new ArrayList<>();
+ AtomicInteger index = new AtomicInteger(0);
+
+ splittingPublisher.subscribe(requestBody -> {
+ CompletableFuture baosFuture = new CompletableFuture<>();
+ BaosSubscriber subscriber = new BaosSubscriber(baosFuture);
+ futures.add(baosFuture);
+ requestBody.subscribe(subscriber);
+ if (index.incrementAndGet() == NUM_OF_CHUNK) {
+ assertThat(requestBody.contentLength()).hasValue(1L);
+ } else {
+ assertThat(requestBody.contentLength()).hasValue((long) CHUNK_SIZE);
+ }
+ }).get(5, TimeUnit.SECONDS);
+ assertThat(futures.size()).isEqualTo(NUM_OF_CHUNK);
+
+ for (int i = 0; i < futures.size(); i++) {
+ try (ByteArrayInputStream inputStream = new ByteArrayInputStream(CONTENT)) {
+ byte[] expected;
+ if (i == futures.size() - 1) {
+ expected = new byte[1];
+ } else {
+ expected = new byte[CHUNK_SIZE];
+ }
+ inputStream.skip(i * CHUNK_SIZE);
+ inputStream.read(expected);
+ byte[] actualBytes = futures.get(i).join();
+ assertThat(actualBytes).isEqualTo(expected);
+ };
+ }
+
+ }
+
+ private static class TestAsyncRequestBody implements AsyncRequestBody {
+ private volatile boolean cancelled;
+ private volatile boolean isDone;
@Override
public Optional contentLength() {
@@ -137,8 +191,13 @@ public void subscribe(Subscriber super ByteBuffer> s) {
s.onSubscribe(new Subscription() {
@Override
public void request(long n) {
+ if (isDone) {
+ return;
+ }
+ isDone = true;
s.onNext(ByteBuffer.wrap(CONTENT));
s.onComplete();
+
}
@Override
diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelper.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelper.java
index d043d88936c6..7502dd1a9743 100644
--- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelper.java
+++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelper.java
@@ -70,7 +70,8 @@ public CompletableFuture uploadObject(PutObjectRequest putObj
AsyncRequestBody asyncRequestBody) {
Long contentLength = asyncRequestBody.contentLength().orElseGet(putObjectRequest::contentLength);
- // TODO: support null content length. Should be trivial to support it now
+ // TODO: support null content length. Need to determine whether to use single object or MPU based on the first
+ // AsyncRequestBody
if (contentLength == null) {
throw new IllegalArgumentException("Content-length is required");
}