-
Notifications
You must be signed in to change notification settings - Fork 122
/
Copy pathUtils.java
326 lines (283 loc) · 11.2 KB
/
Utils.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package com.amazonaws.encryptionsdk.internal;
import java.io.Serializable;
import java.math.BigInteger;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.security.SecureRandom;
import java.util.Arrays;
import java.util.Comparator;
import java.util.WeakHashMap;
import java.util.concurrent.atomic.AtomicLong;
import org.apache.commons.lang3.ArrayUtils;
import org.bouncycastle.util.encoders.Base64;
/** Internal utility methods. */
public final class Utils {
// SecureRandom objects can both be expensive to initialize and incur synchronization costs.
// This allows us to minimize both initializations and keep SecureRandom usage thread local
// to avoid lock contention.
private static final ThreadLocal<SecureRandom> LOCAL_RANDOM =
new ThreadLocal<SecureRandom>() {
@Override
protected SecureRandom initialValue() {
final SecureRandom rnd = new SecureRandom();
rnd.nextBoolean(); // Force seeding
return rnd;
}
};
private Utils() {
// Prevent instantiation
}
/*
* In some areas we need to be able to assign a total order over Java objects - generally with some primary sort,
* but we need a fallback sort that always works in order to ensure that we don't falsely claim objects A and B
* are equal just because the primary sort declares them to have equal rank.
*
* To do this, we'll define a fallback sort that assigns an arbitrary order to all objects. This order is
* implemented by first comparing hashcode, and in the rare case where we are asked to compare two objects with
* equal hashcode, we explicitly assign an index to them - using a WeakHashMap to track this index - and sort
* based on this index.
*/
private static AtomicLong FALLBACK_COUNTER = new AtomicLong(0);
private static WeakHashMap<Object, Long> FALLBACK_COMPARATOR_MAP = new WeakHashMap<>();
private static synchronized long getFallbackObjectId(Object object) {
return FALLBACK_COMPARATOR_MAP.computeIfAbsent(
object, ignored -> FALLBACK_COUNTER.incrementAndGet());
}
/**
* Provides an <i>arbitrary</i> but consistent total ordering over all objects. This comparison
* function will return 0 if and only if a == b, and otherwise will return arbitrarily either -1
* or 1, but will do so in a way that results in a consistent total order.
*
* @param a
* @param b
* @return -1 or 1 (consistently) if a != b; 0 if a == b.
*/
public static int compareObjectIdentity(Object a, Object b) {
if (a == b) {
return 0;
}
if (a == null) {
return -1;
}
if (b == null) {
return 1;
}
int hashCompare = Integer.compare(System.identityHashCode(a), System.identityHashCode(b));
if (hashCompare != 0) {
return hashCompare;
}
// Unfortunately these objects have identical hashcodes, so we need to find some other way to
// compare them.
// We'll do this by mapping them to an incrementing counter, and comparing their assigned IDs
// instead.
int fallbackCompare = Long.compare(getFallbackObjectId(a), getFallbackObjectId(b));
if (fallbackCompare == 0) {
throw new AssertionError("Failed to assign unique order to objects");
}
return fallbackCompare;
}
public static long saturatingAdd(long a, long b) {
long r = a + b;
if (a > 0 && b > 0 && r < a) {
return Long.MAX_VALUE;
}
if (a < 0 && b < 0 && r > a) {
return Long.MIN_VALUE;
}
// If the signs between a and b differ, overflow is impossible.
return r;
}
/**
* Comparator that performs a lexicographical comparison of byte arrays, treating them as
* unsigned.
*/
public static class ComparingByteArrays implements Comparator<byte[]>, Serializable {
// We don't really need to be serializable, but it doesn't hurt, and FindBugs gets annoyed if
// we're not.
private static final long serialVersionUID = 0xdf641037ffe509e2L;
@Override
public int compare(byte[] o1, byte[] o2) {
return new ComparingByteBuffers().compare(ByteBuffer.wrap(o1), ByteBuffer.wrap(o2));
}
}
public static class ComparingByteBuffers implements Comparator<ByteBuffer>, Serializable {
private static final long serialVersionUID = 0xa3c4a7300fbbf043L;
@Override
public int compare(ByteBuffer o1, ByteBuffer o2) {
o1 = o1.slice();
o2 = o2.slice();
int commonLength = Math.min(o1.remaining(), o2.remaining());
for (int i = 0; i < commonLength; i++) {
// Perform zero-extension as we want to treat the bytes as unsigned
int v1 = o1.get(i) & 0xFF;
int v2 = o2.get(i) & 0xFF;
if (v1 != v2) {
return v1 - v2;
}
}
// The longer buffer is bigger (0x00 comes after end-of-buffer)
return o1.remaining() - o2.remaining();
}
}
/**
* Throws {@link NullPointerException} with message {@code paramName} if {@code object} is null.
*
* @param object value to be null-checked
* @param paramName message for the potential {@link NullPointerException}
* @return {@code object}
* @throws NullPointerException if {@code object} is null
*/
public static <T> T assertNonNull(final T object, final String paramName)
throws NullPointerException {
if (object == null) {
throw new NullPointerException(paramName + " must not be null");
}
return object;
}
/**
* Returns a possibly truncated version of {@code arr} which is guaranteed to be exactly {@code
* len} elements long. If {@code arr} is already exactly {@code len} elements long, then {@code
* arr} is returned without copy or modification. If {@code arr} is longer than {@code len}, then
* a truncated copy is returned. If {@code arr} is shorter than {@code len} then this throws an
* {@link IllegalArgumentException}.
*/
public static byte[] truncate(final byte[] arr, final int len) throws IllegalArgumentException {
if (arr.length == len) {
return arr;
} else if (arr.length > len) {
return Arrays.copyOf(arr, len);
} else {
throw new IllegalArgumentException("arr is not at least " + len + " elements long");
}
}
public static SecureRandom getSecureRandom() {
return LOCAL_RANDOM.get();
}
/**
* Generate the AAD bytes to use when encrypting/decrypting content. The generated AAD is a block
* of bytes containing the provided message identifier, the string identifier, the sequence
* number, and the length of the content.
*
* @param messageId the unique message identifier for the ciphertext.
* @param idString the string describing the type of content processed.
* @param seqNum the sequence number.
* @param len the length of the content.
* @return the bytes containing the generated AAD.
*/
static byte[] generateContentAad(
final byte[] messageId, final String idString, final int seqNum, final long len) {
final byte[] idBytes = idString.getBytes(StandardCharsets.UTF_8);
final int aadLen =
messageId.length + idBytes.length + Integer.SIZE / Byte.SIZE + Long.SIZE / Byte.SIZE;
final ByteBuffer aad = ByteBuffer.allocate(aadLen);
aad.put(messageId);
aad.put(idBytes);
aad.putInt(seqNum);
aad.putLong(len);
return aad.array();
}
static IllegalArgumentException cannotBeNegative(String field) {
return new IllegalArgumentException(field + " cannot be negative");
}
/**
* Equivalent to calling {@link ByteBuffer#flip()} but in a manner which is safe when compiled on
* Java 9 or newer but used on Java 8 or older.
*/
public static ByteBuffer flip(final ByteBuffer buff) {
((Buffer) buff).flip();
return buff;
}
/**
* Equivalent to calling {@link ByteBuffer#clear()} but in a manner which is safe when compiled on
* Java 9 or newer but used on Java 8 or older.
*/
public static ByteBuffer clear(final ByteBuffer buff) {
((Buffer) buff).clear();
return buff;
}
/**
* Equivalent to calling {@link ByteBuffer#position(int)} but in a manner which is safe when
* compiled on Java 9 or newer but used on Java 8 or older.
*/
public static ByteBuffer position(final ByteBuffer buff, final int newPosition) {
((Buffer) buff).position(newPosition);
return buff;
}
/**
* Equivalent to calling {@link ByteBuffer#limit(int)} but in a manner which is safe when compiled
* on Java 9 or newer but used on Java 8 or older.
*/
public static ByteBuffer limit(final ByteBuffer buff, final int newLimit) {
((Buffer) buff).limit(newLimit);
return buff;
}
/**
* Takes a Base64-encoded String, decodes it, and returns contents as a byte array.
*
* @param encoded Base64 encoded String
* @return decoded data as a byte array
*/
public static byte[] decodeBase64String(final String encoded) {
return encoded.isEmpty() ? ArrayUtils.EMPTY_BYTE_ARRAY : Base64.decode(encoded);
}
/**
* Takes data in a byte array, encodes them in Base64, and returns the result as a String.
*
* @param data The data to encode.
* @return Base64 string that encodes the {@code data}.
*/
public static String encodeBase64String(final byte[] data) {
return Base64.toBase64String(data);
}
/**
* Removes the leading zero sign byte from the byte array representation of a BigInteger (if
* present) and left pads with zeroes to produce a byte array of the given length.
*
* @param bigInteger The BigInteger to convert to a byte array
* @param length The length of the byte array, must be at least as long as the BigInteger byte
* array without the sign byte
* @return The byte array
*/
public static byte[] bigIntegerToByteArray(final BigInteger bigInteger, final int length) {
byte[] rawBytes = bigInteger.toByteArray();
// If rawBytes is already the correct length, return it.
if (rawBytes.length == length) {
return rawBytes;
}
// If we're exactly one byte too large, but we have a leading zero byte, remove it and return.
if (rawBytes.length == length + 1 && rawBytes[0] == 0) {
return Arrays.copyOfRange(rawBytes, 1, rawBytes.length);
}
if (rawBytes.length > length) {
throw new IllegalArgumentException(
"Length must be at least as long as the BigInteger byte array "
+ "without the sign byte");
}
final byte[] paddedResult = new byte[length];
System.arraycopy(rawBytes, 0, paddedResult, length - rawBytes.length, rawBytes.length);
return paddedResult;
}
/**
* Returns true if the prefix of the given length for the input arrays are equal. This method will
* return as soon as the first difference is found, and is thus not constant-time.
*
* @param a The first array.
* @param b The second array.
* @param length The length of the prefix to compare.
* @return True if the prefixes are equal, false otherwise.
*/
public static boolean arrayPrefixEquals(final byte[] a, final byte[] b, final int length) {
if (a == null || b == null || a.length < length || b.length < length) {
return false;
}
for (int x = 0; x < length; x++) {
if (a[x] != b[x]) {
return false;
}
}
return true;
}
}