diff --git a/src/main/java/com/amazonaws/encryptionsdk/keyrings/RawRsaKeyring.java b/src/main/java/com/amazonaws/encryptionsdk/keyrings/RawRsaKeyring.java index 9b8a7b453..0ca7923e5 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/keyrings/RawRsaKeyring.java +++ b/src/main/java/com/amazonaws/encryptionsdk/keyrings/RawRsaKeyring.java @@ -14,8 +14,10 @@ package com.amazonaws.encryptionsdk.keyrings; import com.amazonaws.encryptionsdk.EncryptedDataKey; +import com.amazonaws.encryptionsdk.exception.AwsCryptoException; import com.amazonaws.encryptionsdk.internal.JceKeyCipher; import com.amazonaws.encryptionsdk.keyrings.RawRsaKeyringBuilder.RsaPaddingScheme; +import com.amazonaws.encryptionsdk.model.EncryptionMaterials; import java.security.PrivateKey; import java.security.PublicKey; @@ -29,8 +31,20 @@ */ class RawRsaKeyring extends RawKeyring { + private final boolean validToEncrypt; + RawRsaKeyring(String keyNamespace, String keyName, PublicKey publicKey, PrivateKey privateKey, RsaPaddingScheme rsaPaddingScheme) { super(keyNamespace, keyName, JceKeyCipher.rsa(publicKey, privateKey, rsaPaddingScheme.getTransformation())); + validToEncrypt = publicKey != null; + } + + @Override + public EncryptionMaterials onEncrypt(EncryptionMaterials encryptionMaterials) { + if(!validToEncrypt) { + throw new AwsCryptoException("A public key is required to encrypt"); + } + + return super.onEncrypt(encryptionMaterials); } @Override diff --git a/src/test/java/com/amazonaws/encryptionsdk/keyrings/RawRsaKeyringTest.java b/src/test/java/com/amazonaws/encryptionsdk/keyrings/RawRsaKeyringTest.java index 59fff4da1..b79d88dfe 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/keyrings/RawRsaKeyringTest.java +++ b/src/test/java/com/amazonaws/encryptionsdk/keyrings/RawRsaKeyringTest.java @@ -14,6 +14,7 @@ package com.amazonaws.encryptionsdk.keyrings; import com.amazonaws.encryptionsdk.EncryptedDataKey; +import com.amazonaws.encryptionsdk.exception.AwsCryptoException; import com.amazonaws.encryptionsdk.keyrings.RawRsaKeyringBuilder.RsaPaddingScheme; import com.amazonaws.encryptionsdk.model.DecryptionMaterials; import com.amazonaws.encryptionsdk.model.EncryptionMaterials; @@ -34,6 +35,7 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; class RawRsaKeyringTest { @@ -134,4 +136,47 @@ void testEncryptDecryptGenerateDataKey() { assertTrue(decryptionMaterials.getKeyringTrace().getEntries().get(0).getFlags().contains(KeyringTraceFlag.DECRYPTED_DATA_KEY)); } + @Test + void testEncryptWithNoPublicKey() throws Exception { + final KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("RSA"); + keyPairGenerator.initialize(2048); + final KeyPair keyPair = keyPairGenerator.generateKeyPair(); + + Keyring noPublicKey = new RawRsaKeyring(KEYNAMESPACE, KEYNAME, null, keyPair.getPrivate(), PADDING_SCHEME); + + EncryptionMaterials encryptionMaterials = EncryptionMaterials.newBuilder() + .setAlgorithm(ALGORITHM) + .setCleartextDataKey(DATA_KEY) + .setEncryptionContext(ENCRYPTION_CONTEXT) + .build(); + + assertThrows(AwsCryptoException.class, () -> noPublicKey.onEncrypt(encryptionMaterials)); + } + + @Test + void testDecryptWithNoPrivateKey() throws Exception { + final KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("RSA"); + keyPairGenerator.initialize(2048); + final KeyPair keyPair = keyPairGenerator.generateKeyPair(); + + Keyring noPrivateKey = new RawRsaKeyring(KEYNAMESPACE, KEYNAME, keyPair.getPublic(), null, PADDING_SCHEME); + + EncryptionMaterials encryptionMaterials = EncryptionMaterials.newBuilder() + .setAlgorithm(ALGORITHM) + .setCleartextDataKey(DATA_KEY) + .setEncryptionContext(ENCRYPTION_CONTEXT) + .build(); + + encryptionMaterials = noPrivateKey.onEncrypt(encryptionMaterials); + + DecryptionMaterials decryptionMaterials = DecryptionMaterials.newBuilder() + .setAlgorithm(ALGORITHM) + .setEncryptionContext(ENCRYPTION_CONTEXT) + .build(); + + DecryptionMaterials resultDecryptionMaterials = noPrivateKey.onDecrypt(decryptionMaterials, encryptionMaterials.getEncryptedDataKeys()); + + assertEquals(decryptionMaterials, resultDecryptionMaterials); + } + }