Skip to content

Commit 820dd4f

Browse files
authored
Fixed an issue in ChecksumCalculatingAsyncRequestBody where the posit… (#4244)
* Fixed an issue in ChecksumCalculatingAsyncRequestBody where the position of the ByteBuffer was not honored. * Fix checkstyle * rename methods and variables * Add javadocs
1 parent 4420728 commit 820dd4f

File tree

4 files changed

+196
-120
lines changed

4 files changed

+196
-120
lines changed

core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/ChecksumCalculatingAsyncRequestBody.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ private static final class SynchronousChunkBuffer {
239239
}
240240

241241
private Iterable<ByteBuffer> buffer(ByteBuffer bytes) {
242-
return chunkBuffer.bufferAndCreateChunks(bytes);
242+
return chunkBuffer.split(bytes);
243243
}
244244
}
245245

core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/ChunkBuffer.java

Lines changed: 95 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,11 @@
1919

2020
import java.nio.ByteBuffer;
2121
import java.util.ArrayList;
22+
import java.util.Collections;
2223
import java.util.List;
2324
import java.util.concurrent.atomic.AtomicLong;
2425
import software.amazon.awssdk.annotations.SdkInternalApi;
25-
import software.amazon.awssdk.utils.BinaryUtils;
26+
import software.amazon.awssdk.utils.Logger;
2627
import software.amazon.awssdk.utils.Validate;
2728
import software.amazon.awssdk.utils.builder.SdkBuilder;
2829

