diff --git a/.changes/next-release/bugfix-AWSSDKforJavav2-70dee34.json b/.changes/next-release/bugfix-AWSSDKforJavav2-70dee34.json new file mode 100644 index 000000000000..3d384dcfb83d --- /dev/null +++ b/.changes/next-release/bugfix-AWSSDKforJavav2-70dee34.json @@ -0,0 +1,5 @@ +{ + "category": "AWS SDK for Java v2", + "type": "bugfix", + "description": "ChecksumValidatingPublisher deals with any packetization of the incoming data. See https://github.com/aws/aws-sdk-java-v2/issues/965" +} diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/checksums/ChecksumValidatingPublisher.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/checksums/ChecksumValidatingPublisher.java index 81646e566c42..400f3131f101 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/checksums/ChecksumValidatingPublisher.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/checksums/ChecksumValidatingPublisher.java @@ -45,7 +45,11 @@ public ChecksumValidatingPublisher(Publisher publisher, @Override public void subscribe(Subscriber s) { - publisher.subscribe(new ChecksumValidatingSubscriber(s, sdkChecksum, contentLength)); + if (contentLength > 0) { + publisher.subscribe(new ChecksumValidatingSubscriber(s, sdkChecksum, contentLength)); + } else { + publisher.subscribe(new ChecksumSkippingSubscriber(s)); + } } private static class ChecksumValidatingSubscriber implements Subscriber { @@ -80,14 +84,34 @@ public void onNext(ByteBuffer byteBuffer) { int toUpdate = (int) Math.min(strippedLength - lengthRead, buf.length); sdkChecksum.update(buf, 0, toUpdate); - lengthRead += buf.length; } + lengthRead += buf.length; if (lengthRead >= strippedLength) { - int offset = toIntExact(lengthRead - strippedLength); - streamChecksum = Arrays.copyOfRange(buf, buf.length - offset, buf.length); - wrapped.onNext(ByteBuffer.wrap(Arrays.copyOfRange(buf, 0, buf.length - offset))); + // Incoming buffer contains at least a bit of the checksum + // Code below covers both cases of the incoming buffer relative to checksum border + // a) buffer starts before checksum border and extends into checksum + // |<------ data ------->|<--cksum-->| <--- original data + // |<---buffer--->| <--- incoming buffer + // |<------->| <--- checksum bytes so far + // |<-->| <--- bufChecksumOffset + // | <--- streamChecksumOffset + // b) buffer starts at or after checksum border + // |<------ data ------->|<--cksum-->| <--- original data + // |<-->| <--- incoming buffer + // |<------>| <--- checksum bytes so far + // | <--- bufChecksumOffset + // |<->| <--- streamChecksumOffset + int cksumBytesSoFar = toIntExact(lengthRead - strippedLength); + int bufChecksumOffset = (buf.length > cksumBytesSoFar) ? (buf.length - cksumBytesSoFar) : 0; + int streamChecksumOffset = (buf.length > cksumBytesSoFar) ? 0 : (cksumBytesSoFar - buf.length); + int cksumBytes = Math.min(cksumBytesSoFar, buf.length); + System.arraycopy(buf, bufChecksumOffset, streamChecksum, streamChecksumOffset, cksumBytes); + if (buf.length > cksumBytesSoFar) { + wrapped.onNext(ByteBuffer.wrap(Arrays.copyOfRange(buf, 0, buf.length - cksumBytesSoFar))); + } } else { + // Incoming buffer totally excludes the checksum wrapped.onNext(byteBuffer); } } @@ -111,4 +135,36 @@ public void onComplete() { wrapped.onComplete(); } } + + private static class ChecksumSkippingSubscriber implements Subscriber { + private static final int CHECKSUM_SIZE = 16; + + private final Subscriber wrapped; + + ChecksumSkippingSubscriber(Subscriber wrapped) { + this.wrapped = wrapped; + } + + @Override + public void onSubscribe(Subscription s) { + wrapped.onSubscribe(s); + } + + @Override + public void onNext(ByteBuffer byteBuffer) { + byte[] buf = BinaryUtils.copyBytesFrom(byteBuffer); + wrapped.onNext(ByteBuffer.wrap(Arrays.copyOfRange(buf, 0, buf.length - CHECKSUM_SIZE))); + } + + @Override + public void onError(Throwable t) { + wrapped.onError(t); + } + + @Override + public void onComplete() { + wrapped.onComplete(); + } + } + } diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/checksums/ChecksumValidatingPublisherTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/checksums/ChecksumValidatingPublisherTest.java new file mode 100644 index 000000000000..b80c49e78323 --- /dev/null +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/checksums/ChecksumValidatingPublisherTest.java @@ -0,0 +1,189 @@ +/* + * Copyright 2010-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.checksums; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import org.junit.BeforeClass; +import org.junit.Test; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; + +import software.amazon.awssdk.core.checksums.Md5Checksum; + +/** + * Unit test for ChecksumValidatingPublisher + */ +public class ChecksumValidatingPublisherTest { + private static int TEST_DATA_SIZE = 32; // size of the test data, in bytes + private static final int CHECKSUM_SIZE = 16; + private static byte[] testData; + + @BeforeClass + public static void populateData() { + testData = new byte[TEST_DATA_SIZE + CHECKSUM_SIZE]; + for (int i = 0; i < TEST_DATA_SIZE; i++) { + testData[i] = (byte)(i & 0x7f); + } + final Md5Checksum checksum = new Md5Checksum(); + checksum.update(testData, 0, TEST_DATA_SIZE); + byte[] checksumBytes = checksum.getChecksumBytes(); + for (int i = 0; i < CHECKSUM_SIZE; i++) { + testData[TEST_DATA_SIZE + i] = checksumBytes[i]; + } + } + + @Test + public void testSinglePacket() { + final TestPublisher driver = new TestPublisher(); + final TestSubscriber s = new TestSubscriber(Arrays.copyOfRange(testData, 0, TEST_DATA_SIZE)); + final ChecksumValidatingPublisher p = new ChecksumValidatingPublisher(driver, new Md5Checksum(), TEST_DATA_SIZE + CHECKSUM_SIZE); + p.subscribe(s); + + driver.doOnNext(ByteBuffer.wrap(testData)); + driver.doOnComplete(); + + assertTrue(s.hasCompleted()); + } + + @Test + public void testTwoPackets() { + for (int i = 1; i < TEST_DATA_SIZE + CHECKSUM_SIZE - 1; i++) { + final TestPublisher driver = new TestPublisher(); + final TestSubscriber s = new TestSubscriber(Arrays.copyOfRange(testData, 0, TEST_DATA_SIZE)); + final ChecksumValidatingPublisher p = new ChecksumValidatingPublisher(driver, new Md5Checksum(), TEST_DATA_SIZE + CHECKSUM_SIZE); + p.subscribe(s); + + driver.doOnNext(ByteBuffer.wrap(testData, 0, i)); + driver.doOnNext(ByteBuffer.wrap(testData, i, TEST_DATA_SIZE + CHECKSUM_SIZE - i)); + driver.doOnComplete(); + + assertTrue(s.hasCompleted()); + } + } + + @Test + public void testTinyPackets() { + for (int packetSize = 1; packetSize < CHECKSUM_SIZE; packetSize++) { + final TestPublisher driver = new TestPublisher(); + final TestSubscriber s = new TestSubscriber(Arrays.copyOfRange(testData, 0, TEST_DATA_SIZE)); + final ChecksumValidatingPublisher p = new ChecksumValidatingPublisher(driver, new Md5Checksum(), TEST_DATA_SIZE + CHECKSUM_SIZE); + p.subscribe(s); + int currOffset = 0; + while (currOffset < TEST_DATA_SIZE + CHECKSUM_SIZE) { + final int toSend = Math.min(packetSize, TEST_DATA_SIZE + CHECKSUM_SIZE - currOffset); + driver.doOnNext(ByteBuffer.wrap(testData, currOffset, toSend)); + currOffset += toSend; + } + driver.doOnComplete(); + + assertTrue(s.hasCompleted()); + } + } + + @Test + public void testUnknownLength() { + // When the length is unknown, the last 16 bytes are treated as a checksum, but are later ignored when completing + final TestPublisher driver = new TestPublisher(); + final TestSubscriber s = new TestSubscriber(Arrays.copyOfRange(testData, 0, TEST_DATA_SIZE)); + final ChecksumValidatingPublisher p = new ChecksumValidatingPublisher(driver, new Md5Checksum(), 0); + p.subscribe(s); + + byte[] randomChecksumData = new byte[testData.length]; + System.arraycopy(testData, 0, randomChecksumData, 0, TEST_DATA_SIZE); + for (int i = TEST_DATA_SIZE; i < randomChecksumData.length; i++) { + randomChecksumData[i] = (byte)((testData[i] + 1) & 0x7f); + } + + driver.doOnNext(ByteBuffer.wrap(randomChecksumData)); + driver.doOnComplete(); + + assertTrue(s.hasCompleted()); + } + + private class TestSubscriber implements Subscriber { + final byte[] expected; + final List received; + boolean completed; + + TestSubscriber(byte[] expected) { + this.expected = expected; + this.received = new ArrayList<>(); + this.completed = false; + } + + @Override + public void onSubscribe(Subscription s) { + fail("This method not expected to be invoked"); + throw new UnsupportedOperationException("!!!TODO: implement this"); + } + + @Override + public void onNext(ByteBuffer buffer) { + received.add(buffer); + } + + + @Override + public void onError(Throwable t) { + fail("Test failed"); + } + + + @Override + public void onComplete() { + int matchPos = 0; + for (ByteBuffer buffer : received) { + byte[] bufferData = new byte[buffer.limit() - buffer.position()]; + buffer.get(bufferData); + assertArrayEquals(Arrays.copyOfRange(expected, matchPos, matchPos + bufferData.length), bufferData); + matchPos += bufferData.length; + } + assertEquals(expected.length, matchPos); + completed = true; + } + + public boolean hasCompleted() { + return completed; + } + } + + private class TestPublisher implements Publisher { + Subscriber s; + + @Override + public void subscribe(Subscriber s) { + this.s = s; + } + + public void doOnNext(ByteBuffer b) { + s.onNext(b); + } + + public void doOnComplete() { + s.onComplete(); + } + } +}