From 513fb9d8e571b719a16a87af3f7f9cb32661e6be Mon Sep 17 00:00:00 2001 From: Matthew Miller Date: Thu, 5 Aug 2021 15:42:42 -0700 Subject: [PATCH] Fixed an issue where checksum validation only considered the first 4 bytes of the 16 byte checksum, creating the potential for corrupted downloads to go undetected. --- .../next-release/bugfix-AmazonS3-421839e.json | 6 ++ pom.xml | 5 +- .../ChecksumValidatingInputStream.java | 24 ++--- .../ChecksumValidatingPublisher.java | 9 +- .../ChecksumValidatingInputStreamTest.java | 87 +++++++++++++++++++ .../ChecksumValidatingPublisherTest.java | 63 ++++++++++---- 6 files changed, 152 insertions(+), 42 deletions(-) create mode 100644 .changes/next-release/bugfix-AmazonS3-421839e.json create mode 100644 services/s3/src/test/java/software/amazon/awssdk/services/s3/checksums/ChecksumValidatingInputStreamTest.java diff --git a/.changes/next-release/bugfix-AmazonS3-421839e.json b/.changes/next-release/bugfix-AmazonS3-421839e.json new file mode 100644 index 000000000000..117be49f54bb --- /dev/null +++ b/.changes/next-release/bugfix-AmazonS3-421839e.json @@ -0,0 +1,6 @@ +{ + "category": "Amazon S3", + "contributor": "", + "type": "bugfix", + "description": "Fixed an issue where checksum validation only considered the first 4 bytes of the 16 byte checksum, creating the potential for corrupted downloads to go undetected." +} diff --git a/pom.xml b/pom.xml index 760e167ec03c..415f8741b8e1 100644 --- a/pom.xml +++ b/pom.xml @@ -511,9 +511,8 @@ *.internal.* - - software.amazon.awssdk.core.util.json.JacksonUtils - software.amazon.awssdk.protocols.json.* + + software.amazon.awssdk.services.s3.checksums.ChecksumValidatingInputStream diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/checksums/ChecksumValidatingInputStream.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/checksums/ChecksumValidatingInputStream.java index ab089377e8cc..36fff298baf0 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/checksums/ChecksumValidatingInputStream.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/checksums/ChecksumValidatingInputStream.java @@ -17,11 +17,12 @@ import java.io.IOException; import java.io.InputStream; -import java.nio.ByteBuffer; +import java.util.Arrays; import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.core.checksums.SdkChecksum; import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.http.Abortable; +import software.amazon.awssdk.utils.BinaryUtils; @SdkInternalApi public class ChecksumValidatingInputStream extends InputStream implements Abortable { @@ -34,7 +35,7 @@ public class ChecksumValidatingInputStream extends InputStream implements Aborta private long lengthRead = 0; // Preserve the computed checksum because some InputStream readers (e.g., java.util.Properties) read more than once at the // end of the stream. - private Integer computedChecksum; + private byte[] computedChecksum; /** * Creates an input stream using the specified Checksum, input stream, and length. @@ -162,26 +163,15 @@ public void close() throws IOException { inputStream.close(); } - /** - * Gets the stream's checksum as an integer. - * - * @return checksum. - */ - public int getStreamChecksum() { - ByteBuffer bb = ByteBuffer.wrap(streamChecksum); - return bb.getInt(); - } - private void validateAndThrow() { - int streamChecksumInt = getStreamChecksum(); if (computedChecksum == null) { - computedChecksum = ByteBuffer.wrap(checkSum.getChecksumBytes()).getInt(); + computedChecksum = checkSum.getChecksumBytes(); } - if (streamChecksumInt != computedChecksum) { + if (!Arrays.equals(computedChecksum, streamChecksum)) { throw SdkClientException.builder().message( - String.format("Data read has a different checksum than expected. Was %d, but expected %d", - computedChecksum, streamChecksumInt)).build(); + String.format("Data read has a different checksum than expected. Was 0x%s, but expected 0x%s", + BinaryUtils.toHex(computedChecksum), BinaryUtils.toHex(streamChecksum))).build(); } } diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/checksums/ChecksumValidatingPublisher.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/checksums/ChecksumValidatingPublisher.java index 2c871470d84e..a3310331dd23 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/checksums/ChecksumValidatingPublisher.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/checksums/ChecksumValidatingPublisher.java @@ -132,12 +132,11 @@ public void onError(Throwable t) { @Override public void onComplete() { if (strippedLength > 0) { - int streamChecksumInt = ByteBuffer.wrap(streamChecksum).getInt(); - int computedChecksumInt = ByteBuffer.wrap(sdkChecksum.getChecksumBytes()).getInt(); - if (streamChecksumInt != computedChecksumInt) { + byte[] computedChecksum = sdkChecksum.getChecksumBytes(); + if (!Arrays.equals(computedChecksum, streamChecksum)) { onError(SdkClientException.create( - String.format("Data read has a different checksum than expected. Was %d, but expected %d", - computedChecksumInt, streamChecksumInt))); + String.format("Data read has a different checksum than expected. Was 0x%s, but expected 0x%s", + BinaryUtils.toHex(computedChecksum), BinaryUtils.toHex(streamChecksum)))); return; // Return after onError and not call onComplete below } } diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/checksums/ChecksumValidatingInputStreamTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/checksums/ChecksumValidatingInputStreamTest.java new file mode 100644 index 000000000000..a83bf45b4155 --- /dev/null +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/checksums/ChecksumValidatingInputStreamTest.java @@ -0,0 +1,87 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.checksums; + +import static org.junit.Assert.assertArrayEquals; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.util.Arrays; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; +import software.amazon.awssdk.core.checksums.Md5Checksum; +import software.amazon.awssdk.core.exception.SdkClientException; +import software.amazon.awssdk.utils.IoUtils; + +public class ChecksumValidatingInputStreamTest { + private static final int TEST_DATA_SIZE = 32; + private static final int CHECKSUM_SIZE = 16; + + private static byte[] testData; + private static byte[] testDataWithoutChecksum; + + @BeforeClass + public static void populateData() { + testData = new byte[TEST_DATA_SIZE + CHECKSUM_SIZE]; + for (int i = 0; i < TEST_DATA_SIZE; i++) { + testData[i] = (byte)(i & 0x7f); + } + + Md5Checksum checksum = new Md5Checksum(); + checksum.update(testData, 0, TEST_DATA_SIZE); + byte[] checksumBytes = checksum.getChecksumBytes(); + + for (int i = 0; i < CHECKSUM_SIZE; i++) { + testData[TEST_DATA_SIZE + i] = checksumBytes[i]; + } + + testDataWithoutChecksum = Arrays.copyOfRange(testData, 0, TEST_DATA_SIZE); + } + + @Test + public void validChecksumSucceeds() throws IOException { + InputStream validatingInputStream = newValidatingStream(testData); + byte[] dataFromValidatingStream = IoUtils.toByteArray(validatingInputStream); + + assertArrayEquals(testDataWithoutChecksum, dataFromValidatingStream); + } + + @Test + public void invalidChecksumFails() throws IOException { + for (int i = 0; i < testData.length; i++) { + // Make sure that corruption of any byte in the test data causes a checksum validation failure. + byte[] corruptedChecksumData = Arrays.copyOf(testData, testData.length); + corruptedChecksumData[i] = (byte) ~corruptedChecksumData[i]; + + InputStream validatingInputStream = newValidatingStream(corruptedChecksumData); + + try { + IoUtils.toByteArray(validatingInputStream); + Assert.fail("Corruption at byte " + i + " was not detected."); + } catch (SdkClientException e) { + // Expected + } + } + } + + private InputStream newValidatingStream(byte[] dataFromS3) { + return new ChecksumValidatingInputStream(new ByteArrayInputStream(dataFromS3), + new Md5Checksum(), + TEST_DATA_SIZE + CHECKSUM_SIZE); + } +} \ No newline at end of file diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/checksums/ChecksumValidatingPublisherTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/checksums/ChecksumValidatingPublisherTest.java index 935b656d8539..23027a2317fc 100644 --- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/checksums/ChecksumValidatingPublisherTest.java +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/checksums/ChecksumValidatingPublisherTest.java @@ -16,11 +16,13 @@ package software.amazon.awssdk.services.s3.checksums; import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Arrays; @@ -31,6 +33,7 @@ import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; import software.amazon.awssdk.core.checksums.Md5Checksum; +import software.amazon.awssdk.utils.BinaryUtils; /** * Unit test for ChecksumValidatingPublisher @@ -39,6 +42,7 @@ public class ChecksumValidatingPublisherTest { private static int TEST_DATA_SIZE = 32; // size of the test data, in bytes private static final int CHECKSUM_SIZE = 16; private static byte[] testData; + private static byte[] testDataWithoutChecksum; @BeforeClass public static void populateData() { @@ -52,27 +56,47 @@ public static void populateData() { for (int i = 0; i < CHECKSUM_SIZE; i++) { testData[TEST_DATA_SIZE + i] = checksumBytes[i]; } + + testDataWithoutChecksum = Arrays.copyOfRange(testData, 0, TEST_DATA_SIZE); } @Test public void testSinglePacket() { final TestPublisher driver = new TestPublisher(); - final TestSubscriber s = new TestSubscriber(Arrays.copyOfRange(testData, 0, TEST_DATA_SIZE)); + final TestSubscriber s = new TestSubscriber(); final ChecksumValidatingPublisher p = new ChecksumValidatingPublisher(driver, new Md5Checksum(), TEST_DATA_SIZE + CHECKSUM_SIZE); p.subscribe(s); driver.doOnNext(ByteBuffer.wrap(testData)); driver.doOnComplete(); + assertArrayEquals(testDataWithoutChecksum, s.receivedData()); assertTrue(s.hasCompleted()); assertFalse(s.isOnErrorCalled()); } + @Test + public void testLastChecksumByteCorrupted() { + TestPublisher driver = new TestPublisher(); + + TestSubscriber s = new TestSubscriber(); + ChecksumValidatingPublisher p = new ChecksumValidatingPublisher(driver, new Md5Checksum(), TEST_DATA_SIZE + CHECKSUM_SIZE); + p.subscribe(s); + + byte[] incorrectChecksumData = Arrays.copyOfRange(testData, 0, TEST_DATA_SIZE); + incorrectChecksumData[TEST_DATA_SIZE - 1] = (byte) ~incorrectChecksumData[TEST_DATA_SIZE - 1]; + driver.doOnNext(ByteBuffer.wrap(incorrectChecksumData)); + driver.doOnComplete(); + + assertFalse(s.hasCompleted()); + assertTrue(s.isOnErrorCalled()); + } + @Test public void testTwoPackets() { for (int i = 1; i < TEST_DATA_SIZE + CHECKSUM_SIZE - 1; i++) { final TestPublisher driver = new TestPublisher(); - final TestSubscriber s = new TestSubscriber(Arrays.copyOfRange(testData, 0, TEST_DATA_SIZE)); + final TestSubscriber s = new TestSubscriber(); final ChecksumValidatingPublisher p = new ChecksumValidatingPublisher(driver, new Md5Checksum(), TEST_DATA_SIZE + CHECKSUM_SIZE); p.subscribe(s); @@ -80,6 +104,7 @@ public void testTwoPackets() { driver.doOnNext(ByteBuffer.wrap(testData, i, TEST_DATA_SIZE + CHECKSUM_SIZE - i)); driver.doOnComplete(); + assertArrayEquals(testDataWithoutChecksum, s.receivedData()); assertTrue(s.hasCompleted()); assertFalse(s.isOnErrorCalled()); } @@ -89,7 +114,7 @@ public void testTwoPackets() { public void testTinyPackets() { for (int packetSize = 1; packetSize < CHECKSUM_SIZE; packetSize++) { final TestPublisher driver = new TestPublisher(); - final TestSubscriber s = new TestSubscriber(Arrays.copyOfRange(testData, 0, TEST_DATA_SIZE)); + final TestSubscriber s = new TestSubscriber(); final ChecksumValidatingPublisher p = new ChecksumValidatingPublisher(driver, new Md5Checksum(), TEST_DATA_SIZE + CHECKSUM_SIZE); p.subscribe(s); int currOffset = 0; @@ -100,6 +125,7 @@ public void testTinyPackets() { } driver.doOnComplete(); + assertArrayEquals(testDataWithoutChecksum, s.receivedData()); assertTrue(s.hasCompleted()); assertFalse(s.isOnErrorCalled()); } @@ -109,7 +135,7 @@ public void testTinyPackets() { public void testUnknownLength() { // When the length is unknown, the last 16 bytes are treated as a checksum, but are later ignored when completing final TestPublisher driver = new TestPublisher(); - final TestSubscriber s = new TestSubscriber(Arrays.copyOfRange(testData, 0, TEST_DATA_SIZE)); + final TestSubscriber s = new TestSubscriber(); final ChecksumValidatingPublisher p = new ChecksumValidatingPublisher(driver, new Md5Checksum(), 0); p.subscribe(s); @@ -122,6 +148,7 @@ public void testUnknownLength() { driver.doOnNext(ByteBuffer.wrap(randomChecksumData)); driver.doOnComplete(); + assertArrayEquals(testDataWithoutChecksum, s.receivedData()); assertTrue(s.hasCompleted()); assertFalse(s.isOnErrorCalled()); } @@ -130,7 +157,7 @@ public void testUnknownLength() { public void checksumValidationFailure_throwsSdkClientException_NotNPE() { final byte[] incorrectData = new byte[0]; final TestPublisher driver = new TestPublisher(); - final TestSubscriber s = new TestSubscriber(Arrays.copyOfRange(incorrectData, 0, TEST_DATA_SIZE)); + final TestSubscriber s = new TestSubscriber(); final ChecksumValidatingPublisher p = new ChecksumValidatingPublisher(driver, new Md5Checksum(), TEST_DATA_SIZE + CHECKSUM_SIZE); p.subscribe(s); @@ -142,13 +169,11 @@ public void checksumValidationFailure_throwsSdkClientException_NotNPE() { } private class TestSubscriber implements Subscriber { - final byte[] expected; final List received; boolean completed; boolean onErrorCalled; - TestSubscriber(byte[] expected) { - this.expected = expected; + TestSubscriber() { this.received = new ArrayList<>(); this.completed = false; } @@ -172,17 +197,21 @@ public void onError(Throwable t) { @Override public void onComplete() { - int matchPos = 0; - for (ByteBuffer buffer : received) { - byte[] bufferData = new byte[buffer.limit() - buffer.position()]; - buffer.get(bufferData); - assertArrayEquals(Arrays.copyOfRange(expected, matchPos, matchPos + bufferData.length), bufferData); - matchPos += bufferData.length; - } - assertEquals(expected.length, matchPos); completed = true; } + public byte[] receivedData() { + try { + ByteArrayOutputStream os = new ByteArrayOutputStream(); + for (ByteBuffer buffer : received) { + os.write(BinaryUtils.copyBytesFrom(buffer)); + } + return os.toByteArray(); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + public boolean hasCompleted() { return completed; }