@@ -31,70 +32,118 @@
3132
*/
3233
@SdkInternalApi
3334
public final class ChunkBuffer {
34-
private final AtomicLong remainingBytes;
35+
private static final Logger log = Logger.loggerFor(ChunkBuffer.class);
36+
private final AtomicLong transferredBytes;
3537
private final ByteBuffer currentBuffer;
36-
private final int bufferSize;
38+
private final int chunkSize;
39+
private final long totalBytes;
3740

3841
private ChunkBuffer(Long totalBytes, Integer bufferSize) {
3942
Validate.notNull(totalBytes, "The totalBytes must not be null");
4043

4144
int chunkSize = bufferSize != null ? bufferSize : DEFAULT_ASYNC_CHUNK_SIZE;
42-
this.bufferSize = chunkSize;
45+
this.chunkSize = chunkSize;
4346
this.currentBuffer = ByteBuffer.allocate(chunkSize);
44-
this.remainingBytes = new AtomicLong(totalBytes);
47+
this.totalBytes = totalBytes;
48+
this.transferredBytes = new AtomicLong(0);
4549
}
4650

4751
public static Builder builder() {
4852
return new DefaultBuilder();
4953
}
5054

5155

52-
// currentBuffer and bufferedList can get over written if concurrent Threads calls this method at the same time.
53-
public synchronized Iterable<ByteBuffer> bufferAndCreateChunks(ByteBuffer buffer) {
54-
int startPosition = 0;
55-
List<ByteBuffer> bufferedList = new ArrayList<>();
56-
int currentBytesRead = buffer.remaining();
57-
do {
58-
int bufferedBytes = currentBuffer.position();
59-
int availableToRead = bufferSize - bufferedBytes;
60-
int bytesToMove = Math.min(availableToRead, currentBytesRead - startPosition);
56+
/**
57+
* Split the input {@link ByteBuffer} into multiple smaller {@link ByteBuffer}s, each of which contains {@link #chunkSize}
58+
* worth of bytes. If the last chunk of the input ByteBuffer contains less than {@link #chunkSize} data, the last chunk will
59+
* be buffered.
60+
*/
61+
public synchronized Iterable<ByteBuffer> split(ByteBuffer inputByteBuffer) {
6162

62-
byte[] bytes = BinaryUtils.copyAllBytesFrom(buffer);
63-
if (bufferedBytes == 0) {
64-
currentBuffer.put(bytes, startPosition, bytesToMove);
65-
} else {
66-
currentBuffer.put(bytes, 0, bytesToMove);
63+
if (!inputByteBuffer.hasRemaining()) {
64+
return Collections.singletonList(inputByteBuffer);
65+
}
66+
67+
List<ByteBuffer> byteBuffers = new ArrayList<>();
68+
69+
// If current buffer is not empty, fill the buffer first.
70+
if (currentBuffer.position() != 0) {
71+
fillCurrentBuffer(inputByteBuffer);
72+
73+
if (isCurrentBufferFull()) {
74+
addCurrentBufferToIterable(byteBuffers, chunkSize);
75+
}
76+
}
77+
78+
// If the input buffer is not empty, split the input buffer
79+
if (inputByteBuffer.hasRemaining()) {
80+
splitRemainingInputByteBuffer(inputByteBuffer, byteBuffers);
81+
}
82+
83+
// If this is the last chunk, add data buffered to the iterable
84+
if (isLastChunk()) {
85+
int remainingBytesInBuffer = currentBuffer.position();
86+
addCurrentBufferToIterable(byteBuffers, remainingBytesInBuffer);
87+
}
88+
return byteBuffers;
89+
}
90+
91+
private boolean isCurrentBufferFull() {
92+
return currentBuffer.position() == chunkSize;
93+
}
94+
95+
/**
96+
* Splits the input ByteBuffer to multiple chunks and add them to the iterable.
97+
*/
98+
private void splitRemainingInputByteBuffer(ByteBuffer inputByteBuffer, List<ByteBuffer> byteBuffers) {
99+
while (inputByteBuffer.hasRemaining()) {
100+
ByteBuffer inputByteBufferCopy = inputByteBuffer.asReadOnlyBuffer();
101+
if (inputByteBuffer.remaining() < chunkSize) {
102+
currentBuffer.put(inputByteBuffer);
103+
break;
67104
}
68105

69-
startPosition = startPosition + bytesToMove;
70-
71-
// Send the data once the buffer is full
72-
if (currentBuffer.position() == bufferSize) {
73-
currentBuffer.position(0);
74-
ByteBuffer bufferToSend = ByteBuffer.allocate(bufferSize);
75-
bufferToSend.put(currentBuffer.array(), 0, bufferSize);
76-
bufferToSend.clear();
77-
currentBuffer.clear();
78-
bufferedList.add(bufferToSend);
79-
remainingBytes.addAndGet(-bufferSize);
106+
int newLimit = inputByteBufferCopy.position() + chunkSize;
107+
inputByteBufferCopy.limit(newLimit);
108+
inputByteBuffer.position(newLimit);
109+
byteBuffers.add(inputByteBufferCopy);
110+
transferredBytes.addAndGet(chunkSize);
111+
}
112+
}
113+
114+
private boolean isLastChunk() {
115+
long remainingBytes = totalBytes - transferredBytes.get();
116+
return remainingBytes != 0 && remainingBytes == currentBuffer.position();
117+
}
118+
119+
private void addCurrentBufferToIterable(List<ByteBuffer> byteBuffers, int capacity) {
120+
ByteBuffer bufferedChunk = ByteBuffer.allocate(capacity);
121+
currentBuffer.flip();
122+
bufferedChunk.put(currentBuffer);
123+
bufferedChunk.flip();
124+
byteBuffers.add(bufferedChunk);
125+
transferredBytes.addAndGet(bufferedChunk.remaining());
126+
currentBuffer.clear();
127+
}
128+
129+
private void fillCurrentBuffer(ByteBuffer inputByteBuffer) {
130+
while (currentBuffer.position() < chunkSize) {
131+
if (!inputByteBuffer.hasRemaining()) {
132+
break;
133+
}
134+
135+
int remainingCapacity = chunkSize - currentBuffer.position();
136+
137+
if (inputByteBuffer.remaining() < remainingCapacity) {
138+
currentBuffer.put(inputByteBuffer);
139+
} else {
140+
ByteBuffer remainingChunk = inputByteBuffer.asReadOnlyBuffer();
141+
int newLimit = inputByteBuffer.position() + remainingCapacity;
142+
remainingChunk.limit(newLimit);
143+
inputByteBuffer.position(newLimit);
144+
currentBuffer.put(remainingChunk);
80145
}
81-
} while (startPosition < currentBytesRead);
82-
83-
int remainingBytesInBuffer = currentBuffer.position();
84-
85-
// Send the remaining buffer when
86-
// 1. remainingBytes in buffer are same as the last few bytes to be read.
87-
// 2. If it is a zero byte and the last byte to be read.
88-
if (remainingBytes.get() == remainingBytesInBuffer &&
89-
(buffer.remaining() == 0 || remainingBytesInBuffer > 0)) {
90-
currentBuffer.clear();
91-
ByteBuffer trimmedBuffer = ByteBuffer.allocate(remainingBytesInBuffer);
92-
trimmedBuffer.put(currentBuffer.array(), 0, remainingBytesInBuffer);
93-
trimmedBuffer.clear();
94-
bufferedList.add(trimmedBuffer);
95-
remainingBytes.addAndGet(-remainingBytesInBuffer);
96146
}
97-
return bufferedList;
98147
}
99148

100149
public interface Builder extends SdkBuilder<Builder, ChunkBuffer> {

core/sdk-core/src/test/java/software/amazon/awssdk/core/async/ChunkBufferTest.java

Lines changed: 35 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
import static org.assertj.core.api.Assertions.assertThat;
1919
import static org.assertj.core.api.Assertions.assertThatThrownBy;
2020

21+
import java.io.ByteArrayInputStream;
22+
import java.io.IOException;
2123
import java.nio.ByteBuffer;
2224
import java.nio.charset.StandardCharsets;
2325
import java.util.ArrayList;
@@ -29,8 +31,12 @@
2931
import java.util.concurrent.atomic.AtomicInteger;
3032
import java.util.stream.Collectors;
3133
import java.util.stream.IntStream;
34+
import org.apache.commons.lang3.RandomStringUtils;
3235
import org.junit.jupiter.api.Test;
36+
import org.junit.jupiter.params.ParameterizedTest;
37+
import org.junit.jupiter.params.provider.ValueSource;
3338
import software.amazon.awssdk.core.internal.async.ChunkBuffer;
39+
import software.amazon.awssdk.utils.BinaryUtils;
3440
import software.amazon.awssdk.utils.StringUtils;
3541

3642
class ChunkBufferTest {
@@ -40,42 +46,38 @@ void builderWithNoTotalSize() {
4046
assertThatThrownBy(() -> ChunkBuffer.builder().build()).isInstanceOf(NullPointerException.class);
4147
}
4248

43-
@Test
44-
void numberOfChunkMultipleOfTotalBytes() {
45-
String inputString = StringUtils.repeat("*", 25);
46-
47-
ChunkBuffer chunkBuffer =
48-
ChunkBuffer.builder().bufferSize(5).totalBytes(inputString.getBytes(StandardCharsets.UTF_8).length).build();
49-
Iterable<ByteBuffer> byteBuffers =
50-
chunkBuffer.bufferAndCreateChunks(ByteBuffer.wrap(inputString.getBytes(StandardCharsets.UTF_8)));
51-
52-
AtomicInteger iteratedCounts = new AtomicInteger();
53-
byteBuffers.forEach(r -> {
54-
iteratedCounts.getAndIncrement();
55-
assertThat(r.array()).isEqualTo(StringUtils.repeat("*", 5).getBytes(StandardCharsets.UTF_8));
56-
});
57-
assertThat(iteratedCounts.get()).isEqualTo(5);
58-
}
59-
60-
@Test
61-
void numberOfChunk_Not_MultipleOfTotalBytes() {
62-
int totalBytes = 23;
49+
@ParameterizedTest
50+
@ValueSource(ints = {1, 6, 10, 23, 25})
51+
void numberOfChunk_Not_MultipleOfTotalBytes(int totalBytes) {
6352
int bufferSize = 5;
6453

65-
String inputString = StringUtils.repeat("*", totalBytes);
54+
String inputString = RandomStringUtils.randomAscii(totalBytes);
6655
ChunkBuffer chunkBuffer =
6756
ChunkBuffer.builder().bufferSize(bufferSize).totalBytes(inputString.getBytes(StandardCharsets.UTF_8).length).build();
6857
Iterable<ByteBuffer> byteBuffers =
69-
chunkBuffer.bufferAndCreateChunks(ByteBuffer.wrap(inputString.getBytes(StandardCharsets.UTF_8)));
58+
chunkBuffer.split(ByteBuffer.wrap(inputString.getBytes(StandardCharsets.UTF_8)));
59+
60+
AtomicInteger index = new AtomicInteger(0);
61+
int count = (int) Math.ceil(totalBytes / (double) bufferSize);
62+
int remainder = totalBytes % bufferSize;
7063

71-
AtomicInteger iteratedCounts = new AtomicInteger();
7264
byteBuffers.forEach(r -> {
73-
iteratedCounts.getAndIncrement();
74-
if (iteratedCounts.get() * bufferSize < totalBytes) {
75-
assertThat(r.array()).isEqualTo(StringUtils.repeat("*", bufferSize).getBytes(StandardCharsets.UTF_8));
76-
} else {
77-
assertThat(r.array()).isEqualTo(StringUtils.repeat("*", 3).getBytes(StandardCharsets.UTF_8));
65+
int i = index.get();
7866

67+
try (ByteArrayInputStream inputStream = new ByteArrayInputStream(inputString.getBytes(StandardCharsets.UTF_8))) {
68+
byte[] expected;
69+
if (i == count - 1 && remainder != 0) {
70+
expected = new byte[remainder];
71+
} else {
72+
expected = new byte[bufferSize];
73+
}
74+
inputStream.skip(i * bufferSize);
75+
inputStream.read(expected);
76+
byte[] actualBytes = BinaryUtils.copyBytesFrom(r);
77+
assertThat(actualBytes).isEqualTo(expected);
78+
index.incrementAndGet();
79+
} catch (IOException e) {
80+
throw new RuntimeException(e);
7981
}
8082
});
8183
}
@@ -86,7 +88,7 @@ void zeroTotalBytesAsInput_returnsZeroByte() {
8688
ChunkBuffer chunkBuffer =
8789
ChunkBuffer.builder().bufferSize(5).totalBytes(zeroByte.length).build();
8890
Iterable<ByteBuffer> byteBuffers =
89-
chunkBuffer.bufferAndCreateChunks(ByteBuffer.wrap(zeroByte));
91+
chunkBuffer.split(ByteBuffer.wrap(zeroByte));
9092

9193
AtomicInteger iteratedCounts = new AtomicInteger();
9294
byteBuffers.forEach(r -> {
@@ -104,16 +106,16 @@ void emptyAllocatedBytes_returnSameNumberOfEmptyBytes() {
104106
ChunkBuffer chunkBuffer =
105107
ChunkBuffer.builder().bufferSize(bufferSize).totalBytes(wrap.remaining()).build();
106108
Iterable<ByteBuffer> byteBuffers =
107-
chunkBuffer.bufferAndCreateChunks(wrap);
109+
chunkBuffer.split(wrap);
108110

109111
AtomicInteger iteratedCounts = new AtomicInteger();
110112
byteBuffers.forEach(r -> {
111113
iteratedCounts.getAndIncrement();
112114
if (iteratedCounts.get() * bufferSize < totalBytes) {
113115
// array of empty bytes
114-
assertThat(r.array()).isEqualTo(ByteBuffer.allocate(bufferSize).array());
116+
assertThat(BinaryUtils.copyBytesFrom(r)).isEqualTo(ByteBuffer.allocate(bufferSize).array());
115117
} else {
116-
assertThat(r.array()).isEqualTo(ByteBuffer.allocate(totalBytes % bufferSize).array());
118+
assertThat(BinaryUtils.copyBytesFrom(r)).isEqualTo(ByteBuffer.allocate(totalBytes % bufferSize).array());
117119
}
118120
});
119121
assertThat(iteratedCounts.get()).isEqualTo(4);
@@ -167,7 +169,7 @@ void concurrentTreads_calling_bufferAndCreateChunks() throws ExecutionException,
167169

168170
futures = IntStream.range(0, threads).<Future<Iterable>>mapToObj(t -> service.submit(() -> {
169171
String inputString = StringUtils.repeat(Integer.toString(counter.incrementAndGet()), totalBytes);
170-
return chunkBuffer.bufferAndCreateChunks(ByteBuffer.wrap(inputString.getBytes(StandardCharsets.UTF_8)));
172+
return chunkBuffer.split(ByteBuffer.wrap(inputString.getBytes(StandardCharsets.UTF_8)));
171173
})).collect(Collectors.toCollection(() -> new ArrayList<>(threads)));
172174

173175
AtomicInteger filledBuffers = new AtomicInteger(0);

0 commit comments

Comments
 (0)