Skip to content

Commit 969d721

Browse files
authored
fix: determine effective contentLength, account for tagLength on decrypt (#463)
* only add tagLength on encrypt * add a simple, passing unit test, cleanup tests, add repro case * tweak onComplete logic, fix bug in unit tests
1 parent dec503b commit 969d721

File tree

2 files changed

+307
-5
lines changed

2 files changed

+307
-5
lines changed

src/main/java/software/amazon/encryption/s3/internal/CipherSubscriber.java

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ public class CipherSubscriber implements Subscriber<ByteBuffer> {
2121
private final Long contentLength;
2222
private final boolean isLastPart;
2323
private final int tagLength;
24+
private final boolean isEncrypt;
2425
private final AtomicBoolean finalBytesCalled = new AtomicBoolean(false);
2526

2627
private byte[] outputBuffer;
@@ -31,6 +32,7 @@ public class CipherSubscriber implements Subscriber<ByteBuffer> {
3132
this.cipher = materials.getCipher(iv);
3233
this.isLastPart = isLastPart;
3334
this.tagLength = materials.algorithmSuite().cipherTagLengthBytes();
35+
this.isEncrypt = (CipherMode.DECRYPT != materials.cipherMode());
3436
}
3537

3638
CipherSubscriber(Subscriber<? super ByteBuffer> wrappedSubscriber, Long contentLength, CryptographicMaterials materials, byte[] iv) {
@@ -56,7 +58,9 @@ public void onNext(ByteBuffer byteBuffer) {
5658
// Note that while the JCE Javadoc specifies that the outputBuffer is null in this case,
5759
// in practice SunJCE and ACCP return an empty buffer instead, hence checks for
5860
// null OR length == 0.
59-
if (contentRead.get() + tagLength >= contentLength) {
61+
62+
// tagLength should only be added on Encrypt
63+
if (contentRead.get() + (isEncrypt ? tagLength : 0) >= contentLength) {
6064
// All content has been read, so complete to get the final bytes
6165
finalBytes();
6266
return;
@@ -84,7 +88,7 @@ public void onNext(ByteBuffer byteBuffer) {
8488
Calling `wrappedSubscriber.onNext` more than once for `request(1)`
8589
violates the Reactive Streams specification and can cause exceptions downstream.
8690
*/
87-
if (contentRead.get() + tagLength >= contentLength) {
91+
if (contentRead.get() + (isEncrypt ? tagLength : 0) >= contentLength) {
8892
// All content has been read; complete the stream.
8993
finalBytes();
9094
} else {
@@ -125,9 +129,10 @@ public void onError(Throwable t) {
125129
public void onComplete() {
126130
// In rare cases, e.g. when the last part of a low-level MPU has 0 length,
127131
// onComplete will be called before onNext is called once.
128-
if (contentRead.get() + tagLength <= contentLength) {
129-
finalBytes();
130-
}
132+
// So, call finalBytes here just in case there's any unsent data left.
133+
// Most likely, finalBytes has already been called by the last onNext,
134+
// but finalBytes guards against multiple invocations so it's safe to call again.
135+
finalBytes();
131136
wrappedSubscriber.onComplete();
132137
}
133138

Lines changed: 297 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,297 @@
1+
// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
package software.amazon.encryption.s3.internal;
4+
5+
import org.junit.jupiter.api.Test;
6+
import org.reactivestreams.Subscriber;
7+
import org.reactivestreams.Subscription;
8+
import software.amazon.encryption.s3.algorithms.AlgorithmSuite;
9+
import software.amazon.encryption.s3.materials.DecryptionMaterials;
10+
import software.amazon.encryption.s3.materials.EncryptionMaterials;
11+
12+
import javax.crypto.KeyGenerator;
13+
import javax.crypto.SecretKey;
14+
import java.nio.ByteBuffer;
15+
import java.nio.charset.StandardCharsets;
16+
import java.security.NoSuchAlgorithmException;
17+
import java.util.ArrayList;
18+
import java.util.LinkedList;
19+
import java.util.List;
20+
import java.util.concurrent.atomic.AtomicBoolean;
21+
import java.util.concurrent.atomic.AtomicLong;
22+
23+
import static org.junit.jupiter.api.Assertions.assertEquals;
24+
import static org.junit.jupiter.api.Assertions.assertTrue;
25+
26+
class CipherSubscriberTest {
27+
// Helper classes for testing
28+
class SimpleSubscriber implements Subscriber<ByteBuffer> {
29+
30+
public static final long DEFAULT_REQUEST_SIZE = 1;
31+
32+
private final AtomicBoolean isSubscribed = new AtomicBoolean(false);
33+
private final AtomicLong requestedItems = new AtomicLong(0);
34+
private final AtomicLong lengthOfData = new AtomicLong(0);
35+
private final LinkedList<ByteBuffer> buffersSeen = new LinkedList<>();
36+
private Subscription subscription;
37+
38+
@Override
39+
public void onSubscribe(Subscription s) {
40+
if (isSubscribed.compareAndSet(false, true)) {
41+
this.subscription = s;
42+
requestMore(DEFAULT_REQUEST_SIZE);
43+
} else {
44+
s.cancel();
45+
}
46+
}
47+
48+
@Override
49+
public void onNext(ByteBuffer item) {
50+
// Process the item here
51+
lengthOfData.addAndGet(item.capacity());
52+
buffersSeen.add(item);
53+
54+
// Request the next item
55+
requestMore(1);
56+
}
57+
58+
@Override
59+
public void onError(Throwable t) {
60+
System.err.println("Error occurred: " + t.getMessage());
61+
}
62+
63+
@Override
64+
public void onComplete() {
65+
// Do nothing.
66+
}
67+
68+
public void cancel() {
69+
if (isSubscribed.getAndSet(false)) {
70+
subscription.cancel();
71+
}
72+
}
73+
74+
private void requestMore(long n) {
75+
if (subscription != null) {
76+
requestedItems.addAndGet(n);
77+
subscription.request(n);
78+
}
79+
}
80+
81+
public List<ByteBuffer> getBuffersSeen() {
82+
return buffersSeen;
83+
}
84+
}
85+
86+
class TestPublisher<T> {
87+
private final List<Subscriber<T>> subscribers = new ArrayList<>(1);
88+
89+
public void subscribe(Subscriber<T> subscriber) {
90+
subscribers.add(subscriber);
91+
subscriber.onSubscribe(new TestSubscription());
92+
}
93+
94+
public void emit(T item) {
95+
subscribers.forEach(s -> s.onNext(item));
96+
}
97+
98+
public void complete() {
99+
subscribers.forEach(Subscriber::onComplete);
100+
}
101+
102+
public boolean isSubscribed() {
103+
return !subscribers.isEmpty();
104+
}
105+
106+
public int getSubscriberCount() {
107+
return subscribers.size();
108+
}
109+
}
110+
111+
class TestSubscription implements Subscription {
112+
private long requestCount = 0;
113+
private final AtomicBoolean canceled = new AtomicBoolean(false);
114+
115+
@Override
116+
public void request(long n) {
117+
if (!canceled.get()) {
118+
requestCount += n;
119+
} else {
120+
// Maybe do something more useful/correct eventually,
121+
// for now just throw an exception
122+
throw new RuntimeException("Subscription has been canceled!");
123+
}
124+
}
125+
126+
@Override
127+
public void cancel() {
128+
canceled.set(true);
129+
}
130+
131+
public long getRequestCount() {
132+
return requestCount;
133+
}
134+
}
135+
136+
private EncryptionMaterials getTestEncryptMaterials(String plaintext) {
137+
try {
138+
SecretKey AES_KEY;
139+
KeyGenerator keyGen = KeyGenerator.getInstance("AES");
140+
keyGen.init(256);
141+
AES_KEY = keyGen.generateKey();
142+
return EncryptionMaterials.builder()
143+
.plaintextDataKey(AES_KEY.getEncoded())
144+
.algorithmSuite(AlgorithmSuite.ALG_AES_256_GCM_IV12_TAG16_NO_KDF)
145+
.plaintextLength(plaintext.getBytes(StandardCharsets.UTF_8).length)
146+
.build();
147+
} catch (NoSuchAlgorithmException exception) {
148+
// this should never happen
149+
throw new RuntimeException("AES doesn't exist");
150+
}
151+
}
152+
153+
private DecryptionMaterials getTestDecryptionMaterialsFromEncMats(EncryptionMaterials encMats) {
154+
return DecryptionMaterials.builder()
155+
.plaintextDataKey(encMats.plaintextDataKey())
156+
.algorithmSuite(AlgorithmSuite.ALG_AES_256_GCM_IV12_TAG16_NO_KDF)
157+
.ciphertextLength(encMats.getCiphertextLength())
158+
.build();
159+
}
160+
161+
private byte[] getByteArrayFromFixedLengthByteBuffers(List<ByteBuffer> byteBuffers, long expectedLength) {
162+
if (expectedLength > Integer.MAX_VALUE) {
163+
throw new RuntimeException("Use a smaller expected length.");
164+
}
165+
return getByteArrayFromFixedLengthByteBuffers(byteBuffers, (int) expectedLength);
166+
}
167+
168+
private byte[] getByteArrayFromFixedLengthByteBuffers(List<ByteBuffer> byteBuffers, int expectedLength) {
169+
byte[] bytes = new byte[expectedLength];
170+
int offset = 0;
171+
for (ByteBuffer bb : byteBuffers) {
172+
int remaining = bb.remaining();
173+
bb.get(bytes, offset, remaining);
174+
offset += remaining;
175+
}
176+
return bytes;
177+
}
178+
179+
@Test
180+
public void testSubscriberBehaviorOneChunk() {
181+
AlgorithmSuite algorithmSuite = AlgorithmSuite.ALG_AES_256_GCM_IV12_TAG16_NO_KDF;
182+
String plaintext = "unit test of cipher subscriber";
183+
EncryptionMaterials materials = getTestEncryptMaterials(plaintext);
184+
byte[] iv = new byte[materials.algorithmSuite().iVLengthBytes()];
185+
// we reject 0-ized IVs, so just do something
186+
iv[0] = 1;
187+
SimpleSubscriber wrappedSubscriber = new SimpleSubscriber();
188+
CipherSubscriber subscriber = new CipherSubscriber(wrappedSubscriber, materials.getCiphertextLength(), materials, iv);
189+
190+
// Act
191+
TestPublisher<ByteBuffer> publisher = new TestPublisher<>();
192+
publisher.subscribe(subscriber);
193+
194+
// Verify subscription behavior
195+
assertTrue(publisher.isSubscribed());
196+
assertEquals(1, publisher.getSubscriberCount());
197+
198+
ByteBuffer ptBb = ByteBuffer.wrap(plaintext.getBytes(StandardCharsets.UTF_8));
199+
publisher.emit(ptBb);
200+
201+
// Complete the stream
202+
publisher.complete();
203+
204+
long expectedLength = plaintext.getBytes(StandardCharsets.UTF_8).length + algorithmSuite.cipherTagLengthBytes();
205+
assertEquals(expectedLength, wrappedSubscriber.lengthOfData.get());
206+
byte[] ctBytes = getByteArrayFromFixedLengthByteBuffers(wrappedSubscriber.getBuffersSeen(), expectedLength);
207+
208+
// Now decrypt.
209+
DecryptionMaterials decryptionMaterials = getTestDecryptionMaterialsFromEncMats(materials);
210+
SimpleSubscriber wrappedDecryptSubscriber = new SimpleSubscriber();
211+
CipherSubscriber decryptSubscriber = new CipherSubscriber(wrappedDecryptSubscriber, expectedLength, decryptionMaterials, iv);
212+
TestPublisher<ByteBuffer> decryptPublisher = new TestPublisher<>();
213+
decryptPublisher.subscribe(decryptSubscriber);
214+
215+
// Verify subscription behavior
216+
assertTrue(decryptPublisher.isSubscribed());
217+
assertEquals(1, decryptPublisher.getSubscriberCount());
218+
219+
// Simulate publishing items
220+
ByteBuffer ctBb = ByteBuffer.wrap(ctBytes);
221+
decryptPublisher.emit(ctBb);
222+
223+
// Complete the stream
224+
decryptPublisher.complete();
225+
226+
long expectedLengthPt = plaintext.getBytes(StandardCharsets.UTF_8).length;
227+
assertEquals(expectedLengthPt, wrappedDecryptSubscriber.lengthOfData.get());
228+
byte[] ptBytes = getByteArrayFromFixedLengthByteBuffers(wrappedDecryptSubscriber.getBuffersSeen(), expectedLengthPt);
229+
// Assert round trip encrypt/decrypt succeeds.
230+
assertEquals(plaintext, new String(ptBytes, StandardCharsets.UTF_8));
231+
}
232+
233+
@Test
234+
public void testSubscriberBehaviorTagLengthLastChunk() {
235+
AlgorithmSuite algorithmSuite = AlgorithmSuite.ALG_AES_256_GCM_IV12_TAG16_NO_KDF;
236+
String plaintext = "unit test of cipher subscriber tag length last chunk";
237+
EncryptionMaterials materials = getTestEncryptMaterials(plaintext);
238+
byte[] iv = new byte[materials.algorithmSuite().iVLengthBytes()];
239+
// we reject 0-ized IVs, so just do something non-zero
240+
iv[0] = 1;
241+
SimpleSubscriber wrappedSubscriber = new SimpleSubscriber();
242+
CipherSubscriber subscriber = new CipherSubscriber(wrappedSubscriber, materials.getCiphertextLength(), materials, iv);
243+
244+
// Setup Publisher
245+
TestPublisher<ByteBuffer> publisher = new TestPublisher<>();
246+
publisher.subscribe(subscriber);
247+
248+
// Verify subscription behavior
249+
assertTrue(publisher.isSubscribed());
250+
assertEquals(1, publisher.getSubscriberCount());
251+
252+
// Send data to be encrypted
253+
ByteBuffer ptBb = ByteBuffer.wrap(plaintext.getBytes(StandardCharsets.UTF_8));
254+
publisher.emit(ptBb);
255+
publisher.complete();
256+
257+
// Convert to byte array for convenience
258+
long expectedLength = plaintext.getBytes(StandardCharsets.UTF_8).length + algorithmSuite.cipherTagLengthBytes();
259+
assertEquals(expectedLength, wrappedSubscriber.lengthOfData.get());
260+
byte[] ctBytes = getByteArrayFromFixedLengthByteBuffers(wrappedSubscriber.getBuffersSeen(), expectedLength);
261+
262+
// Now decrypt the ciphertext
263+
DecryptionMaterials decryptionMaterials = getTestDecryptionMaterialsFromEncMats(materials);
264+
SimpleSubscriber wrappedDecryptSubscriber = new SimpleSubscriber();
265+
CipherSubscriber decryptSubscriber = new CipherSubscriber(wrappedDecryptSubscriber, expectedLength, decryptionMaterials, iv);
266+
TestPublisher<ByteBuffer> decryptPublisher = new TestPublisher<>();
267+
decryptPublisher.subscribe(decryptSubscriber);
268+
269+
// Verify subscription behavior
270+
assertTrue(decryptPublisher.isSubscribed());
271+
assertEquals(1, decryptPublisher.getSubscriberCount());
272+
273+
int taglength = algorithmSuite.cipherTagLengthBytes();
274+
int ciphertextWithoutTagLength = ctBytes.length - taglength;
275+
276+
// Create the main ByteBuffer (all except last 16 bytes)
277+
ByteBuffer mainBuffer = ByteBuffer.allocate(ciphertextWithoutTagLength);
278+
mainBuffer.put(ctBytes, 0, ciphertextWithoutTagLength);
279+
mainBuffer.flip();
280+
281+
// Create the tag ByteBuffer (last 16 bytes)
282+
ByteBuffer tagBuffer = ByteBuffer.allocate(taglength);
283+
tagBuffer.put(ctBytes, ciphertextWithoutTagLength, taglength);
284+
tagBuffer.flip();
285+
286+
// Send the ciphertext, then the tag separately
287+
decryptPublisher.emit(mainBuffer);
288+
decryptPublisher.emit(tagBuffer);
289+
decryptPublisher.complete();
290+
291+
long expectedLengthPt = plaintext.getBytes(StandardCharsets.UTF_8).length;
292+
assertEquals(expectedLengthPt, wrappedDecryptSubscriber.lengthOfData.get());
293+
byte[] ptBytes = getByteArrayFromFixedLengthByteBuffers(wrappedDecryptSubscriber.getBuffersSeen(), expectedLengthPt);
294+
// Assert round trip encrypt/decrypt succeeds
295+
assertEquals(plaintext, new String(ptBytes, StandardCharsets.UTF_8));
296+
}
297+
}

0 commit comments

Comments
 (0)