Skip to content

Fixed an issue where checksum validation only considered the first 4 bytes of the 16 byte checksum... #2646

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/next-release/bugfix-AmazonS3-421839e.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"category": "Amazon S3",
"contributor": "",
"type": "bugfix",
"description": "Fixed an issue where checksum validation only considered the first 4 bytes of the 16 byte checksum, creating the potential for corrupted downloads to go undetected."
}
5 changes: 2 additions & 3 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -511,9 +511,8 @@
<excludes>
<exclude>*.internal.*</exclude>

<!-- Jackson removal -->
<exclude>software.amazon.awssdk.core.util.json.JacksonUtils</exclude>
<exclude>software.amazon.awssdk.protocols.json.*</exclude>
<!-- Checksum bug fix -->
<exclude>software.amazon.awssdk.services.s3.checksums.ChecksumValidatingInputStream</exclude>
</excludes>

<excludeModules>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@

import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.util.Arrays;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.core.checksums.SdkChecksum;
import software.amazon.awssdk.core.exception.SdkClientException;
import software.amazon.awssdk.http.Abortable;
import software.amazon.awssdk.utils.BinaryUtils;

@SdkInternalApi
public class ChecksumValidatingInputStream extends InputStream implements Abortable {
Expand All @@ -34,7 +35,7 @@ public class ChecksumValidatingInputStream extends InputStream implements Aborta
private long lengthRead = 0;
// Preserve the computed checksum because some InputStream readers (e.g., java.util.Properties) read more than once at the
// end of the stream.
private Integer computedChecksum;
private byte[] computedChecksum;

/**
* Creates an input stream using the specified Checksum, input stream, and length.
Expand Down Expand Up @@ -162,26 +163,15 @@ public void close() throws IOException {
inputStream.close();
}

/**
* Gets the stream's checksum as an integer.
*
* @return checksum.
*/
public int getStreamChecksum() {
ByteBuffer bb = ByteBuffer.wrap(streamChecksum);
return bb.getInt();
}

private void validateAndThrow() {
int streamChecksumInt = getStreamChecksum();
if (computedChecksum == null) {
computedChecksum = ByteBuffer.wrap(checkSum.getChecksumBytes()).getInt();
computedChecksum = checkSum.getChecksumBytes();
}

if (streamChecksumInt != computedChecksum) {
if (!Arrays.equals(computedChecksum, streamChecksum)) {
throw SdkClientException.builder().message(
String.format("Data read has a different checksum than expected. Was %d, but expected %d",
computedChecksum, streamChecksumInt)).build();
String.format("Data read has a different checksum than expected. Was 0x%s, but expected 0x%s",
BinaryUtils.toHex(computedChecksum), BinaryUtils.toHex(streamChecksum))).build();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,12 +132,11 @@ public void onError(Throwable t) {
@Override
public void onComplete() {
if (strippedLength > 0) {
int streamChecksumInt = ByteBuffer.wrap(streamChecksum).getInt();
int computedChecksumInt = ByteBuffer.wrap(sdkChecksum.getChecksumBytes()).getInt();
if (streamChecksumInt != computedChecksumInt) {
byte[] computedChecksum = sdkChecksum.getChecksumBytes();
if (!Arrays.equals(computedChecksum, streamChecksum)) {
onError(SdkClientException.create(
String.format("Data read has a different checksum than expected. Was %d, but expected %d",
computedChecksumInt, streamChecksumInt)));
String.format("Data read has a different checksum than expected. Was 0x%s, but expected 0x%s",
BinaryUtils.toHex(computedChecksum), BinaryUtils.toHex(streamChecksum))));
return; // Return after onError and not call onComplete below
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package software.amazon.awssdk.services.s3.checksums;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: copyright


import static org.junit.Assert.assertArrayEquals;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.Arrays;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import software.amazon.awssdk.core.checksums.Md5Checksum;
import software.amazon.awssdk.core.exception.SdkClientException;
import software.amazon.awssdk.utils.IoUtils;

public class ChecksumValidatingInputStreamTest {
private static final int TEST_DATA_SIZE = 32;
private static final int CHECKSUM_SIZE = 16;

private static byte[] testData;
private static byte[] testDataWithoutChecksum;

@BeforeClass
public static void populateData() {
testData = new byte[TEST_DATA_SIZE + CHECKSUM_SIZE];
for (int i = 0; i < TEST_DATA_SIZE; i++) {
testData[i] = (byte)(i & 0x7f);
}

Md5Checksum checksum = new Md5Checksum();
checksum.update(testData, 0, TEST_DATA_SIZE);
byte[] checksumBytes = checksum.getChecksumBytes();

for (int i = 0; i < CHECKSUM_SIZE; i++) {
testData[TEST_DATA_SIZE + i] = checksumBytes[i];
}

testDataWithoutChecksum = Arrays.copyOfRange(testData, 0, TEST_DATA_SIZE);
}

@Test
public void validChecksumSucceeds() throws IOException {
InputStream validatingInputStream = newValidatingStream(testData);
byte[] dataFromValidatingStream = IoUtils.toByteArray(validatingInputStream);

assertArrayEquals(testDataWithoutChecksum, dataFromValidatingStream);
}

@Test
public void invalidChecksumFails() throws IOException {
for (int i = 0; i < testData.length; i++) {
// Make sure that corruption of any byte in the test data causes a checksum validation failure.
byte[] corruptedChecksumData = Arrays.copyOf(testData, testData.length);
corruptedChecksumData[i] = (byte) ~corruptedChecksumData[i];

InputStream validatingInputStream = newValidatingStream(corruptedChecksumData);

try {
IoUtils.toByteArray(validatingInputStream);
Assert.fail("Corruption at byte " + i + " was not detected.");
} catch (SdkClientException e) {
// Expected
}
}
}

private InputStream newValidatingStream(byte[] dataFromS3) {
return new ChecksumValidatingInputStream(new ByteArrayInputStream(dataFromS3),
new Md5Checksum(),
TEST_DATA_SIZE + CHECKSUM_SIZE);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
package software.amazon.awssdk.services.s3.checksums;

import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
Expand All @@ -31,6 +33,7 @@
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import software.amazon.awssdk.core.checksums.Md5Checksum;
import software.amazon.awssdk.utils.BinaryUtils;

/**
* Unit test for ChecksumValidatingPublisher
Expand All @@ -39,6 +42,7 @@ public class ChecksumValidatingPublisherTest {
private static int TEST_DATA_SIZE = 32; // size of the test data, in bytes
private static final int CHECKSUM_SIZE = 16;
private static byte[] testData;
private static byte[] testDataWithoutChecksum;

@BeforeClass
public static void populateData() {
Expand All @@ -52,34 +56,55 @@ public static void populateData() {
for (int i = 0; i < CHECKSUM_SIZE; i++) {
testData[TEST_DATA_SIZE + i] = checksumBytes[i];
}

testDataWithoutChecksum = Arrays.copyOfRange(testData, 0, TEST_DATA_SIZE);
}

@Test
public void testSinglePacket() {
final TestPublisher driver = new TestPublisher();
final TestSubscriber s = new TestSubscriber(Arrays.copyOfRange(testData, 0, TEST_DATA_SIZE));
final TestSubscriber s = new TestSubscriber();
final ChecksumValidatingPublisher p = new ChecksumValidatingPublisher(driver, new Md5Checksum(), TEST_DATA_SIZE + CHECKSUM_SIZE);
p.subscribe(s);

driver.doOnNext(ByteBuffer.wrap(testData));
driver.doOnComplete();

assertArrayEquals(testDataWithoutChecksum, s.receivedData());
assertTrue(s.hasCompleted());
assertFalse(s.isOnErrorCalled());
}

@Test
public void testLastChecksumByteCorrupted() {
TestPublisher driver = new TestPublisher();

TestSubscriber s = new TestSubscriber();
ChecksumValidatingPublisher p = new ChecksumValidatingPublisher(driver, new Md5Checksum(), TEST_DATA_SIZE + CHECKSUM_SIZE);
p.subscribe(s);

byte[] incorrectChecksumData = Arrays.copyOfRange(testData, 0, TEST_DATA_SIZE);
incorrectChecksumData[TEST_DATA_SIZE - 1] = (byte) ~incorrectChecksumData[TEST_DATA_SIZE - 1];
driver.doOnNext(ByteBuffer.wrap(incorrectChecksumData));
driver.doOnComplete();

assertFalse(s.hasCompleted());
assertTrue(s.isOnErrorCalled());
}

@Test
public void testTwoPackets() {
for (int i = 1; i < TEST_DATA_SIZE + CHECKSUM_SIZE - 1; i++) {
final TestPublisher driver = new TestPublisher();
final TestSubscriber s = new TestSubscriber(Arrays.copyOfRange(testData, 0, TEST_DATA_SIZE));
final TestSubscriber s = new TestSubscriber();
final ChecksumValidatingPublisher p = new ChecksumValidatingPublisher(driver, new Md5Checksum(), TEST_DATA_SIZE + CHECKSUM_SIZE);
p.subscribe(s);

driver.doOnNext(ByteBuffer.wrap(testData, 0, i));
driver.doOnNext(ByteBuffer.wrap(testData, i, TEST_DATA_SIZE + CHECKSUM_SIZE - i));
driver.doOnComplete();

assertArrayEquals(testDataWithoutChecksum, s.receivedData());
assertTrue(s.hasCompleted());
assertFalse(s.isOnErrorCalled());
}
Expand All @@ -89,7 +114,7 @@ public void testTwoPackets() {
public void testTinyPackets() {
for (int packetSize = 1; packetSize < CHECKSUM_SIZE; packetSize++) {
final TestPublisher driver = new TestPublisher();
final TestSubscriber s = new TestSubscriber(Arrays.copyOfRange(testData, 0, TEST_DATA_SIZE));
final TestSubscriber s = new TestSubscriber();
final ChecksumValidatingPublisher p = new ChecksumValidatingPublisher(driver, new Md5Checksum(), TEST_DATA_SIZE + CHECKSUM_SIZE);
p.subscribe(s);
int currOffset = 0;
Expand All @@ -100,6 +125,7 @@ public void testTinyPackets() {
}
driver.doOnComplete();

assertArrayEquals(testDataWithoutChecksum, s.receivedData());
assertTrue(s.hasCompleted());
assertFalse(s.isOnErrorCalled());
}
Expand All @@ -109,7 +135,7 @@ public void testTinyPackets() {
public void testUnknownLength() {
// When the length is unknown, the last 16 bytes are treated as a checksum, but are later ignored when completing
final TestPublisher driver = new TestPublisher();
final TestSubscriber s = new TestSubscriber(Arrays.copyOfRange(testData, 0, TEST_DATA_SIZE));
final TestSubscriber s = new TestSubscriber();
final ChecksumValidatingPublisher p = new ChecksumValidatingPublisher(driver, new Md5Checksum(), 0);
p.subscribe(s);

Expand All @@ -122,6 +148,7 @@ public void testUnknownLength() {
driver.doOnNext(ByteBuffer.wrap(randomChecksumData));
driver.doOnComplete();

assertArrayEquals(testDataWithoutChecksum, s.receivedData());
assertTrue(s.hasCompleted());
assertFalse(s.isOnErrorCalled());
}
Expand All @@ -130,7 +157,7 @@ public void testUnknownLength() {
public void checksumValidationFailure_throwsSdkClientException_NotNPE() {
final byte[] incorrectData = new byte[0];
final TestPublisher driver = new TestPublisher();
final TestSubscriber s = new TestSubscriber(Arrays.copyOfRange(incorrectData, 0, TEST_DATA_SIZE));
final TestSubscriber s = new TestSubscriber();
final ChecksumValidatingPublisher p = new ChecksumValidatingPublisher(driver, new Md5Checksum(), TEST_DATA_SIZE + CHECKSUM_SIZE);
p.subscribe(s);

Expand All @@ -142,13 +169,11 @@ public void checksumValidationFailure_throwsSdkClientException_NotNPE() {
}

private class TestSubscriber implements Subscriber<ByteBuffer> {
final byte[] expected;
final List<ByteBuffer> received;
boolean completed;
boolean onErrorCalled;

TestSubscriber(byte[] expected) {
this.expected = expected;
TestSubscriber() {
this.received = new ArrayList<>();
this.completed = false;
}
Expand All @@ -172,17 +197,21 @@ public void onError(Throwable t) {

@Override
public void onComplete() {
int matchPos = 0;
for (ByteBuffer buffer : received) {
byte[] bufferData = new byte[buffer.limit() - buffer.position()];
buffer.get(bufferData);
assertArrayEquals(Arrays.copyOfRange(expected, matchPos, matchPos + bufferData.length), bufferData);
matchPos += bufferData.length;
}
assertEquals(expected.length, matchPos);
completed = true;
}

public byte[] receivedData() {
try {
ByteArrayOutputStream os = new ByteArrayOutputStream();
for (ByteBuffer buffer : received) {
os.write(BinaryUtils.copyBytesFrom(buffer));
}
return os.toByteArray();
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}

public boolean hasCompleted() {
return completed;
}
Expand Down