Skip to content

Commit 7d5ab6e

Browse files
Merge pull request #51 from bdonlan/bare-region
Fix bare aliases not using default region
2 parents 5af4b07 + 39a711e commit 7d5ab6e

File tree

3 files changed

+68
-11
lines changed

3 files changed

+68
-11
lines changed

src/main/java/com/amazonaws/encryptionsdk/kms/KmsMasterKeyProvider.java

+5
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,11 @@ public KmsMasterKey getMasterKey(final String provider, final String keyId) thro
450450
}
451451

452452
String regionName = parseRegionfromKeyArn(keyId);
453+
454+
if (regionName == null && defaultRegion_ != null) {
455+
regionName = defaultRegion_;
456+
}
457+
453458
AWSKMS kms = regionalClientSupplier_.getClient(regionName);
454459
if (kms == null) {
455460
throw new AwsCryptoException("Can't use keys from region " + regionName);

src/test/java/com/amazonaws/services/kms/KMSProviderBuilderMockTests.java

+47
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@
22

33
import static com.amazonaws.encryptionsdk.multi.MultipleProviderFactory.buildMultiProvider;
44
import static com.amazonaws.regions.Region.getRegion;
5+
import static com.amazonaws.regions.Regions.DEFAULT_REGION;
56
import static com.amazonaws.regions.Regions.fromName;
67
import static java.util.Collections.singletonList;
78
import static org.junit.Assert.assertEquals;
89
import static org.junit.Assert.assertFalse;
910
import static org.junit.Assert.assertTrue;
1011
import static org.mockito.ArgumentMatchers.any;
12+
import static org.mockito.ArgumentMatchers.notNull;
1113
import static org.mockito.Mockito.atLeastOnce;
1214
import static org.mockito.Mockito.mock;
1315
import static org.mockito.Mockito.spy;
@@ -31,11 +33,56 @@
3133
import com.amazonaws.encryptionsdk.kms.KmsMasterKey;
3234
import com.amazonaws.encryptionsdk.kms.KmsMasterKeyProvider;
3335
import com.amazonaws.encryptionsdk.kms.KmsMasterKeyProvider.RegionalClientSupplier;
36+
import com.amazonaws.regions.Region;
37+
import com.amazonaws.regions.Regions;
38+
import com.amazonaws.services.kms.model.CreateAliasRequest;
3439
import com.amazonaws.services.kms.model.DecryptRequest;
3540
import com.amazonaws.services.kms.model.EncryptRequest;
3641
import com.amazonaws.services.kms.model.GenerateDataKeyRequest;
3742

3843
public class KMSProviderBuilderMockTests {
44+
@Test
45+
public void testBareAliasMapping() {
46+
MockKMSClient client = spy(new MockKMSClient());
47+
48+
RegionalClientSupplier supplier = mock(RegionalClientSupplier.class);
49+
when(supplier.getClient(notNull())).thenReturn(client);
50+
51+
String key1 = client.createKey().getKeyMetadata().getKeyId();
52+
client.createAlias(new CreateAliasRequest()
53+
.withAliasName("foo")
54+
.withTargetKeyId(key1)
55+
);
56+
57+
KmsMasterKeyProvider mkp0 = KmsMasterKeyProvider.builder()
58+
.withKeysForEncryption("alias/foo")
59+
.withCustomClientFactory(supplier)
60+
.withDefaultRegion("us-west-2")
61+
.build();
62+
63+
new AwsCrypto().encryptData(mkp0, new byte[0]);
64+
}
65+
66+
@Test
67+
public void testBareAliasMapping_withLegacyCtor() {
68+
MockKMSClient client = spy(new MockKMSClient());
69+
70+
RegionalClientSupplier supplier = mock(RegionalClientSupplier.class);
71+
when(supplier.getClient(any())).thenReturn(client);
72+
73+
String key1 = client.createKey().getKeyMetadata().getKeyId();
74+
client.createAlias(new CreateAliasRequest()
75+
.withAliasName("foo")
76+
.withTargetKeyId(key1)
77+
);
78+
79+
KmsMasterKeyProvider mkp0 = new KmsMasterKeyProvider(
80+
client, Region.getRegion(Regions.DEFAULT_REGION), Arrays.asList("alias/foo")
81+
);
82+
83+
new AwsCrypto().encryptData(mkp0, new byte[0]);
84+
}
85+
3986
@Test
4087
public void testGrantTokenPassthrough_usingMKsetCall() throws Exception {
4188
MockKMSClient client = spy(new MockKMSClient());

src/test/java/com/amazonaws/services/kms/MockKMSClient.java

+16-11
Original file line numberDiff line numberDiff line change
@@ -88,13 +88,20 @@ public class MockKMSClient extends AWSKMSClient {
8888
private static final SecureRandom rnd = new SecureRandom();
8989
private static final String ACCOUNT_ID = "01234567890";
9090
private final Map<DecryptMapKey, DecryptResult> results_ = new HashMap<>();
91-
private final Map<String, String> idToArnMap = new HashMap<>();
9291
private final Set<String> activeKeys = new HashSet<>();
92+
private final Map<String, String> keyAliases = new HashMap<>();
9393
private Region region_ = Region.getRegion(Regions.DEFAULT_REGION);
9494

9595
@Override
9696
public CreateAliasResult createAlias(CreateAliasRequest arg0) throws AmazonServiceException, AmazonClientException {
97-
throw new java.lang.UnsupportedOperationException();
97+
assertExists(arg0.getTargetKeyId());
98+
99+
keyAliases.put(
100+
"alias/" + arg0.getAliasName(),
101+
keyAliases.get(arg0.getTargetKeyId())
102+
);
103+
104+
return new CreateAliasResult();
98105
}
99106

100107
@Override
@@ -111,8 +118,9 @@ public CreateKeyResult createKey() throws AmazonServiceException, AmazonClientEx
111118
public CreateKeyResult createKey(CreateKeyRequest req) throws AmazonServiceException, AmazonClientException {
112119
String keyId = UUID.randomUUID().toString();
113120
String arn = "arn:aws:kms:" + region_.getName() + ":" + ACCOUNT_ID + ":key/" + keyId;
114-
idToArnMap.put(keyId, arn);
115121
activeKeys.add(arn);
122+
keyAliases.put(keyId, arn);
123+
keyAliases.put(arn, arn);
116124
CreateKeyResult result = new CreateKeyResult();
117125
result.setKeyMetadata(new KeyMetadata().withAWSAccountId(ACCOUNT_ID).withCreationDate(new Date())
118126
.withDescription(req.getDescription()).withEnabled(true).withKeyId(keyId)
@@ -183,7 +191,7 @@ private EncryptResult encrypt0(EncryptRequest req) throws AmazonServiceException
183191
final byte[] cipherText = new byte[512];
184192
rnd.nextBytes(cipherText);
185193
DecryptResult dec = new DecryptResult();
186-
dec.withKeyId(req.getKeyId()).withPlaintext(req.getPlaintext().asReadOnlyBuffer());
194+
dec.withKeyId(retrieveArn(req.getKeyId())).withPlaintext(req.getPlaintext().asReadOnlyBuffer());
187195
ByteBuffer ctBuff = ByteBuffer.wrap(cipherText);
188196

189197
results_.put(new DecryptMapKey(ctBuff, req.getEncryptionContext()), dec);
@@ -336,20 +344,17 @@ public void deleteKey(final String keyId) {
336344
}
337345

338346
private String retrieveArn(final String keyId) {
339-
String arn = keyId;
340-
if (keyId.contains("arn:") == false) {
341-
arn = idToArnMap.get(keyId);
342-
}
347+
String arn = keyAliases.get(keyId);
343348
assertExists(arn);
344349
return arn;
345350
}
346351

347352
private void assertExists(String keyId) {
348-
if (idToArnMap.containsKey(keyId)) {
349-
keyId = idToArnMap.get(keyId);
353+
if (keyAliases.containsKey(keyId)) {
354+
keyId = keyAliases.get(keyId);
350355
}
351356
if (keyId == null || !activeKeys.contains(keyId)) {
352-
throw new NotFoundException("Key doesn't exist");
357+
throw new NotFoundException("Key doesn't exist: " + keyId);
353358
}
354359
}
355360

0 commit comments

Comments
 (0)