Skip to content

Support null content length in SplittingPublisher #4173

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,11 @@
/**
* Splits an {@link SdkPublisher} to multiple smaller {@link AsyncRequestBody}s, each of which publishes a specific portion of the
* original data.
*
* <p>If content length is known, each {@link AsyncRequestBody} is sent to the subscriber right after it's initialized.
* Otherwise, it is sent after the entire content for that chunk is buffered. This is required to get content length.
*
* // TODO: create a default method in AsyncRequestBody for this
* // TODO: fix the case where content length is null
*/
@SdkInternalApi
public class SplittingPublisher implements SdkPublisher<AsyncRequestBody> {
Expand Down Expand Up @@ -86,6 +89,7 @@ private class SplittingSubscriber implements Subscriber<ByteBuffer> {
* A hint to determine whether we will exceed maxMemoryUsage by the next OnNext call.
*/
private int byteBufferSizeHint;
private volatile boolean upstreamComplete;

SplittingSubscriber(Long upstreamSize) {
this.upstreamSize = upstreamSize;
Expand All @@ -94,36 +98,49 @@ private class SplittingSubscriber implements Subscriber<ByteBuffer> {
@Override
public void onSubscribe(Subscription s) {
this.upstreamSubscription = s;
this.currentBody = new DownstreamBody(calculateChunkSize(), chunkNumber.get());
sendCurrentBody();
this.currentBody =
initializeNextDownstreamBody(upstreamSize != null, calculateChunkSize(upstreamSize),
chunkNumber.get());
// We need to request subscription *after* we set currentBody because onNext could be invoked right away.
upstreamSubscription.request(1);
}

private DownstreamBody initializeNextDownstreamBody(boolean contentLengthKnown, long chunkSize, int chunkNumber) {
DownstreamBody body = new DownstreamBody(contentLengthKnown, chunkSize, chunkNumber);
if (contentLengthKnown) {
sendCurrentBody(body);
}
return body;
}

@Override
public void onNext(ByteBuffer byteBuffer) {
hasOpenUpstreamDemand.set(false);
byteBufferSizeHint = byteBuffer.remaining();

while (true) {
int amountRemainingInPart = amountRemainingInPart();
int finalAmountRemainingInPart = amountRemainingInPart;
if (amountRemainingInPart == 0) {
currentBody.complete();
int currentChunk = chunkNumber.incrementAndGet();
Long partSize = calculateChunkSize();
currentBody = new DownstreamBody(partSize, currentChunk);
sendCurrentBody();
int amountRemainingInChunk = amountRemainingInChunk();

// If we have fulfilled this chunk,
// we should create a new DownstreamBody if needed
if (amountRemainingInChunk == 0) {
completeCurrentBody();

if (shouldCreateNewDownstreamRequestBody(byteBuffer)) {
int currentChunk = chunkNumber.incrementAndGet();
long chunkSize = calculateChunkSize(totalDataRemaining());
currentBody = initializeNextDownstreamBody(upstreamSize != null, chunkSize, currentChunk);
}
}

amountRemainingInPart = amountRemainingInPart();
if (amountRemainingInPart >= byteBuffer.remaining()) {
amountRemainingInChunk = amountRemainingInChunk();
if (amountRemainingInChunk >= byteBuffer.remaining()) {
currentBody.send(byteBuffer.duplicate());
break;
}

ByteBuffer firstHalf = byteBuffer.duplicate();
int newLimit = firstHalf.position() + amountRemainingInPart;
int newLimit = firstHalf.position() + amountRemainingInChunk;
firstHalf.limit(newLimit);
byteBuffer.position(newLimit);
currentBody.send(firstHalf);
Expand All @@ -132,33 +149,50 @@ public void onNext(ByteBuffer byteBuffer) {
maybeRequestMoreUpstreamData();
}

private int amountRemainingInPart() {
return Math.toIntExact(currentBody.totalLength - currentBody.transferredLength);

/**
* If content length is known, we should create new DownstreamRequestBody if there's remaining data.
* If content length is unknown, we should create new DownstreamRequestBody if upstream is not completed yet.
*/
private boolean shouldCreateNewDownstreamRequestBody(ByteBuffer byteBuffer) {
return !upstreamComplete || byteBuffer.remaining() > 0;
}

private int amountRemainingInChunk() {
return Math.toIntExact(currentBody.maxLength - currentBody.transferredLength);
}

private void completeCurrentBody() {
currentBody.complete();
if (upstreamSize == null) {
sendCurrentBody(currentBody);
}
}

@Override
public void onComplete() {
upstreamComplete = true;
log.trace(() -> "Received onComplete()");
completeCurrentBody();
downstreamPublisher.complete().thenRun(() -> future.complete(null));
currentBody.complete();
}

@Override
public void onError(Throwable t) {
currentBody.error(t);
}

private void sendCurrentBody() {
downstreamPublisher.send(currentBody).exceptionally(t -> {
private void sendCurrentBody(AsyncRequestBody body) {
downstreamPublisher.send(body).exceptionally(t -> {
downstreamPublisher.error(t);
return null;
});
}

private Long calculateChunkSize() {
Long dataRemaining = dataRemaining();
private long calculateChunkSize(Long dataRemaining) {
// Use default chunk size if the content length is unknown
if (dataRemaining == null) {
return null;
return chunkSizeInBytes;
}

return Math.min(chunkSizeInBytes, dataRemaining);
Expand All @@ -177,27 +211,34 @@ private boolean shouldRequestMoreData(long buffered) {
return buffered == 0 || buffered + byteBufferSizeHint < maxMemoryUsageInBytes;
}

private Long dataRemaining() {
private Long totalDataRemaining() {
if (upstreamSize == null) {
return null;
}
return upstreamSize - (chunkNumber.get() * chunkSizeInBytes);
}

private class DownstreamBody implements AsyncRequestBody {
private final class DownstreamBody implements AsyncRequestBody {

/**
* The maximum length of the content this AsyncRequestBody can hold.
* If the upstream content length is known, this is the same as totalLength
*/
private final long maxLength;
private final Long totalLength;
private final SimplePublisher<ByteBuffer> delegate = new SimplePublisher<>();
private final int chunkNumber;
private volatile long transferredLength = 0;

private DownstreamBody(Long totalLength, int chunkNumber) {
this.totalLength = totalLength;
private DownstreamBody(boolean contentLengthKnown, long maxLength, int chunkNumber) {
this.totalLength = contentLengthKnown ? maxLength : null;
this.maxLength = maxLength;
this.chunkNumber = chunkNumber;
}

@Override
public Optional<Long> contentLength() {
return Optional.ofNullable(totalLength);
return totalLength != null ? Optional.of(totalLength) : Optional.of(transferredLength);
}

public void send(ByteBuffer data) {
Expand All @@ -214,8 +255,12 @@ public void send(ByteBuffer data) {
}

public void complete() {
log.debug(() -> "Received complete() for chunk number: " + chunkNumber);
delegate.complete();
log.debug(() -> "Received complete() for chunk number: " + chunkNumber + " length " + transferredLength);
delegate.complete().whenComplete((r, t) -> {
if (t != null) {
error(t);
}
});
}

public void error(Throwable error) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
import static software.amazon.awssdk.utils.FunctionalUtils.invokeSafely;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.FileInputStream;
import java.io.IOException;
Expand All @@ -28,6 +29,7 @@
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.commons.lang3.RandomStringUtils;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
Expand All @@ -44,6 +46,8 @@ public class SplittingPublisherTest {
private static final int CHUNK_SIZE = 5;

private static final int CONTENT_SIZE = 101;
private static final byte[] CONTENT =
RandomStringUtils.randomAscii(CONTENT_SIZE).getBytes(Charset.defaultCharset());

private static final int NUM_OF_CHUNK = (int) Math.ceil(CONTENT_SIZE / (double) CHUNK_SIZE);

Expand Down Expand Up @@ -123,9 +127,59 @@ void cancelFuture_shouldCancelUpstream() throws IOException {
assertThat(downstreamSubscriber.asyncRequestBodies.size()).isEqualTo(1);
}

private static final class TestAsyncRequestBody implements AsyncRequestBody {
private static final byte[] CONTENT = RandomStringUtils.random(200).getBytes(Charset.defaultCharset());
private boolean cancelled;
@Test
void contentLengthNotPresent_shouldHandle() throws Exception {
CompletableFuture<Void> future = new CompletableFuture<>();
TestAsyncRequestBody asyncRequestBody = new TestAsyncRequestBody() {
@Override
public Optional<Long> contentLength() {
return Optional.empty();
}
};
SplittingPublisher splittingPublisher = SplittingPublisher.builder()
.resultFuture(future)
.asyncRequestBody(asyncRequestBody)
.chunkSizeInBytes((long) CHUNK_SIZE)
.maxMemoryUsageInBytes(10L)
.build();


List<CompletableFuture<byte[]>> futures = new ArrayList<>();
AtomicInteger index = new AtomicInteger(0);

splittingPublisher.subscribe(requestBody -> {
CompletableFuture<byte[]> baosFuture = new CompletableFuture<>();
BaosSubscriber subscriber = new BaosSubscriber(baosFuture);
futures.add(baosFuture);
requestBody.subscribe(subscriber);
if (index.incrementAndGet() == NUM_OF_CHUNK) {
assertThat(requestBody.contentLength()).hasValue(1L);
} else {
assertThat(requestBody.contentLength()).hasValue((long) CHUNK_SIZE);
}
}).get(5, TimeUnit.SECONDS);
assertThat(futures.size()).isEqualTo(NUM_OF_CHUNK);

for (int i = 0; i < futures.size(); i++) {
try (ByteArrayInputStream inputStream = new ByteArrayInputStream(CONTENT)) {
byte[] expected;
if (i == futures.size() - 1) {
expected = new byte[1];
} else {
expected = new byte[CHUNK_SIZE];
}
inputStream.skip(i * CHUNK_SIZE);
inputStream.read(expected);
byte[] actualBytes = futures.get(i).join();
assertThat(actualBytes).isEqualTo(expected);
};
}

}

private static class TestAsyncRequestBody implements AsyncRequestBody {
private volatile boolean cancelled;
private volatile boolean isDone;

@Override
public Optional<Long> contentLength() {
Expand All @@ -137,8 +191,13 @@ public void subscribe(Subscriber<? super ByteBuffer> s) {
s.onSubscribe(new Subscription() {
@Override
public void request(long n) {
if (isDone) {
return;
}
isDone = true;
s.onNext(ByteBuffer.wrap(CONTENT));
s.onComplete();

}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ public CompletableFuture<PutObjectResponse> uploadObject(PutObjectRequest putObj
AsyncRequestBody asyncRequestBody) {
Long contentLength = asyncRequestBody.contentLength().orElseGet(putObjectRequest::contentLength);

// TODO: support null content length. Should be trivial to support it now
// TODO: support null content length. Need to determine whether to use single object or MPU based on the first
// AsyncRequestBody
if (contentLength == null) {
throw new IllegalArgumentException("Content-length is required");
}
Expand Down