-
Notifications
You must be signed in to change notification settings - Fork 122
/
Copy pathUtils.java
309 lines (265 loc) · 11.3 KB
/
Utils.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
/*
* Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except
* in compliance with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*/
package com.amazonaws.encryptionsdk.internal;
import java.io.Serializable;
import java.math.BigInteger;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.security.SecureRandom;
import java.util.Arrays;
import java.util.Comparator;
import java.util.WeakHashMap;
import java.util.concurrent.atomic.AtomicLong;
import org.bouncycastle.util.encoders.Base64;
/**
* Internal utility methods.
*/
public final class Utils {
// SecureRandom objects can both be expensive to initialize and incur synchronization costs.
// This allows us to minimize both initializations and keep SecureRandom usage thread local
// to avoid lock contention.
private static final ThreadLocal<SecureRandom> LOCAL_RANDOM = new ThreadLocal<SecureRandom>() {
@Override
protected SecureRandom initialValue() {
final SecureRandom rnd = new SecureRandom();
rnd.nextBoolean(); // Force seeding
return rnd;
}
};
private Utils() {
// Prevent instantiation
}
/*
* In some areas we need to be able to assign a total order over Java objects - generally with some primary sort,
* but we need a fallback sort that always works in order to ensure that we don't falsely claim objects A and B
* are equal just because the primary sort declares them to have equal rank.
*
* To do this, we'll define a fallback sort that assigns an arbitrary order to all objects. This order is
* implemented by first comparing hashcode, and in the rare case where we are asked to compare two objects with
* equal hashcode, we explicitly assign an index to them - using a WeakHashMap to track this index - and sort
* based on this index.
*/
private static AtomicLong FALLBACK_COUNTER = new AtomicLong(0);
private static WeakHashMap<Object, Long> FALLBACK_COMPARATOR_MAP = new WeakHashMap<>();
private static synchronized long getFallbackObjectId(Object object) {
return FALLBACK_COMPARATOR_MAP.computeIfAbsent(object, ignored -> FALLBACK_COUNTER.incrementAndGet());
}
/**
* Provides an <i>arbitrary</i> but consistent total ordering over all objects. This comparison function will
* return 0 if and only if a == b, and otherwise will return arbitrarily either -1 or 1, but will do so in a way
* that results in a consistent total order.
*
* @param a
* @param b
* @return -1 or 1 (consistently) if a != b; 0 if a == b.
*/
public static int compareObjectIdentity(Object a, Object b) {
if (a == b) {
return 0;
}
if (a == null) {
return -1;
}
if (b == null) {
return 1;
}
int hashCompare = Integer.compare(System.identityHashCode(a), System.identityHashCode(b));
if (hashCompare != 0) {
return hashCompare;
}
// Unfortunately these objects have identical hashcodes, so we need to find some other way to compare them.
// We'll do this by mapping them to an incrementing counter, and comparing their assigned IDs instead.
int fallbackCompare = Long.compare(getFallbackObjectId(a), getFallbackObjectId(b));
if (fallbackCompare == 0) {
throw new AssertionError("Failed to assign unique order to objects");
}
return fallbackCompare;
}
public static long saturatingAdd(long a, long b) {
long r = a + b;
if (a > 0 && b > 0 && r < a) {
return Long.MAX_VALUE;
}
if (a < 0 && b < 0 && r > a) {
return Long.MIN_VALUE;
}
// If the signs between a and b differ, overflow is impossible.
return r;
}
/**
* Comparator that performs a lexicographical comparison of byte arrays, treating them as unsigned.
*/
public static class ComparingByteArrays implements Comparator<byte[]>, Serializable {
// We don't really need to be serializable, but it doesn't hurt, and FindBugs gets annoyed if we're not.
private static final long serialVersionUID = 0xdf641037ffe509e2L;
@Override public int compare(byte[] o1, byte[] o2) {
return new ComparingByteBuffers().compare(ByteBuffer.wrap(o1), ByteBuffer.wrap(o2));
}
}
public static class ComparingByteBuffers implements Comparator<ByteBuffer>, Serializable {
private static final long serialVersionUID = 0xa3c4a7300fbbf043L;
@Override public int compare(ByteBuffer o1, ByteBuffer o2) {
o1 = o1.slice();
o2 = o2.slice();
int commonLength = Math.min(o1.remaining(), o2.remaining());
for (int i = 0; i < commonLength; i++) {
// Perform zero-extension as we want to treat the bytes as unsigned
int v1 = o1.get(i) & 0xFF;
int v2 = o2.get(i) & 0xFF;
if (v1 != v2) {
return v1 - v2;
}
}
// The longer buffer is bigger (0x00 comes after end-of-buffer)
return o1.remaining() - o2.remaining();
}
}
/**
* Throws {@link NullPointerException} with message {@code paramName} if {@code object} is null.
*
* @param object
* value to be null-checked
* @param paramName
* message for the potential {@link NullPointerException}
* @return {@code object}
* @throws NullPointerException
* if {@code object} is null
*/
public static <T> T assertNonNull(final T object, final String paramName) throws NullPointerException {
if (object == null) {
throw new NullPointerException(paramName + " must not be null");
}
return object;
}
/**
* Returns a possibly truncated version of {@code arr} which is guaranteed to be exactly
* {@code len} elements long. If {@code arr} is already exactly {@code len} elements long, then
* {@code arr} is returned without copy or modification. If {@code arr} is longer than
* {@code len}, then a truncated copy is returned. If {@code arr} is shorter than {@code len}
* then this throws an {@link IllegalArgumentException}.
*/
public static byte[] truncate(final byte[] arr, final int len) throws IllegalArgumentException {
if (arr.length == len) {
return arr;
} else if (arr.length > len) {
return Arrays.copyOf(arr, len);
} else {
throw new IllegalArgumentException("arr is not at least " + len + " elements long");
}
}
public static SecureRandom getSecureRandom() {
return LOCAL_RANDOM.get();
}
/**
* Generate the AAD bytes to use when encrypting/decrypting content. The
* generated AAD is a block of bytes containing the provided message
* identifier, the string identifier, the sequence number, and the length of
* the content.
*
* @param messageId
* the unique message identifier for the ciphertext.
* @param idString
* the string describing the type of content processed.
* @param seqNum
* the sequence number.
* @param len
* the length of the content.
* @return
* the bytes containing the generated AAD.
*/
static byte[] generateContentAad(final byte[] messageId, final String idString, final int seqNum, final long len) {
final byte[] idBytes = idString.getBytes(StandardCharsets.UTF_8);
final int aadLen = messageId.length + idBytes.length + Integer.SIZE / Byte.SIZE + Long.SIZE / Byte.SIZE;
final ByteBuffer aad = ByteBuffer.allocate(aadLen);
aad.put(messageId);
aad.put(idBytes);
aad.putInt(seqNum);
aad.putLong(len);
return aad.array();
}
static IllegalArgumentException cannotBeNegative(String field) {
return new IllegalArgumentException(field + " cannot be negative");
}
/**
* Equivalent to calling {@link ByteBuffer#flip()} but in a manner which is
* safe when compiled on Java 9 or newer but used on Java 8 or older.
*/
public static ByteBuffer flip(final ByteBuffer buff) {
((Buffer) buff).flip();
return buff;
}
/**
* Equivalent to calling {@link ByteBuffer#clear()} but in a manner which is
* safe when compiled on Java 9 or newer but used on Java 8 or older.
*/
public static ByteBuffer clear(final ByteBuffer buff) {
((Buffer) buff).clear();
return buff;
}
/**
* Equivalent to calling {@link ByteBuffer#position(int)} but in a manner which is
* safe when compiled on Java 9 or newer but used on Java 8 or older.
*/
public static ByteBuffer position(final ByteBuffer buff, final int newPosition) {
((Buffer) buff).position(newPosition);
return buff;
}
/**
* Equivalent to calling {@link ByteBuffer#limit(int)} but in a manner which is
* safe when compiled on Java 9 or newer but used on Java 8 or older.
*/
public static ByteBuffer limit(final ByteBuffer buff, final int newLimit) {
((Buffer) buff).limit(newLimit);
return buff;
}
/**
* Takes a Base64-encoded String, decodes it, and returns contents as a byte array.
*
* @param encoded Base64 encoded String
* @return decoded data as a byte array
*/
public static byte[] decodeBase64String(final String encoded) {
return Base64.decode(encoded);
}
/**
* Takes data in a byte array, encodes them in Base64, and returns the result as a String.
*
* @param data The data to encode.
* @return Base64 string that encodes the {@code data}.
*/
public static String encodeBase64String(final byte[] data) {
return Base64.toBase64String(data);
}
/**
* Removes the leading zero sign byte from the byte array representation of a BigInteger (if present)
* and left pads with zeroes to produce a byte array of the given length.
* @param bigInteger The BigInteger to convert to a byte array
* @param length The length of the byte array
* @return The byte array
*/
public static byte[] bigIntegerToByteArray(final BigInteger bigInteger, final int length) {
final byte[] result = new byte[length];
byte[] rawBytes = bigInteger.toByteArray();
//Remove sign byte if one is present
if(rawBytes[0] == 0) {
rawBytes = Arrays.copyOfRange(rawBytes, 1, rawBytes.length);
}
if(length < rawBytes.length) {
throw new IllegalArgumentException("Length must be as least as long as the BigInteger byte array " +
"without the sign byte");
}
System.arraycopy(rawBytes, 0, result, length - rawBytes.length, rawBytes.length);
return result;
}
}