1
+ // Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
2
+ // SPDX-License-Identifier: Apache-2.0
3
+ package software .amazon .encryption .s3 .internal ;
4
+
5
+ import org .junit .jupiter .api .Test ;
6
+ import org .reactivestreams .Subscriber ;
7
+ import org .reactivestreams .Subscription ;
8
+ import software .amazon .encryption .s3 .algorithms .AlgorithmSuite ;
9
+ import software .amazon .encryption .s3 .materials .DecryptionMaterials ;
10
+ import software .amazon .encryption .s3 .materials .EncryptionMaterials ;
11
+
12
+ import javax .crypto .KeyGenerator ;
13
+ import javax .crypto .SecretKey ;
14
+ import java .nio .ByteBuffer ;
15
+ import java .nio .charset .StandardCharsets ;
16
+ import java .security .NoSuchAlgorithmException ;
17
+ import java .util .ArrayList ;
18
+ import java .util .LinkedList ;
19
+ import java .util .List ;
20
+ import java .util .concurrent .atomic .AtomicBoolean ;
21
+ import java .util .concurrent .atomic .AtomicLong ;
22
+
23
+ import static org .junit .jupiter .api .Assertions .assertEquals ;
24
+ import static org .junit .jupiter .api .Assertions .assertTrue ;
25
+
26
+ class CipherSubscriberTest {
27
+ // Helper classes for testing
28
+ class SimpleSubscriber implements Subscriber <ByteBuffer > {
29
+
30
+ public static final long DEFAULT_REQUEST_SIZE = 1 ;
31
+
32
+ private final AtomicBoolean isSubscribed = new AtomicBoolean (false );
33
+ private final AtomicLong requestedItems = new AtomicLong (0 );
34
+ private final AtomicLong lengthOfData = new AtomicLong (0 );
35
+ private final LinkedList <ByteBuffer > buffersSeen = new LinkedList <>();
36
+ private Subscription subscription ;
37
+
38
+ @ Override
39
+ public void onSubscribe (Subscription s ) {
40
+ if (isSubscribed .compareAndSet (false , true )) {
41
+ this .subscription = s ;
42
+ requestMore (DEFAULT_REQUEST_SIZE );
43
+ } else {
44
+ s .cancel ();
45
+ }
46
+ }
47
+
48
+ @ Override
49
+ public void onNext (ByteBuffer item ) {
50
+ // Process the item here
51
+ lengthOfData .addAndGet (item .capacity ());
52
+ buffersSeen .add (item );
53
+
54
+ // Request the next item
55
+ requestMore (1 );
56
+ }
57
+
58
+ @ Override
59
+ public void onError (Throwable t ) {
60
+ System .err .println ("Error occurred: " + t .getMessage ());
61
+ }
62
+
63
+ @ Override
64
+ public void onComplete () {
65
+ // Do nothing.
66
+ }
67
+
68
+ public void cancel () {
69
+ if (isSubscribed .getAndSet (false )) {
70
+ subscription .cancel ();
71
+ }
72
+ }
73
+
74
+ private void requestMore (long n ) {
75
+ if (subscription != null ) {
76
+ requestedItems .addAndGet (n );
77
+ subscription .request (n );
78
+ }
79
+ }
80
+
81
+ public List <ByteBuffer > getBuffersSeen () {
82
+ return buffersSeen ;
83
+ }
84
+ }
85
+
86
+ class TestPublisher <T > {
87
+ private final List <Subscriber <T >> subscribers = new ArrayList <>(1 );
88
+
89
+ public void subscribe (Subscriber <T > subscriber ) {
90
+ subscribers .add (subscriber );
91
+ subscriber .onSubscribe (new TestSubscription ());
92
+ }
93
+
94
+ public void emit (T item ) {
95
+ subscribers .forEach (s -> s .onNext (item ));
96
+ }
97
+
98
+ public void complete () {
99
+ subscribers .forEach (Subscriber ::onComplete );
100
+ }
101
+
102
+ public boolean isSubscribed () {
103
+ return !subscribers .isEmpty ();
104
+ }
105
+
106
+ public int getSubscriberCount () {
107
+ return subscribers .size ();
108
+ }
109
+ }
110
+
111
+ class TestSubscription implements Subscription {
112
+ private long requestCount = 0 ;
113
+ private final AtomicBoolean canceled = new AtomicBoolean (false );
114
+
115
+ @ Override
116
+ public void request (long n ) {
117
+ if (!canceled .get ()) {
118
+ requestCount += n ;
119
+ } else {
120
+ // Maybe do something more useful/correct eventually,
121
+ // for now just throw an exception
122
+ throw new RuntimeException ("Subscription has been canceled!" );
123
+ }
124
+ }
125
+
126
+ @ Override
127
+ public void cancel () {
128
+ canceled .set (true );
129
+ }
130
+
131
+ public long getRequestCount () {
132
+ return requestCount ;
133
+ }
134
+ }
135
+
136
+ private EncryptionMaterials getTestEncryptMaterials (String plaintext ) {
137
+ try {
138
+ SecretKey AES_KEY ;
139
+ KeyGenerator keyGen = KeyGenerator .getInstance ("AES" );
140
+ keyGen .init (256 );
141
+ AES_KEY = keyGen .generateKey ();
142
+ return EncryptionMaterials .builder ()
143
+ .plaintextDataKey (AES_KEY .getEncoded ())
144
+ .algorithmSuite (AlgorithmSuite .ALG_AES_256_GCM_IV12_TAG16_NO_KDF )
145
+ .plaintextLength (plaintext .getBytes (StandardCharsets .UTF_8 ).length )
146
+ .build ();
147
+ } catch (NoSuchAlgorithmException exception ) {
148
+ // this should never happen
149
+ throw new RuntimeException ("AES doesn't exist" );
150
+ }
151
+ }
152
+
153
+ private DecryptionMaterials getTestDecryptionMaterialsFromEncMats (EncryptionMaterials encMats ) {
154
+ return DecryptionMaterials .builder ()
155
+ .plaintextDataKey (encMats .plaintextDataKey ())
156
+ .algorithmSuite (AlgorithmSuite .ALG_AES_256_GCM_IV12_TAG16_NO_KDF )
157
+ .ciphertextLength (encMats .getCiphertextLength ())
158
+ .build ();
159
+ }
160
+
161
+ private byte [] getByteArrayFromFixedLengthByteBuffers (List <ByteBuffer > byteBuffers , long expectedLength ) {
162
+ if (expectedLength > Integer .MAX_VALUE ) {
163
+ throw new RuntimeException ("Use a smaller expected length." );
164
+ }
165
+ return getByteArrayFromFixedLengthByteBuffers (byteBuffers , (int ) expectedLength );
166
+ }
167
+
168
+ private byte [] getByteArrayFromFixedLengthByteBuffers (List <ByteBuffer > byteBuffers , int expectedLength ) {
169
+ byte [] bytes = new byte [expectedLength ];
170
+ int offset = 0 ;
171
+ for (ByteBuffer bb : byteBuffers ) {
172
+ int remaining = bb .remaining ();
173
+ bb .get (bytes , offset , remaining );
174
+ offset += remaining ;
175
+ }
176
+ return bytes ;
177
+ }
178
+
179
+ @ Test
180
+ public void testSubscriberBehaviorOneChunk () {
181
+ AlgorithmSuite algorithmSuite = AlgorithmSuite .ALG_AES_256_GCM_IV12_TAG16_NO_KDF ;
182
+ String plaintext = "unit test of cipher subscriber" ;
183
+ EncryptionMaterials materials = getTestEncryptMaterials (plaintext );
184
+ byte [] iv = new byte [materials .algorithmSuite ().iVLengthBytes ()];
185
+ // we reject 0-ized IVs, so just do something
186
+ iv [0 ] = 1 ;
187
+ SimpleSubscriber wrappedSubscriber = new SimpleSubscriber ();
188
+ CipherSubscriber subscriber = new CipherSubscriber (wrappedSubscriber , materials .getCiphertextLength (), materials , iv );
189
+
190
+ // Act
191
+ TestPublisher <ByteBuffer > publisher = new TestPublisher <>();
192
+ publisher .subscribe (subscriber );
193
+
194
+ // Verify subscription behavior
195
+ assertTrue (publisher .isSubscribed ());
196
+ assertEquals (1 , publisher .getSubscriberCount ());
197
+
198
+ ByteBuffer ptBb = ByteBuffer .wrap (plaintext .getBytes (StandardCharsets .UTF_8 ));
199
+ publisher .emit (ptBb );
200
+
201
+ // Complete the stream
202
+ publisher .complete ();
203
+
204
+ long expectedLength = plaintext .getBytes (StandardCharsets .UTF_8 ).length + algorithmSuite .cipherTagLengthBytes ();
205
+ assertEquals (expectedLength , wrappedSubscriber .lengthOfData .get ());
206
+ byte [] ctBytes = getByteArrayFromFixedLengthByteBuffers (wrappedSubscriber .getBuffersSeen (), expectedLength );
207
+
208
+ // Now decrypt.
209
+ DecryptionMaterials decryptionMaterials = getTestDecryptionMaterialsFromEncMats (materials );
210
+ SimpleSubscriber wrappedDecryptSubscriber = new SimpleSubscriber ();
211
+ CipherSubscriber decryptSubscriber = new CipherSubscriber (wrappedDecryptSubscriber , expectedLength , decryptionMaterials , iv );
212
+ TestPublisher <ByteBuffer > decryptPublisher = new TestPublisher <>();
213
+ decryptPublisher .subscribe (decryptSubscriber );
214
+
215
+ // Verify subscription behavior
216
+ assertTrue (decryptPublisher .isSubscribed ());
217
+ assertEquals (1 , decryptPublisher .getSubscriberCount ());
218
+
219
+ // Simulate publishing items
220
+ ByteBuffer ctBb = ByteBuffer .wrap (ctBytes );
221
+ decryptPublisher .emit (ctBb );
222
+
223
+ // Complete the stream
224
+ decryptPublisher .complete ();
225
+
226
+ long expectedLengthPt = plaintext .getBytes (StandardCharsets .UTF_8 ).length ;
227
+ assertEquals (expectedLengthPt , wrappedDecryptSubscriber .lengthOfData .get ());
228
+ byte [] ptBytes = getByteArrayFromFixedLengthByteBuffers (wrappedDecryptSubscriber .getBuffersSeen (), expectedLengthPt );
229
+ // Assert round trip encrypt/decrypt succeeds.
230
+ assertEquals (plaintext , new String (ptBytes , StandardCharsets .UTF_8 ));
231
+ }
232
+
233
+ @ Test
234
+ public void testSubscriberBehaviorTagLengthLastChunk () {
235
+ AlgorithmSuite algorithmSuite = AlgorithmSuite .ALG_AES_256_GCM_IV12_TAG16_NO_KDF ;
236
+ String plaintext = "unit test of cipher subscriber tag length last chunk" ;
237
+ EncryptionMaterials materials = getTestEncryptMaterials (plaintext );
238
+ byte [] iv = new byte [materials .algorithmSuite ().iVLengthBytes ()];
239
+ // we reject 0-ized IVs, so just do something non-zero
240
+ iv [0 ] = 1 ;
241
+ SimpleSubscriber wrappedSubscriber = new SimpleSubscriber ();
242
+ CipherSubscriber subscriber = new CipherSubscriber (wrappedSubscriber , materials .getCiphertextLength (), materials , iv );
243
+
244
+ // Setup Publisher
245
+ TestPublisher <ByteBuffer > publisher = new TestPublisher <>();
246
+ publisher .subscribe (subscriber );
247
+
248
+ // Verify subscription behavior
249
+ assertTrue (publisher .isSubscribed ());
250
+ assertEquals (1 , publisher .getSubscriberCount ());
251
+
252
+ // Send data to be encrypted
253
+ ByteBuffer ptBb = ByteBuffer .wrap (plaintext .getBytes (StandardCharsets .UTF_8 ));
254
+ publisher .emit (ptBb );
255
+ publisher .complete ();
256
+
257
+ // Convert to byte array for convenience
258
+ long expectedLength = plaintext .getBytes (StandardCharsets .UTF_8 ).length + algorithmSuite .cipherTagLengthBytes ();
259
+ assertEquals (expectedLength , wrappedSubscriber .lengthOfData .get ());
260
+ byte [] ctBytes = getByteArrayFromFixedLengthByteBuffers (wrappedSubscriber .getBuffersSeen (), expectedLength );
261
+
262
+ // Now decrypt the ciphertext
263
+ DecryptionMaterials decryptionMaterials = getTestDecryptionMaterialsFromEncMats (materials );
264
+ SimpleSubscriber wrappedDecryptSubscriber = new SimpleSubscriber ();
265
+ CipherSubscriber decryptSubscriber = new CipherSubscriber (wrappedDecryptSubscriber , expectedLength , decryptionMaterials , iv );
266
+ TestPublisher <ByteBuffer > decryptPublisher = new TestPublisher <>();
267
+ decryptPublisher .subscribe (decryptSubscriber );
268
+
269
+ // Verify subscription behavior
270
+ assertTrue (decryptPublisher .isSubscribed ());
271
+ assertEquals (1 , decryptPublisher .getSubscriberCount ());
272
+
273
+ int taglength = algorithmSuite .cipherTagLengthBytes ();
274
+ int ciphertextWithoutTagLength = ctBytes .length - taglength ;
275
+
276
+ // Create the main ByteBuffer (all except last 16 bytes)
277
+ ByteBuffer mainBuffer = ByteBuffer .allocate (ciphertextWithoutTagLength );
278
+ mainBuffer .put (ctBytes , 0 , ciphertextWithoutTagLength );
279
+ mainBuffer .flip ();
280
+
281
+ // Create the tag ByteBuffer (last 16 bytes)
282
+ ByteBuffer tagBuffer = ByteBuffer .allocate (taglength );
283
+ tagBuffer .put (ctBytes , ciphertextWithoutTagLength , taglength );
284
+ tagBuffer .flip ();
285
+
286
+ // Send the ciphertext, then the tag separately
287
+ decryptPublisher .emit (mainBuffer );
288
+ decryptPublisher .emit (tagBuffer );
289
+ decryptPublisher .complete ();
290
+
291
+ long expectedLengthPt = plaintext .getBytes (StandardCharsets .UTF_8 ).length ;
292
+ assertEquals (expectedLengthPt , wrappedDecryptSubscriber .lengthOfData .get ());
293
+ byte [] ptBytes = getByteArrayFromFixedLengthByteBuffers (wrappedDecryptSubscriber .getBuffersSeen (), expectedLengthPt );
294
+ // Assert round trip encrypt/decrypt succeeds
295
+ assertEquals (plaintext , new String (ptBytes , StandardCharsets .UTF_8 ));
296
+ }
297
+ }
0 commit comments