Skip to content

Commit d998908

Browse files
authored
Fix null content length in SplittingPublisher (#4173)
1 parent 910b30f commit d998908

File tree

3 files changed

+138
-33
lines changed

3 files changed

+138
-33
lines changed

core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SplittingPublisher.java

Lines changed: 74 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,11 @@
3333
/**
3434
* Splits an {@link SdkPublisher} to multiple smaller {@link AsyncRequestBody}s, each of which publishes a specific portion of the
3535
* original data.
36+
*
37+
* <p>If content length is known, each {@link AsyncRequestBody} is sent to the subscriber right after it's initialized.
38+
* Otherwise, it is sent after the entire content for that chunk is buffered. This is required to get content length.
39+
*
3640
* // TODO: create a default method in AsyncRequestBody for this
37-
* // TODO: fix the case where content length is null
3841
*/
3942
@SdkInternalApi
4043
public class SplittingPublisher implements SdkPublisher<AsyncRequestBody> {
@@ -86,6 +89,7 @@ private class SplittingSubscriber implements Subscriber<ByteBuffer> {
8689
* A hint to determine whether we will exceed maxMemoryUsage by the next OnNext call.
8790
*/
8891
private int byteBufferSizeHint;
92+
private volatile boolean upstreamComplete;
8993

9094
SplittingSubscriber(Long upstreamSize) {
9195
this.upstreamSize = upstreamSize;
@@ -94,36 +98,49 @@ private class SplittingSubscriber implements Subscriber<ByteBuffer> {
9498
@Override
9599
public void onSubscribe(Subscription s) {
96100
this.upstreamSubscription = s;
97-
this.currentBody = new DownstreamBody(calculateChunkSize(), chunkNumber.get());
98-
sendCurrentBody();
101+
this.currentBody =
102+
initializeNextDownstreamBody(upstreamSize != null, calculateChunkSize(upstreamSize),
103+
chunkNumber.get());
99104
// We need to request subscription *after* we set currentBody because onNext could be invoked right away.
100105
upstreamSubscription.request(1);
101106
}
102107

108+
private DownstreamBody initializeNextDownstreamBody(boolean contentLengthKnown, long chunkSize, int chunkNumber) {
109+
DownstreamBody body = new DownstreamBody(contentLengthKnown, chunkSize, chunkNumber);
110+
if (contentLengthKnown) {
111+
sendCurrentBody(body);
112+
}
113+
return body;
114+
}
115+
103116
@Override
104117
public void onNext(ByteBuffer byteBuffer) {
105118
hasOpenUpstreamDemand.set(false);
106119
byteBufferSizeHint = byteBuffer.remaining();
107120

108121
while (true) {
109-
int amountRemainingInPart = amountRemainingInPart();
110-
int finalAmountRemainingInPart = amountRemainingInPart;
111-
if (amountRemainingInPart == 0) {
112-
currentBody.complete();
113-
int currentChunk = chunkNumber.incrementAndGet();
114-
Long partSize = calculateChunkSize();
115-
currentBody = new DownstreamBody(partSize, currentChunk);
116-
sendCurrentBody();
122+
int amountRemainingInChunk = amountRemainingInChunk();
123+
124+
// If we have fulfilled this chunk,
125+
// we should create a new DownstreamBody if needed
126+
if (amountRemainingInChunk == 0) {
127+
completeCurrentBody();
128+
129+
if (shouldCreateNewDownstreamRequestBody(byteBuffer)) {
130+
int currentChunk = chunkNumber.incrementAndGet();
131+
long chunkSize = calculateChunkSize(totalDataRemaining());
132+
currentBody = initializeNextDownstreamBody(upstreamSize != null, chunkSize, currentChunk);
133+
}
117134
}
118135

119-
amountRemainingInPart = amountRemainingInPart();
120-
if (amountRemainingInPart >= byteBuffer.remaining()) {
136+
amountRemainingInChunk = amountRemainingInChunk();
137+
if (amountRemainingInChunk >= byteBuffer.remaining()) {
121138
currentBody.send(byteBuffer.duplicate());
122139
break;
123140
}
124141

125142
ByteBuffer firstHalf = byteBuffer.duplicate();
126-
int newLimit = firstHalf.position() + amountRemainingInPart;
143+
int newLimit = firstHalf.position() + amountRemainingInChunk;
127144
firstHalf.limit(newLimit);
128145
byteBuffer.position(newLimit);
129146
currentBody.send(firstHalf);
@@ -132,33 +149,50 @@ public void onNext(ByteBuffer byteBuffer) {
132149
maybeRequestMoreUpstreamData();
133150
}
134151

135-
private int amountRemainingInPart() {
136-
return Math.toIntExact(currentBody.totalLength - currentBody.transferredLength);
152+
153+
/**
154+
* If content length is known, we should create new DownstreamRequestBody if there's remaining data.
155+
* If content length is unknown, we should create new DownstreamRequestBody if upstream is not completed yet.
156+
*/
157+
private boolean shouldCreateNewDownstreamRequestBody(ByteBuffer byteBuffer) {
158+
return !upstreamComplete || byteBuffer.remaining() > 0;
159+
}
160+
161+
private int amountRemainingInChunk() {
162+
return Math.toIntExact(currentBody.maxLength - currentBody.transferredLength);
163+
}
164+
165+
private void completeCurrentBody() {
166+
currentBody.complete();
167+
if (upstreamSize == null) {
168+
sendCurrentBody(currentBody);
169+
}
137170
}
138171

139172
@Override
140173
public void onComplete() {
174+
upstreamComplete = true;
141175
log.trace(() -> "Received onComplete()");
176+
completeCurrentBody();
142177
downstreamPublisher.complete().thenRun(() -> future.complete(null));
143-
currentBody.complete();
144178
}
145179

146180
@Override
147181
public void onError(Throwable t) {
148182
currentBody.error(t);
149183
}
150184

151-
private void sendCurrentBody() {
152-
downstreamPublisher.send(currentBody).exceptionally(t -> {
185+
private void sendCurrentBody(AsyncRequestBody body) {
186+
downstreamPublisher.send(body).exceptionally(t -> {
153187
downstreamPublisher.error(t);
154188
return null;
155189
});
156190
}
157191

158-
private Long calculateChunkSize() {
159-
Long dataRemaining = dataRemaining();
192+
private long calculateChunkSize(Long dataRemaining) {
193+
// Use default chunk size if the content length is unknown
160194
if (dataRemaining == null) {
161-
return null;
195+
return chunkSizeInBytes;
162196
}
163197

164198
return Math.min(chunkSizeInBytes, dataRemaining);
@@ -177,27 +211,34 @@ private boolean shouldRequestMoreData(long buffered) {
177211
return buffered == 0 || buffered + byteBufferSizeHint < maxMemoryUsageInBytes;
178212
}
179213

180-
private Long dataRemaining() {
214+
private Long totalDataRemaining() {
181215
if (upstreamSize == null) {
182216
return null;
183217
}
184218
return upstreamSize - (chunkNumber.get() * chunkSizeInBytes);
185219
}
186220

187-
private class DownstreamBody implements AsyncRequestBody {
221+
private final class DownstreamBody implements AsyncRequestBody {
222+
223+
/**
224+
* The maximum length of the content this AsyncRequestBody can hold.
225+
* If the upstream content length is known, this is the same as totalLength
226+
*/
227+
private final long maxLength;
188228
private final Long totalLength;
189229
private final SimplePublisher<ByteBuffer> delegate = new SimplePublisher<>();
190230
private final int chunkNumber;
191231
private volatile long transferredLength = 0;
192232

193-
private DownstreamBody(Long totalLength, int chunkNumber) {
194-
this.totalLength = totalLength;
233+
private DownstreamBody(boolean contentLengthKnown, long maxLength, int chunkNumber) {
234+
this.totalLength = contentLengthKnown ? maxLength : null;
235+
this.maxLength = maxLength;
195236
this.chunkNumber = chunkNumber;
196237
}
197238

198239
@Override
199240
public Optional<Long> contentLength() {
200-
return Optional.ofNullable(totalLength);
241+
return totalLength != null ? Optional.of(totalLength) : Optional.of(transferredLength);
201242
}
202243

203244
public void send(ByteBuffer data) {
@@ -214,8 +255,12 @@ public void send(ByteBuffer data) {
214255
}
215256

216257
public void complete() {
217-
log.debug(() -> "Received complete() for chunk number: " + chunkNumber);
218-
delegate.complete();
258+
log.debug(() -> "Received complete() for chunk number: " + chunkNumber + " length " + transferredLength);
259+
delegate.complete().whenComplete((r, t) -> {
260+
if (t != null) {
261+
error(t);
262+
}
263+
});
219264
}
220265

221266
public void error(Throwable error) {

core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/SplittingPublisherTest.java

Lines changed: 62 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
1919
import static software.amazon.awssdk.utils.FunctionalUtils.invokeSafely;
2020

21+
import java.io.ByteArrayInputStream;
2122
import java.io.ByteArrayOutputStream;
2223
import java.io.FileInputStream;
2324
import java.io.IOException;
@@ -28,6 +29,7 @@
2829
import java.util.Optional;
2930
import java.util.concurrent.CompletableFuture;
3031
import java.util.concurrent.TimeUnit;
32+
import java.util.concurrent.atomic.AtomicInteger;
3133
import org.apache.commons.lang3.RandomStringUtils;
3234
import org.junit.jupiter.api.AfterAll;
3335
import org.junit.jupiter.api.BeforeAll;
@@ -44,6 +46,8 @@ public class SplittingPublisherTest {
4446
private static final int CHUNK_SIZE = 5;
4547

4648
private static final int CONTENT_SIZE = 101;
49+
private static final byte[] CONTENT =
50+
RandomStringUtils.randomAscii(CONTENT_SIZE).getBytes(Charset.defaultCharset());
4751

4852
private static final int NUM_OF_CHUNK = (int) Math.ceil(CONTENT_SIZE / (double) CHUNK_SIZE);
4953

@@ -123,9 +127,59 @@ void cancelFuture_shouldCancelUpstream() throws IOException {
123127
assertThat(downstreamSubscriber.asyncRequestBodies.size()).isEqualTo(1);
124128
}
125129

126-
private static final class TestAsyncRequestBody implements AsyncRequestBody {
127-
private static final byte[] CONTENT = RandomStringUtils.random(200).getBytes(Charset.defaultCharset());
128-
private boolean cancelled;
130+
@Test
131+
void contentLengthNotPresent_shouldHandle() throws Exception {
132+
CompletableFuture<Void> future = new CompletableFuture<>();
133+
TestAsyncRequestBody asyncRequestBody = new TestAsyncRequestBody() {
134+
@Override
135+
public Optional<Long> contentLength() {
136+
return Optional.empty();
137+
}
138+
};
139+
SplittingPublisher splittingPublisher = SplittingPublisher.builder()
140+
.resultFuture(future)
141+
.asyncRequestBody(asyncRequestBody)
142+
.chunkSizeInBytes((long) CHUNK_SIZE)
143+
.maxMemoryUsageInBytes(10L)
144+
.build();
145+
146+
147+
List<CompletableFuture<byte[]>> futures = new ArrayList<>();
148+
AtomicInteger index = new AtomicInteger(0);
149+
150+
splittingPublisher.subscribe(requestBody -> {
151+
CompletableFuture<byte[]> baosFuture = new CompletableFuture<>();
152+
BaosSubscriber subscriber = new BaosSubscriber(baosFuture);
153+
futures.add(baosFuture);
154+
requestBody.subscribe(subscriber);
155+
if (index.incrementAndGet() == NUM_OF_CHUNK) {
156+
assertThat(requestBody.contentLength()).hasValue(1L);
157+
} else {
158+
assertThat(requestBody.contentLength()).hasValue((long) CHUNK_SIZE);
159+
}
160+
}).get(5, TimeUnit.SECONDS);
161+
assertThat(futures.size()).isEqualTo(NUM_OF_CHUNK);
162+
163+
for (int i = 0; i < futures.size(); i++) {
164+
try (ByteArrayInputStream inputStream = new ByteArrayInputStream(CONTENT)) {
165+
byte[] expected;
166+
if (i == futures.size() - 1) {
167+
expected = new byte[1];
168+
} else {
169+
expected = new byte[CHUNK_SIZE];
170+
}
171+
inputStream.skip(i * CHUNK_SIZE);
172+
inputStream.read(expected);
173+
byte[] actualBytes = futures.get(i).join();
174+
assertThat(actualBytes).isEqualTo(expected);
175+
};
176+
}
177+
178+
}
179+
180+
private static class TestAsyncRequestBody implements AsyncRequestBody {
181+
private volatile boolean cancelled;
182+
private volatile boolean isDone;
129183

130184
@Override
131185
public Optional<Long> contentLength() {
@@ -137,8 +191,13 @@ public void subscribe(Subscriber<? super ByteBuffer> s) {
137191
s.onSubscribe(new Subscription() {
138192
@Override
139193
public void request(long n) {
194+
if (isDone) {
195+
return;
196+
}
197+
isDone = true;
140198
s.onNext(ByteBuffer.wrap(CONTENT));
141199
s.onComplete();
200+
142201
}
143202

144203
@Override

services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelper.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ public CompletableFuture<PutObjectResponse> uploadObject(PutObjectRequest putObj
6969
AsyncRequestBody asyncRequestBody) {
7070
Long contentLength = asyncRequestBody.contentLength().orElseGet(putObjectRequest::contentLength);
7171

72-
// TODO: support null content length. Should be trivial to support it now
72+
// TODO: support null content length. Need to determine whether to use single object or MPU based on the first
73+
// AsyncRequestBody
7374
if (contentLength == null) {
7475
throw new IllegalArgumentException("Content-length is required");
7576
}

0 commit comments

Comments
 (0)