From 882dd03b2f38581e250b856a6ce4fcce926d76b0 Mon Sep 17 00:00:00 2001 From: Adriano Hernandez Date: Mon, 5 Aug 2019 18:28:38 -0700 Subject: [PATCH 1/6] Fixed linking issue with algorithm in Sphinx. Black also made some small changes so formatting may be slightly different, but it should be the same functionality wise. I have a couple errors that I believe are not from my own changes (they were there on master when I pulled and pertain to there not being valid AWS credentials for tests for the most part), but they should not be important. --- src/aws_encryption_sdk/__init__.py | 26 +- src/aws_encryption_sdk/identifiers.py | 90 ++++++- .../test_f_aws_encryption_sdk_client.py | 249 +++++++++++++----- test/functional/test_f_crypto.py | 39 ++- test/functional/test_f_xcompat.py | 50 +++- .../test_i_aws_encrytion_sdk_client.py | 172 ++++++++---- test/integration/test_i_xcompat_kms.py | 8 +- tox.ini | 2 +- 8 files changed, 481 insertions(+), 155 deletions(-) diff --git a/src/aws_encryption_sdk/__init__.py b/src/aws_encryption_sdk/__init__.py index 3f6d86e2e..60fa8c570 100644 --- a/src/aws_encryption_sdk/__init__.py +++ b/src/aws_encryption_sdk/__init__.py @@ -14,10 +14,17 @@ # Below are imported for ease of use by implementors from aws_encryption_sdk.caches.local import LocalCryptoMaterialsCache # noqa from aws_encryption_sdk.caches.null import NullCryptoMaterialsCache # noqa -from aws_encryption_sdk.identifiers import Algorithm, __version__ # noqa -from aws_encryption_sdk.key_providers.kms import KMSMasterKeyProvider, KMSMasterKeyProviderConfig # noqa -from aws_encryption_sdk.materials_managers.caching import CachingCryptoMaterialsManager # noqa -from aws_encryption_sdk.materials_managers.default import DefaultCryptoMaterialsManager # noqa +from aws_encryption_sdk.identifiers import AlgorithmSuite, __version__ # noqa +from aws_encryption_sdk.key_providers.kms import ( + KMSMasterKeyProvider, + KMSMasterKeyProviderConfig, +) # noqa +from aws_encryption_sdk.materials_managers.caching import ( + CachingCryptoMaterialsManager, +) # noqa +from aws_encryption_sdk.materials_managers.default import ( + DefaultCryptoMaterialsManager, +) # noqa from aws_encryption_sdk.streaming_client import ( # noqa DecryptorConfig, EncryptorConfig, @@ -69,8 +76,8 @@ def encrypt(**kwargs): this is not enforced if a `key_provider` is provided. :param dict encryption_context: Dictionary defining encryption context - :param algorithm: Algorithm to use for encryption - :type algorithm: aws_encryption_sdk.identifiers.Algorithm + :param algorithm: AlgorithmSuite to use for encryption + :type algorithm: aws_encryption_sdk.identifiers.AlgorithmSuite :param int frame_length: Frame length in bytes :returns: Tuple containing the encrypted ciphertext and the message header object :rtype: tuple of bytes and :class:`aws_encryption_sdk.structures.MessageHeader` @@ -177,7 +184,12 @@ def stream(**kwargs): :raises ValueError: if supplied with an unsupported mode value """ mode = kwargs.pop("mode") - _stream_map = {"e": StreamEncryptor, "encrypt": StreamEncryptor, "d": StreamDecryptor, "decrypt": StreamDecryptor} + _stream_map = { + "e": StreamEncryptor, + "encrypt": StreamEncryptor, + "d": StreamDecryptor, + "decrypt": StreamDecryptor, + } try: return _stream_map[mode.lower()](**kwargs) except KeyError: diff --git a/src/aws_encryption_sdk/identifiers.py b/src/aws_encryption_sdk/identifiers.py index 1bd9bb1f1..663884227 100644 --- a/src/aws_encryption_sdk/identifiers.py +++ b/src/aws_encryption_sdk/identifiers.py @@ -50,7 +50,15 @@ class EncryptionSuite(Enum): AES_192_GCM_IV12_TAG16 = (algorithms.AES, modes.GCM, 24, 12, 16) AES_256_GCM_IV12_TAG16 = (algorithms.AES, modes.GCM, 32, 12, 16) - def __init__(self, algorithm, mode, data_key_length, iv_length, auth_length, auth_key_length=0): + def __init__( + self, + algorithm, + mode, + data_key_length, + iv_length, + auth_length, + auth_key_length=0, + ): """Prepare a new EncryptionSuite.""" self.algorithm = algorithm self.mode = mode @@ -157,9 +165,21 @@ class AlgorithmSuite(Enum): # pylint: disable=too-many-instance-attributes AES_128_GCM_IV12_TAG16 = (0x0014, EncryptionSuite.AES_128_GCM_IV12_TAG16) AES_192_GCM_IV12_TAG16 = (0x0046, EncryptionSuite.AES_192_GCM_IV12_TAG16) AES_256_GCM_IV12_TAG16 = (0x0078, EncryptionSuite.AES_256_GCM_IV12_TAG16) - AES_128_GCM_IV12_TAG16_HKDF_SHA256 = (0x0114, EncryptionSuite.AES_128_GCM_IV12_TAG16, KDFSuite.HKDF_SHA256) - AES_192_GCM_IV12_TAG16_HKDF_SHA256 = (0x0146, EncryptionSuite.AES_192_GCM_IV12_TAG16, KDFSuite.HKDF_SHA256) - AES_256_GCM_IV12_TAG16_HKDF_SHA256 = (0x0178, EncryptionSuite.AES_256_GCM_IV12_TAG16, KDFSuite.HKDF_SHA256) + AES_128_GCM_IV12_TAG16_HKDF_SHA256 = ( + 0x0114, + EncryptionSuite.AES_128_GCM_IV12_TAG16, + KDFSuite.HKDF_SHA256, + ) + AES_192_GCM_IV12_TAG16_HKDF_SHA256 = ( + 0x0146, + EncryptionSuite.AES_192_GCM_IV12_TAG16, + KDFSuite.HKDF_SHA256, + ) + AES_256_GCM_IV12_TAG16_HKDF_SHA256 = ( + 0x0178, + EncryptionSuite.AES_256_GCM_IV12_TAG16, + KDFSuite.HKDF_SHA256, + ) AES_128_GCM_IV12_TAG16_HKDF_SHA256_ECDSA_P256 = ( 0x0214, EncryptionSuite.AES_128_GCM_IV12_TAG16, @@ -240,6 +260,8 @@ def safe_to_cache(self): return self.kdf is not KDFSuite.NONE +# algorithm is just an alias for AlgorithmSuite ... but Sphinx does not recognize this fact +# so we need to go through and fix the references Algorithm = AlgorithmSuite @@ -271,16 +293,60 @@ class WrappingAlgorithm(Enum): :type padding_mgf: """ - AES_128_GCM_IV12_TAG16_NO_PADDING = (EncryptionType.SYMMETRIC, Algorithm.AES_128_GCM_IV12_TAG16, None, None, None) - AES_192_GCM_IV12_TAG16_NO_PADDING = (EncryptionType.SYMMETRIC, Algorithm.AES_192_GCM_IV12_TAG16, None, None, None) - AES_256_GCM_IV12_TAG16_NO_PADDING = (EncryptionType.SYMMETRIC, Algorithm.AES_256_GCM_IV12_TAG16, None, None, None) + AES_128_GCM_IV12_TAG16_NO_PADDING = ( + EncryptionType.SYMMETRIC, + AlgorithmSuite.AES_128_GCM_IV12_TAG16, + None, + None, + None, + ) + AES_192_GCM_IV12_TAG16_NO_PADDING = ( + EncryptionType.SYMMETRIC, + AlgorithmSuite.AES_192_GCM_IV12_TAG16, + None, + None, + None, + ) + AES_256_GCM_IV12_TAG16_NO_PADDING = ( + EncryptionType.SYMMETRIC, + AlgorithmSuite.AES_256_GCM_IV12_TAG16, + None, + None, + None, + ) RSA_PKCS1 = (EncryptionType.ASYMMETRIC, rsa, padding.PKCS1v15, None, None) - RSA_OAEP_SHA1_MGF1 = (EncryptionType.ASYMMETRIC, rsa, padding.OAEP, hashes.SHA1, padding.MGF1) - RSA_OAEP_SHA256_MGF1 = (EncryptionType.ASYMMETRIC, rsa, padding.OAEP, hashes.SHA256, padding.MGF1) - RSA_OAEP_SHA384_MGF1 = (EncryptionType.ASYMMETRIC, rsa, padding.OAEP, hashes.SHA384, padding.MGF1) - RSA_OAEP_SHA512_MGF1 = (EncryptionType.ASYMMETRIC, rsa, padding.OAEP, hashes.SHA512, padding.MGF1) + RSA_OAEP_SHA1_MGF1 = ( + EncryptionType.ASYMMETRIC, + rsa, + padding.OAEP, + hashes.SHA1, + padding.MGF1, + ) + RSA_OAEP_SHA256_MGF1 = ( + EncryptionType.ASYMMETRIC, + rsa, + padding.OAEP, + hashes.SHA256, + padding.MGF1, + ) + RSA_OAEP_SHA384_MGF1 = ( + EncryptionType.ASYMMETRIC, + rsa, + padding.OAEP, + hashes.SHA384, + padding.MGF1, + ) + RSA_OAEP_SHA512_MGF1 = ( + EncryptionType.ASYMMETRIC, + rsa, + padding.OAEP, + hashes.SHA512, + padding.MGF1, + ) - def __init__(self, encryption_type, algorithm, padding_type, padding_algorithm, padding_mgf): + def __init__( + self, encryption_type, algorithm, padding_type, padding_algorithm, padding_mgf + ): """Prepares new WrappingAlgorithm.""" self.encryption_type = encryption_type self.algorithm = algorithm diff --git a/test/functional/test_f_aws_encryption_sdk_client.py b/test/functional/test_f_aws_encryption_sdk_client.py index fb19e868a..b0a367253 100644 --- a/test/functional/test_f_aws_encryption_sdk_client.py +++ b/test/functional/test_f_aws_encryption_sdk_client.py @@ -28,15 +28,27 @@ import aws_encryption_sdk from aws_encryption_sdk import KMSMasterKeyProvider -from aws_encryption_sdk.caches import build_decryption_materials_cache_key, build_encryption_materials_cache_key +from aws_encryption_sdk.caches import ( + build_decryption_materials_cache_key, + build_encryption_materials_cache_key, +) from aws_encryption_sdk.exceptions import CustomMaximumValueExceeded -from aws_encryption_sdk.identifiers import Algorithm, EncryptionKeyType, WrappingAlgorithm +from aws_encryption_sdk.identifiers import ( + AlgorithmSuite, + EncryptionKeyType, + WrappingAlgorithm, +) from aws_encryption_sdk.internal.crypto.wrapping_keys import WrappingKey from aws_encryption_sdk.internal.defaults import LINE_LENGTH -from aws_encryption_sdk.internal.formatting.encryption_context import serialize_encryption_context +from aws_encryption_sdk.internal.formatting.encryption_context import ( + serialize_encryption_context, +) from aws_encryption_sdk.key_providers.base import MasterKeyProviderConfig from aws_encryption_sdk.key_providers.raw import RawMasterKeyProvider -from aws_encryption_sdk.materials_managers import DecryptionMaterialsRequest, EncryptionMaterialsRequest +from aws_encryption_sdk.materials_managers import ( + DecryptionMaterialsRequest, + EncryptionMaterialsRequest, +) pytestmark = [pytest.mark.functional, pytest.mark.local] @@ -179,7 +191,9 @@ class FakeRawMasterKeyProvider(RawMasterKeyProvider): def _get_raw_key(self, key_id): wrapping_key = VALUES["raw"][key_id][self.config.encryption_key_type] if key_id == b"sym1": - wrapping_key = wrapping_key[: self.config.wrapping_algorithm.algorithm.data_key_len] + wrapping_key = wrapping_key[ + : self.config.wrapping_algorithm.algorithm.data_key_len + ] return WrappingKey( wrapping_algorithm=self.config.wrapping_algorithm, wrapping_key=wrapping_key, @@ -189,12 +203,16 @@ def _get_raw_key(self, key_id): def _mgf1_sha256_supported(): wk = serialization.load_pem_private_key( - data=VALUES["raw"][b"asym1"][EncryptionKeyType.PRIVATE], password=None, backend=default_backend() + data=VALUES["raw"][b"asym1"][EncryptionKeyType.PRIVATE], + password=None, + backend=default_backend(), ) try: wk.public_key().encrypt( plaintext=b"aosdjfoiajfoiaj;foijae;rogijaerg", - padding=padding.OAEP(mgf=padding.MGF1(hashes.SHA256()), algorithm=hashes.SHA256(), label=None), + padding=padding.OAEP( + mgf=padding.MGF1(hashes.SHA256()), algorithm=hashes.SHA256(), label=None + ), ) except cryptography.exceptions.UnsupportedAlgorithm: return False @@ -255,13 +273,18 @@ def test_no_infinite_encryption_cycle_on_empty_source(): def test_encrypt_load_header(): """Test that StreamEncryptor can extract header without reading plaintext.""" # Using a non-signed algorithm to simplify header size calculation - algorithm = aws_encryption_sdk.Algorithm.AES_256_GCM_IV12_TAG16_HKDF_SHA256 + algorithm = aws_encryption_sdk.AlgorithmSuite.AES_256_GCM_IV12_TAG16_HKDF_SHA256 key_provider = fake_kms_key_provider(algorithm.kdf_input_len) header_length = len(serialize_encryption_context(VALUES["encryption_context"])) header_length += 34 header_length += algorithm.iv_len header_length += algorithm.auth_len - header_length += 6 + 7 + len(VALUES["arn"]) + len(VALUES["data_keys"][algorithm.kdf_input_len]["encrypted"]) + header_length += ( + 6 + + 7 + + len(VALUES["arn"]) + + len(VALUES["data_keys"][algorithm.kdf_input_len]["encrypted"]) + ) with aws_encryption_sdk.stream( mode="e", source=VALUES["plaintext_128"], @@ -283,11 +306,14 @@ def test_encrypt_decrypt_header_only(): key_provider=fake_kms_key_provider(), encryption_context=VALUES["encryption_context"], ) - with aws_encryption_sdk.stream(mode="d", source=ciphertext, key_provider=fake_kms_key_provider()) as decryptor: + with aws_encryption_sdk.stream( + mode="d", source=ciphertext, key_provider=fake_kms_key_provider() + ) as decryptor: decryptor_header = decryptor.header assert decryptor.output_buffer == b"" assert all( - pair in decryptor_header.encryption_context.items() for pair in encryptor_header.encryption_context.items() + pair in decryptor_header.encryption_context.items() + for pair in encryptor_header.encryption_context.items() ) @@ -296,7 +322,7 @@ def test_encrypt_decrypt_header_only(): [ [frame_length, algorithm_suite, encryption_context] for frame_length in VALUES["frame_lengths"] - for algorithm_suite in Algorithm + for algorithm_suite in AlgorithmSuite for encryption_context in [{}, VALUES["encryption_context"]] ], ) @@ -318,25 +344,53 @@ def test_encrypt_ciphertext_message(frame_length, algorithm, encryption_context) @pytest.mark.parametrize( "wrapping_algorithm, encryption_key_type, decryption_key_type", ( - (WrappingAlgorithm.AES_256_GCM_IV12_TAG16_NO_PADDING, EncryptionKeyType.SYMMETRIC, EncryptionKeyType.SYMMETRIC), - (WrappingAlgorithm.RSA_PKCS1, EncryptionKeyType.PRIVATE, EncryptionKeyType.PRIVATE), - (WrappingAlgorithm.RSA_PKCS1, EncryptionKeyType.PUBLIC, EncryptionKeyType.PRIVATE), - (WrappingAlgorithm.RSA_OAEP_SHA1_MGF1, EncryptionKeyType.PRIVATE, EncryptionKeyType.PRIVATE), - (WrappingAlgorithm.RSA_OAEP_SHA1_MGF1, EncryptionKeyType.PUBLIC, EncryptionKeyType.PRIVATE), + ( + WrappingAlgorithm.AES_256_GCM_IV12_TAG16_NO_PADDING, + EncryptionKeyType.SYMMETRIC, + EncryptionKeyType.SYMMETRIC, + ), + ( + WrappingAlgorithm.RSA_PKCS1, + EncryptionKeyType.PRIVATE, + EncryptionKeyType.PRIVATE, + ), + ( + WrappingAlgorithm.RSA_PKCS1, + EncryptionKeyType.PUBLIC, + EncryptionKeyType.PRIVATE, + ), + ( + WrappingAlgorithm.RSA_OAEP_SHA1_MGF1, + EncryptionKeyType.PRIVATE, + EncryptionKeyType.PRIVATE, + ), + ( + WrappingAlgorithm.RSA_OAEP_SHA1_MGF1, + EncryptionKeyType.PUBLIC, + EncryptionKeyType.PRIVATE, + ), ), ) -def test_encryption_cycle_raw_mkp(caplog, wrapping_algorithm, encryption_key_type, decryption_key_type): +def test_encryption_cycle_raw_mkp( + caplog, wrapping_algorithm, encryption_key_type, decryption_key_type +): caplog.set_level(logging.DEBUG) - encrypting_key_provider = build_fake_raw_key_provider(wrapping_algorithm, encryption_key_type) - decrypting_key_provider = build_fake_raw_key_provider(wrapping_algorithm, decryption_key_type) + encrypting_key_provider = build_fake_raw_key_provider( + wrapping_algorithm, encryption_key_type + ) + decrypting_key_provider = build_fake_raw_key_provider( + wrapping_algorithm, decryption_key_type + ) ciphertext, _ = aws_encryption_sdk.encrypt( source=VALUES["plaintext_128"], key_provider=encrypting_key_provider, encryption_context=VALUES["encryption_context"], frame_length=0, ) - plaintext, _ = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=decrypting_key_provider) + plaintext, _ = aws_encryption_sdk.decrypt( + source=ciphertext, key_provider=decrypting_key_provider + ) assert plaintext == VALUES["plaintext_128"] for member in encrypting_key_provider._members: @@ -344,7 +398,8 @@ def test_encryption_cycle_raw_mkp(caplog, wrapping_algorithm, encryption_key_typ @pytest.mark.skipif( - not _mgf1_sha256_supported(), reason="MGF1-SHA2 not supported by this backend: OpenSSL required v1.0.2+" + not _mgf1_sha256_supported(), + reason="MGF1-SHA2 not supported by this backend: OpenSSL required v1.0.2+", ) @pytest.mark.parametrize( "wrapping_algorithm", @@ -354,18 +409,28 @@ def test_encryption_cycle_raw_mkp(caplog, wrapping_algorithm, encryption_key_typ WrappingAlgorithm.RSA_OAEP_SHA512_MGF1, ), ) -@pytest.mark.parametrize("encryption_key_type", (EncryptionKeyType.PUBLIC, EncryptionKeyType.PRIVATE)) -def test_encryption_cycle_raw_mkp_openssl_102_plus(wrapping_algorithm, encryption_key_type): +@pytest.mark.parametrize( + "encryption_key_type", (EncryptionKeyType.PUBLIC, EncryptionKeyType.PRIVATE) +) +def test_encryption_cycle_raw_mkp_openssl_102_plus( + wrapping_algorithm, encryption_key_type +): decryption_key_type = EncryptionKeyType.PRIVATE - encrypting_key_provider = build_fake_raw_key_provider(wrapping_algorithm, encryption_key_type) - decrypting_key_provider = build_fake_raw_key_provider(wrapping_algorithm, decryption_key_type) + encrypting_key_provider = build_fake_raw_key_provider( + wrapping_algorithm, encryption_key_type + ) + decrypting_key_provider = build_fake_raw_key_provider( + wrapping_algorithm, decryption_key_type + ) ciphertext, _ = aws_encryption_sdk.encrypt( source=VALUES["plaintext_128"], key_provider=encrypting_key_provider, encryption_context=VALUES["encryption_context"], frame_length=0, ) - plaintext, _ = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=decrypting_key_provider) + plaintext, _ = aws_encryption_sdk.decrypt( + source=ciphertext, key_provider=decrypting_key_provider + ) assert plaintext == VALUES["plaintext_128"] @@ -374,7 +439,7 @@ def test_encryption_cycle_raw_mkp_openssl_102_plus(wrapping_algorithm, encryptio [ [frame_length, algorithm_suite, encryption_context] for frame_length in VALUES["frame_lengths"] - for algorithm_suite in Algorithm + for algorithm_suite in AlgorithmSuite for encryption_context in [{}, VALUES["encryption_context"]] ], ) @@ -389,7 +454,9 @@ def test_encryption_cycle_oneshot_kms(frame_length, algorithm, encryption_contex encryption_context=encryption_context, ) - plaintext, _ = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=key_provider) + plaintext, _ = aws_encryption_sdk.decrypt( + source=ciphertext, key_provider=key_provider + ) assert plaintext == VALUES["plaintext_128"] * 10 @@ -399,7 +466,7 @@ def test_encryption_cycle_oneshot_kms(frame_length, algorithm, encryption_contex [ [frame_length, algorithm_suite, encryption_context] for frame_length in VALUES["frame_lengths"] - for algorithm_suite in Algorithm + for algorithm_suite in AlgorithmSuite for encryption_context in [{}, VALUES["encryption_context"]] ], ) @@ -420,7 +487,9 @@ def test_encryption_cycle_stream_kms(frame_length, algorithm, encryption_context ciphertext = bytes(ciphertext) plaintext = bytearray() - with aws_encryption_sdk.stream(mode="d", source=io.BytesIO(ciphertext), key_provider=key_provider) as decryptor: + with aws_encryption_sdk.stream( + mode="d", source=io.BytesIO(ciphertext), key_provider=key_provider + ) as decryptor: for chunk in decryptor: plaintext.extend(chunk) plaintext = bytes(plaintext) @@ -433,7 +502,9 @@ def test_encryption_cycle_stream_kms(frame_length, algorithm, encryption_context def test_decrypt_legacy_provided_message(): """Tests backwards compatiblity against some legacy provided ciphertext.""" region = "us-west-2" - key_info = "arn:aws:kms:us-west-2:249645522726:key/d1720f4e-953b-44bb-b9dd-fc8b9d0baa5f" + key_info = ( + "arn:aws:kms:us-west-2:249645522726:key/d1720f4e-953b-44bb-b9dd-fc8b9d0baa5f" + ) mock_kms_client = fake_kms_client() mock_kms_client.decrypt.return_value = {"Plaintext": VALUES["provided"]["key"]} mock_kms_key_provider = fake_kms_key_provider() @@ -446,12 +517,15 @@ def test_decrypt_legacy_provided_message(): def test_encryption_cycle_with_caching(): - algorithm = Algorithm.AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384 + algorithm = AlgorithmSuite.AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384 frame_length = 1024 key_provider = fake_kms_key_provider(algorithm.kdf_input_len) cache = aws_encryption_sdk.LocalCryptoMaterialsCache(capacity=10) ccmm = aws_encryption_sdk.CachingCryptoMaterialsManager( - master_key_provider=key_provider, cache=cache, max_age=3600.0, max_messages_encrypted=5 + master_key_provider=key_provider, + cache=cache, + max_age=3600.0, + max_messages_encrypted=5, ) encrypt_kwargs = dict( source=VALUES["plaintext_128"], @@ -511,7 +585,9 @@ def test_encrypt_source_length_enforcement(): plaintext = io.BytesIO(VALUES["plaintext_128"]) with pytest.raises(CustomMaximumValueExceeded) as excinfo: aws_encryption_sdk.encrypt( - source=plaintext, materials_manager=cmm, source_length=int(len(VALUES["plaintext_128"]) / 2) + source=plaintext, + materials_manager=cmm, + source_length=int(len(VALUES["plaintext_128"]) / 2), ) excinfo.match(r"Bytes encrypted has exceeded stated source length estimate:*") @@ -524,7 +600,9 @@ def test_encrypt_source_length_enforcement_legacy_support(): # provider is provided. key_provider = fake_kms_key_provider() aws_encryption_sdk.encrypt( - source=VALUES["plaintext_128"], key_provider=key_provider, source_length=int(len(VALUES["plaintext_128"]) / 2) + source=VALUES["plaintext_128"], + key_provider=key_provider, + source_length=int(len(VALUES["plaintext_128"]) / 2), ) @@ -547,11 +625,15 @@ def test_stream_encryptor_no_seek_input(): plaintext = NoSeekBytesIO(VALUES["plaintext_128"]) ciphertext = io.BytesIO() with aws_encryption_sdk.StreamEncryptor( - source=plaintext, key_provider=key_provider, encryption_context=VALUES["encryption_context"] + source=plaintext, + key_provider=key_provider, + encryption_context=VALUES["encryption_context"], ) as encryptor: for chunk in encryptor: ciphertext.write(chunk) - decrypted, _header = aws_encryption_sdk.decrypt(source=ciphertext.getvalue(), key_provider=key_provider) + decrypted, _header = aws_encryption_sdk.decrypt( + source=ciphertext.getvalue(), key_provider=key_provider + ) assert decrypted == VALUES["plaintext_128"] @@ -559,11 +641,15 @@ def test_stream_decryptor_no_seek_input(): """Test that StreamDecryptor can handle an input stream that is not seekable.""" key_provider = fake_kms_key_provider() ciphertext, _header = aws_encryption_sdk.encrypt( - source=VALUES["plaintext_128"], key_provider=key_provider, encryption_context=VALUES["encryption_context"] + source=VALUES["plaintext_128"], + key_provider=key_provider, + encryption_context=VALUES["encryption_context"], ) ciphertext_no_seek = NoSeekBytesIO(ciphertext) decrypted = io.BytesIO() - with aws_encryption_sdk.StreamDecryptor(source=ciphertext_no_seek, key_provider=key_provider) as decryptor: + with aws_encryption_sdk.StreamDecryptor( + source=ciphertext_no_seek, key_provider=key_provider + ) as decryptor: for chunk in decryptor: decrypted.write(chunk) assert decrypted.getvalue() == VALUES["plaintext_128"] @@ -574,9 +660,13 @@ def test_encrypt_oneshot_no_seek_input(): key_provider = fake_kms_key_provider() plaintext = NoSeekBytesIO(VALUES["plaintext_128"]) ciphertext, _header = aws_encryption_sdk.encrypt( - source=plaintext, key_provider=key_provider, encryption_context=VALUES["encryption_context"] + source=plaintext, + key_provider=key_provider, + encryption_context=VALUES["encryption_context"], + ) + decrypted, _header = aws_encryption_sdk.decrypt( + source=ciphertext, key_provider=key_provider ) - decrypted, _header = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=key_provider) assert decrypted == VALUES["plaintext_128"] @@ -584,10 +674,14 @@ def test_decrypt_oneshot_no_seek_input(): """Test that decrypt can handle an input stream that is not seekable.""" key_provider = fake_kms_key_provider() ciphertext, _header = aws_encryption_sdk.encrypt( - source=VALUES["plaintext_128"], key_provider=key_provider, encryption_context=VALUES["encryption_context"] + source=VALUES["plaintext_128"], + key_provider=key_provider, + encryption_context=VALUES["encryption_context"], ) ciphertext_no_seek = NoSeekBytesIO(ciphertext) - decrypted, _header = aws_encryption_sdk.decrypt(source=ciphertext_no_seek, key_provider=key_provider) + decrypted, _header = aws_encryption_sdk.decrypt( + source=ciphertext_no_seek, key_provider=key_provider + ) assert decrypted == VALUES["plaintext_128"] @@ -595,7 +689,9 @@ def test_stream_encryptor_readable(): """Verify that open StreamEncryptor instances report as readable.""" key_provider = fake_kms_key_provider() plaintext = io.BytesIO(VALUES["plaintext_128"]) - with aws_encryption_sdk.StreamEncryptor(source=plaintext, key_provider=key_provider) as handler: + with aws_encryption_sdk.StreamEncryptor( + source=plaintext, key_provider=key_provider + ) as handler: assert handler.readable() handler.read() assert not handler.readable() @@ -605,8 +701,12 @@ def test_stream_decryptor_readable(): """Verify that open StreamEncryptor instances report as readable.""" key_provider = fake_kms_key_provider() plaintext = io.BytesIO(VALUES["plaintext_128"]) - ciphertext, _header = aws_encryption_sdk.encrypt(source=plaintext, key_provider=key_provider) - with aws_encryption_sdk.StreamDecryptor(source=ciphertext, key_provider=key_provider) as handler: + ciphertext, _header = aws_encryption_sdk.encrypt( + source=plaintext, key_provider=key_provider + ) + with aws_encryption_sdk.StreamDecryptor( + source=ciphertext, key_provider=key_provider + ) as handler: assert handler.readable() handler.read() assert not handler.readable() @@ -667,7 +767,9 @@ def test_incomplete_read_stream_cycle(frame_length): decrypted = b"" cycle_count = 0 with aws_encryption_sdk.stream( - mode="decrypt", source=SometimesIncompleteReaderIO(ciphertext), key_provider=key_provider + mode="decrypt", + source=SometimesIncompleteReaderIO(ciphertext), + key_provider=key_provider, ) as decryptor: while True: cycle_count += 1 @@ -715,9 +817,12 @@ def _error_check(capsys_instance): assert "Call stack:" not in stderr -@pytest.mark.parametrize("frame_size", (0, LINE_LENGTH // 2, LINE_LENGTH, LINE_LENGTH * 2)) @pytest.mark.parametrize( - "plaintext_length", (1, LINE_LENGTH // 2, LINE_LENGTH, int(LINE_LENGTH * 1.5), LINE_LENGTH * 2) + "frame_size", (0, LINE_LENGTH // 2, LINE_LENGTH, LINE_LENGTH * 2) +) +@pytest.mark.parametrize( + "plaintext_length", + (1, LINE_LENGTH // 2, LINE_LENGTH, int(LINE_LENGTH * 1.5), LINE_LENGTH * 2), ) def test_plaintext_logs_oneshot(caplog, capsys, plaintext_length, frame_size): plaintext, key_provider = _prep_plaintext_and_logs(caplog, plaintext_length) @@ -730,16 +835,22 @@ def test_plaintext_logs_oneshot(caplog, capsys, plaintext_length, frame_size): _error_check(capsys) -@pytest.mark.parametrize("frame_size", (0, LINE_LENGTH // 2, LINE_LENGTH, LINE_LENGTH * 2)) @pytest.mark.parametrize( - "plaintext_length", (1, LINE_LENGTH // 2, LINE_LENGTH, int(LINE_LENGTH * 1.5), LINE_LENGTH * 2) + "frame_size", (0, LINE_LENGTH // 2, LINE_LENGTH, LINE_LENGTH * 2) +) +@pytest.mark.parametrize( + "plaintext_length", + (1, LINE_LENGTH // 2, LINE_LENGTH, int(LINE_LENGTH * 1.5), LINE_LENGTH * 2), ) def test_plaintext_logs_stream(caplog, capsys, plaintext_length, frame_size): plaintext, key_provider = _prep_plaintext_and_logs(caplog, plaintext_length) ciphertext = b"" with aws_encryption_sdk.stream( - mode="encrypt", source=plaintext, key_provider=key_provider, frame_length=frame_size + mode="encrypt", + source=plaintext, + key_provider=key_provider, + frame_length=frame_size, ) as encryptor: for line in encryptor: ciphertext += line @@ -780,7 +891,9 @@ def test_cycle_minimal_source_stream_api(frame_length, wrapping_class): source=plaintext, key_provider=key_provider, frame_length=frame_length ) ciphertext = wrapping_class(io.BytesIO(raw_ciphertext)) - decrypted, _decrypt_header = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=key_provider) + decrypted, _decrypt_header = aws_encryption_sdk.decrypt( + source=ciphertext, key_provider=key_provider + ) assert raw_plaintext == decrypted @@ -793,7 +906,9 @@ def test_encrypt_minimal_source_stream_api(frame_length, wrapping_class): ciphertext, _encrypt_header = aws_encryption_sdk.encrypt( source=plaintext, key_provider=key_provider, frame_length=frame_length ) - decrypted, _decrypt_header = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=key_provider) + decrypted, _decrypt_header = aws_encryption_sdk.decrypt( + source=ciphertext, key_provider=key_provider + ) assert raw_plaintext == decrypted @@ -806,11 +921,15 @@ def test_decrypt_minimal_source_stream_api(frame_length, wrapping_class): source=plaintext, key_provider=key_provider, frame_length=frame_length ) ciphertext = wrapping_class(io.BytesIO(raw_ciphertext)) - decrypted, _decrypt_header = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=key_provider) + decrypted, _decrypt_header = aws_encryption_sdk.decrypt( + source=ciphertext, key_provider=key_provider + ) assert plaintext == decrypted -def _assert_deprecated_but_not_yet_removed(logcap, instance, attribute_name, error_message, no_later_than): +def _assert_deprecated_but_not_yet_removed( + logcap, instance, attribute_name, error_message, no_later_than +): assert hasattr(instance, attribute_name) assert error_message in logcap.text assert aws_encryption_sdk.__version__ < no_later_than @@ -821,13 +940,19 @@ def _assert_decrypted_and_removed(instance, attribute_name, removed_in): assert aws_encryption_sdk.__version__ >= removed_in -@pytest.mark.parametrize("attribute, no_later_than", (("body_start", "1.4.0"), ("body_end", "1.4.0"))) +@pytest.mark.parametrize( + "attribute, no_later_than", (("body_start", "1.4.0"), ("body_end", "1.4.0")) +) def test_decryptor_deprecated_attributes(caplog, attribute, no_later_than): caplog.set_level(logging.WARNING) plaintext = exact_length_plaintext(100) key_provider = fake_kms_key_provider() - ciphertext, _header = aws_encryption_sdk.encrypt(source=plaintext, key_provider=key_provider, frame_length=0) - with aws_encryption_sdk.stream(mode="decrypt", source=ciphertext, key_provider=key_provider) as decryptor: + ciphertext, _header = aws_encryption_sdk.encrypt( + source=plaintext, key_provider=key_provider, frame_length=0 + ) + with aws_encryption_sdk.stream( + mode="decrypt", source=ciphertext, key_provider=key_provider + ) as decryptor: decrypted = decryptor.read() assert decrypted == plaintext @@ -842,4 +967,6 @@ def test_decryptor_deprecated_attributes(caplog, attribute, no_later_than): no_later_than=no_later_than, ) else: - _assert_decrypted_and_removed(instance=decryptor, attribute_name=attribute, removed_in=no_later_than) + _assert_decrypted_and_removed( + instance=decryptor, attribute_name=attribute, removed_in=no_later_than + ) diff --git a/test/functional/test_f_crypto.py b/test/functional/test_f_crypto.py index 9242deedd..3c90b957d 100644 --- a/test/functional/test_f_crypto.py +++ b/test/functional/test_f_crypto.py @@ -18,7 +18,9 @@ import aws_encryption_sdk from aws_encryption_sdk.internal.crypto.authentication import Signer -from aws_encryption_sdk.internal.crypto.elliptic_curve import _ecc_static_length_signature +from aws_encryption_sdk.internal.crypto.elliptic_curve import ( + _ecc_static_length_signature, +) pytestmark = [pytest.mark.functional, pytest.mark.local] @@ -26,27 +28,46 @@ # Run several of each type to make get a high probability of forcing signature length correction @pytest.mark.parametrize( "algorithm", - [aws_encryption_sdk.Algorithm.AES_128_GCM_IV12_TAG16_HKDF_SHA256_ECDSA_P256 for i in range(10)] - + [aws_encryption_sdk.Algorithm.AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384 for i in range(10)], + [ + aws_encryption_sdk.AlgorithmSuite.AES_128_GCM_IV12_TAG16_HKDF_SHA256_ECDSA_P256 + for i in range(10) + ] + + [ + aws_encryption_sdk.AlgorithmSuite.AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384 + for i in range(10) + ], ) def test_ecc_static_length_signature(algorithm): - private_key = ec.generate_private_key(curve=algorithm.signing_algorithm_info(), backend=default_backend()) + private_key = ec.generate_private_key( + curve=algorithm.signing_algorithm_info(), backend=default_backend() + ) hasher = hashes.Hash(algorithm.signing_hash_type(), backend=default_backend()) data = b"aifuhaw9fe48haw9e8cnavwp9e8fhaw9438fnhjzsudfvhnsa89w74fhp90se8rhgfi" hasher.update(data) digest = hasher.finalize() - signature = _ecc_static_length_signature(key=private_key, algorithm=algorithm, digest=digest) + signature = _ecc_static_length_signature( + key=private_key, algorithm=algorithm, digest=digest + ) assert len(signature) == algorithm.signature_len private_key.public_key().verify( - signature=signature, data=data, signature_algorithm=ec.ECDSA(algorithm.signing_hash_type()) + signature=signature, + data=data, + signature_algorithm=ec.ECDSA(algorithm.signing_hash_type()), ) def test_signer_key_bytes_cycle(): key = ec.generate_private_key(curve=ec.SECP384R1, backend=default_backend()) - signer = Signer(algorithm=aws_encryption_sdk.Algorithm.AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384, key=key) + signer = Signer( + algorithm=aws_encryption_sdk.AlgorithmSuite.AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384, + key=key, + ) key_bytes = signer.key_bytes() new_signer = Signer.from_key_bytes( - algorithm=aws_encryption_sdk.Algorithm.AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384, key_bytes=key_bytes + algorithm=aws_encryption_sdk.AlgorithmSuite.AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384, + key_bytes=key_bytes, + ) + assert ( + new_signer.key.private_numbers().private_value + == signer.key.private_numbers().private_value ) - assert new_signer.key.private_numbers().private_value == signer.key.private_numbers().private_value diff --git a/test/functional/test_f_xcompat.py b/test/functional/test_f_xcompat.py index e87082503..231fe5d26 100644 --- a/test/functional/test_f_xcompat.py +++ b/test/functional/test_f_xcompat.py @@ -63,7 +63,10 @@ def _file_root(): } ), } -_KEY_TYPES_MAP = {b"AES": EncryptionKeyType.SYMMETRIC, b"RSA": EncryptionKeyType.PRIVATE} +_KEY_TYPES_MAP = { + b"AES": EncryptionKeyType.SYMMETRIC, + b"RSA": EncryptionKeyType.PRIVATE, +} _STATIC_KEYS = defaultdict(dict) @@ -75,13 +78,19 @@ class StaticStoredMasterKeyProvider(RawMasterKeyProvider): def _get_raw_key(self, key_id): """Finds a loaded raw key.""" try: - algorithm, key_bits, padding_algorithm, padding_hash = key_id.upper().split(b".", 3) + algorithm, key_bits, padding_algorithm, padding_hash = key_id.upper().split( + b".", 3 + ) key_bits = int(key_bits) key_type = _KEY_TYPES_MAP[algorithm] - wrapping_algorithm = _WRAPPING_ALGORITHM_MAP[algorithm][key_bits][padding_algorithm][padding_hash] + wrapping_algorithm = _WRAPPING_ALGORITHM_MAP[algorithm][key_bits][ + padding_algorithm + ][padding_hash] static_key = _STATIC_KEYS[algorithm][key_bits] return WrappingKey( - wrapping_algorithm=wrapping_algorithm, wrapping_key=static_key, wrapping_key_type=key_type + wrapping_algorithm=wrapping_algorithm, + wrapping_key=static_key, + wrapping_key_type=key_type, ) except KeyError: _LOGGER.exception("Unknown Key ID: %s", key_id) @@ -92,7 +101,9 @@ def _get_raw_key(self, key_id): class RawKeyDescription(object): """Customer raw key descriptor used by StaticStoredMasterKeyProvider.""" - encryption_algorithm = attr.ib(validator=attr.validators.instance_of(six.string_types)) + encryption_algorithm = attr.ib( + validator=attr.validators.instance_of(six.string_types) + ) key_bits = attr.ib(validator=attr.validators.instance_of(int)) padding_algorithm = attr.ib(validator=attr.validators.instance_of(six.string_types)) padding_hash = attr.ib(validator=attr.validators.instance_of(six.string_types)) @@ -100,15 +111,26 @@ class RawKeyDescription(object): @property def key_id(self): """Build a key ID from instance parameters.""" - return ".".join([self.encryption_algorithm, str(self.key_bits), self.padding_algorithm, self.padding_hash]) + return ".".join( + [ + self.encryption_algorithm, + str(self.key_bits), + self.padding_algorithm, + self.padding_hash, + ] + ) @attr.s class Scenario(object): """Scenario details.""" - plaintext_filename = attr.ib(validator=attr.validators.instance_of(six.string_types)) - ciphertext_filename = attr.ib(validator=attr.validators.instance_of(six.string_types)) + plaintext_filename = attr.ib( + validator=attr.validators.instance_of(six.string_types) + ) + ciphertext_filename = attr.ib( + validator=attr.validators.instance_of(six.string_types) + ) key_ids = attr.ib(validator=attr.validators.instance_of(list)) @@ -120,7 +142,9 @@ def _generate_test_cases(): # noqa=C901 if not os.path.isdir(root_dir): root_dir = os.getcwd() base_dir = os.path.join(root_dir, "aws_encryption_sdk_resources") - ciphertext_manifest_path = os.path.join(base_dir, "manifests", "ciphertext.manifest") + ciphertext_manifest_path = os.path.join( + base_dir, "manifests", "ciphertext.manifest" + ) if not os.path.isfile(ciphertext_manifest_path): # Make no test cases if the ciphertext file is not found @@ -147,7 +171,9 @@ def _generate_test_cases(): # noqa=C901 # Collect test cases from ciphertext manifest for test_case in ciphertext_manifest["test_cases"]: key_ids = [] - algorithm = aws_encryption_sdk.Algorithm.get_by_id(int(test_case["algorithm"], 16)) + algorithm = aws_encryption_sdk.AlgorithmSuite.get_by_id( + int(test_case["algorithm"], 16) + ) for key in test_case["master_keys"]: sys.stderr.write("XC:: " + json.dumps(key) + "\n") if key["provider_id"] == StaticStoredMasterKeyProvider.provider_id: @@ -179,5 +205,7 @@ def test_decrypt_from_file(scenario): plaintext = infile.read() key_provider = StaticStoredMasterKeyProvider() key_provider.add_master_keys_from_list(scenario.key_ids) - decrypted_ciphertext, _header = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=key_provider) + decrypted_ciphertext, _header = aws_encryption_sdk.decrypt( + source=ciphertext, key_provider=key_provider + ) assert decrypted_ciphertext == plaintext diff --git a/test/integration/test_i_aws_encrytion_sdk_client.py b/test/integration/test_i_aws_encrytion_sdk_client.py index 56b0536fd..18d541ddf 100644 --- a/test/integration/test_i_aws_encrytion_sdk_client.py +++ b/test/integration/test_i_aws_encrytion_sdk_client.py @@ -18,7 +18,7 @@ from botocore.exceptions import BotoCoreError import aws_encryption_sdk -from aws_encryption_sdk.identifiers import USER_AGENT_SUFFIX, Algorithm +from aws_encryption_sdk.identifiers import USER_AGENT_SUFFIX, AlgorithmSuite from aws_encryption_sdk.key_providers.kms import KMSMasterKey, KMSMasterKeyProvider from .integration_test_utils import get_cmk_arn, setup_kms_master_key_provider @@ -44,7 +44,10 @@ def test_encrypt_verify_user_agent_kms_master_key_provider(caplog): mkp = setup_kms_master_key_provider() mk = mkp.master_key(get_cmk_arn()) - mk.generate_data_key(algorithm=Algorithm.AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384, encryption_context={}) + mk.generate_data_key( + algorithm=AlgorithmSuite.AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384, + encryption_context={}, + ) assert USER_AGENT_SUFFIX in caplog.text @@ -53,7 +56,10 @@ def test_encrypt_verify_user_agent_kms_master_key(caplog): caplog.set_level(level=logging.DEBUG) mk = KMSMasterKey(key_id=get_cmk_arn()) - mk.generate_data_key(algorithm=Algorithm.AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384, encryption_context={}) + mk.generate_data_key( + algorithm=AlgorithmSuite.AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384, + encryption_context={}, + ) assert USER_AGENT_SUFFIX in caplog.text @@ -86,7 +92,9 @@ def test_encryption_cycle_default_algorithm_framed_stream(self): ciphertext = encryptor.read() header_1 = encryptor.header with aws_encryption_sdk.stream( - source=io.BytesIO(ciphertext), key_provider=self.kms_master_key_provider, mode="d" + source=io.BytesIO(ciphertext), + key_provider=self.kms_master_key_provider, + mode="d", ) as decryptor: plaintext = decryptor.read() header_2 = decryptor.header @@ -110,7 +118,9 @@ def test_encryption_cycle_default_algorithm_framed_stream_many_lines(self): header_1 = encryptor.header plaintext = b"" with aws_encryption_sdk.stream( - source=io.BytesIO(ciphertext), key_provider=self.kms_master_key_provider, mode="d" + source=io.BytesIO(ciphertext), + key_provider=self.kms_master_key_provider, + mode="d", ) as decryptor: for chunk in decryptor: plaintext += chunk @@ -128,7 +138,9 @@ def test_encryption_cycle_default_algorithm_non_framed(self): encryption_context=VALUES["encryption_context"], frame_length=0, ) - plaintext, _ = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=self.kms_master_key_provider) + plaintext, _ = aws_encryption_sdk.decrypt( + source=ciphertext, key_provider=self.kms_master_key_provider + ) assert plaintext == VALUES["plaintext_128"] def test_encryption_cycle_default_algorithm_non_framed_no_encryption_context(self): @@ -136,9 +148,13 @@ def test_encryption_cycle_default_algorithm_non_framed_no_encryption_context(sel for a non-framed message using the default algorithm. """ ciphertext, _ = aws_encryption_sdk.encrypt( - source=VALUES["plaintext_128"], key_provider=self.kms_master_key_provider, frame_length=0 + source=VALUES["plaintext_128"], + key_provider=self.kms_master_key_provider, + frame_length=0, + ) + plaintext, _ = aws_encryption_sdk.decrypt( + source=ciphertext, key_provider=self.kms_master_key_provider ) - plaintext, _ = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=self.kms_master_key_provider) assert plaintext == VALUES["plaintext_128"] def test_encryption_cycle_default_algorithm_single_frame(self): @@ -151,7 +167,9 @@ def test_encryption_cycle_default_algorithm_single_frame(self): encryption_context=VALUES["encryption_context"], frame_length=1024, ) - plaintext, _ = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=self.kms_master_key_provider) + plaintext, _ = aws_encryption_sdk.decrypt( + source=ciphertext, key_provider=self.kms_master_key_provider + ) assert plaintext == VALUES["plaintext_128"] def test_encryption_cycle_default_algorithm_multiple_frames(self): @@ -165,7 +183,9 @@ def test_encryption_cycle_default_algorithm_multiple_frames(self): encryption_context=VALUES["encryption_context"], frame_length=1024, ) - plaintext, _ = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=self.kms_master_key_provider) + plaintext, _ = aws_encryption_sdk.decrypt( + source=ciphertext, key_provider=self.kms_master_key_provider + ) assert plaintext == VALUES["plaintext_128"] * 100 def test_encryption_cycle_aes_128_gcm_iv12_tag16_single_frame(self): @@ -178,9 +198,11 @@ def test_encryption_cycle_aes_128_gcm_iv12_tag16_single_frame(self): key_provider=self.kms_master_key_provider, encryption_context=VALUES["encryption_context"], frame_length=1024, - algorithm=Algorithm.AES_128_GCM_IV12_TAG16, + algorithm=AlgorithmSuite.AES_128_GCM_IV12_TAG16, + ) + plaintext, _ = aws_encryption_sdk.decrypt( + source=ciphertext, key_provider=self.kms_master_key_provider ) - plaintext, _ = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=self.kms_master_key_provider) assert plaintext == VALUES["plaintext_128"] def test_encryption_cycle_aes_128_gcm_iv12_tag16_non_framed(self): @@ -193,9 +215,11 @@ def test_encryption_cycle_aes_128_gcm_iv12_tag16_non_framed(self): key_provider=self.kms_master_key_provider, encryption_context=VALUES["encryption_context"], frame_length=0, - algorithm=Algorithm.AES_128_GCM_IV12_TAG16, + algorithm=AlgorithmSuite.AES_128_GCM_IV12_TAG16, + ) + plaintext, _ = aws_encryption_sdk.decrypt( + source=ciphertext, key_provider=self.kms_master_key_provider ) - plaintext, _ = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=self.kms_master_key_provider) assert plaintext == VALUES["plaintext_128"] def test_encryption_cycle_aes_192_gcm_iv12_tag16_single_frame(self): @@ -208,9 +232,11 @@ def test_encryption_cycle_aes_192_gcm_iv12_tag16_single_frame(self): key_provider=self.kms_master_key_provider, encryption_context=VALUES["encryption_context"], frame_length=1024, - algorithm=Algorithm.AES_192_GCM_IV12_TAG16, + algorithm=AlgorithmSuite.AES_192_GCM_IV12_TAG16, + ) + plaintext, _ = aws_encryption_sdk.decrypt( + source=ciphertext, key_provider=self.kms_master_key_provider ) - plaintext, _ = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=self.kms_master_key_provider) assert plaintext == VALUES["plaintext_128"] def test_encryption_cycle_aes_192_gcm_iv12_tag16_non_framed(self): @@ -223,9 +249,11 @@ def test_encryption_cycle_aes_192_gcm_iv12_tag16_non_framed(self): key_provider=self.kms_master_key_provider, encryption_context=VALUES["encryption_context"], frame_length=0, - algorithm=Algorithm.AES_192_GCM_IV12_TAG16, + algorithm=AlgorithmSuite.AES_192_GCM_IV12_TAG16, + ) + plaintext, _ = aws_encryption_sdk.decrypt( + source=ciphertext, key_provider=self.kms_master_key_provider ) - plaintext, _ = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=self.kms_master_key_provider) assert plaintext == VALUES["plaintext_128"] def test_encryption_cycle_aes_256_gcm_iv12_tag16_single_frame(self): @@ -238,9 +266,11 @@ def test_encryption_cycle_aes_256_gcm_iv12_tag16_single_frame(self): key_provider=self.kms_master_key_provider, encryption_context=VALUES["encryption_context"], frame_length=1024, - algorithm=Algorithm.AES_256_GCM_IV12_TAG16, + algorithm=AlgorithmSuite.AES_256_GCM_IV12_TAG16, + ) + plaintext, _ = aws_encryption_sdk.decrypt( + source=ciphertext, key_provider=self.kms_master_key_provider ) - plaintext, _ = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=self.kms_master_key_provider) assert plaintext == VALUES["plaintext_128"] def test_encryption_cycle_aes_256_gcm_iv12_tag16_non_framed(self): @@ -253,9 +283,11 @@ def test_encryption_cycle_aes_256_gcm_iv12_tag16_non_framed(self): key_provider=self.kms_master_key_provider, encryption_context=VALUES["encryption_context"], frame_length=0, - algorithm=Algorithm.AES_256_GCM_IV12_TAG16, + algorithm=AlgorithmSuite.AES_256_GCM_IV12_TAG16, + ) + plaintext, _ = aws_encryption_sdk.decrypt( + source=ciphertext, key_provider=self.kms_master_key_provider ) - plaintext, _ = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=self.kms_master_key_provider) assert plaintext == VALUES["plaintext_128"] def test_encryption_cycle_aes_128_gcm_iv12_tag16_hkdf_sha256_single_frame(self): @@ -268,9 +300,11 @@ def test_encryption_cycle_aes_128_gcm_iv12_tag16_hkdf_sha256_single_frame(self): key_provider=self.kms_master_key_provider, encryption_context=VALUES["encryption_context"], frame_length=1024, - algorithm=Algorithm.AES_128_GCM_IV12_TAG16_HKDF_SHA256, + algorithm=AlgorithmSuite.AES_128_GCM_IV12_TAG16_HKDF_SHA256, + ) + plaintext, _ = aws_encryption_sdk.decrypt( + source=ciphertext, key_provider=self.kms_master_key_provider ) - plaintext, _ = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=self.kms_master_key_provider) assert plaintext == VALUES["plaintext_128"] def test_encryption_cycle_aes_128_gcm_iv12_tag16_hkdf_sha256_non_framed(self): @@ -283,9 +317,11 @@ def test_encryption_cycle_aes_128_gcm_iv12_tag16_hkdf_sha256_non_framed(self): key_provider=self.kms_master_key_provider, encryption_context=VALUES["encryption_context"], frame_length=0, - algorithm=Algorithm.AES_128_GCM_IV12_TAG16_HKDF_SHA256, + algorithm=AlgorithmSuite.AES_128_GCM_IV12_TAG16_HKDF_SHA256, + ) + plaintext, _ = aws_encryption_sdk.decrypt( + source=ciphertext, key_provider=self.kms_master_key_provider ) - plaintext, _ = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=self.kms_master_key_provider) assert plaintext == VALUES["plaintext_128"] def test_encryption_cycle_aes_192_gcm_iv12_tag16_hkdf_sha256_single_frame(self): @@ -298,9 +334,11 @@ def test_encryption_cycle_aes_192_gcm_iv12_tag16_hkdf_sha256_single_frame(self): key_provider=self.kms_master_key_provider, encryption_context=VALUES["encryption_context"], frame_length=1024, - algorithm=Algorithm.AES_192_GCM_IV12_TAG16_HKDF_SHA256, + algorithm=AlgorithmSuite.AES_192_GCM_IV12_TAG16_HKDF_SHA256, + ) + plaintext, _ = aws_encryption_sdk.decrypt( + source=ciphertext, key_provider=self.kms_master_key_provider ) - plaintext, _ = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=self.kms_master_key_provider) assert plaintext == VALUES["plaintext_128"] def test_encryption_cycle_aes_192_gcm_iv12_tag16_hkdf_sha256_non_framed(self): @@ -313,9 +351,11 @@ def test_encryption_cycle_aes_192_gcm_iv12_tag16_hkdf_sha256_non_framed(self): key_provider=self.kms_master_key_provider, encryption_context=VALUES["encryption_context"], frame_length=0, - algorithm=Algorithm.AES_192_GCM_IV12_TAG16_HKDF_SHA256, + algorithm=AlgorithmSuite.AES_192_GCM_IV12_TAG16_HKDF_SHA256, + ) + plaintext, _ = aws_encryption_sdk.decrypt( + source=ciphertext, key_provider=self.kms_master_key_provider ) - plaintext, _ = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=self.kms_master_key_provider) assert plaintext == VALUES["plaintext_128"] def test_encryption_cycle_aes_256_gcm_iv12_tag16_hkdf_sha256_single_frame(self): @@ -328,9 +368,11 @@ def test_encryption_cycle_aes_256_gcm_iv12_tag16_hkdf_sha256_single_frame(self): key_provider=self.kms_master_key_provider, encryption_context=VALUES["encryption_context"], frame_length=1024, - algorithm=Algorithm.AES_256_GCM_IV12_TAG16_HKDF_SHA256, + algorithm=AlgorithmSuite.AES_256_GCM_IV12_TAG16_HKDF_SHA256, + ) + plaintext, _ = aws_encryption_sdk.decrypt( + source=ciphertext, key_provider=self.kms_master_key_provider ) - plaintext, _ = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=self.kms_master_key_provider) assert plaintext == VALUES["plaintext_128"] def test_encryption_cycle_aes_256_gcm_iv12_tag16_hkdf_sha256_non_framed(self): @@ -343,12 +385,16 @@ def test_encryption_cycle_aes_256_gcm_iv12_tag16_hkdf_sha256_non_framed(self): key_provider=self.kms_master_key_provider, encryption_context=VALUES["encryption_context"], frame_length=0, - algorithm=Algorithm.AES_256_GCM_IV12_TAG16_HKDF_SHA256, + algorithm=AlgorithmSuite.AES_256_GCM_IV12_TAG16_HKDF_SHA256, + ) + plaintext, _ = aws_encryption_sdk.decrypt( + source=ciphertext, key_provider=self.kms_master_key_provider ) - plaintext, _ = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=self.kms_master_key_provider) assert plaintext == VALUES["plaintext_128"] - def test_encryption_cycle_aes_128_gcm_iv12_tag16_hkdf_sha256_ecdsa_p256_single_frame(self): + def test_encryption_cycle_aes_128_gcm_iv12_tag16_hkdf_sha256_ecdsa_p256_single_frame( + self + ): """Test that the enrypt/decrypt cycle completes successfully for a single frame message using the aes_128_gcm_iv12_tag16_hkdf_sha256_ecdsa_p256 algorithm. @@ -358,12 +404,16 @@ def test_encryption_cycle_aes_128_gcm_iv12_tag16_hkdf_sha256_ecdsa_p256_single_f key_provider=self.kms_master_key_provider, encryption_context=VALUES["encryption_context"], frame_length=1024, - algorithm=Algorithm.AES_128_GCM_IV12_TAG16_HKDF_SHA256_ECDSA_P256, + algorithm=AlgorithmSuite.AES_128_GCM_IV12_TAG16_HKDF_SHA256_ECDSA_P256, + ) + plaintext, _ = aws_encryption_sdk.decrypt( + source=ciphertext, key_provider=self.kms_master_key_provider ) - plaintext, _ = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=self.kms_master_key_provider) assert plaintext == VALUES["plaintext_128"] - def test_encryption_cycle_aes_128_gcm_iv12_tag16_hkdf_sha256_ecdsa_p256_non_framed(self): + def test_encryption_cycle_aes_128_gcm_iv12_tag16_hkdf_sha256_ecdsa_p256_non_framed( + self + ): """Test that the enrypt/decrypt cycle completes successfully for a single block message using the aes_128_gcm_iv12_tag16_hkdf_sha256_ecdsa_p256 algorithm. @@ -373,12 +423,16 @@ def test_encryption_cycle_aes_128_gcm_iv12_tag16_hkdf_sha256_ecdsa_p256_non_fram key_provider=self.kms_master_key_provider, encryption_context=VALUES["encryption_context"], frame_length=0, - algorithm=Algorithm.AES_128_GCM_IV12_TAG16_HKDF_SHA256_ECDSA_P256, + algorithm=AlgorithmSuite.AES_128_GCM_IV12_TAG16_HKDF_SHA256_ECDSA_P256, + ) + plaintext, _ = aws_encryption_sdk.decrypt( + source=ciphertext, key_provider=self.kms_master_key_provider ) - plaintext, _ = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=self.kms_master_key_provider) assert plaintext == VALUES["plaintext_128"] - def test_encryption_cycle_aes_192_gcm_iv12_tag16_hkdf_sha384_ecdsa_p384_single_frame(self): + def test_encryption_cycle_aes_192_gcm_iv12_tag16_hkdf_sha384_ecdsa_p384_single_frame( + self + ): """Test that the enrypt/decrypt cycle completes successfully for a single frame message using the aes_192_gcm_iv12_tag16_hkdf_sha384_ecdsa_p384 algorithm. @@ -388,12 +442,16 @@ def test_encryption_cycle_aes_192_gcm_iv12_tag16_hkdf_sha384_ecdsa_p384_single_f key_provider=self.kms_master_key_provider, encryption_context=VALUES["encryption_context"], frame_length=1024, - algorithm=Algorithm.AES_192_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384, + algorithm=AlgorithmSuite.AES_192_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384, + ) + plaintext, _ = aws_encryption_sdk.decrypt( + source=ciphertext, key_provider=self.kms_master_key_provider ) - plaintext, _ = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=self.kms_master_key_provider) assert plaintext == VALUES["plaintext_128"] - def test_encryption_cycle_aes_192_gcm_iv12_tag16_hkdf_sha384_ecdsa_p384_non_framed(self): + def test_encryption_cycle_aes_192_gcm_iv12_tag16_hkdf_sha384_ecdsa_p384_non_framed( + self + ): """Test that the enrypt/decrypt cycle completes successfully for a single block message using the aes_192_gcm_iv12_tag16_hkdf_sha384_ecdsa_p384 algorithm. @@ -403,12 +461,16 @@ def test_encryption_cycle_aes_192_gcm_iv12_tag16_hkdf_sha384_ecdsa_p384_non_fram key_provider=self.kms_master_key_provider, encryption_context=VALUES["encryption_context"], frame_length=0, - algorithm=Algorithm.AES_192_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384, + algorithm=AlgorithmSuite.AES_192_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384, + ) + plaintext, _ = aws_encryption_sdk.decrypt( + source=ciphertext, key_provider=self.kms_master_key_provider ) - plaintext, _ = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=self.kms_master_key_provider) assert plaintext == VALUES["plaintext_128"] - def test_encryption_cycle_aes_256_gcm_iv12_tag16_hkdf_sha384_ecdsa_p384_single_frame(self): + def test_encryption_cycle_aes_256_gcm_iv12_tag16_hkdf_sha384_ecdsa_p384_single_frame( + self + ): """Test that the enrypt/decrypt cycle completes successfully for a single frame message using the aes_256_gcm_iv12_tag16_hkdf_sha384_ecdsa_p384 algorithm. @@ -418,12 +480,16 @@ def test_encryption_cycle_aes_256_gcm_iv12_tag16_hkdf_sha384_ecdsa_p384_single_f key_provider=self.kms_master_key_provider, encryption_context=VALUES["encryption_context"], frame_length=1024, - algorithm=Algorithm.AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384, + algorithm=AlgorithmSuite.AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384, + ) + plaintext, _ = aws_encryption_sdk.decrypt( + source=ciphertext, key_provider=self.kms_master_key_provider ) - plaintext, _ = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=self.kms_master_key_provider) assert plaintext == VALUES["plaintext_128"] - def test_encryption_cycle_aes_256_gcm_iv12_tag16_hkdf_sha384_ecdsa_p384_non_framed(self): + def test_encryption_cycle_aes_256_gcm_iv12_tag16_hkdf_sha384_ecdsa_p384_non_framed( + self + ): """Test that the enrypt/decrypt cycle completes successfully for a single block message using the aes_256_gcm_iv12_tag16_hkdf_sha384_ecdsa_p384 algorithm. @@ -433,7 +499,9 @@ def test_encryption_cycle_aes_256_gcm_iv12_tag16_hkdf_sha384_ecdsa_p384_non_fram key_provider=self.kms_master_key_provider, encryption_context=VALUES["encryption_context"], frame_length=0, - algorithm=Algorithm.AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384, + algorithm=AlgorithmSuite.AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384, + ) + plaintext, _ = aws_encryption_sdk.decrypt( + source=ciphertext, key_provider=self.kms_master_key_provider ) - plaintext, _ = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=self.kms_master_key_provider) assert plaintext == VALUES["plaintext_128"] diff --git a/test/integration/test_i_xcompat_kms.py b/test/integration/test_i_xcompat_kms.py index b6f22a3ff..aef83716d 100644 --- a/test/integration/test_i_xcompat_kms.py +++ b/test/integration/test_i_xcompat_kms.py @@ -42,7 +42,9 @@ def _generate_test_cases(): if not os.path.isdir(root_dir): root_dir = os.getcwd() base_dir = os.path.join(root_dir, "aws_encryption_sdk_resources") - ciphertext_manifest_path = os.path.join(base_dir, "manifests", "ciphertext.manifest") + ciphertext_manifest_path = os.path.join( + base_dir, "manifests", "ciphertext.manifest" + ) if not os.path.isfile(ciphertext_manifest_path): # Make no test cases if the ciphertext file is not found @@ -66,7 +68,9 @@ def _generate_test_cases(): return _test_cases -@pytest.mark.parametrize("plaintext_filename, ciphertext_filename", _generate_test_cases()) +@pytest.mark.parametrize( + "plaintext_filename, ciphertext_filename", _generate_test_cases() +) def test_decrypt_from_file(plaintext_filename, ciphertext_filename): """Tests decrypt from known good files.""" with open(ciphertext_filename, "rb") as infile: diff --git a/tox.ini b/tox.ini index 06564ef6a..47bf379d5 100644 --- a/tox.ini +++ b/tox.ini @@ -252,7 +252,7 @@ commands = python setup.py check -r -s [testenv:bandit] basepython = python3 -deps = +deps = bandit>=1.5.1 commands = bandit -r src/aws_encryption_sdk/ From 25e8fd35653b3995432321be1355db50aa1bd76a Mon Sep 17 00:00:00 2001 From: Adriano Hernandez Date: Mon, 5 Aug 2019 22:43:41 -0700 Subject: [PATCH 2/6] Went through and fixed all instances where Algorithm was referenced instead of AlgorithmSuite. Changed messages in the docs accordingly. Now if you go to RTD you will see everything links together nicely. What I did not do: change the variable names from "algorithm" to i.e. "algorithm_suite". Mainly because this is a rabbit hole and not immediately important I think. To do it you need to change every single thing that calls another thing that calls another thing that has "algorithm" instead of "algorithm_suite" and I wanted to keep it simple. I also found what I think was an incorrect test. In /test/integration/test_i_aws_encryption_sdk_client.py in function test_remove_bad_client() there was an assertion error for my tox env tests that the dict containing the regional clients was not empty when it should have been supposedly. I believe that it was not meant for it to be empty, but for the bad client "us-fakey-12" to be removed, and it was assumed that in __init__ for KMSMasterKeyProvider() no regional clients were added, but this is false because there is a default added depending on some features of the botocore session etc... So when it was checking to see if it was empty it wanted to see that a bad client was added, and then was rightfully removed -> back to an empty dict, but if the dict did not start out empty the test would fail. So I made a one line change of this test to make it test that specifically the dict did not contain the exact bad client that the test was using. Please get back to me on this final change because you guys know this SDK way better than me and I don't want to break anything. Thanks! --- src/aws_encryption_sdk/__init__.py | 2 +- .../internal/crypto/authentication.py | 55 +++-- .../internal/crypto/data_keys.py | 6 +- .../internal/crypto/elliptic_curve.py | 39 +++- .../internal/crypto/encryption.py | 28 ++- src/aws_encryption_sdk/internal/crypto/iv.py | 12 +- src/aws_encryption_sdk/internal/defaults.py | 4 +- .../internal/formatting/serialize.py | 73 ++++-- .../internal/utils/__init__.py | 46 ++-- src/aws_encryption_sdk/key_providers/base.py | 107 ++++++--- src/aws_encryption_sdk/key_providers/kms.py | 84 +++++-- src/aws_encryption_sdk/key_providers/raw.py | 45 ++-- .../materials_managers/__init__.py | 41 ++-- .../materials_managers/default.py | 57 +++-- src/aws_encryption_sdk/streaming_client.py | 219 +++++++++++++----- src/aws_encryption_sdk/structures.py | 82 +++++-- .../test_i_aws_encrytion_sdk_client.py | 11 +- test/unit/test_crypto_elliptic_curve.py | 165 ++++++++++--- 18 files changed, 791 insertions(+), 285 deletions(-) diff --git a/src/aws_encryption_sdk/__init__.py b/src/aws_encryption_sdk/__init__.py index 60fa8c570..85e02884b 100644 --- a/src/aws_encryption_sdk/__init__.py +++ b/src/aws_encryption_sdk/__init__.py @@ -76,7 +76,7 @@ def encrypt(**kwargs): this is not enforced if a `key_provider` is provided. :param dict encryption_context: Dictionary defining encryption context - :param algorithm: AlgorithmSuite to use for encryption + :param algorithm: Algorithm suite to use for encryption :type algorithm: aws_encryption_sdk.identifiers.AlgorithmSuite :param int frame_length: Frame length in bytes :returns: Tuple containing the encrypted ciphertext and the message header object diff --git a/src/aws_encryption_sdk/internal/crypto/authentication.py b/src/aws_encryption_sdk/internal/crypto/authentication.py index 560fdb2a2..67b938d6e 100644 --- a/src/aws_encryption_sdk/internal/crypto/authentication.py +++ b/src/aws_encryption_sdk/internal/crypto/authentication.py @@ -33,8 +33,8 @@ class _PrehashingAuthenticator(object): """Parent class for Signer/Verifier. Provides common behavior and interface. - :param algorithm: Algorithm on which to base authenticator - :type algorithm: aws_encryption_sdk.identifiers.Algorithm + :param algorithm: Algorithm suite on which to base authenticator + :type algorithm: aws_encryption_sdk.identifiers.AlgorithmSuite :param key: Key with which to build authenticator """ @@ -46,7 +46,7 @@ def __init__(self, algorithm, key): self._hasher = self._build_hasher() def _set_signature_type(self): - """Ensures that the algorithm signature type is a known type and sets a reference value.""" + """Ensures that the algorithm (suite) signature type is a known type and sets a reference value.""" try: verify_interface(ec.EllipticCurve, self.algorithm.signing_algorithm_info) return ec.EllipticCurve @@ -58,14 +58,16 @@ def _build_hasher(self): :returns: Hasher object """ - return hashes.Hash(self.algorithm.signing_hash_type(), backend=default_backend()) + return hashes.Hash( + self.algorithm.signing_hash_type(), backend=default_backend() + ) class Signer(_PrehashingAuthenticator): """Abstract signing handler. - :param algorithm: Algorithm on which to base signer - :type algorithm: aws_encryption_sdk.identifiers.Algorithm + :param algorithm: Algorithm suite on which to base signer + :type algorithm: aws_encryption_sdk.identifiers.AlgorithmSuite :param key: Private key from which a signer can be generated :type key: currently only Elliptic Curve Private Keys are supported """ @@ -74,12 +76,14 @@ class Signer(_PrehashingAuthenticator): def from_key_bytes(cls, algorithm, key_bytes): """Builds a `Signer` from an algorithm suite and a raw signing key. - :param algorithm: Algorithm on which to base signer - :type algorithm: aws_encryption_sdk.identifiers.Algorithm + :param algorithm: Algorithm suite on which to base signer + :type algorithm: aws_encryption_sdk.identifiers.AlgorithmSuite :param bytes key_bytes: Raw signing key :rtype: aws_encryption_sdk.internal.crypto.Signer """ - key = serialization.load_der_private_key(data=key_bytes, password=None, backend=default_backend()) + key = serialization.load_der_private_key( + data=key_bytes, password=None, backend=default_backend() + ) return cls(algorithm, key) def key_bytes(self): @@ -118,7 +122,9 @@ def finalize(self): :rtype: bytes """ prehashed_digest = self._hasher.finalize() - return _ecc_static_length_signature(key=self.key, algorithm=self.algorithm, digest=prehashed_digest) + return _ecc_static_length_signature( + key=self.key, algorithm=self.algorithm, digest=prehashed_digest + ) class Verifier(_PrehashingAuthenticator): @@ -127,18 +133,18 @@ class Verifier(_PrehashingAuthenticator): .. note:: For ECC curves, the signature must be DER encoded as specified in RFC 3279. - :param algorithm: Algorithm on which to base verifier - :type algorithm: aws_encryption_sdk.identifiers.Algorithm - :param public_key: Appropriate public key object for algorithm + :param algorithm: Algorithm suite on which to base verifier + :type algorithm: aws_encryption_sdk.identifiers.AlgorithmSuite + :param public_key: Appropriate public key object for algorithm suite :type public_key: may vary """ @classmethod def from_encoded_point(cls, algorithm, encoded_point): - """Creates a Verifier object based on the supplied algorithm and encoded compressed ECC curve point. + """Creates a Verifier object based on the supplied algorithm suite and encoded compressed ECC curve point. - :param algorithm: Algorithm on which to base verifier - :type algorithm: aws_encryption_sdk.identifiers.Algorithm + :param algorithm: Algorithm suite on which to base verifier + :type algorithm: aws_encryption_sdk.identifiers.AlgorithmSuite :param bytes encoded_point: ECC public point compressed and encoded with _ecc_encode_compressed_point :returns: Instance of Verifier generated from encoded point :rtype: aws_encryption_sdk.internal.crypto.Verifier @@ -146,22 +152,26 @@ def from_encoded_point(cls, algorithm, encoded_point): return cls( algorithm=algorithm, key=_ecc_public_numbers_from_compressed_point( - curve=algorithm.signing_algorithm_info(), compressed_point=base64.b64decode(encoded_point) + curve=algorithm.signing_algorithm_info(), + compressed_point=base64.b64decode(encoded_point), ).public_key(default_backend()), ) @classmethod def from_key_bytes(cls, algorithm, key_bytes): - """Creates a `Verifier` object based on the supplied algorithm and raw verification key. + """Creates a `Verifier` object based on the supplied algorithm suite and raw verification key. - :param algorithm: Algorithm on which to base verifier - :type algorithm: aws_encryption_sdk.identifiers.Algorithm + :param algorithm: Algorithm suite on which to base verifier + :type algorithm: aws_encryption_sdk.identifiers.AlgorithmSuite :param bytes encoded_point: Raw verification key :returns: Instance of Verifier generated from encoded point :rtype: aws_encryption_sdk.internal.crypto.Verifier """ return cls( - algorithm=algorithm, key=serialization.load_der_public_key(data=key_bytes, backend=default_backend()) + algorithm=algorithm, + key=serialization.load_der_public_key( + data=key_bytes, backend=default_backend() + ), ) def key_bytes(self): @@ -170,7 +180,8 @@ def key_bytes(self): :rtype: bytes """ return self.key.public_bytes( - encoding=serialization.Encoding.DER, format=serialization.PublicFormat.SubjectPublicKeyInfo + encoding=serialization.Encoding.DER, + format=serialization.PublicFormat.SubjectPublicKeyInfo, ) def update(self, data): diff --git a/src/aws_encryption_sdk/internal/crypto/data_keys.py b/src/aws_encryption_sdk/internal/crypto/data_keys.py index f16873106..4d773f9ec 100644 --- a/src/aws_encryption_sdk/internal/crypto/data_keys.py +++ b/src/aws_encryption_sdk/internal/crypto/data_keys.py @@ -20,11 +20,11 @@ def derive_data_encryption_key(source_key, algorithm, message_id): - """Derives the data encryption key using the defined algorithm. + """Derives the data encryption key using the defined algorithm suite. :param bytes source_key: Raw source key - :param algorithm: Algorithm used to encrypt this body - :type algorithm: aws_encryption_sdk.identifiers.Algorithm + :param algorithm: Algorithm suite used to encrypt this body + :type algorithm: aws_encryption_sdk.identifiers.AlgorithmSuite :param bytes message_id: Message ID :returns: Derived data encryption key :rtype: bytes diff --git a/src/aws_encryption_sdk/internal/crypto/elliptic_curve.py b/src/aws_encryption_sdk/internal/crypto/elliptic_curve.py index 47af50b8c..b35ae1c22 100644 --- a/src/aws_encryption_sdk/internal/crypto/elliptic_curve.py +++ b/src/aws_encryption_sdk/internal/crypto/elliptic_curve.py @@ -17,8 +17,17 @@ import six from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.asymmetric import ec -from cryptography.hazmat.primitives.asymmetric.utils import Prehashed, decode_dss_signature, encode_dss_signature -from cryptography.utils import InterfaceNotImplemented, int_from_bytes, int_to_bytes, verify_interface +from cryptography.hazmat.primitives.asymmetric.utils import ( + Prehashed, + decode_dss_signature, + encode_dss_signature, +) +from cryptography.utils import ( + InterfaceNotImplemented, + int_from_bytes, + int_to_bytes, + verify_interface, +) from ...exceptions import NotSupportedError from ..str_ops import to_bytes @@ -57,8 +66,8 @@ def _ecc_static_length_signature(key, algorithm, digest): :param key: Elliptic curve private key :type key: cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePrivateKey - :param algorithm: Master algorithm to use - :type algorithm: aws_encryption_sdk.identifiers.Algorithm + :param algorithm: Master algorithm suite to use + :type algorithm: aws_encryption_sdk.identifiers.AlgorithmSuite :param bytes digest: Pre-calculated hash digest :returns: Signature with required length :rtype: bytes @@ -67,14 +76,18 @@ def _ecc_static_length_signature(key, algorithm, digest): signature = b"" while len(signature) != algorithm.signature_len: _LOGGER.debug( - "Signature length %d is not desired length %d. Recalculating.", len(signature), algorithm.signature_len + "Signature length %d is not desired length %d. Recalculating.", + len(signature), + algorithm.signature_len, ) signature = key.sign(digest, pre_hashed_algorithm) if len(signature) != algorithm.signature_len: # Most of the time, a signature of the wrong length can be fixed # by negating s in the signature relative to the group order. _LOGGER.debug( - "Signature length %d is not desired length %d. Negating s.", len(signature), algorithm.signature_len + "Signature length %d is not desired length %d. Negating s.", + len(signature), + algorithm.signature_len, ) r, s = decode_dss_signature(signature) s = _ECC_CURVE_PARAMETERS[algorithm.signing_algorithm_info.name].order - s @@ -136,7 +149,9 @@ def _ecc_decode_compressed_point(curve, compressed_point): try: params = _ECC_CURVE_PARAMETERS[curve.name] except KeyError: - raise NotSupportedError("Curve {name} is not supported at this time".format(name=curve.name)) + raise NotSupportedError( + "Curve {name} is not supported at this time".format(name=curve.name) + ) alpha = (pow(x, 3, params.p) + (params.a * x % params.p) + params.b) % params.p # Only works for p % 4 == 3 at this time. # This is the case for all currently supported algorithms. @@ -177,13 +192,15 @@ def _ecc_public_numbers_from_compressed_point(curve, compressed_point): def generate_ecc_signing_key(algorithm): """Returns an ECC signing key. - :param algorithm: Algorithm object which determines what signature to generate - :type algorithm: aws_encryption_sdk.identifiers.Algorithm + :param algorithm: Algorithm suite object which determines what signature to generate + :type algorithm: aws_encryption_sdk.identifiers.AlgorithmSuite :returns: Generated signing key - :raises NotSupportedError: if signing algorithm is not supported on this platform + :raises NotSupportedError: if signing algorithm suite is not supported on this platform """ try: verify_interface(ec.EllipticCurve, algorithm.signing_algorithm_info) - return ec.generate_private_key(curve=algorithm.signing_algorithm_info(), backend=default_backend()) + return ec.generate_private_key( + curve=algorithm.signing_algorithm_info(), backend=default_backend() + ) except InterfaceNotImplemented: raise NotSupportedError("Unsupported signing algorithm info") diff --git a/src/aws_encryption_sdk/internal/crypto/encryption.py b/src/aws_encryption_sdk/internal/crypto/encryption.py index 1e5523826..f549b483c 100644 --- a/src/aws_encryption_sdk/internal/crypto/encryption.py +++ b/src/aws_encryption_sdk/internal/crypto/encryption.py @@ -24,8 +24,8 @@ class Encryptor(object): """Abstract encryption handler. - :param algorithm: Algorithm used to encrypt this body - :type algorithm: aws_encryption_sdk.identifiers.Algorithm + :param algorithm: Algorithm suite used to encrypt this body + :type algorithm: aws_encryption_sdk.identifiers.AlgorithmSuite :param bytes key: Encryption key :param bytes associated_data: Associated Data to send to encryption subsystem :param bytes iv: IV to use when encrypting message @@ -39,7 +39,9 @@ def __init__(self, algorithm, key, associated_data, iv): # This is intentionally generic to leave an option for non-Cipher encryptor types in the future. self.iv = iv self._encryptor = Cipher( - algorithm.encryption_algorithm(key), algorithm.encryption_mode(self.iv), backend=default_backend() + algorithm.encryption_algorithm(key), + algorithm.encryption_mode(self.iv), + backend=default_backend(), ).encryptor() # associated_data will be authenticated but not encrypted, @@ -76,8 +78,8 @@ def tag(self): def encrypt(algorithm, key, plaintext, associated_data, iv): """Encrypts a frame body. - :param algorithm: Algorithm used to encrypt this body - :type algorithm: aws_encryption_sdk.identifiers.Algorithm + :param algorithm: Algorithm suite used to encrypt this body + :type algorithm: aws_encryption_sdk.identifiers.AlgorithmSuite :param bytes key: Encryption key :param bytes plaintext: Body plaintext :param bytes associated_data: Body AAD Data @@ -93,8 +95,8 @@ def encrypt(algorithm, key, plaintext, associated_data, iv): class Decryptor(object): """Abstract decryption handler. - :param algorithm: Algorithm used to encrypt this body - :type algorithm: aws_encryption_sdk.identifiers.Algorithm + :param algorithm: Algorithm suite used to encrypt this body + :type algorithm: aws_encryption_sdk.identifiers.AlgorithmSuite :param bytes key: Raw source key :param bytes associated_data: Associated Data to send to decryption subsystem :param bytes iv: IV value with which to initialize decryption subsystem @@ -108,7 +110,9 @@ def __init__(self, algorithm, key, associated_data, iv, tag): # Construct a decryptor object with the given key and a provided IV. # This is intentionally generic to leave an option for non-Cipher decryptor types in the future. self._decryptor = Cipher( - algorithm.encryption_algorithm(key), algorithm.encryption_mode(iv, tag), backend=default_backend() + algorithm.encryption_algorithm(key), + algorithm.encryption_mode(iv, tag), + backend=default_backend(), ).decryptor() # Put associated_data back in or the tag will fail to verify when the _decryptor is finalized. @@ -135,8 +139,8 @@ def finalize(self): def decrypt(algorithm, key, encrypted_data, associated_data): """Decrypts a frame body. - :param algorithm: Algorithm used to encrypt this body - :type algorithm: aws_encryption_sdk.identifiers.Algorithm + :param algorithm: Algorithm suite used to encrypt this body + :type algorithm: aws_encryption_sdk.identifiers.AlgorithmSuite :param bytes key: Plaintext data key :param encrypted_data: EncryptedData containing body data :type encrypted_data: :class:`aws_encryption_sdk.internal.structures.EncryptedData`, @@ -147,5 +151,7 @@ def decrypt(algorithm, key, encrypted_data, associated_data): :returns: Plaintext of body :rtype: bytes """ - decryptor = Decryptor(algorithm, key, associated_data, encrypted_data.iv, encrypted_data.tag) + decryptor = Decryptor( + algorithm, key, associated_data, encrypted_data.iv, encrypted_data.tag + ) return decryptor.update(encrypted_data.ciphertext) + decryptor.finalize() diff --git a/src/aws_encryption_sdk/internal/crypto/iv.py b/src/aws_encryption_sdk/internal/crypto/iv.py index e5424057b..d6515df7c 100644 --- a/src/aws_encryption_sdk/internal/crypto/iv.py +++ b/src/aws_encryption_sdk/internal/crypto/iv.py @@ -46,8 +46,8 @@ def frame_iv(algorithm, sequence_number): """Builds the deterministic IV for a body frame. - :param algorithm: Algorithm for which to build IV - :type algorithm: aws_encryption_sdk.identifiers.Algorithm + :param algorithm: Algorithm suite for which to build IV + :type algorithm: aws_encryption_sdk.identifiers.AlgorithmSuite :param int sequence_number: Frame sequence number :returns: Generated IV :rtype: bytes @@ -67,8 +67,8 @@ def frame_iv(algorithm, sequence_number): def non_framed_body_iv(algorithm): """Builds the deterministic IV for a non-framed body. - :param algorithm: Algorithm for which to build IV - :type algorithm: aws_encryption_sdk.identifiers.Algorithm + :param algorithm: Algorithm suite for which to build IV + :type algorithm: aws_encryption_sdk.identifiers.AlgorithmSuite :returns: Generated IV :rtype: bytes """ @@ -78,8 +78,8 @@ def non_framed_body_iv(algorithm): def header_auth_iv(algorithm): """Builds the deterministic IV for header authentication. - :param algorithm: Algorithm for which to build IV - :type algorithm: aws_encryption_sdk.identifiers.Algorithm + :param algorithm: Algorithm suite for which to build IV + :type algorithm: aws_encryption_sdk.identifiers.AlgorithmSuite :returns: Generated IV :rtype: bytes """ diff --git a/src/aws_encryption_sdk/internal/defaults.py b/src/aws_encryption_sdk/internal/defaults.py index f63a42d61..5e4c905bd 100644 --- a/src/aws_encryption_sdk/internal/defaults.py +++ b/src/aws_encryption_sdk/internal/defaults.py @@ -29,7 +29,9 @@ #: Default message structure Type as defined in specification TYPE = aws_encryption_sdk.identifiers.ObjectType.CUSTOMER_AE_DATA #: Default algorithm as defined in specification -ALGORITHM = aws_encryption_sdk.identifiers.Algorithm.AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384 +ALGORITHM = ( + aws_encryption_sdk.identifiers.AlgorithmSuite.AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384 +) #: Key to add encoded signing key to encryption context dictionary as defined in specification ENCODED_SIGNER_KEY = "aws-crypto-public-key" diff --git a/src/aws_encryption_sdk/internal/formatting/serialize.py b/src/aws_encryption_sdk/internal/formatting/serialize.py index e7c86a0cb..57bcdd39b 100644 --- a/src/aws_encryption_sdk/internal/formatting/serialize.py +++ b/src/aws_encryption_sdk/internal/formatting/serialize.py @@ -17,7 +17,11 @@ import aws_encryption_sdk.internal.defaults import aws_encryption_sdk.internal.formatting.encryption_context from aws_encryption_sdk.exceptions import SerializationError -from aws_encryption_sdk.identifiers import ContentAADString, EncryptionType, SequenceIdentifier +from aws_encryption_sdk.identifiers import ( + ContentAADString, + EncryptionType, + SequenceIdentifier, +) from aws_encryption_sdk.internal.crypto.encryption import encrypt from aws_encryption_sdk.internal.crypto.iv import frame_iv, header_auth_iv from aws_encryption_sdk.internal.str_ops import to_bytes @@ -110,7 +114,12 @@ def serialize_header(header, signer=None): "I" # frame length ) header_bytes.extend( - struct.pack(header_close_format, header.content_type.value, header.algorithm.iv_len, header.frame_length) + struct.pack( + header_close_format, + header.content_type.value, + header.algorithm.iv_len, + header.frame_length, + ) ) output = bytes(header_bytes) if signer is not None: @@ -121,8 +130,8 @@ def serialize_header(header, signer=None): def serialize_header_auth(algorithm, header, data_encryption_key, signer=None): """Creates serialized header authentication data. - :param algorithm: Algorithm to use for encryption - :type algorithm: aws_encryption_sdk.identifiers.Algorithm + :param algorithm: Algorithm suite to use for encryption + :type algorithm: aws_encryption_sdk.identifiers.AlgorithmSuite :param bytes header: Serialized message header :param bytes data_encryption_key: Data key with which to encrypt message :param signer: Cryptographic signer object (optional) @@ -138,7 +147,9 @@ def serialize_header_auth(algorithm, header, data_encryption_key, signer=None): iv=header_auth_iv(algorithm), ) output = struct.pack( - ">{iv_len}s{tag_len}s".format(iv_len=algorithm.iv_len, tag_len=algorithm.tag_len), + ">{iv_len}s{tag_len}s".format( + iv_len=algorithm.iv_len, tag_len=algorithm.tag_len + ), header_auth.iv, header_auth.tag, ) @@ -150,8 +161,8 @@ def serialize_header_auth(algorithm, header, data_encryption_key, signer=None): def serialize_non_framed_open(algorithm, iv, plaintext_length, signer=None): """Serializes the opening block for a non-framed message body. - :param algorithm: Algorithm to use for encryption - :type algorithm: aws_encryption_sdk.identifiers.Algorithm + :param algorithm: Algorithm suite to use for encryption + :type algorithm: aws_encryption_sdk.identifiers.AlgorithmSuite :param bytes iv: IV value used to encrypt body :param int plaintext_length: Length of plaintext (and thus ciphertext) in body :param signer: Cryptographic signer object (optional) @@ -159,7 +170,9 @@ def serialize_non_framed_open(algorithm, iv, plaintext_length, signer=None): :returns: Serialized body start block :rtype: bytes """ - body_start_format = (">" "{iv_length}s" "Q").format(iv_length=algorithm.iv_len) # nonce (IV) # content length + body_start_format = (">" "{iv_length}s" "Q").format( + iv_length=algorithm.iv_len + ) # nonce (IV) # content length body_start = struct.pack(body_start_format, iv, plaintext_length) if signer: signer.update(body_start) @@ -182,13 +195,20 @@ def serialize_non_framed_close(tag, signer=None): def serialize_frame( - algorithm, plaintext, message_id, data_encryption_key, frame_length, sequence_number, is_final_frame, signer=None + algorithm, + plaintext, + message_id, + data_encryption_key, + frame_length, + sequence_number, + is_final_frame, + signer=None, ): """Receives a message plaintext, breaks off a frame, encrypts and serializes the frame, and returns the encrypted frame and the remaining plaintext. - :param algorithm: Algorithm to use for encryption - :type algorithm: aws_encryption_sdk.identifiers.Algorithm + :param algorithm: Algorithm suite to use for encryption + :type algorithm: aws_encryption_sdk.identifiers.AlgorithmSuite :param bytes plaintext: Source plaintext to encrypt and serialize :param bytes message_id: Message ID :param bytes data_encryption_key: Data key with which to encrypt message @@ -227,7 +247,9 @@ def serialize_frame( _LOGGER.debug("Serializing final frame") packed_frame = struct.pack( ">II{iv_len}sI{content_len}s{auth_len}s".format( - iv_len=algorithm.iv_len, content_len=len(frame_ciphertext.ciphertext), auth_len=algorithm.auth_len + iv_len=algorithm.iv_len, + content_len=len(frame_ciphertext.ciphertext), + auth_len=algorithm.auth_len, ), SequenceIdentifier.SEQUENCE_NUMBER_END.value, sequence_number, @@ -240,7 +262,9 @@ def serialize_frame( _LOGGER.debug("Serializing frame") packed_frame = struct.pack( ">I{iv_len}s{content_len}s{auth_len}s".format( - iv_len=algorithm.iv_len, content_len=frame_length, auth_len=algorithm.auth_len + iv_len=algorithm.iv_len, + content_len=frame_length, + auth_len=algorithm.auth_len, ), sequence_number, frame_ciphertext.iv, @@ -264,7 +288,9 @@ def serialize_footer(signer): footer = b"" if signer is not None: signature = signer.finalize() - footer = struct.pack(">H{sig_len}s".format(sig_len=len(signature)), len(signature), signature) + footer = struct.pack( + ">H{sig_len}s".format(sig_len=len(signature)), len(signature), signature + ) return footer @@ -277,7 +303,10 @@ def serialize_raw_master_key_prefix(raw_master_key): :returns: Serialized key_info prefix :rtype: bytes """ - if raw_master_key.config.wrapping_key.wrapping_algorithm.encryption_type is EncryptionType.ASYMMETRIC: + if ( + raw_master_key.config.wrapping_key.wrapping_algorithm.encryption_type + is EncryptionType.ASYMMETRIC + ): return to_bytes(raw_master_key.key_id) return struct.pack( ">{}sII".format(len(raw_master_key.key_id)), @@ -288,7 +317,9 @@ def serialize_raw_master_key_prefix(raw_master_key): ) -def serialize_wrapped_key(key_provider, wrapping_algorithm, wrapping_key_id, encrypted_wrapped_key): +def serialize_wrapped_key( + key_provider, wrapping_algorithm, wrapping_key_id, encrypted_wrapped_key +): """Serializes EncryptedData into a Wrapped EncryptedDataKey. :param key_provider: Info for Wrapping MasterKey @@ -307,15 +338,19 @@ def serialize_wrapped_key(key_provider, wrapping_algorithm, wrapping_key_id, enc else: key_info = struct.pack( ">{key_id_len}sII{iv_len}s".format( - key_id_len=len(wrapping_key_id), iv_len=wrapping_algorithm.algorithm.iv_len + key_id_len=len(wrapping_key_id), + iv_len=wrapping_algorithm.algorithm.iv_len, ), to_bytes(wrapping_key_id), - len(encrypted_wrapped_key.tag) * 8, # Tag Length is stored in bits, not bytes + len(encrypted_wrapped_key.tag) + * 8, # Tag Length is stored in bits, not bytes wrapping_algorithm.algorithm.iv_len, encrypted_wrapped_key.iv, ) key_ciphertext = encrypted_wrapped_key.ciphertext + encrypted_wrapped_key.tag return EncryptedDataKey( - key_provider=MasterKeyInfo(provider_id=key_provider.provider_id, key_info=key_info), + key_provider=MasterKeyInfo( + provider_id=key_provider.provider_id, key_info=key_info + ), encrypted_data_key=key_ciphertext, ) diff --git a/src/aws_encryption_sdk/internal/utils/__init__.py b/src/aws_encryption_sdk/internal/utils/__init__.py index 1e7400c3a..d3087e24e 100644 --- a/src/aws_encryption_sdk/internal/utils/__init__.py +++ b/src/aws_encryption_sdk/internal/utils/__init__.py @@ -18,7 +18,11 @@ import six import aws_encryption_sdk.internal.defaults -from aws_encryption_sdk.exceptions import InvalidDataKeyError, SerializationError, UnknownIdentityError +from aws_encryption_sdk.exceptions import ( + InvalidDataKeyError, + SerializationError, + UnknownIdentityError, +) from aws_encryption_sdk.identifiers import ContentAADString, ContentType from aws_encryption_sdk.internal.str_ops import to_bytes from aws_encryption_sdk.structures import EncryptedDataKey @@ -45,12 +49,15 @@ def validate_frame_length(frame_length, algorithm): """Validates that frame length is within the defined limits and is compatible with the selected algorithm. :param int frame_length: Frame size in bytes - :param algorithm: Algorithm to use for encryption - :type algorithm: aws_encryption_sdk.identifiers.Algorithm + :param algorithm: Algorithm suite to use for encryption + :type algorithm: aws_encryption_sdk.identifiers.AlgorithmSuite :raises SerializationError: if frame size is negative or not a multiple of the algorithm block size :raises SerializationError: if frame size is larger than the maximum allowed frame size """ - if frame_length < 0 or frame_length % algorithm.encryption_algorithm.block_size != 0: + if ( + frame_length < 0 + or frame_length % algorithm.encryption_algorithm.block_size != 0 + ): raise SerializationError( "Frame size must be a non-negative multiple of the block size of the crypto algorithm: {block_size}".format( block_size=algorithm.encryption_algorithm.block_size @@ -59,7 +66,8 @@ def validate_frame_length(frame_length, algorithm): if frame_length > aws_encryption_sdk.internal.defaults.MAX_FRAME_SIZE: raise SerializationError( "Frame size too large: {frame} > {max}".format( - frame=frame_length, max=aws_encryption_sdk.internal.defaults.MAX_FRAME_SIZE + frame=frame_length, + max=aws_encryption_sdk.internal.defaults.MAX_FRAME_SIZE, ) ) @@ -103,29 +111,39 @@ def prepare_data_keys(primary_master_key, master_keys, algorithm, encryption_con :type primary_master_key: aws_encryption_sdk.key_providers.base.MasterKey :param master_keys: All master keys with which to encrypt data keys :type master_keys: list of :class:`aws_encryption_sdk.key_providers.base.MasterKey` - :param algorithm: Algorithm to use for encryption - :type algorithm: aws_encryption_sdk.identifiers.Algorithm + :param algorithm: Algorithm suite to use for encryption + :type algorithm: aws_encryption_sdk.identifiers.AlgorithmSuite :param dict encryption_context: Encryption context to use when generating data key :rtype: tuple containing :class:`aws_encryption_sdk.structures.DataKey` and set of :class:`aws_encryption_sdk.structures.EncryptedDataKey` """ encrypted_data_keys = set() encrypted_data_encryption_key = None - data_encryption_key = primary_master_key.generate_data_key(algorithm, encryption_context) - _LOGGER.debug("encryption data generated with master key: %s", data_encryption_key.key_provider) + data_encryption_key = primary_master_key.generate_data_key( + algorithm, encryption_context + ) + _LOGGER.debug( + "encryption data generated with master key: %s", + data_encryption_key.key_provider, + ) for master_key in master_keys: # Don't re-encrypt the encryption data key; we already have the ciphertext if master_key is primary_master_key: encrypted_data_encryption_key = EncryptedDataKey( - key_provider=data_encryption_key.key_provider, encrypted_data_key=data_encryption_key.encrypted_data_key + key_provider=data_encryption_key.key_provider, + encrypted_data_key=data_encryption_key.encrypted_data_key, ) encrypted_data_keys.add(encrypted_data_encryption_key) continue encrypted_key = master_key.encrypt_data_key( - data_key=data_encryption_key, algorithm=algorithm, encryption_context=encryption_context + data_key=data_encryption_key, + algorithm=algorithm, + encryption_context=encryption_context, ) encrypted_data_keys.add(encrypted_key) - _LOGGER.debug("encryption key encrypted with master key: %s", master_key.key_provider) + _LOGGER.debug( + "encryption key encrypted with master key: %s", master_key.key_provider + ) return data_encryption_key, encrypted_data_keys @@ -151,8 +169,8 @@ def source_data_key_length_check(source_data_key, algorithm): :param source_data_key: Source data key object received from MasterKey decrypt or generate data_key methods :type source_data_key: :class:`aws_encryption_sdk.structures.RawDataKey` or :class:`aws_encryption_sdk.structures.DataKey` - :param algorithm: Algorithm object which directs how this data key will be used - :type algorithm: aws_encryption_sdk.identifiers.Algorithm + :param algorithm: Algorithm suite object which directs how this data key will be used + :type algorithm: aws_encryption_sdk.identifiers.AlgorithmSuite :raises InvalidDataKeyError: if data key length does not match required kdf input length """ if len(source_data_key.data_key) != algorithm.kdf_input_len: diff --git a/src/aws_encryption_sdk/key_providers/base.py b/src/aws_encryption_sdk/key_providers/base.py index 3112cba6d..b272d8836 100644 --- a/src/aws_encryption_sdk/key_providers/base.py +++ b/src/aws_encryption_sdk/key_providers/base.py @@ -71,8 +71,12 @@ def __new__(cls, **kwargs): """ instance = super(MasterKeyProvider, cls).__new__(cls) config = kwargs.pop("config", None) - if not isinstance(config, instance._config_class): # pylint: disable=protected-access - config = instance._config_class(**kwargs) # pylint: disable=protected-access + if not isinstance( + config, instance._config_class + ): # pylint: disable=protected-access + config = instance._config_class( + **kwargs + ) # pylint: disable=protected-access instance.config = config #: Index matching key IDs to existing MasterKey objects. instance._encrypt_key_index = {} # pylint: disable=protected-access @@ -88,11 +92,15 @@ def __repr__(self): name=self.__class__.__name__, kwargs=", ".join( "{key}={value}".format(key=key, value=value) - for key, value in sorted(attr.asdict(self.config, recurse=True).items(), key=lambda x: x[0]) + for key, value in sorted( + attr.asdict(self.config, recurse=True).items(), key=lambda x: x[0] + ) ), ) - def master_keys_for_encryption(self, encryption_context, plaintext_rostream, plaintext_length=None): + def master_keys_for_encryption( + self, encryption_context, plaintext_rostream, plaintext_length=None + ): """Returns a set containing all Master Keys added to this Provider, or any member Providers, which should be used to encrypt data keys for the specified data. @@ -125,7 +133,9 @@ def master_keys_for_encryption(self, encryption_context, plaintext_rostream, pla primary = _primary master_keys.extend(_master_keys) if not master_keys: - raise MasterKeyProviderError("No Master Keys available from Master Key Provider") + raise MasterKeyProviderError( + "No Master Keys available from Master Key Provider" + ) return primary, master_keys @abc.abstractmethod @@ -218,8 +228,8 @@ def decrypt_data_key(self, encrypted_data_key, algorithm, encryption_context): :param encrypted_data_key: Encrypted data key to decrypt :type encrypted_data_key: aws_encryption_sdk.structures.EncryptedDataKey - :param algorithm: Algorithm object which directs how this Master Key will encrypt the data key - :type algorithm: aws_encryption_sdk.identifiers.Algorithm + :param algorithm: Algorithm suite object which directs how this Master Key will encrypt the data key + :type algorithm: aws_encryption_sdk.identifiers.AlgorithmSuite :param dict encryption_context: Encryption context to use in encryption :returns: Decrypted data key :rtype: aws_encryption_sdk.structures.DataKey @@ -230,26 +240,38 @@ def decrypt_data_key(self, encrypted_data_key, algorithm, encryption_context): _LOGGER.debug("starting decrypt data key attempt") for member in [self] + self._members: if member.provider_id == encrypted_data_key.key_provider.provider_id: - _LOGGER.debug("attempting to locate master key from key provider: %s", member.provider_id) + _LOGGER.debug( + "attempting to locate master key from key provider: %s", + member.provider_id, + ) if isinstance(member, MasterKey): _LOGGER.debug("using existing master key") master_key = member elif self.vend_masterkey_on_decrypt: try: - _LOGGER.debug("attempting to add master key: %s", encrypted_data_key.key_provider.key_info) - master_key = member.master_key_for_decrypt(encrypted_data_key.key_provider.key_info) + _LOGGER.debug( + "attempting to add master key: %s", + encrypted_data_key.key_provider.key_info, + ) + master_key = member.master_key_for_decrypt( + encrypted_data_key.key_provider.key_info + ) except InvalidKeyIdError: _LOGGER.debug( - "master key %s not available in provider", encrypted_data_key.key_provider.key_info + "master key %s not available in provider", + encrypted_data_key.key_provider.key_info, ) continue else: continue try: _LOGGER.debug( - "attempting to decrypt data key with provider %s", encrypted_data_key.key_provider.key_info + "attempting to decrypt data key with provider %s", + encrypted_data_key.key_provider.key_info, + ) + data_key = master_key.decrypt_data_key( + encrypted_data_key, algorithm, encryption_context ) - data_key = master_key.decrypt_data_key(encrypted_data_key, algorithm, encryption_context) except (IncorrectMasterKeyError, DecryptKeyError) as error: _LOGGER.debug( "%s raised when attempting to decrypt data key with master key %s", @@ -262,7 +284,9 @@ def decrypt_data_key(self, encrypted_data_key, algorithm, encryption_context): raise DecryptKeyError("Unable to decrypt data key") return data_key - def decrypt_data_key_from_list(self, encrypted_data_keys, algorithm, encryption_context): + def decrypt_data_key_from_list( + self, encrypted_data_keys, algorithm, encryption_context + ): """Receives a list of encrypted data keys and returns the first one which this provider is able to decrypt. :param encrypted_data_keys: List of encrypted data keys @@ -277,7 +301,9 @@ def decrypt_data_key_from_list(self, encrypted_data_keys, algorithm, encryption_ data_key = None for encrypted_data_key in encrypted_data_keys: try: - data_key = self.decrypt_data_key(encrypted_data_key, algorithm, encryption_context) + data_key = self.decrypt_data_key( + encrypted_data_key, algorithm, encryption_context + ) # MasterKeyProvider.decrypt_data_key throws DecryptKeyError # but MasterKey.decrypt_data_key throws IncorrectMasterKeyError except (DecryptKeyError, IncorrectMasterKeyError): @@ -296,12 +322,18 @@ class MasterKeyConfig(object): :param bytes key_id: Key ID for Master Key """ - key_id = attr.ib(hash=True, validator=attr.validators.instance_of((six.string_types, bytes)), converter=to_bytes) + key_id = attr.ib( + hash=True, + validator=attr.validators.instance_of((six.string_types, bytes)), + converter=to_bytes, + ) def __attrs_post_init__(self): """Verify that children of this class define a "provider_id" attribute.""" if not hasattr(self, "provider_id"): - raise TypeError('Instances of MasterKeyConfig must have a "provider_id" attribute defined.') + raise TypeError( + 'Instances of MasterKeyConfig must have a "provider_id" attribute defined.' + ) @six.add_metaclass(abc.ABCMeta) @@ -318,7 +350,9 @@ def __new__(cls, **kwargs): instance = super(MasterKey, cls).__new__(cls, **kwargs) if not hasattr(instance.config, "provider_id"): - raise TypeError('MasterKey config classes must have a "provider_id" attribute defined.') + raise TypeError( + 'MasterKey config classes must have a "provider_id" attribute defined.' + ) if instance.config.provider_id is not None: # Only allow override if provider_id is NOT set to non-None for the class @@ -327,11 +361,14 @@ def __new__(cls, **kwargs): elif instance.provider_id != instance.config.provider_id: raise ConfigMismatchError( "Config provider_id does not match MasterKey provider_id: {config} != {instance}".format( - config=instance.config.provider_id, instance=instance.provider_id + config=instance.config.provider_id, + instance=instance.provider_id, ) ) instance.key_id = instance.config.key_id - instance._encrypt_key_index = {instance.key_id: instance} # pylint: disable=protected-access + instance._encrypt_key_index = { + instance.key_id: instance + } # pylint: disable=protected-access # We cannot make any general statements about key_info, so specifically enforce that decrypt index is empty. instance._decrypt_key_index = {} # pylint: disable=protected-access instance._members = [instance] # pylint: disable=protected-access @@ -360,7 +397,9 @@ def owns_data_key(self, data_key): return True return False - def master_keys_for_encryption(self, encryption_context, plaintext_rostream, plaintext_length=None): + def master_keys_for_encryption( + self, encryption_context, plaintext_rostream, plaintext_length=None + ): """Returns self and a list containing self, to match the format of output for a Master Key Provider. .. warning:: @@ -416,8 +455,12 @@ def generate_data_key(self, algorithm, encryption_context): :returns: Generated data key :rtype: aws_encryption_sdk.structures.DataKey """ - _LOGGER.info("generating data key with encryption context: %s", encryption_context) - generated_data_key = self._generate_data_key(algorithm=algorithm, encryption_context=encryption_context) + _LOGGER.info( + "generating data key with encryption context: %s", encryption_context + ) + generated_data_key = self._generate_data_key( + algorithm=algorithm, encryption_context=encryption_context + ) aws_encryption_sdk.internal.utils.source_data_key_length_check( source_data_key=generated_data_key, algorithm=algorithm ) @@ -450,8 +493,14 @@ def encrypt_data_key(self, data_key, algorithm, encryption_context): :rtype: aws_encryption_sdk.structures.EncryptedDataKey :raises IncorrectMasterKeyError: if Data Key's key provider does not match this Master Key """ - _LOGGER.info("encrypting data key with encryption context: %s", encryption_context) - return self._encrypt_data_key(data_key=data_key, algorithm=algorithm, encryption_context=encryption_context) + _LOGGER.info( + "encrypting data key with encryption context: %s", encryption_context + ) + return self._encrypt_data_key( + data_key=data_key, + algorithm=algorithm, + encryption_context=encryption_context, + ) @abc.abstractmethod def _encrypt_data_key(self, data_key, algorithm, encryption_context): @@ -483,10 +532,14 @@ def decrypt_data_key(self, encrypted_data_key, algorithm, encryption_context): :rtype: aws_encryption_sdk.structures.DataKey :raises IncorrectMasterKeyError: if Data Key's key provider does not match this Master Key """ - _LOGGER.info("decrypting data key with encryption context: %s", encryption_context) + _LOGGER.info( + "decrypting data key with encryption context: %s", encryption_context + ) self._key_check(encrypted_data_key) decrypted_data_key = self._decrypt_data_key( - encrypted_data_key=encrypted_data_key, algorithm=algorithm, encryption_context=encryption_context + encrypted_data_key=encrypted_data_key, + algorithm=algorithm, + encryption_context=encryption_context, ) aws_encryption_sdk.internal.utils.source_data_key_length_check( source_data_key=decrypted_data_key, algorithm=algorithm diff --git a/src/aws_encryption_sdk/key_providers/kms.py b/src/aws_encryption_sdk/key_providers/kms.py index df089e3b1..3855a329f 100644 --- a/src/aws_encryption_sdk/key_providers/kms.py +++ b/src/aws_encryption_sdk/key_providers/kms.py @@ -21,10 +21,20 @@ import botocore.session from botocore.exceptions import ClientError -from aws_encryption_sdk.exceptions import DecryptKeyError, EncryptKeyError, GenerateKeyError, UnknownRegionError +from aws_encryption_sdk.exceptions import ( + DecryptKeyError, + EncryptKeyError, + GenerateKeyError, + UnknownRegionError, +) from aws_encryption_sdk.identifiers import USER_AGENT_SUFFIX from aws_encryption_sdk.internal.str_ops import to_str -from aws_encryption_sdk.key_providers.base import MasterKey, MasterKeyConfig, MasterKeyProvider, MasterKeyProviderConfig +from aws_encryption_sdk.key_providers.base import ( + MasterKey, + MasterKeyConfig, + MasterKeyProvider, + MasterKeyProviderConfig, +) from aws_encryption_sdk.structures import DataKey, EncryptedDataKey, MasterKeyInfo _LOGGER = logging.getLogger(__name__) @@ -46,7 +56,9 @@ def _region_from_key_id(key_id, default_region=None): except IndexError: if default_region is None: raise UnknownRegionError( - "No default region found and no region determinable from key id: {}".format(key_id) + "No default region found and no region determinable from key id: {}".format( + key_id + ) ) region_name = default_region return region_name @@ -68,10 +80,16 @@ class KMSMasterKeyProviderConfig(MasterKeyProviderConfig): validator=attr.validators.instance_of(botocore.session.Session), ) key_ids = attr.ib( - hash=True, default=attr.Factory(tuple), validator=attr.validators.instance_of(tuple), converter=tuple + hash=True, + default=attr.Factory(tuple), + validator=attr.validators.instance_of(tuple), + converter=tuple, ) region_names = attr.ib( - hash=True, default=attr.Factory(tuple), validator=attr.validators.instance_of(tuple), converter=tuple + hash=True, + default=attr.Factory(tuple), + validator=attr.validators.instance_of(tuple), + converter=tuple, ) @@ -114,16 +132,18 @@ def __init__(self, **kwargs): # pylint: disable=unused-argument def _process_config(self): """Traverses the config and adds master keys and regional clients as needed.""" - self._user_agent_adding_config = botocore.config.Config(user_agent_extra=USER_AGENT_SUFFIX) - + self._user_agent_adding_config = botocore.config.Config( + user_agent_extra=USER_AGENT_SUFFIX + ) if self.config.region_names: self.add_regional_clients_from_list(self.config.region_names) self.default_region = self.config.region_names[0] else: - self.default_region = self.config.botocore_session.get_config_variable("region") + self.default_region = self.config.botocore_session.get_config_variable( + "region" + ) if self.default_region is not None: self.add_regional_client(self.default_region) - if self.config.key_ids: self.add_master_keys_from_list(self.config.key_ids) @@ -140,7 +160,9 @@ def _wrap_client(self, region_name, method, *args, **kwargs): except botocore.exceptions.BotoCoreError: self._regional_clients.pop(region_name) _LOGGER.error( - 'Removing regional client "%s" from cache due to BotoCoreError on %s call', region_name, method.__name__ + 'Removing regional client "%s" from cache due to BotoCoreError on %s call', + region_name, + method.__name__, ) raise @@ -161,7 +183,9 @@ def add_regional_client(self, region_name): :param str region_name: AWS Region ID (ex: us-east-1) """ if region_name not in self._regional_clients: - session = boto3.session.Session(region_name=region_name, botocore_session=self.config.botocore_session) + session = boto3.session.Session( + region_name=region_name, botocore_session=self.config.botocore_session + ) client = session.client("kms", config=self._user_agent_adding_config) self._register_client(client, region_name) self._regional_clients[region_name] = client @@ -192,7 +216,9 @@ def _new_master_key(self, key_id): :raises InvalidKeyIdError: if key_id is not a valid KMS CMK ID to which this key provider has access """ _key_id = to_str(key_id) # KMS client requires str, not bytes - return KMSMasterKey(config=KMSMasterKeyConfig(key_id=key_id, client=self._client(_key_id))) + return KMSMasterKey( + config=KMSMasterKeyConfig(key_id=key_id, client=self._client(_key_id)) + ) @attr.s(hash=True) @@ -206,9 +232,14 @@ class KMSMasterKeyConfig(MasterKeyConfig): """ provider_id = _PROVIDER_ID - client = attr.ib(hash=True, validator=attr.validators.instance_of(botocore.client.BaseClient)) + client = attr.ib( + hash=True, validator=attr.validators.instance_of(botocore.client.BaseClient) + ) grant_tokens = attr.ib( - hash=True, default=attr.Factory(tuple), validator=attr.validators.instance_of(tuple), converter=tuple + hash=True, + default=attr.Factory(tuple), + validator=attr.validators.instance_of(tuple), + converter=tuple, ) @client.default @@ -244,8 +275,8 @@ def __init__(self, **kwargs): # pylint: disable=unused-argument def _generate_data_key(self, algorithm, encryption_context=None): """Generates data key and returns plaintext and ciphertext of key. - :param algorithm: Algorithm on which to base data key - :type algorithm: aws_encryption_sdk.identifiers.Algorithm + :param algorithm: Algorithm suite on which to base data key + :type algorithm: aws_encryption_sdk.identifiers.AlgorithmSuite :param dict encryption_context: Encryption context to pass to KMS :returns: Generated data key :rtype: aws_encryption_sdk.structures.DataKey @@ -262,7 +293,9 @@ def _generate_data_key(self, algorithm, encryption_context=None): ciphertext = response["CiphertextBlob"] key_id = response["KeyId"] except (ClientError, KeyError): - error_message = "Master Key {key_id} unable to generate data key".format(key_id=self._key_id) + error_message = "Master Key {key_id} unable to generate data key".format( + key_id=self._key_id + ) _LOGGER.exception(error_message) raise GenerateKeyError(error_message) return DataKey( @@ -294,11 +327,14 @@ def _encrypt_data_key(self, data_key, algorithm, encryption_context=None): ciphertext = response["CiphertextBlob"] key_id = response["KeyId"] except (ClientError, KeyError): - error_message = "Master Key {key_id} unable to encrypt data key".format(key_id=self._key_id) + error_message = "Master Key {key_id} unable to encrypt data key".format( + key_id=self._key_id + ) _LOGGER.exception(error_message) raise EncryptKeyError(error_message) return EncryptedDataKey( - key_provider=MasterKeyInfo(provider_id=self.provider_id, key_info=key_id), encrypted_data_key=ciphertext + key_provider=MasterKeyInfo(provider_id=self.provider_id, key_info=key_id), + encrypted_data_key=ciphertext, ) def _decrypt_data_key(self, encrypted_data_key, algorithm, encryption_context=None): @@ -306,7 +342,7 @@ def _decrypt_data_key(self, encrypted_data_key, algorithm, encryption_context=No :param data_key: Encrypted data key :type data_key: aws_encryption_sdk.structures.EncryptedDataKey - :type algorithm: `aws_encryption_sdk.identifiers.Algorithm` (not used for KMS) + :type algorithm: `aws_encryption_sdk.identifiers.AlgorithmSuite` (not used for KMS) :param dict encryption_context: Encryption context to use in decryption :returns: Decrypted data key :rtype: aws_encryption_sdk.structures.DataKey @@ -322,9 +358,13 @@ def _decrypt_data_key(self, encrypted_data_key, algorithm, encryption_context=No response = self.config.client.decrypt(**kms_params) plaintext = response["Plaintext"] except (ClientError, KeyError): - error_message = "Master Key {key_id} unable to decrypt data key".format(key_id=self._key_id) + error_message = "Master Key {key_id} unable to decrypt data key".format( + key_id=self._key_id + ) _LOGGER.exception(error_message) raise DecryptKeyError(error_message) return DataKey( - key_provider=self.key_provider, data_key=plaintext, encrypted_data_key=encrypted_data_key.encrypted_data_key + key_provider=self.key_provider, + data_key=plaintext, + encrypted_data_key=encrypted_data_key.encrypted_data_key, ) diff --git a/src/aws_encryption_sdk/key_providers/raw.py b/src/aws_encryption_sdk/key_providers/raw.py index 57a1d5edf..e8769457a 100644 --- a/src/aws_encryption_sdk/key_providers/raw.py +++ b/src/aws_encryption_sdk/key_providers/raw.py @@ -22,7 +22,12 @@ import aws_encryption_sdk.internal.formatting.serialize from aws_encryption_sdk.identifiers import EncryptionType from aws_encryption_sdk.internal.crypto.wrapping_keys import WrappingKey -from aws_encryption_sdk.key_providers.base import MasterKey, MasterKeyConfig, MasterKeyProvider, MasterKeyProviderConfig +from aws_encryption_sdk.key_providers.base import ( + MasterKey, + MasterKeyConfig, + MasterKeyProvider, + MasterKeyProviderConfig, +) from aws_encryption_sdk.structures import DataKey, RawDataKey _LOGGER = logging.getLogger(__name__) @@ -43,7 +48,9 @@ class RawMasterKeyConfig(MasterKeyConfig): validator=attr.validators.instance_of((six.string_types, bytes)), converter=aws_encryption_sdk.internal.str_ops.to_str, ) - wrapping_key = attr.ib(hash=True, validator=attr.validators.instance_of(WrappingKey)) + wrapping_key = attr.ib( + hash=True, validator=attr.validators.instance_of(WrappingKey) + ) class RawMasterKey(MasterKey): @@ -84,13 +91,18 @@ def owns_data_key(self, data_key): """ expected_key_info_len = -1 if ( - self.config.wrapping_key.wrapping_algorithm.encryption_type is EncryptionType.ASYMMETRIC + self.config.wrapping_key.wrapping_algorithm.encryption_type + is EncryptionType.ASYMMETRIC and data_key.key_provider == self.key_provider ): return True - elif self.config.wrapping_key.wrapping_algorithm.encryption_type is EncryptionType.SYMMETRIC: + elif ( + self.config.wrapping_key.wrapping_algorithm.encryption_type + is EncryptionType.SYMMETRIC + ): expected_key_info_len = ( - len(self._key_info_prefix) + self.config.wrapping_key.wrapping_algorithm.algorithm.iv_len + len(self._key_info_prefix) + + self.config.wrapping_key.wrapping_algorithm.algorithm.iv_len ) if ( data_key.key_provider.provider_id == self.provider_id @@ -115,15 +127,17 @@ def owns_data_key(self, data_key): def _generate_data_key(self, algorithm, encryption_context): """Generates data key and returns :class:`aws_encryption_sdk.structures.DataKey`. - :param algorithm: Algorithm on which to base data key - :type algorithm: aws_encryption_sdk.identifiers.Algorithm + :param algorithm: Algorithm suite on which to base data key + :type algorithm: aws_encryption_sdk.identifiers.AlgorithmSuite :param dict encryption_context: Encryption context to use in encryption :returns: Generated data key :rtype: aws_encryption_sdk.structures.DataKey """ plaintext_data_key = os.urandom(algorithm.kdf_input_len) encrypted_data_key = self._encrypt_data_key( - data_key=RawDataKey(key_provider=self.key_provider, data_key=plaintext_data_key), + data_key=RawDataKey( + key_provider=self.key_provider, data_key=plaintext_data_key + ), algorithm=algorithm, encryption_context=encryption_context, ) @@ -139,8 +153,8 @@ def _encrypt_data_key(self, data_key, algorithm, encryption_context): :param data_key: Unencrypted data key :type data_key: :class:`aws_encryption_sdk.structures.RawDataKey` or :class:`aws_encryption_sdk.structures.DataKey` - :param algorithm: Algorithm object which directs how this Master Key will encrypt the data key - :type algorithm: aws_encryption_sdk.identifiers.Algorithm + :param algorithm: Algorithm suite object which directs how this Master Key will encrypt the data key + :type algorithm: aws_encryption_sdk.identifiers.AlgorithmSuite :param dict encryption_context: Encryption context to use in encryption :returns: Decrypted data key :rtype: aws_encryption_sdk.structures.EncryptedDataKey @@ -163,8 +177,8 @@ def _decrypt_data_key(self, encrypted_data_key, algorithm, encryption_context): :param data_key: Encrypted data key :type data_key: aws_encryption_sdk.structures.EncryptedDataKey - :param algorithm: Algorithm object which directs how this Master Key will encrypt the data key - :type algorithm: aws_encryption_sdk.identifiers.Algorithm + :param algorithm: Algorithm suite object which directs how this Master Key will encrypt the data key + :type algorithm: aws_encryption_sdk.identifiers.AlgorithmSuite :param dict encryption_context: Encryption context to use in decryption :returns: Data key containing decrypted data key :rtype: aws_encryption_sdk.structures.DataKey @@ -178,7 +192,8 @@ def _decrypt_data_key(self, encrypted_data_key, algorithm, encryption_context): ) # EncryptedData to raw key string plaintext_data_key = self.config.wrapping_key.decrypt( - encrypted_wrapped_data_key=encrypted_wrapped_key, encryption_context=encryption_context + encrypted_wrapped_data_key=encrypted_wrapped_key, + encryption_context=encryption_context, ) # Raw key string to DataKey return DataKey( @@ -222,5 +237,7 @@ def _new_master_key(self, key_id): _LOGGER.debug("Retrieving wrapping key with id: %s", key_id) wrapping_key = self._get_raw_key(key_id) return self._master_key_class( - config=RawMasterKeyConfig(key_id=key_id, provider_id=self.provider_id, wrapping_key=wrapping_key) + config=RawMasterKeyConfig( + key_id=key_id, provider_id=self.provider_id, wrapping_key=wrapping_key + ) ) diff --git a/src/aws_encryption_sdk/materials_managers/__init__.py b/src/aws_encryption_sdk/materials_managers/__init__.py index bc5230c51..daad892ff 100644 --- a/src/aws_encryption_sdk/materials_managers/__init__.py +++ b/src/aws_encryption_sdk/materials_managers/__init__.py @@ -17,7 +17,7 @@ import attr import six -from ..identifiers import Algorithm +from ..identifiers import AlgorithmSuite from ..internal.utils.streams import ROStream from ..structures import DataKey @@ -35,19 +35,26 @@ class EncryptionMaterialsRequest(object): :param int frame_length: Frame length to be used while encrypting stream :param plaintext_rostream: Source plaintext read-only stream (optional) :type plaintext_rostream: aws_encryption_sdk.internal.utils.streams.ROStream - :param algorithm: Algorithm passed to underlying master key provider and master keys (optional) - :type algorithm: aws_encryption_sdk.identifiers.Algorithm + :param algorithm: Algorithm suite passed to underlying master key provider and master keys (optional) + :type algorithm: aws_encryption_sdk.identifiers.AlgorithmSuite :param int plaintext_length: Length of source plaintext (optional) """ encryption_context = attr.ib(validator=attr.validators.instance_of(dict)) frame_length = attr.ib(validator=attr.validators.instance_of(six.integer_types)) plaintext_rostream = attr.ib( - default=None, validator=attr.validators.optional(attr.validators.instance_of(ROStream)) + default=None, + validator=attr.validators.optional(attr.validators.instance_of(ROStream)), + ) + algorithm = attr.ib( + default=None, + validator=attr.validators.optional(attr.validators.instance_of(AlgorithmSuite)), ) - algorithm = attr.ib(default=None, validator=attr.validators.optional(attr.validators.instance_of(Algorithm))) plaintext_length = attr.ib( - default=None, validator=attr.validators.optional(attr.validators.instance_of(six.integer_types)) + default=None, + validator=attr.validators.optional( + attr.validators.instance_of(six.integer_types) + ), ) @@ -57,8 +64,8 @@ class EncryptionMaterials(object): .. versionadded:: 1.3.0 - :param algorithm: Algorithm to use for encrypting message - :type algorithm: aws_encryption_sdk.identifiers.Algorithm + :param algorithm: Algorithm suite to use for encrypting message + :type algorithm: aws_encryption_sdk.identifiers.AlgorithmSuite :param data_encryption_key: Plaintext data key to use for encrypting message :type data_encryption_key: aws_encryption_sdk.structures.DataKey :param encrypted_data_keys: List of encrypted data keys @@ -67,11 +74,14 @@ class EncryptionMaterials(object): :param bytes signing_key: Encoded signing key """ - algorithm = attr.ib(validator=attr.validators.instance_of(Algorithm)) + algorithm = attr.ib(validator=attr.validators.instance_of(AlgorithmSuite)) data_encryption_key = attr.ib(validator=attr.validators.instance_of(DataKey)) encrypted_data_keys = attr.ib(validator=attr.validators.instance_of(set)) encryption_context = attr.ib(validator=attr.validators.instance_of(dict)) - signing_key = attr.ib(default=None, validator=attr.validators.optional(attr.validators.instance_of(bytes))) + signing_key = attr.ib( + default=None, + validator=attr.validators.optional(attr.validators.instance_of(bytes)), + ) @attr.s(hash=False) @@ -80,14 +90,14 @@ class DecryptionMaterialsRequest(object): .. versionadded:: 1.3.0 - :param algorithm: Algorithm to provide to master keys for underlying decrypt requests - :type algorithm: aws_encryption_sdk.identifiers.Algorithm + :param algorithm: Algorithm suite to provide to master keys for underlying decrypt requests + :type algorithm: aws_encryption_sdk.identifiers.AlgorithmSuite :param encrypted_data_keys: Set of encrypted data keys :type encrypted_data_keys: set of `aws_encryption_sdk.structures.EncryptedDataKey` :param dict encryption_context: Encryption context to provide to master keys for underlying decrypt requests """ - algorithm = attr.ib(validator=attr.validators.instance_of(Algorithm)) + algorithm = attr.ib(validator=attr.validators.instance_of(AlgorithmSuite)) encrypted_data_keys = attr.ib(validator=attr.validators.instance_of(set)) encryption_context = attr.ib(validator=attr.validators.instance_of(dict)) @@ -104,4 +114,7 @@ class DecryptionMaterials(object): """ data_key = attr.ib(validator=attr.validators.instance_of(DataKey)) - verification_key = attr.ib(default=None, validator=attr.validators.optional(attr.validators.instance_of(bytes))) + verification_key = attr.ib( + default=None, + validator=attr.validators.optional(attr.validators.instance_of(bytes)), + ) diff --git a/src/aws_encryption_sdk/materials_managers/default.py b/src/aws_encryption_sdk/materials_managers/default.py index 6d10465a9..42d5ee7a6 100644 --- a/src/aws_encryption_sdk/materials_managers/default.py +++ b/src/aws_encryption_sdk/materials_managers/default.py @@ -39,13 +39,17 @@ class DefaultCryptoMaterialsManager(CryptoMaterialsManager): """ algorithm = ALGORITHM - master_key_provider = attr.ib(validator=attr.validators.instance_of(MasterKeyProvider)) + master_key_provider = attr.ib( + validator=attr.validators.instance_of(MasterKeyProvider) + ) - def _generate_signing_key_and_update_encryption_context(self, algorithm, encryption_context): + def _generate_signing_key_and_update_encryption_context( + self, algorithm, encryption_context + ): """Generates a signing key based on the provided algorithm. - :param algorithm: Algorithm for which to generate signing key - :type algorithm: aws_encryption_sdk.identifiers.Algorithm + :param algorithm: Algorithm suite for which to generate signing key + :type algorithm: aws_encryption_sdk.identifiers.AlgorithmSuite :param dict encryption_context: Encryption context from request :returns: Signing key bytes :rtype: bytes or None @@ -54,7 +58,9 @@ def _generate_signing_key_and_update_encryption_context(self, algorithm, encrypt if algorithm.signing_algorithm_info is None: return None - signer = Signer(algorithm=algorithm, key=generate_ecc_signing_key(algorithm=algorithm)) + signer = Signer( + algorithm=algorithm, key=generate_ecc_signing_key(algorithm=algorithm) + ) encryption_context[ENCODED_SIGNER_KEY] = to_str(signer.encoded_public_key()) return signer.key_bytes() @@ -69,10 +75,14 @@ def get_encryption_materials(self, request): :raises MasterKeyProviderError: if the primary master key provided by the underlying master key provider is not included in the full set of master keys provided by that provider """ - algorithm = request.algorithm if request.algorithm is not None else self.algorithm + algorithm = ( + request.algorithm if request.algorithm is not None else self.algorithm + ) encryption_context = request.encryption_context.copy() - signing_key = self._generate_signing_key_and_update_encryption_context(algorithm, encryption_context) + signing_key = self._generate_signing_key_and_update_encryption_context( + algorithm, encryption_context + ) primary_master_key, master_keys = self.master_key_provider.master_keys_for_encryption( encryption_context=encryption_context, @@ -80,9 +90,13 @@ def get_encryption_materials(self, request): plaintext_length=request.plaintext_length, ) if not master_keys: - raise MasterKeyProviderError("No Master Keys available from Master Key Provider") + raise MasterKeyProviderError( + "No Master Keys available from Master Key Provider" + ) if primary_master_key not in master_keys: - raise MasterKeyProviderError("Primary Master Key not in provided Master Keys") + raise MasterKeyProviderError( + "Primary Master Key not in provided Master Keys" + ) data_encryption_key, encrypted_data_keys = prepare_data_keys( primary_master_key=primary_master_key, @@ -101,11 +115,13 @@ def get_encryption_materials(self, request): signing_key=signing_key, ) - def _load_verification_key_from_encryption_context(self, algorithm, encryption_context): + def _load_verification_key_from_encryption_context( + self, algorithm, encryption_context + ): """Loads the verification key from the encryption context if used by algorithm suite. - :param algorithm: Algorithm for which to generate signing key - :type algorithm: aws_encryption_sdk.identifiers.Algorithm + :param algorithm: Algorithm suite for which to generate signing key + :type algorithm: aws_encryption_sdk.identifiers.AlgorithmSuite :param dict encryption_context: Encryption context from request :returns: Raw verification key :rtype: bytes @@ -113,15 +129,24 @@ def _load_verification_key_from_encryption_context(self, algorithm, encryption_c """ encoded_verification_key = encryption_context.get(ENCODED_SIGNER_KEY, None) - if algorithm.signing_algorithm_info is not None and encoded_verification_key is None: - raise SerializationError("No signature verification key found in header for signed algorithm.") + if ( + algorithm.signing_algorithm_info is not None + and encoded_verification_key is None + ): + raise SerializationError( + "No signature verification key found in header for signed algorithm." + ) if algorithm.signing_algorithm_info is None: if encoded_verification_key is not None: - raise SerializationError("Signature verification key found in header for non-signed algorithm.") + raise SerializationError( + "Signature verification key found in header for non-signed algorithm." + ) return None - verifier = Verifier.from_encoded_point(algorithm=algorithm, encoded_point=encoded_verification_key) + verifier = Verifier.from_encoded_point( + algorithm=algorithm, encoded_point=encoded_verification_key + ) return verifier.key_bytes() def decrypt_materials(self, request): diff --git a/src/aws_encryption_sdk/streaming_client.py b/src/aws_encryption_sdk/streaming_client.py index 90dc9d25c..0ed383768 100644 --- a/src/aws_encryption_sdk/streaming_client.py +++ b/src/aws_encryption_sdk/streaming_client.py @@ -29,12 +29,18 @@ NotSupportedError, SerializationError, ) -from aws_encryption_sdk.identifiers import Algorithm, ContentType +from aws_encryption_sdk.identifiers import AlgorithmSuite, ContentType from aws_encryption_sdk.internal.crypto.authentication import Signer, Verifier from aws_encryption_sdk.internal.crypto.data_keys import derive_data_encryption_key from aws_encryption_sdk.internal.crypto.encryption import Decryptor, Encryptor, decrypt from aws_encryption_sdk.internal.crypto.iv import non_framed_body_iv -from aws_encryption_sdk.internal.defaults import FRAME_LENGTH, LINE_LENGTH, MAX_NON_FRAMED_SIZE, TYPE, VERSION +from aws_encryption_sdk.internal.defaults import ( + FRAME_LENGTH, + LINE_LENGTH, + MAX_NON_FRAMED_SIZE, + TYPE, + VERSION, +) from aws_encryption_sdk.internal.formatting.deserialize import ( deserialize_footer, deserialize_frame, @@ -44,7 +50,9 @@ deserialize_tag, validate_header, ) -from aws_encryption_sdk.internal.formatting.encryption_context import assemble_content_aad +from aws_encryption_sdk.internal.formatting.encryption_context import ( + assemble_content_aad, +) from aws_encryption_sdk.internal.formatting.serialize import ( serialize_footer, serialize_frame, @@ -54,7 +62,10 @@ serialize_non_framed_open, ) from aws_encryption_sdk.key_providers.base import MasterKeyProvider -from aws_encryption_sdk.materials_managers import DecryptionMaterialsRequest, EncryptionMaterialsRequest +from aws_encryption_sdk.materials_managers import ( + DecryptionMaterialsRequest, + EncryptionMaterialsRequest, +) from aws_encryption_sdk.materials_managers.base import CryptoMaterialsManager from aws_encryption_sdk.materials_managers.default import DefaultCryptoMaterialsManager from aws_encryption_sdk.structures import MessageHeader @@ -82,29 +93,53 @@ class _ClientConfig(object): will attempt to seek() to the end of the stream and tell() to find the length of source data. """ - source = attr.ib(hash=True, converter=aws_encryption_sdk.internal.utils.prep_stream_data) + source = attr.ib( + hash=True, converter=aws_encryption_sdk.internal.utils.prep_stream_data + ) materials_manager = attr.ib( - hash=True, default=None, validator=attr.validators.optional(attr.validators.instance_of(CryptoMaterialsManager)) + hash=True, + default=None, + validator=attr.validators.optional( + attr.validators.instance_of(CryptoMaterialsManager) + ), ) key_provider = attr.ib( - hash=True, default=None, validator=attr.validators.optional(attr.validators.instance_of(MasterKeyProvider)) + hash=True, + default=None, + validator=attr.validators.optional( + attr.validators.instance_of(MasterKeyProvider) + ), ) source_length = attr.ib( - hash=True, default=None, validator=attr.validators.optional(attr.validators.instance_of(six.integer_types)) + hash=True, + default=None, + validator=attr.validators.optional( + attr.validators.instance_of(six.integer_types) + ), ) line_length = attr.ib( - hash=True, default=LINE_LENGTH, validator=attr.validators.instance_of(six.integer_types) + hash=True, + default=LINE_LENGTH, + validator=attr.validators.instance_of(six.integer_types), ) # DEPRECATED: Value is no longer configurable here. Parameter left here to avoid breaking consumers. def __attrs_post_init__(self): """Normalize inputs to crypto material manager.""" - both_cmm_and_mkp_defined = self.materials_manager is not None and self.key_provider is not None - neither_cmm_nor_mkp_defined = self.materials_manager is None and self.key_provider is None + both_cmm_and_mkp_defined = ( + self.materials_manager is not None and self.key_provider is not None + ) + neither_cmm_nor_mkp_defined = ( + self.materials_manager is None and self.key_provider is None + ) if both_cmm_and_mkp_defined or neither_cmm_nor_mkp_defined: - raise TypeError("Exactly one of materials_manager or key_provider must be provided") + raise TypeError( + "Exactly one of materials_manager or key_provider must be provided" + ) if self.materials_manager is None: - self.materials_manager = DefaultCryptoMaterialsManager(master_key_provider=self.key_provider) + self.materials_manager = DefaultCryptoMaterialsManager( + master_key_provider=self.key_provider + ) class _EncryptionStream(io.IOBase): @@ -157,15 +192,21 @@ def __new__(cls, **kwargs): instance = super(_EncryptionStream, cls).__new__(cls) config = kwargs.pop("config", None) - if not isinstance(config, instance._config_class): # pylint: disable=protected-access - config = instance._config_class(**kwargs) # pylint: disable=protected-access + if not isinstance( + config, instance._config_class + ): # pylint: disable=protected-access + config = instance._config_class( + **kwargs + ) # pylint: disable=protected-access instance.config = config instance.bytes_read = 0 instance.output_buffer = b"" instance._message_prepped = False # pylint: disable=protected-access instance.source_stream = instance.config.source - instance._stream_length = instance.config.source_length # pylint: disable=protected-access + instance._stream_length = ( + instance.config.source_length + ) # pylint: disable=protected-access return instance @@ -333,8 +374,8 @@ class EncryptorConfig(_ClientConfig): this is not enforced if a `key_provider` is provided. :param dict encryption_context: Dictionary defining encryption context - :param algorithm: Algorithm to use for encryption (optional) - :type algorithm: aws_encryption_sdk.identifiers.Algorithm + :param algorithm: Algorithm suite to use for encryption (optional) + :type algorithm: aws_encryption_sdk.identifiers.AlgorithmSuite :param int frame_length: Frame length in bytes (optional) """ @@ -344,12 +385,20 @@ class EncryptorConfig(_ClientConfig): validator=attr.validators.instance_of(dict), ) algorithm = attr.ib( - hash=True, default=None, validator=attr.validators.optional(attr.validators.instance_of(Algorithm)) + hash=True, + default=None, + validator=attr.validators.optional(attr.validators.instance_of(AlgorithmSuite)), + ) + frame_length = attr.ib( + hash=True, + default=FRAME_LENGTH, + validator=attr.validators.instance_of(six.integer_types), ) - frame_length = attr.ib(hash=True, default=FRAME_LENGTH, validator=attr.validators.instance_of(six.integer_types)) -class StreamEncryptor(_EncryptionStream): # pylint: disable=too-many-instance-attributes +class StreamEncryptor( + _EncryptionStream +): # pylint: disable=too-many-instance-attributes """Provides a streaming encryptor for encrypting a stream source. Behaves as a standard file-like object. @@ -384,22 +433,27 @@ class StreamEncryptor(_EncryptionStream): # pylint: disable=too-many-instance-a this is not enforced if a `key_provider` is provided. :param dict encryption_context: Dictionary defining encryption context - :param algorithm: Algorithm to use for encryption - :type algorithm: aws_encryption_sdk.identifiers.Algorithm + :param algorithm: Algorithm suite to use for encryption + :type algorithm: aws_encryption_sdk.identifiers.AlgorithmSuite :param int frame_length: Frame length in bytes """ _config_class = EncryptorConfig - def __init__(self, **kwargs): # pylint: disable=unused-argument,super-init-not-called + def __init__( + self, **kwargs + ): # pylint: disable=unused-argument,super-init-not-called """Prepares necessary initial values.""" self.sequence_number = 1 - self.content_type = aws_encryption_sdk.internal.utils.content_type(self.config.frame_length) + self.content_type = aws_encryption_sdk.internal.utils.content_type( + self.config.frame_length + ) self._bytes_encrypted = 0 if self.config.frame_length == 0 and ( - self.config.source_length is not None and self.config.source_length > MAX_NON_FRAMED_SIZE + self.config.source_length is not None + and self.config.source_length > MAX_NON_FRAMED_SIZE ): raise SerializationError("Source too large for non-framed message") @@ -431,31 +485,41 @@ def _prep_message(self): algorithm=self.config.algorithm, encryption_context=self.config.encryption_context.copy(), frame_length=self.config.frame_length, - plaintext_rostream=aws_encryption_sdk.internal.utils.streams.ROStream(self.source_stream), + plaintext_rostream=aws_encryption_sdk.internal.utils.streams.ROStream( + self.source_stream + ), plaintext_length=plaintext_length, ) self._encryption_materials = self.config.materials_manager.get_encryption_materials( request=encryption_materials_request ) - if self.config.algorithm is not None and self._encryption_materials.algorithm != self.config.algorithm: + if ( + self.config.algorithm is not None + and self._encryption_materials.algorithm != self.config.algorithm + ): raise ActionNotAllowedError( ( "Cryptographic materials manager provided algorithm suite" " differs from algorithm suite in request.\n" "Required: {requested}\n" "Provided: {provided}" - ).format(requested=self.config.algorithm, provided=self._encryption_materials.algorithm) + ).format( + requested=self.config.algorithm, + provided=self._encryption_materials.algorithm, + ) ) if self._encryption_materials.signing_key is None: self.signer = None else: self.signer = Signer.from_key_bytes( - algorithm=self._encryption_materials.algorithm, key_bytes=self._encryption_materials.signing_key + algorithm=self._encryption_materials.algorithm, + key_bytes=self._encryption_materials.signing_key, ) aws_encryption_sdk.internal.utils.validate_frame_length( - frame_length=self.config.frame_length, algorithm=self._encryption_materials.algorithm + frame_length=self.config.frame_length, + algorithm=self._encryption_materials.algorithm, ) self._derived_data_key = derive_data_encryption_key( @@ -545,14 +609,20 @@ def _read_bytes_to_non_framed_body(self, b): self.signer.update(ciphertext) if len(plaintext) < b: - _LOGGER.debug("Closing encryptor after receiving only %d bytes of %d bytes requested", plaintext_length, b) + _LOGGER.debug( + "Closing encryptor after receiving only %d bytes of %d bytes requested", + plaintext_length, + b, + ) closing = self.encryptor.finalize() if self.signer is not None: self.signer.update(closing) - closing += serialize_non_framed_close(tag=self.encryptor.tag, signer=self.signer) + closing += serialize_non_framed_close( + tag=self.encryptor.tag, signer=self.signer + ) if self.signer is not None: closing += serialize_footer(self.signer) @@ -574,7 +644,11 @@ def _read_bytes_to_framed_body(self, b): if b > 0: _frames_to_read = math.ceil(b / float(self.config.frame_length)) b = int(_frames_to_read * self.config.frame_length) - _LOGGER.debug("%d bytes requested; reading %d bytes after normalizing to frame length", _b, b) + _LOGGER.debug( + "%d bytes requested; reading %d bytes after normalizing to frame length", + _b, + b, + ) plaintext = self.source_stream.read(b) plaintext_length = len(plaintext) @@ -596,7 +670,9 @@ def _read_bytes_to_framed_body(self, b): or (finalize and not final_frame_written) ): current_plaintext_length = len(plaintext) - is_final_frame = finalize and current_plaintext_length < self.config.frame_length + is_final_frame = ( + finalize and current_plaintext_length < self.config.frame_length + ) bytes_in_frame = min(current_plaintext_length, self.config.frame_length) _LOGGER.debug( "Writing %d bytes into%s frame %d", @@ -632,7 +708,9 @@ def _read_bytes(self, b): :param int b: Number of bytes to read :raises NotSupportedError: if content type is not supported """ - _LOGGER.debug("%d bytes requested from stream with content type: %s", b, self.content_type) + _LOGGER.debug( + "%d bytes requested from stream with content type: %s", b, self.content_type + ) if 0 <= b <= len(self.output_buffer) or self.__message_complete: _LOGGER.debug("No need to read from source stream or source stream closed") return @@ -653,7 +731,8 @@ def _read_bytes(self, b): if self._bytes_encrypted > self.config.source_length: raise CustomMaximumValueExceeded( "Bytes encrypted has exceeded stated source length estimate:\n{actual:d} > {estimated:d}".format( - actual=self._bytes_encrypted, estimated=self.config.source_length + actual=self._bytes_encrypted, + estimated=self.config.source_length, ) ) @@ -686,11 +765,17 @@ class DecryptorConfig(_ClientConfig): """ max_body_length = attr.ib( - hash=True, default=None, validator=attr.validators.optional(attr.validators.instance_of(six.integer_types)) + hash=True, + default=None, + validator=attr.validators.optional( + attr.validators.instance_of(six.integer_types) + ), ) -class StreamDecryptor(_EncryptionStream): # pylint: disable=too-many-instance-attributes +class StreamDecryptor( + _EncryptionStream +): # pylint: disable=too-many-instance-attributes """Provides a streaming encryptor for encrypting a stream source. Behaves as a standard file-like object. @@ -723,7 +808,9 @@ class StreamDecryptor(_EncryptionStream): # pylint: disable=too-many-instance-a _config_class = DecryptorConfig - def __init__(self, **kwargs): # pylint: disable=unused-argument,super-init-not-called + def __init__( + self, **kwargs + ): # pylint: disable=unused-argument,super-init-not-called """Prepares necessary initial values.""" self.last_sequence_number = 0 self.__unframed_bytes_read = 0 @@ -762,23 +849,35 @@ def _read_header(self): algorithm=header.algorithm, encryption_context=header.encryption_context, ) - decryption_materials = self.config.materials_manager.decrypt_materials(request=decrypt_materials_request) + decryption_materials = self.config.materials_manager.decrypt_materials( + request=decrypt_materials_request + ) if decryption_materials.verification_key is None: self.verifier = None else: self.verifier = Verifier.from_key_bytes( - algorithm=header.algorithm, key_bytes=decryption_materials.verification_key + algorithm=header.algorithm, + key_bytes=decryption_materials.verification_key, ) if self.verifier is not None: self.verifier.update(raw_header) header_auth = deserialize_header_auth( - stream=self.source_stream, algorithm=header.algorithm, verifier=self.verifier + stream=self.source_stream, + algorithm=header.algorithm, + verifier=self.verifier, ) self._derived_data_key = derive_data_encryption_key( - source_key=decryption_materials.data_key.data_key, algorithm=header.algorithm, message_id=header.message_id + source_key=decryption_materials.data_key.data_key, + algorithm=header.algorithm, + message_id=header.message_id, + ) + validate_header( + header=header, + header_auth=header_auth, + raw_header=raw_header, + data_key=self._derived_data_key, ) - validate_header(header=header, header_auth=header_auth, raw_header=raw_header, data_key=self._derived_data_key) return header, header_auth def _prep_non_framed(self): @@ -787,7 +886,10 @@ def _prep_non_framed(self): stream=self.source_stream, header=self._header, verifier=self.verifier ) - if self.config.max_body_length is not None and self.body_length > self.config.max_body_length: + if ( + self.config.max_body_length is not None + and self.body_length > self.config.max_body_length + ): raise CustomMaximumValueExceeded( "Non-framed message content length found larger than custom value: {found:d} > {custom:d}".format( found=self.body_length, custom=self.config.max_body_length @@ -814,12 +916,16 @@ def _read_bytes_from_non_framed_body(self, b): ciphertext = self.source_stream.read(bytes_to_read) if len(self.output_buffer) + len(ciphertext) < self.body_length: - raise SerializationError("Total message body contents less than specified in body description") + raise SerializationError( + "Total message body contents less than specified in body description" + ) if self.verifier is not None: self.verifier.update(ciphertext) - tag = deserialize_tag(stream=self.source_stream, header=self._header, verifier=self.verifier) + tag = deserialize_tag( + stream=self.source_stream, header=self._header, verifier=self.verifier + ) aad_content_string = aws_encryption_sdk.internal.utils.get_aad_content_string( content_type=self._header.content_type, is_final_frame=True @@ -841,7 +947,9 @@ def _read_bytes_from_non_framed_body(self, b): plaintext = self.decryptor.update(ciphertext) plaintext += self.decryptor.finalize() - self.footer = deserialize_footer(stream=self.source_stream, verifier=self.verifier) + self.footer = deserialize_footer( + stream=self.source_stream, verifier=self.verifier + ) return plaintext def _read_bytes_from_framed_body(self, b): @@ -864,7 +972,8 @@ def _read_bytes_from_framed_body(self, b): raise SerializationError("Malformed message: frames out of order") self.last_sequence_number += 1 aad_content_string = aws_encryption_sdk.internal.utils.get_aad_content_string( - content_type=self._header.content_type, is_final_frame=frame_data.final_frame + content_type=self._header.content_type, + is_final_frame=frame_data.final_frame, ) associated_data = assemble_content_aad( message_id=self._header.message_id, @@ -882,7 +991,9 @@ def _read_bytes_from_framed_body(self, b): _LOGGER.debug("bytes collected: %d", plaintext_length) if final_frame: _LOGGER.debug("Reading footer") - self.footer = deserialize_footer(stream=self.source_stream, verifier=self.verifier) + self.footer = deserialize_footer( + stream=self.source_stream, verifier=self.verifier + ) return plaintext @@ -898,7 +1009,11 @@ def _read_bytes(self, b): buffer_length = len(self.output_buffer) if 0 <= b <= buffer_length: - _LOGGER.debug("%d bytes requested less than or equal to current output buffer size %d", b, buffer_length) + _LOGGER.debug( + "%d bytes requested less than or equal to current output buffer size %d", + b, + buffer_length, + ) return if self._header.content_type == ContentType.FRAMED_DATA: diff --git a/src/aws_encryption_sdk/structures.py b/src/aws_encryption_sdk/structures.py index 8229d65fb..f26ad70ee 100644 --- a/src/aws_encryption_sdk/structures.py +++ b/src/aws_encryption_sdk/structures.py @@ -26,8 +26,8 @@ class MessageHeader(object): :type version: aws_encryption_sdk.identifiers.SerializationVersion :param type: Message content type, per spec :type type: aws_encryption_sdk.identifiers.ObjectType - :param algorithm: Algorithm to use for encryption - :type algorithm: aws_encryption_sdk.identifiers.Algorithm + :param algorithm: Algorithm suite to use for encryption + :type algorithm: aws_encryption_sdk.identifiers.AlgorithmSuite :param bytes message_id: Message ID :param dict encryption_context: Dictionary defining encryption context :param encrypted_data_keys: Encrypted data keys @@ -40,17 +40,41 @@ class MessageHeader(object): """ version = attr.ib( - hash=True, validator=attr.validators.instance_of(aws_encryption_sdk.identifiers.SerializationVersion) + hash=True, + validator=attr.validators.instance_of( + aws_encryption_sdk.identifiers.SerializationVersion + ), + ) + type = attr.ib( + hash=True, + validator=attr.validators.instance_of( + aws_encryption_sdk.identifiers.ObjectType + ), + ) + algorithm = attr.ib( + hash=True, + validator=attr.validators.instance_of( + aws_encryption_sdk.identifiers.AlgorithmSuite + ), ) - type = attr.ib(hash=True, validator=attr.validators.instance_of(aws_encryption_sdk.identifiers.ObjectType)) - algorithm = attr.ib(hash=True, validator=attr.validators.instance_of(aws_encryption_sdk.identifiers.Algorithm)) message_id = attr.ib(hash=True, validator=attr.validators.instance_of(bytes)) encryption_context = attr.ib(hash=True, validator=attr.validators.instance_of(dict)) encrypted_data_keys = attr.ib(hash=True, validator=attr.validators.instance_of(set)) - content_type = attr.ib(hash=True, validator=attr.validators.instance_of(aws_encryption_sdk.identifiers.ContentType)) - content_aad_length = attr.ib(hash=True, validator=attr.validators.instance_of(six.integer_types)) - header_iv_length = attr.ib(hash=True, validator=attr.validators.instance_of(six.integer_types)) - frame_length = attr.ib(hash=True, validator=attr.validators.instance_of(six.integer_types)) + content_type = attr.ib( + hash=True, + validator=attr.validators.instance_of( + aws_encryption_sdk.identifiers.ContentType + ), + ) + content_aad_length = attr.ib( + hash=True, validator=attr.validators.instance_of(six.integer_types) + ) + header_iv_length = attr.ib( + hash=True, validator=attr.validators.instance_of(six.integer_types) + ) + frame_length = attr.ib( + hash=True, validator=attr.validators.instance_of(six.integer_types) + ) @attr.s(hash=True) @@ -61,8 +85,16 @@ class MasterKeyInfo(object): :param bytes key_info: MasterKey key_info value """ - provider_id = attr.ib(hash=True, validator=attr.validators.instance_of((six.string_types, bytes)), converter=to_str) - key_info = attr.ib(hash=True, validator=attr.validators.instance_of((six.string_types, bytes)), converter=to_bytes) + provider_id = attr.ib( + hash=True, + validator=attr.validators.instance_of((six.string_types, bytes)), + converter=to_str, + ) + key_info = attr.ib( + hash=True, + validator=attr.validators.instance_of((six.string_types, bytes)), + converter=to_bytes, + ) @attr.s(hash=True) @@ -74,8 +106,12 @@ class RawDataKey(object): :param bytes data_key: Plaintext data key """ - key_provider = attr.ib(hash=True, validator=attr.validators.instance_of(MasterKeyInfo)) - data_key = attr.ib(hash=True, repr=False, validator=attr.validators.instance_of(bytes)) + key_provider = attr.ib( + hash=True, validator=attr.validators.instance_of(MasterKeyInfo) + ) + data_key = attr.ib( + hash=True, repr=False, validator=attr.validators.instance_of(bytes) + ) @attr.s(hash=True) @@ -88,9 +124,15 @@ class DataKey(object): :param bytes encrypted_data_key: Encrypted data key """ - key_provider = attr.ib(hash=True, validator=attr.validators.instance_of(MasterKeyInfo)) - data_key = attr.ib(hash=True, repr=False, validator=attr.validators.instance_of(bytes)) - encrypted_data_key = attr.ib(hash=True, validator=attr.validators.instance_of(bytes)) + key_provider = attr.ib( + hash=True, validator=attr.validators.instance_of(MasterKeyInfo) + ) + data_key = attr.ib( + hash=True, repr=False, validator=attr.validators.instance_of(bytes) + ) + encrypted_data_key = attr.ib( + hash=True, validator=attr.validators.instance_of(bytes) + ) @attr.s(hash=True) @@ -102,5 +144,9 @@ class EncryptedDataKey(object): :param bytes encrypted_data_key: Encrypted data key """ - key_provider = attr.ib(hash=True, validator=attr.validators.instance_of(MasterKeyInfo)) - encrypted_data_key = attr.ib(hash=True, validator=attr.validators.instance_of(bytes)) + key_provider = attr.ib( + hash=True, validator=attr.validators.instance_of(MasterKeyInfo) + ) + encrypted_data_key = attr.ib( + hash=True, validator=attr.validators.instance_of(bytes) + ) diff --git a/test/integration/test_i_aws_encrytion_sdk_client.py b/test/integration/test_i_aws_encrytion_sdk_client.py index 18d541ddf..2247c3041 100644 --- a/test/integration/test_i_aws_encrytion_sdk_client.py +++ b/test/integration/test_i_aws_encrytion_sdk_client.py @@ -67,11 +67,18 @@ def test_encrypt_verify_user_agent_kms_master_key(caplog): def test_remove_bad_client(): test = KMSMasterKeyProvider() test.add_regional_client("us-fakey-12") - with pytest.raises(BotoCoreError): test._regional_clients["us-fakey-12"].list_keys() - assert not test._regional_clients + # I believe that because KMSMasterKeyProvider() sets a default regional client + # we want to test that the fake key was properly removed, instead of the dict (of regional clients) + # being empty. That is to say, after the first line of this test function + # the dict is NOT EMPTY, and this default first value will stay with us, so + # if we test for emptiness of the dict then we will get a non-passing test, when really + # it might be passing. The old line is commented out in case it matters later. + + # assert not test._regional_clients + assert "us-fakey-12" not in test._regional_clients class TestKMSThickClientIntegration(object): diff --git a/test/unit/test_crypto_elliptic_curve.py b/test/unit/test_crypto_elliptic_curve.py index b030db5c2..fadf9c35d 100644 --- a/test/unit/test_crypto_elliptic_curve.py +++ b/test/unit/test_crypto_elliptic_curve.py @@ -38,7 +38,9 @@ @pytest.yield_fixture def patch_default_backend(mocker): - mocker.patch.object(aws_encryption_sdk.internal.crypto.elliptic_curve, "default_backend") + mocker.patch.object( + aws_encryption_sdk.internal.crypto.elliptic_curve, "default_backend" + ) yield aws_encryption_sdk.internal.crypto.elliptic_curve.default_backend @@ -56,31 +58,42 @@ def patch_pow(mocker): @pytest.yield_fixture def patch_encode_dss_signature(mocker): - mocker.patch.object(aws_encryption_sdk.internal.crypto.elliptic_curve, "encode_dss_signature") + mocker.patch.object( + aws_encryption_sdk.internal.crypto.elliptic_curve, "encode_dss_signature" + ) yield aws_encryption_sdk.internal.crypto.elliptic_curve.encode_dss_signature @pytest.yield_fixture def patch_decode_dss_signature(mocker): - mocker.patch.object(aws_encryption_sdk.internal.crypto.elliptic_curve, "decode_dss_signature") + mocker.patch.object( + aws_encryption_sdk.internal.crypto.elliptic_curve, "decode_dss_signature" + ) yield aws_encryption_sdk.internal.crypto.elliptic_curve.decode_dss_signature @pytest.yield_fixture def patch_ecc_decode_compressed_point(mocker): - mocker.patch.object(aws_encryption_sdk.internal.crypto.elliptic_curve, "_ecc_decode_compressed_point") + mocker.patch.object( + aws_encryption_sdk.internal.crypto.elliptic_curve, + "_ecc_decode_compressed_point", + ) yield aws_encryption_sdk.internal.crypto.elliptic_curve._ecc_decode_compressed_point @pytest.yield_fixture def patch_verify_interface(mocker): - mocker.patch.object(aws_encryption_sdk.internal.crypto.elliptic_curve, "verify_interface") + mocker.patch.object( + aws_encryption_sdk.internal.crypto.elliptic_curve, "verify_interface" + ) yield aws_encryption_sdk.internal.crypto.elliptic_curve.verify_interface @pytest.yield_fixture def patch_ecc_curve_parameters(mocker): - mocker.patch.object(aws_encryption_sdk.internal.crypto.elliptic_curve, "_ECC_CURVE_PARAMETERS") + mocker.patch.object( + aws_encryption_sdk.internal.crypto.elliptic_curve, "_ECC_CURVE_PARAMETERS" + ) yield aws_encryption_sdk.internal.crypto.elliptic_curve._ECC_CURVE_PARAMETERS @@ -102,9 +115,45 @@ def test_ecc_curve_not_in_cryptography(): def test_ecc_curve_parameters_secp256r1(): """Verify values from http://www.secg.org/sec2-v2.pdf""" p = pow(2, 224) * (pow(2, 32) - 1) + pow(2, 192) + pow(2, 96) - 1 - a = int(("FFFFFFFF" "00000001" "00000000" "00000000" "00000000" "FFFFFFFF" "FFFFFFFF" "FFFFFFFC"), 16) - b = int(("5AC635D8" "AA3A93E7" "B3EBBD55" "769886BC" "651D06B0" "CC53B0F6" "3BCE3C3E" "27D2604B"), 16) - order = int(("FFFFFFFF" "00000000" "FFFFFFFF" "FFFFFFFF" "BCE6FAAD" "A7179E84" "F3B9CAC2" "FC632551"), 16) + a = int( + ( + "FFFFFFFF" + "00000001" + "00000000" + "00000000" + "00000000" + "FFFFFFFF" + "FFFFFFFF" + "FFFFFFFC" + ), + 16, + ) + b = int( + ( + "5AC635D8" + "AA3A93E7" + "B3EBBD55" + "769886BC" + "651D06B0" + "CC53B0F6" + "3BCE3C3E" + "27D2604B" + ), + 16, + ) + order = int( + ( + "FFFFFFFF" + "00000000" + "FFFFFFFF" + "FFFFFFFF" + "BCE6FAAD" + "A7179E84" + "F3B9CAC2" + "FC632551" + ), + 16, + ) assert _ECC_CURVE_PARAMETERS["secp256r1"].p == p assert _ECC_CURVE_PARAMETERS["secp256r1"].a == a assert _ECC_CURVE_PARAMETERS["secp256r1"].b == b @@ -247,22 +296,34 @@ def test_ecc_curve_parameters_secp521r1(): def test_ecc_static_length_signature_first_try( - patch_default_backend, patch_ec, patch_encode_dss_signature, patch_decode_dss_signature, patch_prehashed + patch_default_backend, + patch_ec, + patch_encode_dss_signature, + patch_decode_dss_signature, + patch_prehashed, ): algorithm = MagicMock(signature_len=55) private_key = MagicMock() private_key.sign.return_value = b"a" * 55 - test_signature = _ecc_static_length_signature(key=private_key, algorithm=algorithm, digest=sentinel.digest) + test_signature = _ecc_static_length_signature( + key=private_key, algorithm=algorithm, digest=sentinel.digest + ) patch_prehashed.assert_called_once_with(algorithm.signing_hash_type.return_value) patch_ec.ECDSA.assert_called_once_with(patch_prehashed.return_value) - private_key.sign.assert_called_once_with(sentinel.digest, patch_ec.ECDSA.return_value) + private_key.sign.assert_called_once_with( + sentinel.digest, patch_ec.ECDSA.return_value + ) assert not patch_encode_dss_signature.called assert not patch_decode_dss_signature.called assert test_signature is private_key.sign.return_value def test_ecc_static_length_signature_single_negation( - patch_default_backend, patch_ec, patch_encode_dss_signature, patch_decode_dss_signature, patch_prehashed + patch_default_backend, + patch_ec, + patch_encode_dss_signature, + patch_decode_dss_signature, + patch_prehashed, ): algorithm = MagicMock(signature_len=55) algorithm.signing_algorithm_info.name = "secp256r1" @@ -270,15 +331,23 @@ def test_ecc_static_length_signature_single_negation( private_key.sign.return_value = b"a" patch_decode_dss_signature.return_value = sentinel.r, 100 patch_encode_dss_signature.return_value = "a" * 55 - test_signature = _ecc_static_length_signature(key=private_key, algorithm=algorithm, digest=sentinel.digest) + test_signature = _ecc_static_length_signature( + key=private_key, algorithm=algorithm, digest=sentinel.digest + ) assert len(private_key.sign.mock_calls) == 1 patch_decode_dss_signature.assert_called_once_with(b"a") - patch_encode_dss_signature.assert_called_once_with(sentinel.r, _ECC_CURVE_PARAMETERS["secp256r1"].order - 100) + patch_encode_dss_signature.assert_called_once_with( + sentinel.r, _ECC_CURVE_PARAMETERS["secp256r1"].order - 100 + ) assert test_signature is patch_encode_dss_signature.return_value def test_ecc_static_length_signature_recalculate( - patch_default_backend, patch_ec, patch_encode_dss_signature, patch_decode_dss_signature, patch_prehashed + patch_default_backend, + patch_ec, + patch_encode_dss_signature, + patch_decode_dss_signature, + patch_prehashed, ): algorithm = MagicMock(signature_len=55) algorithm.signing_algorithm_info.name = "secp256r1" @@ -286,7 +355,9 @@ def test_ecc_static_length_signature_recalculate( private_key.sign.side_effect = (b"a", b"b" * 55) patch_decode_dss_signature.return_value = sentinel.r, 100 patch_encode_dss_signature.return_value = "a" * 100 - test_signature = _ecc_static_length_signature(key=private_key, algorithm=algorithm, digest=sentinel.digest) + test_signature = _ecc_static_length_signature( + key=private_key, algorithm=algorithm, digest=sentinel.digest + ) assert len(private_key.sign.mock_calls) == 2 assert len(patch_decode_dss_signature.mock_calls) == 1 assert len(patch_encode_dss_signature.mock_calls) == 1 @@ -294,7 +365,9 @@ def test_ecc_static_length_signature_recalculate( def test_ecc_encode_compressed_point_prime(): - compressed_point = _ecc_encode_compressed_point(private_key=VALUES["ecc_private_key_prime"]) + compressed_point = _ecc_encode_compressed_point( + private_key=VALUES["ecc_private_key_prime"] + ) assert compressed_point == VALUES["ecc_compressed_point"] @@ -313,82 +386,110 @@ def test_ecc_decode_compressed_point_infinity(): def test_ecc_decode_compressed_point_prime(): - x, y = _ecc_decode_compressed_point(curve=ec.SECP384R1(), compressed_point=VALUES["ecc_compressed_point"]) + x, y = _ecc_decode_compressed_point( + curve=ec.SECP384R1(), compressed_point=VALUES["ecc_compressed_point"] + ) numbers = VALUES["ecc_private_key_prime"].public_key().public_numbers() assert x == numbers.x assert y == numbers.y @pytest.mark.skipif( - sys.version_info.major == 3 and sys.version_info.minor == 4, reason='Patching builtin "pow" fails in Python3.4' + sys.version_info.major == 3 and sys.version_info.minor == 4, + reason='Patching builtin "pow" fails in Python3.4', ) def test_ecc_decode_compressed_point_prime_characteristic_two(patch_pow): patch_pow.return_value = 1 - _, y = _ecc_decode_compressed_point(curve=ec.SECP384R1(), compressed_point=VALUES["ecc_compressed_point"]) + _, y = _ecc_decode_compressed_point( + curve=ec.SECP384R1(), compressed_point=VALUES["ecc_compressed_point"] + ) assert y == 1 @pytest.mark.skipif( - sys.version_info.major == 3 and sys.version_info.minor == 4, reason='Patching builtin "pow" fails in Python3.4' + sys.version_info.major == 3 and sys.version_info.minor == 4, + reason='Patching builtin "pow" fails in Python3.4', ) def test_ecc_decode_compressed_point_prime_not_characteristic_two(patch_pow): patch_pow.return_value = 0 - _, y = _ecc_decode_compressed_point(curve=ec.SECP384R1(), compressed_point=VALUES["ecc_compressed_point"]) + _, y = _ecc_decode_compressed_point( + curve=ec.SECP384R1(), compressed_point=VALUES["ecc_compressed_point"] + ) assert y == _ECC_CURVE_PARAMETERS["secp384r1"].p def test_ecc_decode_compressed_point_prime_unsupported(): with pytest.raises(NotSupportedError) as excinfo: - _ecc_decode_compressed_point(curve=ec.SECP192R1(), compressed_point="\x02skdgaiuhgijudflkjsdgfkjsdflgjhsd") + _ecc_decode_compressed_point( + curve=ec.SECP192R1(), + compressed_point="\x02skdgaiuhgijudflkjsdgfkjsdflgjhsd", + ) excinfo.match(r"Curve secp192r1 is not supported at this time") def test_ecc_decode_compressed_point_prime_complex(patch_ecc_curve_parameters): - patch_ecc_curve_parameters.__getitem__.return_value = _ECCCurveParameters(p=5, a=5, b=5, order=5) + patch_ecc_curve_parameters.__getitem__.return_value = _ECCCurveParameters( + p=5, a=5, b=5, order=5 + ) mock_curve = MagicMock() mock_curve.name = "secp_mock_curve" with pytest.raises(NotSupportedError) as excinfo: - _ecc_decode_compressed_point(curve=mock_curve, compressed_point=VALUES["ecc_compressed_point"]) + _ecc_decode_compressed_point( + curve=mock_curve, compressed_point=VALUES["ecc_compressed_point"] + ) excinfo.match(r"S not 1 :: Curve not supported at this time") def test_ecc_decode_compressed_point_nonprime_characteristic_two(): with pytest.raises(NotSupportedError) as excinfo: - _ecc_decode_compressed_point(curve=ec.SECT409K1(), compressed_point="\x02skdgaiuhgijudflkjsdgfkjsdflgjhsd") + _ecc_decode_compressed_point( + curve=ec.SECT409K1(), + compressed_point="\x02skdgaiuhgijudflkjsdgfkjsdflgjhsd", + ) excinfo.match(r"Non-prime curves are not supported at this time") -def test_ecc_public_numbers_from_compressed_point(patch_ec, patch_ecc_decode_compressed_point): +def test_ecc_public_numbers_from_compressed_point( + patch_ec, patch_ecc_decode_compressed_point +): patch_ecc_decode_compressed_point.return_value = sentinel.x, sentinel.y patch_ec.EllipticCurvePublicNumbers.return_value = sentinel.public_numbers_instance test = _ecc_public_numbers_from_compressed_point( curve=sentinel.curve_instance, compressed_point=sentinel.compressed_point ) - patch_ecc_decode_compressed_point.assert_called_once_with(sentinel.curve_instance, sentinel.compressed_point) + patch_ecc_decode_compressed_point.assert_called_once_with( + sentinel.curve_instance, sentinel.compressed_point + ) patch_ec.EllipticCurvePublicNumbers.assert_called_once_with( x=sentinel.x, y=sentinel.y, curve=sentinel.curve_instance ) assert test == sentinel.public_numbers_instance -def test_generate_ecc_signing_key_supported(patch_default_backend, patch_ec, patch_verify_interface): +def test_generate_ecc_signing_key_supported( + patch_default_backend, patch_ec, patch_verify_interface +): patch_ec.generate_private_key.return_value = sentinel.raw_signing_key mock_algorithm_info = MagicMock(return_value=sentinel.algorithm_info) mock_algorithm = MagicMock(signing_algorithm_info=mock_algorithm_info) test_signing_key = generate_ecc_signing_key(algorithm=mock_algorithm) - patch_verify_interface.assert_called_once_with(patch_ec.EllipticCurve, mock_algorithm_info) + patch_verify_interface.assert_called_once_with( + patch_ec.EllipticCurve, mock_algorithm_info + ) patch_ec.generate_private_key.assert_called_once_with( curve=sentinel.algorithm_info, backend=patch_default_backend.return_value ) assert test_signing_key is sentinel.raw_signing_key -def test_generate_ecc_signing_key_unsupported(patch_default_backend, patch_ec, patch_verify_interface): +def test_generate_ecc_signing_key_unsupported( + patch_default_backend, patch_ec, patch_verify_interface +): patch_verify_interface.side_effect = InterfaceNotImplemented mock_algorithm_info = MagicMock(return_value=sentinel.algorithm_info) mock_algorithm = MagicMock(signing_algorithm_info=mock_algorithm_info) From c62ed6824e00798b8971f7c2ca27074df3d9b91a Mon Sep 17 00:00:00 2001 From: Adriano Hernandez Date: Tue, 6 Aug 2019 17:43:04 -0700 Subject: [PATCH 3/6] Formatted better to pass linters. --- src/aws_encryption_sdk/__init__.py | 20 +- src/aws_encryption_sdk/identifiers.py | 64 +---- .../internal/crypto/authentication.py | 23 +- .../internal/crypto/elliptic_curve.py | 29 +-- .../internal/crypto/encryption.py | 12 +- src/aws_encryption_sdk/internal/defaults.py | 4 +- .../internal/formatting/serialize.py | 61 +---- .../internal/utils/__init__.py | 34 +-- src/aws_encryption_sdk/key_providers/base.py | 103 ++------ src/aws_encryption_sdk/key_providers/kms.py | 76 ++---- src/aws_encryption_sdk/key_providers/raw.py | 33 +-- .../materials_managers/__init__.py | 23 +- .../materials_managers/default.py | 49 +--- src/aws_encryption_sdk/streaming_client.py | 209 ++++----------- src/aws_encryption_sdk/structures.py | 78 ++---- .../test_f_aws_encryption_sdk_client.py | 239 ++++-------------- test/functional/test_f_crypto.py | 39 +-- test/functional/test_f_xcompat.py | 50 +--- .../test_i_aws_encrytion_sdk_client.py | 134 +++------- test/integration/test_i_xcompat_kms.py | 8 +- test/unit/test_crypto_elliptic_curve.py | 165 +++--------- 21 files changed, 323 insertions(+), 1130 deletions(-) diff --git a/src/aws_encryption_sdk/__init__.py b/src/aws_encryption_sdk/__init__.py index 85e02884b..3aae2c9c7 100644 --- a/src/aws_encryption_sdk/__init__.py +++ b/src/aws_encryption_sdk/__init__.py @@ -15,16 +15,9 @@ from aws_encryption_sdk.caches.local import LocalCryptoMaterialsCache # noqa from aws_encryption_sdk.caches.null import NullCryptoMaterialsCache # noqa from aws_encryption_sdk.identifiers import AlgorithmSuite, __version__ # noqa -from aws_encryption_sdk.key_providers.kms import ( - KMSMasterKeyProvider, - KMSMasterKeyProviderConfig, -) # noqa -from aws_encryption_sdk.materials_managers.caching import ( - CachingCryptoMaterialsManager, -) # noqa -from aws_encryption_sdk.materials_managers.default import ( - DefaultCryptoMaterialsManager, -) # noqa +from aws_encryption_sdk.key_providers.kms import KMSMasterKeyProvider, KMSMasterKeyProviderConfig # noqa +from aws_encryption_sdk.materials_managers.caching import CachingCryptoMaterialsManager # noqa +from aws_encryption_sdk.materials_managers.default import DefaultCryptoMaterialsManager # noqa from aws_encryption_sdk.streaming_client import ( # noqa DecryptorConfig, EncryptorConfig, @@ -184,12 +177,7 @@ def stream(**kwargs): :raises ValueError: if supplied with an unsupported mode value """ mode = kwargs.pop("mode") - _stream_map = { - "e": StreamEncryptor, - "encrypt": StreamEncryptor, - "d": StreamDecryptor, - "decrypt": StreamDecryptor, - } + _stream_map = {"e": StreamEncryptor, "encrypt": StreamEncryptor, "d": StreamDecryptor, "decrypt": StreamDecryptor} try: return _stream_map[mode.lower()](**kwargs) except KeyError: diff --git a/src/aws_encryption_sdk/identifiers.py b/src/aws_encryption_sdk/identifiers.py index 663884227..b8ec35910 100644 --- a/src/aws_encryption_sdk/identifiers.py +++ b/src/aws_encryption_sdk/identifiers.py @@ -50,15 +50,7 @@ class EncryptionSuite(Enum): AES_192_GCM_IV12_TAG16 = (algorithms.AES, modes.GCM, 24, 12, 16) AES_256_GCM_IV12_TAG16 = (algorithms.AES, modes.GCM, 32, 12, 16) - def __init__( - self, - algorithm, - mode, - data_key_length, - iv_length, - auth_length, - auth_key_length=0, - ): + def __init__(self, algorithm, mode, data_key_length, iv_length, auth_length, auth_key_length=0): """Prepare a new EncryptionSuite.""" self.algorithm = algorithm self.mode = mode @@ -165,21 +157,9 @@ class AlgorithmSuite(Enum): # pylint: disable=too-many-instance-attributes AES_128_GCM_IV12_TAG16 = (0x0014, EncryptionSuite.AES_128_GCM_IV12_TAG16) AES_192_GCM_IV12_TAG16 = (0x0046, EncryptionSuite.AES_192_GCM_IV12_TAG16) AES_256_GCM_IV12_TAG16 = (0x0078, EncryptionSuite.AES_256_GCM_IV12_TAG16) - AES_128_GCM_IV12_TAG16_HKDF_SHA256 = ( - 0x0114, - EncryptionSuite.AES_128_GCM_IV12_TAG16, - KDFSuite.HKDF_SHA256, - ) - AES_192_GCM_IV12_TAG16_HKDF_SHA256 = ( - 0x0146, - EncryptionSuite.AES_192_GCM_IV12_TAG16, - KDFSuite.HKDF_SHA256, - ) - AES_256_GCM_IV12_TAG16_HKDF_SHA256 = ( - 0x0178, - EncryptionSuite.AES_256_GCM_IV12_TAG16, - KDFSuite.HKDF_SHA256, - ) + AES_128_GCM_IV12_TAG16_HKDF_SHA256 = (0x0114, EncryptionSuite.AES_128_GCM_IV12_TAG16, KDFSuite.HKDF_SHA256) + AES_192_GCM_IV12_TAG16_HKDF_SHA256 = (0x0146, EncryptionSuite.AES_192_GCM_IV12_TAG16, KDFSuite.HKDF_SHA256) + AES_256_GCM_IV12_TAG16_HKDF_SHA256 = (0x0178, EncryptionSuite.AES_256_GCM_IV12_TAG16, KDFSuite.HKDF_SHA256) AES_128_GCM_IV12_TAG16_HKDF_SHA256_ECDSA_P256 = ( 0x0214, EncryptionSuite.AES_128_GCM_IV12_TAG16, @@ -315,38 +295,12 @@ class WrappingAlgorithm(Enum): None, ) RSA_PKCS1 = (EncryptionType.ASYMMETRIC, rsa, padding.PKCS1v15, None, None) - RSA_OAEP_SHA1_MGF1 = ( - EncryptionType.ASYMMETRIC, - rsa, - padding.OAEP, - hashes.SHA1, - padding.MGF1, - ) - RSA_OAEP_SHA256_MGF1 = ( - EncryptionType.ASYMMETRIC, - rsa, - padding.OAEP, - hashes.SHA256, - padding.MGF1, - ) - RSA_OAEP_SHA384_MGF1 = ( - EncryptionType.ASYMMETRIC, - rsa, - padding.OAEP, - hashes.SHA384, - padding.MGF1, - ) - RSA_OAEP_SHA512_MGF1 = ( - EncryptionType.ASYMMETRIC, - rsa, - padding.OAEP, - hashes.SHA512, - padding.MGF1, - ) + RSA_OAEP_SHA1_MGF1 = (EncryptionType.ASYMMETRIC, rsa, padding.OAEP, hashes.SHA1, padding.MGF1) + RSA_OAEP_SHA256_MGF1 = (EncryptionType.ASYMMETRIC, rsa, padding.OAEP, hashes.SHA256, padding.MGF1) + RSA_OAEP_SHA384_MGF1 = (EncryptionType.ASYMMETRIC, rsa, padding.OAEP, hashes.SHA384, padding.MGF1) + RSA_OAEP_SHA512_MGF1 = (EncryptionType.ASYMMETRIC, rsa, padding.OAEP, hashes.SHA512, padding.MGF1) - def __init__( - self, encryption_type, algorithm, padding_type, padding_algorithm, padding_mgf - ): + def __init__(self, encryption_type, algorithm, padding_type, padding_algorithm, padding_mgf): """Prepares new WrappingAlgorithm.""" self.encryption_type = encryption_type self.algorithm = algorithm diff --git a/src/aws_encryption_sdk/internal/crypto/authentication.py b/src/aws_encryption_sdk/internal/crypto/authentication.py index 67b938d6e..dc9929bf7 100644 --- a/src/aws_encryption_sdk/internal/crypto/authentication.py +++ b/src/aws_encryption_sdk/internal/crypto/authentication.py @@ -58,9 +58,7 @@ def _build_hasher(self): :returns: Hasher object """ - return hashes.Hash( - self.algorithm.signing_hash_type(), backend=default_backend() - ) + return hashes.Hash(self.algorithm.signing_hash_type(), backend=default_backend()) class Signer(_PrehashingAuthenticator): @@ -81,9 +79,7 @@ def from_key_bytes(cls, algorithm, key_bytes): :param bytes key_bytes: Raw signing key :rtype: aws_encryption_sdk.internal.crypto.Signer """ - key = serialization.load_der_private_key( - data=key_bytes, password=None, backend=default_backend() - ) + key = serialization.load_der_private_key(data=key_bytes, password=None, backend=default_backend()) return cls(algorithm, key) def key_bytes(self): @@ -122,9 +118,7 @@ def finalize(self): :rtype: bytes """ prehashed_digest = self._hasher.finalize() - return _ecc_static_length_signature( - key=self.key, algorithm=self.algorithm, digest=prehashed_digest - ) + return _ecc_static_length_signature(key=self.key, algorithm=self.algorithm, digest=prehashed_digest) class Verifier(_PrehashingAuthenticator): @@ -152,8 +146,7 @@ def from_encoded_point(cls, algorithm, encoded_point): return cls( algorithm=algorithm, key=_ecc_public_numbers_from_compressed_point( - curve=algorithm.signing_algorithm_info(), - compressed_point=base64.b64decode(encoded_point), + curve=algorithm.signing_algorithm_info(), compressed_point=base64.b64decode(encoded_point) ).public_key(default_backend()), ) @@ -168,10 +161,7 @@ def from_key_bytes(cls, algorithm, key_bytes): :rtype: aws_encryption_sdk.internal.crypto.Verifier """ return cls( - algorithm=algorithm, - key=serialization.load_der_public_key( - data=key_bytes, backend=default_backend() - ), + algorithm=algorithm, key=serialization.load_der_public_key(data=key_bytes, backend=default_backend()) ) def key_bytes(self): @@ -180,8 +170,7 @@ def key_bytes(self): :rtype: bytes """ return self.key.public_bytes( - encoding=serialization.Encoding.DER, - format=serialization.PublicFormat.SubjectPublicKeyInfo, + encoding=serialization.Encoding.DER, format=serialization.PublicFormat.SubjectPublicKeyInfo ) def update(self, data): diff --git a/src/aws_encryption_sdk/internal/crypto/elliptic_curve.py b/src/aws_encryption_sdk/internal/crypto/elliptic_curve.py index b35ae1c22..56f49976e 100644 --- a/src/aws_encryption_sdk/internal/crypto/elliptic_curve.py +++ b/src/aws_encryption_sdk/internal/crypto/elliptic_curve.py @@ -17,17 +17,8 @@ import six from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.asymmetric import ec -from cryptography.hazmat.primitives.asymmetric.utils import ( - Prehashed, - decode_dss_signature, - encode_dss_signature, -) -from cryptography.utils import ( - InterfaceNotImplemented, - int_from_bytes, - int_to_bytes, - verify_interface, -) +from cryptography.hazmat.primitives.asymmetric.utils import Prehashed, decode_dss_signature, encode_dss_signature +from cryptography.utils import InterfaceNotImplemented, int_from_bytes, int_to_bytes, verify_interface from ...exceptions import NotSupportedError from ..str_ops import to_bytes @@ -76,18 +67,14 @@ def _ecc_static_length_signature(key, algorithm, digest): signature = b"" while len(signature) != algorithm.signature_len: _LOGGER.debug( - "Signature length %d is not desired length %d. Recalculating.", - len(signature), - algorithm.signature_len, + "Signature length %d is not desired length %d. Recalculating.", len(signature), algorithm.signature_len ) signature = key.sign(digest, pre_hashed_algorithm) if len(signature) != algorithm.signature_len: # Most of the time, a signature of the wrong length can be fixed # by negating s in the signature relative to the group order. _LOGGER.debug( - "Signature length %d is not desired length %d. Negating s.", - len(signature), - algorithm.signature_len, + "Signature length %d is not desired length %d. Negating s.", len(signature), algorithm.signature_len ) r, s = decode_dss_signature(signature) s = _ECC_CURVE_PARAMETERS[algorithm.signing_algorithm_info.name].order - s @@ -149,9 +136,7 @@ def _ecc_decode_compressed_point(curve, compressed_point): try: params = _ECC_CURVE_PARAMETERS[curve.name] except KeyError: - raise NotSupportedError( - "Curve {name} is not supported at this time".format(name=curve.name) - ) + raise NotSupportedError("Curve {name} is not supported at this time".format(name=curve.name)) alpha = (pow(x, 3, params.p) + (params.a * x % params.p) + params.b) % params.p # Only works for p % 4 == 3 at this time. # This is the case for all currently supported algorithms. @@ -199,8 +184,6 @@ def generate_ecc_signing_key(algorithm): """ try: verify_interface(ec.EllipticCurve, algorithm.signing_algorithm_info) - return ec.generate_private_key( - curve=algorithm.signing_algorithm_info(), backend=default_backend() - ) + return ec.generate_private_key(curve=algorithm.signing_algorithm_info(), backend=default_backend()) except InterfaceNotImplemented: raise NotSupportedError("Unsupported signing algorithm info") diff --git a/src/aws_encryption_sdk/internal/crypto/encryption.py b/src/aws_encryption_sdk/internal/crypto/encryption.py index f549b483c..b6251e3ca 100644 --- a/src/aws_encryption_sdk/internal/crypto/encryption.py +++ b/src/aws_encryption_sdk/internal/crypto/encryption.py @@ -39,9 +39,7 @@ def __init__(self, algorithm, key, associated_data, iv): # This is intentionally generic to leave an option for non-Cipher encryptor types in the future. self.iv = iv self._encryptor = Cipher( - algorithm.encryption_algorithm(key), - algorithm.encryption_mode(self.iv), - backend=default_backend(), + algorithm.encryption_algorithm(key), algorithm.encryption_mode(self.iv), backend=default_backend() ).encryptor() # associated_data will be authenticated but not encrypted, @@ -110,9 +108,7 @@ def __init__(self, algorithm, key, associated_data, iv, tag): # Construct a decryptor object with the given key and a provided IV. # This is intentionally generic to leave an option for non-Cipher decryptor types in the future. self._decryptor = Cipher( - algorithm.encryption_algorithm(key), - algorithm.encryption_mode(iv, tag), - backend=default_backend(), + algorithm.encryption_algorithm(key), algorithm.encryption_mode(iv, tag), backend=default_backend() ).decryptor() # Put associated_data back in or the tag will fail to verify when the _decryptor is finalized. @@ -151,7 +147,5 @@ def decrypt(algorithm, key, encrypted_data, associated_data): :returns: Plaintext of body :rtype: bytes """ - decryptor = Decryptor( - algorithm, key, associated_data, encrypted_data.iv, encrypted_data.tag - ) + decryptor = Decryptor(algorithm, key, associated_data, encrypted_data.iv, encrypted_data.tag) return decryptor.update(encrypted_data.ciphertext) + decryptor.finalize() diff --git a/src/aws_encryption_sdk/internal/defaults.py b/src/aws_encryption_sdk/internal/defaults.py index 5e4c905bd..2dc4ae0c4 100644 --- a/src/aws_encryption_sdk/internal/defaults.py +++ b/src/aws_encryption_sdk/internal/defaults.py @@ -29,9 +29,7 @@ #: Default message structure Type as defined in specification TYPE = aws_encryption_sdk.identifiers.ObjectType.CUSTOMER_AE_DATA #: Default algorithm as defined in specification -ALGORITHM = ( - aws_encryption_sdk.identifiers.AlgorithmSuite.AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384 -) +ALGORITHM = aws_encryption_sdk.identifiers.AlgorithmSuite.AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384 #: Key to add encoded signing key to encryption context dictionary as defined in specification ENCODED_SIGNER_KEY = "aws-crypto-public-key" diff --git a/src/aws_encryption_sdk/internal/formatting/serialize.py b/src/aws_encryption_sdk/internal/formatting/serialize.py index 57bcdd39b..ee6e388c1 100644 --- a/src/aws_encryption_sdk/internal/formatting/serialize.py +++ b/src/aws_encryption_sdk/internal/formatting/serialize.py @@ -17,11 +17,7 @@ import aws_encryption_sdk.internal.defaults import aws_encryption_sdk.internal.formatting.encryption_context from aws_encryption_sdk.exceptions import SerializationError -from aws_encryption_sdk.identifiers import ( - ContentAADString, - EncryptionType, - SequenceIdentifier, -) +from aws_encryption_sdk.identifiers import ContentAADString, EncryptionType, SequenceIdentifier from aws_encryption_sdk.internal.crypto.encryption import encrypt from aws_encryption_sdk.internal.crypto.iv import frame_iv, header_auth_iv from aws_encryption_sdk.internal.str_ops import to_bytes @@ -114,12 +110,7 @@ def serialize_header(header, signer=None): "I" # frame length ) header_bytes.extend( - struct.pack( - header_close_format, - header.content_type.value, - header.algorithm.iv_len, - header.frame_length, - ) + struct.pack(header_close_format, header.content_type.value, header.algorithm.iv_len, header.frame_length) ) output = bytes(header_bytes) if signer is not None: @@ -147,9 +138,7 @@ def serialize_header_auth(algorithm, header, data_encryption_key, signer=None): iv=header_auth_iv(algorithm), ) output = struct.pack( - ">{iv_len}s{tag_len}s".format( - iv_len=algorithm.iv_len, tag_len=algorithm.tag_len - ), + ">{iv_len}s{tag_len}s".format(iv_len=algorithm.iv_len, tag_len=algorithm.tag_len), header_auth.iv, header_auth.tag, ) @@ -170,9 +159,7 @@ def serialize_non_framed_open(algorithm, iv, plaintext_length, signer=None): :returns: Serialized body start block :rtype: bytes """ - body_start_format = (">" "{iv_length}s" "Q").format( - iv_length=algorithm.iv_len - ) # nonce (IV) # content length + body_start_format = (">" "{iv_length}s" "Q").format(iv_length=algorithm.iv_len) # nonce (IV) # content length body_start = struct.pack(body_start_format, iv, plaintext_length) if signer: signer.update(body_start) @@ -195,14 +182,7 @@ def serialize_non_framed_close(tag, signer=None): def serialize_frame( - algorithm, - plaintext, - message_id, - data_encryption_key, - frame_length, - sequence_number, - is_final_frame, - signer=None, + algorithm, plaintext, message_id, data_encryption_key, frame_length, sequence_number, is_final_frame, signer=None ): """Receives a message plaintext, breaks off a frame, encrypts and serializes the frame, and returns the encrypted frame and the remaining plaintext. @@ -247,9 +227,7 @@ def serialize_frame( _LOGGER.debug("Serializing final frame") packed_frame = struct.pack( ">II{iv_len}sI{content_len}s{auth_len}s".format( - iv_len=algorithm.iv_len, - content_len=len(frame_ciphertext.ciphertext), - auth_len=algorithm.auth_len, + iv_len=algorithm.iv_len, content_len=len(frame_ciphertext.ciphertext), auth_len=algorithm.auth_len ), SequenceIdentifier.SEQUENCE_NUMBER_END.value, sequence_number, @@ -262,9 +240,7 @@ def serialize_frame( _LOGGER.debug("Serializing frame") packed_frame = struct.pack( ">I{iv_len}s{content_len}s{auth_len}s".format( - iv_len=algorithm.iv_len, - content_len=frame_length, - auth_len=algorithm.auth_len, + iv_len=algorithm.iv_len, content_len=frame_length, auth_len=algorithm.auth_len ), sequence_number, frame_ciphertext.iv, @@ -288,9 +264,7 @@ def serialize_footer(signer): footer = b"" if signer is not None: signature = signer.finalize() - footer = struct.pack( - ">H{sig_len}s".format(sig_len=len(signature)), len(signature), signature - ) + footer = struct.pack(">H{sig_len}s".format(sig_len=len(signature)), len(signature), signature) return footer @@ -303,10 +277,7 @@ def serialize_raw_master_key_prefix(raw_master_key): :returns: Serialized key_info prefix :rtype: bytes """ - if ( - raw_master_key.config.wrapping_key.wrapping_algorithm.encryption_type - is EncryptionType.ASYMMETRIC - ): + if raw_master_key.config.wrapping_key.wrapping_algorithm.encryption_type is EncryptionType.ASYMMETRIC: return to_bytes(raw_master_key.key_id) return struct.pack( ">{}sII".format(len(raw_master_key.key_id)), @@ -317,9 +288,7 @@ def serialize_raw_master_key_prefix(raw_master_key): ) -def serialize_wrapped_key( - key_provider, wrapping_algorithm, wrapping_key_id, encrypted_wrapped_key -): +def serialize_wrapped_key(key_provider, wrapping_algorithm, wrapping_key_id, encrypted_wrapped_key): """Serializes EncryptedData into a Wrapped EncryptedDataKey. :param key_provider: Info for Wrapping MasterKey @@ -338,19 +307,15 @@ def serialize_wrapped_key( else: key_info = struct.pack( ">{key_id_len}sII{iv_len}s".format( - key_id_len=len(wrapping_key_id), - iv_len=wrapping_algorithm.algorithm.iv_len, + key_id_len=len(wrapping_key_id), iv_len=wrapping_algorithm.algorithm.iv_len ), to_bytes(wrapping_key_id), - len(encrypted_wrapped_key.tag) - * 8, # Tag Length is stored in bits, not bytes + len(encrypted_wrapped_key.tag) * 8, # Tag Length is stored in bits, not bytes wrapping_algorithm.algorithm.iv_len, encrypted_wrapped_key.iv, ) key_ciphertext = encrypted_wrapped_key.ciphertext + encrypted_wrapped_key.tag return EncryptedDataKey( - key_provider=MasterKeyInfo( - provider_id=key_provider.provider_id, key_info=key_info - ), + key_provider=MasterKeyInfo(provider_id=key_provider.provider_id, key_info=key_info), encrypted_data_key=key_ciphertext, ) diff --git a/src/aws_encryption_sdk/internal/utils/__init__.py b/src/aws_encryption_sdk/internal/utils/__init__.py index d3087e24e..2288b1ebc 100644 --- a/src/aws_encryption_sdk/internal/utils/__init__.py +++ b/src/aws_encryption_sdk/internal/utils/__init__.py @@ -18,11 +18,7 @@ import six import aws_encryption_sdk.internal.defaults -from aws_encryption_sdk.exceptions import ( - InvalidDataKeyError, - SerializationError, - UnknownIdentityError, -) +from aws_encryption_sdk.exceptions import InvalidDataKeyError, SerializationError, UnknownIdentityError from aws_encryption_sdk.identifiers import ContentAADString, ContentType from aws_encryption_sdk.internal.str_ops import to_bytes from aws_encryption_sdk.structures import EncryptedDataKey @@ -54,10 +50,7 @@ def validate_frame_length(frame_length, algorithm): :raises SerializationError: if frame size is negative or not a multiple of the algorithm block size :raises SerializationError: if frame size is larger than the maximum allowed frame size """ - if ( - frame_length < 0 - or frame_length % algorithm.encryption_algorithm.block_size != 0 - ): + if frame_length < 0 or frame_length % algorithm.encryption_algorithm.block_size != 0: raise SerializationError( "Frame size must be a non-negative multiple of the block size of the crypto algorithm: {block_size}".format( block_size=algorithm.encryption_algorithm.block_size @@ -66,8 +59,7 @@ def validate_frame_length(frame_length, algorithm): if frame_length > aws_encryption_sdk.internal.defaults.MAX_FRAME_SIZE: raise SerializationError( "Frame size too large: {frame} > {max}".format( - frame=frame_length, - max=aws_encryption_sdk.internal.defaults.MAX_FRAME_SIZE, + frame=frame_length, max=aws_encryption_sdk.internal.defaults.MAX_FRAME_SIZE ) ) @@ -119,31 +111,21 @@ def prepare_data_keys(primary_master_key, master_keys, algorithm, encryption_con """ encrypted_data_keys = set() encrypted_data_encryption_key = None - data_encryption_key = primary_master_key.generate_data_key( - algorithm, encryption_context - ) - _LOGGER.debug( - "encryption data generated with master key: %s", - data_encryption_key.key_provider, - ) + data_encryption_key = primary_master_key.generate_data_key(algorithm, encryption_context) + _LOGGER.debug("encryption data generated with master key: %s", data_encryption_key.key_provider) for master_key in master_keys: # Don't re-encrypt the encryption data key; we already have the ciphertext if master_key is primary_master_key: encrypted_data_encryption_key = EncryptedDataKey( - key_provider=data_encryption_key.key_provider, - encrypted_data_key=data_encryption_key.encrypted_data_key, + key_provider=data_encryption_key.key_provider, encrypted_data_key=data_encryption_key.encrypted_data_key ) encrypted_data_keys.add(encrypted_data_encryption_key) continue encrypted_key = master_key.encrypt_data_key( - data_key=data_encryption_key, - algorithm=algorithm, - encryption_context=encryption_context, + data_key=data_encryption_key, algorithm=algorithm, encryption_context=encryption_context ) encrypted_data_keys.add(encrypted_key) - _LOGGER.debug( - "encryption key encrypted with master key: %s", master_key.key_provider - ) + _LOGGER.debug("encryption key encrypted with master key: %s", master_key.key_provider) return data_encryption_key, encrypted_data_keys diff --git a/src/aws_encryption_sdk/key_providers/base.py b/src/aws_encryption_sdk/key_providers/base.py index b272d8836..9554eb44d 100644 --- a/src/aws_encryption_sdk/key_providers/base.py +++ b/src/aws_encryption_sdk/key_providers/base.py @@ -71,12 +71,8 @@ def __new__(cls, **kwargs): """ instance = super(MasterKeyProvider, cls).__new__(cls) config = kwargs.pop("config", None) - if not isinstance( - config, instance._config_class - ): # pylint: disable=protected-access - config = instance._config_class( - **kwargs - ) # pylint: disable=protected-access + if not isinstance(config, instance._config_class): # pylint: disable=protected-access + config = instance._config_class(**kwargs) # pylint: disable=protected-access instance.config = config #: Index matching key IDs to existing MasterKey objects. instance._encrypt_key_index = {} # pylint: disable=protected-access @@ -92,15 +88,11 @@ def __repr__(self): name=self.__class__.__name__, kwargs=", ".join( "{key}={value}".format(key=key, value=value) - for key, value in sorted( - attr.asdict(self.config, recurse=True).items(), key=lambda x: x[0] - ) + for key, value in sorted(attr.asdict(self.config, recurse=True).items(), key=lambda x: x[0]) ), ) - def master_keys_for_encryption( - self, encryption_context, plaintext_rostream, plaintext_length=None - ): + def master_keys_for_encryption(self, encryption_context, plaintext_rostream, plaintext_length=None): """Returns a set containing all Master Keys added to this Provider, or any member Providers, which should be used to encrypt data keys for the specified data. @@ -133,9 +125,7 @@ def master_keys_for_encryption( primary = _primary master_keys.extend(_master_keys) if not master_keys: - raise MasterKeyProviderError( - "No Master Keys available from Master Key Provider" - ) + raise MasterKeyProviderError("No Master Keys available from Master Key Provider") return primary, master_keys @abc.abstractmethod @@ -240,38 +230,26 @@ def decrypt_data_key(self, encrypted_data_key, algorithm, encryption_context): _LOGGER.debug("starting decrypt data key attempt") for member in [self] + self._members: if member.provider_id == encrypted_data_key.key_provider.provider_id: - _LOGGER.debug( - "attempting to locate master key from key provider: %s", - member.provider_id, - ) + _LOGGER.debug("attempting to locate master key from key provider: %s", member.provider_id) if isinstance(member, MasterKey): _LOGGER.debug("using existing master key") master_key = member elif self.vend_masterkey_on_decrypt: try: - _LOGGER.debug( - "attempting to add master key: %s", - encrypted_data_key.key_provider.key_info, - ) - master_key = member.master_key_for_decrypt( - encrypted_data_key.key_provider.key_info - ) + _LOGGER.debug("attempting to add master key: %s", encrypted_data_key.key_provider.key_info) + master_key = member.master_key_for_decrypt(encrypted_data_key.key_provider.key_info) except InvalidKeyIdError: _LOGGER.debug( - "master key %s not available in provider", - encrypted_data_key.key_provider.key_info, + "master key %s not available in provider", encrypted_data_key.key_provider.key_info ) continue else: continue try: _LOGGER.debug( - "attempting to decrypt data key with provider %s", - encrypted_data_key.key_provider.key_info, - ) - data_key = master_key.decrypt_data_key( - encrypted_data_key, algorithm, encryption_context + "attempting to decrypt data key with provider %s", encrypted_data_key.key_provider.key_info ) + data_key = master_key.decrypt_data_key(encrypted_data_key, algorithm, encryption_context) except (IncorrectMasterKeyError, DecryptKeyError) as error: _LOGGER.debug( "%s raised when attempting to decrypt data key with master key %s", @@ -284,9 +262,7 @@ def decrypt_data_key(self, encrypted_data_key, algorithm, encryption_context): raise DecryptKeyError("Unable to decrypt data key") return data_key - def decrypt_data_key_from_list( - self, encrypted_data_keys, algorithm, encryption_context - ): + def decrypt_data_key_from_list(self, encrypted_data_keys, algorithm, encryption_context): """Receives a list of encrypted data keys and returns the first one which this provider is able to decrypt. :param encrypted_data_keys: List of encrypted data keys @@ -301,9 +277,7 @@ def decrypt_data_key_from_list( data_key = None for encrypted_data_key in encrypted_data_keys: try: - data_key = self.decrypt_data_key( - encrypted_data_key, algorithm, encryption_context - ) + data_key = self.decrypt_data_key(encrypted_data_key, algorithm, encryption_context) # MasterKeyProvider.decrypt_data_key throws DecryptKeyError # but MasterKey.decrypt_data_key throws IncorrectMasterKeyError except (DecryptKeyError, IncorrectMasterKeyError): @@ -322,18 +296,12 @@ class MasterKeyConfig(object): :param bytes key_id: Key ID for Master Key """ - key_id = attr.ib( - hash=True, - validator=attr.validators.instance_of((six.string_types, bytes)), - converter=to_bytes, - ) + key_id = attr.ib(hash=True, validator=attr.validators.instance_of((six.string_types, bytes)), converter=to_bytes) def __attrs_post_init__(self): """Verify that children of this class define a "provider_id" attribute.""" if not hasattr(self, "provider_id"): - raise TypeError( - 'Instances of MasterKeyConfig must have a "provider_id" attribute defined.' - ) + raise TypeError('Instances of MasterKeyConfig must have a "provider_id" attribute defined.') @six.add_metaclass(abc.ABCMeta) @@ -350,9 +318,7 @@ def __new__(cls, **kwargs): instance = super(MasterKey, cls).__new__(cls, **kwargs) if not hasattr(instance.config, "provider_id"): - raise TypeError( - 'MasterKey config classes must have a "provider_id" attribute defined.' - ) + raise TypeError('MasterKey config classes must have a "provider_id" attribute defined.') if instance.config.provider_id is not None: # Only allow override if provider_id is NOT set to non-None for the class @@ -361,14 +327,11 @@ def __new__(cls, **kwargs): elif instance.provider_id != instance.config.provider_id: raise ConfigMismatchError( "Config provider_id does not match MasterKey provider_id: {config} != {instance}".format( - config=instance.config.provider_id, - instance=instance.provider_id, + config=instance.config.provider_id, instance=instance.provider_id ) ) instance.key_id = instance.config.key_id - instance._encrypt_key_index = { - instance.key_id: instance - } # pylint: disable=protected-access + instance._encrypt_key_index = {instance.key_id: instance} # pylint: disable=protected-access # We cannot make any general statements about key_info, so specifically enforce that decrypt index is empty. instance._decrypt_key_index = {} # pylint: disable=protected-access instance._members = [instance] # pylint: disable=protected-access @@ -397,9 +360,7 @@ def owns_data_key(self, data_key): return True return False - def master_keys_for_encryption( - self, encryption_context, plaintext_rostream, plaintext_length=None - ): + def master_keys_for_encryption(self, encryption_context, plaintext_rostream, plaintext_length=None): """Returns self and a list containing self, to match the format of output for a Master Key Provider. .. warning:: @@ -455,12 +416,8 @@ def generate_data_key(self, algorithm, encryption_context): :returns: Generated data key :rtype: aws_encryption_sdk.structures.DataKey """ - _LOGGER.info( - "generating data key with encryption context: %s", encryption_context - ) - generated_data_key = self._generate_data_key( - algorithm=algorithm, encryption_context=encryption_context - ) + _LOGGER.info("generating data key with encryption context: %s", encryption_context) + generated_data_key = self._generate_data_key(algorithm=algorithm, encryption_context=encryption_context) aws_encryption_sdk.internal.utils.source_data_key_length_check( source_data_key=generated_data_key, algorithm=algorithm ) @@ -493,14 +450,8 @@ def encrypt_data_key(self, data_key, algorithm, encryption_context): :rtype: aws_encryption_sdk.structures.EncryptedDataKey :raises IncorrectMasterKeyError: if Data Key's key provider does not match this Master Key """ - _LOGGER.info( - "encrypting data key with encryption context: %s", encryption_context - ) - return self._encrypt_data_key( - data_key=data_key, - algorithm=algorithm, - encryption_context=encryption_context, - ) + _LOGGER.info("encrypting data key with encryption context: %s", encryption_context) + return self._encrypt_data_key(data_key=data_key, algorithm=algorithm, encryption_context=encryption_context) @abc.abstractmethod def _encrypt_data_key(self, data_key, algorithm, encryption_context): @@ -532,14 +483,10 @@ def decrypt_data_key(self, encrypted_data_key, algorithm, encryption_context): :rtype: aws_encryption_sdk.structures.DataKey :raises IncorrectMasterKeyError: if Data Key's key provider does not match this Master Key """ - _LOGGER.info( - "decrypting data key with encryption context: %s", encryption_context - ) + _LOGGER.info("decrypting data key with encryption context: %s", encryption_context) self._key_check(encrypted_data_key) decrypted_data_key = self._decrypt_data_key( - encrypted_data_key=encrypted_data_key, - algorithm=algorithm, - encryption_context=encryption_context, + encrypted_data_key=encrypted_data_key, algorithm=algorithm, encryption_context=encryption_context ) aws_encryption_sdk.internal.utils.source_data_key_length_check( source_data_key=decrypted_data_key, algorithm=algorithm diff --git a/src/aws_encryption_sdk/key_providers/kms.py b/src/aws_encryption_sdk/key_providers/kms.py index 3855a329f..8f29cbd0f 100644 --- a/src/aws_encryption_sdk/key_providers/kms.py +++ b/src/aws_encryption_sdk/key_providers/kms.py @@ -21,20 +21,10 @@ import botocore.session from botocore.exceptions import ClientError -from aws_encryption_sdk.exceptions import ( - DecryptKeyError, - EncryptKeyError, - GenerateKeyError, - UnknownRegionError, -) +from aws_encryption_sdk.exceptions import DecryptKeyError, EncryptKeyError, GenerateKeyError, UnknownRegionError from aws_encryption_sdk.identifiers import USER_AGENT_SUFFIX from aws_encryption_sdk.internal.str_ops import to_str -from aws_encryption_sdk.key_providers.base import ( - MasterKey, - MasterKeyConfig, - MasterKeyProvider, - MasterKeyProviderConfig, -) +from aws_encryption_sdk.key_providers.base import MasterKey, MasterKeyConfig, MasterKeyProvider, MasterKeyProviderConfig from aws_encryption_sdk.structures import DataKey, EncryptedDataKey, MasterKeyInfo _LOGGER = logging.getLogger(__name__) @@ -56,9 +46,7 @@ def _region_from_key_id(key_id, default_region=None): except IndexError: if default_region is None: raise UnknownRegionError( - "No default region found and no region determinable from key id: {}".format( - key_id - ) + "No default region found and no region determinable from key id: {}".format(key_id) ) region_name = default_region return region_name @@ -80,16 +68,10 @@ class KMSMasterKeyProviderConfig(MasterKeyProviderConfig): validator=attr.validators.instance_of(botocore.session.Session), ) key_ids = attr.ib( - hash=True, - default=attr.Factory(tuple), - validator=attr.validators.instance_of(tuple), - converter=tuple, + hash=True, default=attr.Factory(tuple), validator=attr.validators.instance_of(tuple), converter=tuple ) region_names = attr.ib( - hash=True, - default=attr.Factory(tuple), - validator=attr.validators.instance_of(tuple), - converter=tuple, + hash=True, default=attr.Factory(tuple), validator=attr.validators.instance_of(tuple), converter=tuple ) @@ -132,16 +114,12 @@ def __init__(self, **kwargs): # pylint: disable=unused-argument def _process_config(self): """Traverses the config and adds master keys and regional clients as needed.""" - self._user_agent_adding_config = botocore.config.Config( - user_agent_extra=USER_AGENT_SUFFIX - ) + self._user_agent_adding_config = botocore.config.Config(user_agent_extra=USER_AGENT_SUFFIX) if self.config.region_names: self.add_regional_clients_from_list(self.config.region_names) self.default_region = self.config.region_names[0] else: - self.default_region = self.config.botocore_session.get_config_variable( - "region" - ) + self.default_region = self.config.botocore_session.get_config_variable("region") if self.default_region is not None: self.add_regional_client(self.default_region) if self.config.key_ids: @@ -160,9 +138,7 @@ def _wrap_client(self, region_name, method, *args, **kwargs): except botocore.exceptions.BotoCoreError: self._regional_clients.pop(region_name) _LOGGER.error( - 'Removing regional client "%s" from cache due to BotoCoreError on %s call', - region_name, - method.__name__, + 'Removing regional client "%s" from cache due to BotoCoreError on %s call', region_name, method.__name__ ) raise @@ -183,9 +159,7 @@ def add_regional_client(self, region_name): :param str region_name: AWS Region ID (ex: us-east-1) """ if region_name not in self._regional_clients: - session = boto3.session.Session( - region_name=region_name, botocore_session=self.config.botocore_session - ) + session = boto3.session.Session(region_name=region_name, botocore_session=self.config.botocore_session) client = session.client("kms", config=self._user_agent_adding_config) self._register_client(client, region_name) self._regional_clients[region_name] = client @@ -216,9 +190,7 @@ def _new_master_key(self, key_id): :raises InvalidKeyIdError: if key_id is not a valid KMS CMK ID to which this key provider has access """ _key_id = to_str(key_id) # KMS client requires str, not bytes - return KMSMasterKey( - config=KMSMasterKeyConfig(key_id=key_id, client=self._client(_key_id)) - ) + return KMSMasterKey(config=KMSMasterKeyConfig(key_id=key_id, client=self._client(_key_id))) @attr.s(hash=True) @@ -232,14 +204,9 @@ class KMSMasterKeyConfig(MasterKeyConfig): """ provider_id = _PROVIDER_ID - client = attr.ib( - hash=True, validator=attr.validators.instance_of(botocore.client.BaseClient) - ) + client = attr.ib(hash=True, validator=attr.validators.instance_of(botocore.client.BaseClient)) grant_tokens = attr.ib( - hash=True, - default=attr.Factory(tuple), - validator=attr.validators.instance_of(tuple), - converter=tuple, + hash=True, default=attr.Factory(tuple), validator=attr.validators.instance_of(tuple), converter=tuple ) @client.default @@ -293,9 +260,7 @@ def _generate_data_key(self, algorithm, encryption_context=None): ciphertext = response["CiphertextBlob"] key_id = response["KeyId"] except (ClientError, KeyError): - error_message = "Master Key {key_id} unable to generate data key".format( - key_id=self._key_id - ) + error_message = "Master Key {key_id} unable to generate data key".format(key_id=self._key_id) _LOGGER.exception(error_message) raise GenerateKeyError(error_message) return DataKey( @@ -327,14 +292,11 @@ def _encrypt_data_key(self, data_key, algorithm, encryption_context=None): ciphertext = response["CiphertextBlob"] key_id = response["KeyId"] except (ClientError, KeyError): - error_message = "Master Key {key_id} unable to encrypt data key".format( - key_id=self._key_id - ) + error_message = "Master Key {key_id} unable to encrypt data key".format(key_id=self._key_id) _LOGGER.exception(error_message) raise EncryptKeyError(error_message) return EncryptedDataKey( - key_provider=MasterKeyInfo(provider_id=self.provider_id, key_info=key_id), - encrypted_data_key=ciphertext, + key_provider=MasterKeyInfo(provider_id=self.provider_id, key_info=key_id), encrypted_data_key=ciphertext ) def _decrypt_data_key(self, encrypted_data_key, algorithm, encryption_context=None): @@ -358,13 +320,9 @@ def _decrypt_data_key(self, encrypted_data_key, algorithm, encryption_context=No response = self.config.client.decrypt(**kms_params) plaintext = response["Plaintext"] except (ClientError, KeyError): - error_message = "Master Key {key_id} unable to decrypt data key".format( - key_id=self._key_id - ) + error_message = "Master Key {key_id} unable to decrypt data key".format(key_id=self._key_id) _LOGGER.exception(error_message) raise DecryptKeyError(error_message) return DataKey( - key_provider=self.key_provider, - data_key=plaintext, - encrypted_data_key=encrypted_data_key.encrypted_data_key, + key_provider=self.key_provider, data_key=plaintext, encrypted_data_key=encrypted_data_key.encrypted_data_key ) diff --git a/src/aws_encryption_sdk/key_providers/raw.py b/src/aws_encryption_sdk/key_providers/raw.py index e8769457a..ca6c690d6 100644 --- a/src/aws_encryption_sdk/key_providers/raw.py +++ b/src/aws_encryption_sdk/key_providers/raw.py @@ -22,12 +22,7 @@ import aws_encryption_sdk.internal.formatting.serialize from aws_encryption_sdk.identifiers import EncryptionType from aws_encryption_sdk.internal.crypto.wrapping_keys import WrappingKey -from aws_encryption_sdk.key_providers.base import ( - MasterKey, - MasterKeyConfig, - MasterKeyProvider, - MasterKeyProviderConfig, -) +from aws_encryption_sdk.key_providers.base import MasterKey, MasterKeyConfig, MasterKeyProvider, MasterKeyProviderConfig from aws_encryption_sdk.structures import DataKey, RawDataKey _LOGGER = logging.getLogger(__name__) @@ -48,9 +43,7 @@ class RawMasterKeyConfig(MasterKeyConfig): validator=attr.validators.instance_of((six.string_types, bytes)), converter=aws_encryption_sdk.internal.str_ops.to_str, ) - wrapping_key = attr.ib( - hash=True, validator=attr.validators.instance_of(WrappingKey) - ) + wrapping_key = attr.ib(hash=True, validator=attr.validators.instance_of(WrappingKey)) class RawMasterKey(MasterKey): @@ -91,18 +84,13 @@ def owns_data_key(self, data_key): """ expected_key_info_len = -1 if ( - self.config.wrapping_key.wrapping_algorithm.encryption_type - is EncryptionType.ASYMMETRIC + self.config.wrapping_key.wrapping_algorithm.encryption_type is EncryptionType.ASYMMETRIC and data_key.key_provider == self.key_provider ): return True - elif ( - self.config.wrapping_key.wrapping_algorithm.encryption_type - is EncryptionType.SYMMETRIC - ): + elif self.config.wrapping_key.wrapping_algorithm.encryption_type is EncryptionType.SYMMETRIC: expected_key_info_len = ( - len(self._key_info_prefix) - + self.config.wrapping_key.wrapping_algorithm.algorithm.iv_len + len(self._key_info_prefix) + self.config.wrapping_key.wrapping_algorithm.algorithm.iv_len ) if ( data_key.key_provider.provider_id == self.provider_id @@ -135,9 +123,7 @@ def _generate_data_key(self, algorithm, encryption_context): """ plaintext_data_key = os.urandom(algorithm.kdf_input_len) encrypted_data_key = self._encrypt_data_key( - data_key=RawDataKey( - key_provider=self.key_provider, data_key=plaintext_data_key - ), + data_key=RawDataKey(key_provider=self.key_provider, data_key=plaintext_data_key), algorithm=algorithm, encryption_context=encryption_context, ) @@ -192,8 +178,7 @@ def _decrypt_data_key(self, encrypted_data_key, algorithm, encryption_context): ) # EncryptedData to raw key string plaintext_data_key = self.config.wrapping_key.decrypt( - encrypted_wrapped_data_key=encrypted_wrapped_key, - encryption_context=encryption_context, + encrypted_wrapped_data_key=encrypted_wrapped_key, encryption_context=encryption_context ) # Raw key string to DataKey return DataKey( @@ -237,7 +222,5 @@ def _new_master_key(self, key_id): _LOGGER.debug("Retrieving wrapping key with id: %s", key_id) wrapping_key = self._get_raw_key(key_id) return self._master_key_class( - config=RawMasterKeyConfig( - key_id=key_id, provider_id=self.provider_id, wrapping_key=wrapping_key - ) + config=RawMasterKeyConfig(key_id=key_id, provider_id=self.provider_id, wrapping_key=wrapping_key) ) diff --git a/src/aws_encryption_sdk/materials_managers/__init__.py b/src/aws_encryption_sdk/materials_managers/__init__.py index daad892ff..a086feb7c 100644 --- a/src/aws_encryption_sdk/materials_managers/__init__.py +++ b/src/aws_encryption_sdk/materials_managers/__init__.py @@ -43,18 +43,11 @@ class EncryptionMaterialsRequest(object): encryption_context = attr.ib(validator=attr.validators.instance_of(dict)) frame_length = attr.ib(validator=attr.validators.instance_of(six.integer_types)) plaintext_rostream = attr.ib( - default=None, - validator=attr.validators.optional(attr.validators.instance_of(ROStream)), - ) - algorithm = attr.ib( - default=None, - validator=attr.validators.optional(attr.validators.instance_of(AlgorithmSuite)), + default=None, validator=attr.validators.optional(attr.validators.instance_of(ROStream)) ) + algorithm = attr.ib(default=None, validator=attr.validators.optional(attr.validators.instance_of(AlgorithmSuite))) plaintext_length = attr.ib( - default=None, - validator=attr.validators.optional( - attr.validators.instance_of(six.integer_types) - ), + default=None, validator=attr.validators.optional(attr.validators.instance_of(six.integer_types)) ) @@ -78,10 +71,7 @@ class EncryptionMaterials(object): data_encryption_key = attr.ib(validator=attr.validators.instance_of(DataKey)) encrypted_data_keys = attr.ib(validator=attr.validators.instance_of(set)) encryption_context = attr.ib(validator=attr.validators.instance_of(dict)) - signing_key = attr.ib( - default=None, - validator=attr.validators.optional(attr.validators.instance_of(bytes)), - ) + signing_key = attr.ib(default=None, validator=attr.validators.optional(attr.validators.instance_of(bytes))) @attr.s(hash=False) @@ -114,7 +104,4 @@ class DecryptionMaterials(object): """ data_key = attr.ib(validator=attr.validators.instance_of(DataKey)) - verification_key = attr.ib( - default=None, - validator=attr.validators.optional(attr.validators.instance_of(bytes)), - ) + verification_key = attr.ib(default=None, validator=attr.validators.optional(attr.validators.instance_of(bytes))) diff --git a/src/aws_encryption_sdk/materials_managers/default.py b/src/aws_encryption_sdk/materials_managers/default.py index 42d5ee7a6..402a853ce 100644 --- a/src/aws_encryption_sdk/materials_managers/default.py +++ b/src/aws_encryption_sdk/materials_managers/default.py @@ -39,13 +39,9 @@ class DefaultCryptoMaterialsManager(CryptoMaterialsManager): """ algorithm = ALGORITHM - master_key_provider = attr.ib( - validator=attr.validators.instance_of(MasterKeyProvider) - ) + master_key_provider = attr.ib(validator=attr.validators.instance_of(MasterKeyProvider)) - def _generate_signing_key_and_update_encryption_context( - self, algorithm, encryption_context - ): + def _generate_signing_key_and_update_encryption_context(self, algorithm, encryption_context): """Generates a signing key based on the provided algorithm. :param algorithm: Algorithm suite for which to generate signing key @@ -58,9 +54,7 @@ def _generate_signing_key_and_update_encryption_context( if algorithm.signing_algorithm_info is None: return None - signer = Signer( - algorithm=algorithm, key=generate_ecc_signing_key(algorithm=algorithm) - ) + signer = Signer(algorithm=algorithm, key=generate_ecc_signing_key(algorithm=algorithm)) encryption_context[ENCODED_SIGNER_KEY] = to_str(signer.encoded_public_key()) return signer.key_bytes() @@ -75,14 +69,10 @@ def get_encryption_materials(self, request): :raises MasterKeyProviderError: if the primary master key provided by the underlying master key provider is not included in the full set of master keys provided by that provider """ - algorithm = ( - request.algorithm if request.algorithm is not None else self.algorithm - ) + algorithm = request.algorithm if request.algorithm is not None else self.algorithm encryption_context = request.encryption_context.copy() - signing_key = self._generate_signing_key_and_update_encryption_context( - algorithm, encryption_context - ) + signing_key = self._generate_signing_key_and_update_encryption_context(algorithm, encryption_context) primary_master_key, master_keys = self.master_key_provider.master_keys_for_encryption( encryption_context=encryption_context, @@ -90,13 +80,9 @@ def get_encryption_materials(self, request): plaintext_length=request.plaintext_length, ) if not master_keys: - raise MasterKeyProviderError( - "No Master Keys available from Master Key Provider" - ) + raise MasterKeyProviderError("No Master Keys available from Master Key Provider") if primary_master_key not in master_keys: - raise MasterKeyProviderError( - "Primary Master Key not in provided Master Keys" - ) + raise MasterKeyProviderError("Primary Master Key not in provided Master Keys") data_encryption_key, encrypted_data_keys = prepare_data_keys( primary_master_key=primary_master_key, @@ -115,9 +101,7 @@ def get_encryption_materials(self, request): signing_key=signing_key, ) - def _load_verification_key_from_encryption_context( - self, algorithm, encryption_context - ): + def _load_verification_key_from_encryption_context(self, algorithm, encryption_context): """Loads the verification key from the encryption context if used by algorithm suite. :param algorithm: Algorithm suite for which to generate signing key @@ -129,24 +113,15 @@ def _load_verification_key_from_encryption_context( """ encoded_verification_key = encryption_context.get(ENCODED_SIGNER_KEY, None) - if ( - algorithm.signing_algorithm_info is not None - and encoded_verification_key is None - ): - raise SerializationError( - "No signature verification key found in header for signed algorithm." - ) + if algorithm.signing_algorithm_info is not None and encoded_verification_key is None: + raise SerializationError("No signature verification key found in header for signed algorithm.") if algorithm.signing_algorithm_info is None: if encoded_verification_key is not None: - raise SerializationError( - "Signature verification key found in header for non-signed algorithm." - ) + raise SerializationError("Signature verification key found in header for non-signed algorithm.") return None - verifier = Verifier.from_encoded_point( - algorithm=algorithm, encoded_point=encoded_verification_key - ) + verifier = Verifier.from_encoded_point(algorithm=algorithm, encoded_point=encoded_verification_key) return verifier.key_bytes() def decrypt_materials(self, request): diff --git a/src/aws_encryption_sdk/streaming_client.py b/src/aws_encryption_sdk/streaming_client.py index 0ed383768..b04750852 100644 --- a/src/aws_encryption_sdk/streaming_client.py +++ b/src/aws_encryption_sdk/streaming_client.py @@ -34,13 +34,7 @@ from aws_encryption_sdk.internal.crypto.data_keys import derive_data_encryption_key from aws_encryption_sdk.internal.crypto.encryption import Decryptor, Encryptor, decrypt from aws_encryption_sdk.internal.crypto.iv import non_framed_body_iv -from aws_encryption_sdk.internal.defaults import ( - FRAME_LENGTH, - LINE_LENGTH, - MAX_NON_FRAMED_SIZE, - TYPE, - VERSION, -) +from aws_encryption_sdk.internal.defaults import FRAME_LENGTH, LINE_LENGTH, MAX_NON_FRAMED_SIZE, TYPE, VERSION from aws_encryption_sdk.internal.formatting.deserialize import ( deserialize_footer, deserialize_frame, @@ -50,9 +44,7 @@ deserialize_tag, validate_header, ) -from aws_encryption_sdk.internal.formatting.encryption_context import ( - assemble_content_aad, -) +from aws_encryption_sdk.internal.formatting.encryption_context import assemble_content_aad from aws_encryption_sdk.internal.formatting.serialize import ( serialize_footer, serialize_frame, @@ -62,10 +54,7 @@ serialize_non_framed_open, ) from aws_encryption_sdk.key_providers.base import MasterKeyProvider -from aws_encryption_sdk.materials_managers import ( - DecryptionMaterialsRequest, - EncryptionMaterialsRequest, -) +from aws_encryption_sdk.materials_managers import DecryptionMaterialsRequest, EncryptionMaterialsRequest from aws_encryption_sdk.materials_managers.base import CryptoMaterialsManager from aws_encryption_sdk.materials_managers.default import DefaultCryptoMaterialsManager from aws_encryption_sdk.structures import MessageHeader @@ -93,53 +82,29 @@ class _ClientConfig(object): will attempt to seek() to the end of the stream and tell() to find the length of source data. """ - source = attr.ib( - hash=True, converter=aws_encryption_sdk.internal.utils.prep_stream_data - ) + source = attr.ib(hash=True, converter=aws_encryption_sdk.internal.utils.prep_stream_data) materials_manager = attr.ib( - hash=True, - default=None, - validator=attr.validators.optional( - attr.validators.instance_of(CryptoMaterialsManager) - ), + hash=True, default=None, validator=attr.validators.optional(attr.validators.instance_of(CryptoMaterialsManager)) ) key_provider = attr.ib( - hash=True, - default=None, - validator=attr.validators.optional( - attr.validators.instance_of(MasterKeyProvider) - ), + hash=True, default=None, validator=attr.validators.optional(attr.validators.instance_of(MasterKeyProvider)) ) source_length = attr.ib( - hash=True, - default=None, - validator=attr.validators.optional( - attr.validators.instance_of(six.integer_types) - ), + hash=True, default=None, validator=attr.validators.optional(attr.validators.instance_of(six.integer_types)) ) line_length = attr.ib( - hash=True, - default=LINE_LENGTH, - validator=attr.validators.instance_of(six.integer_types), + hash=True, default=LINE_LENGTH, validator=attr.validators.instance_of(six.integer_types) ) # DEPRECATED: Value is no longer configurable here. Parameter left here to avoid breaking consumers. def __attrs_post_init__(self): """Normalize inputs to crypto material manager.""" - both_cmm_and_mkp_defined = ( - self.materials_manager is not None and self.key_provider is not None - ) - neither_cmm_nor_mkp_defined = ( - self.materials_manager is None and self.key_provider is None - ) + both_cmm_and_mkp_defined = self.materials_manager is not None and self.key_provider is not None + neither_cmm_nor_mkp_defined = self.materials_manager is None and self.key_provider is None if both_cmm_and_mkp_defined or neither_cmm_nor_mkp_defined: - raise TypeError( - "Exactly one of materials_manager or key_provider must be provided" - ) + raise TypeError("Exactly one of materials_manager or key_provider must be provided") if self.materials_manager is None: - self.materials_manager = DefaultCryptoMaterialsManager( - master_key_provider=self.key_provider - ) + self.materials_manager = DefaultCryptoMaterialsManager(master_key_provider=self.key_provider) class _EncryptionStream(io.IOBase): @@ -192,21 +157,15 @@ def __new__(cls, **kwargs): instance = super(_EncryptionStream, cls).__new__(cls) config = kwargs.pop("config", None) - if not isinstance( - config, instance._config_class - ): # pylint: disable=protected-access - config = instance._config_class( - **kwargs - ) # pylint: disable=protected-access + if not isinstance(config, instance._config_class): # pylint: disable=protected-access + config = instance._config_class(**kwargs) # pylint: disable=protected-access instance.config = config instance.bytes_read = 0 instance.output_buffer = b"" instance._message_prepped = False # pylint: disable=protected-access instance.source_stream = instance.config.source - instance._stream_length = ( - instance.config.source_length - ) # pylint: disable=protected-access + instance._stream_length = instance.config.source_length # pylint: disable=protected-access return instance @@ -385,20 +344,12 @@ class EncryptorConfig(_ClientConfig): validator=attr.validators.instance_of(dict), ) algorithm = attr.ib( - hash=True, - default=None, - validator=attr.validators.optional(attr.validators.instance_of(AlgorithmSuite)), - ) - frame_length = attr.ib( - hash=True, - default=FRAME_LENGTH, - validator=attr.validators.instance_of(six.integer_types), + hash=True, default=None, validator=attr.validators.optional(attr.validators.instance_of(AlgorithmSuite)) ) + frame_length = attr.ib(hash=True, default=FRAME_LENGTH, validator=attr.validators.instance_of(six.integer_types)) -class StreamEncryptor( - _EncryptionStream -): # pylint: disable=too-many-instance-attributes +class StreamEncryptor(_EncryptionStream): # pylint: disable=too-many-instance-attributes """Provides a streaming encryptor for encrypting a stream source. Behaves as a standard file-like object. @@ -440,20 +391,15 @@ class StreamEncryptor( _config_class = EncryptorConfig - def __init__( - self, **kwargs - ): # pylint: disable=unused-argument,super-init-not-called + def __init__(self, **kwargs): # pylint: disable=unused-argument,super-init-not-called """Prepares necessary initial values.""" self.sequence_number = 1 - self.content_type = aws_encryption_sdk.internal.utils.content_type( - self.config.frame_length - ) + self.content_type = aws_encryption_sdk.internal.utils.content_type(self.config.frame_length) self._bytes_encrypted = 0 if self.config.frame_length == 0 and ( - self.config.source_length is not None - and self.config.source_length > MAX_NON_FRAMED_SIZE + self.config.source_length is not None and self.config.source_length > MAX_NON_FRAMED_SIZE ): raise SerializationError("Source too large for non-framed message") @@ -485,41 +431,31 @@ def _prep_message(self): algorithm=self.config.algorithm, encryption_context=self.config.encryption_context.copy(), frame_length=self.config.frame_length, - plaintext_rostream=aws_encryption_sdk.internal.utils.streams.ROStream( - self.source_stream - ), + plaintext_rostream=aws_encryption_sdk.internal.utils.streams.ROStream(self.source_stream), plaintext_length=plaintext_length, ) self._encryption_materials = self.config.materials_manager.get_encryption_materials( request=encryption_materials_request ) - if ( - self.config.algorithm is not None - and self._encryption_materials.algorithm != self.config.algorithm - ): + if self.config.algorithm is not None and self._encryption_materials.algorithm != self.config.algorithm: raise ActionNotAllowedError( ( "Cryptographic materials manager provided algorithm suite" " differs from algorithm suite in request.\n" "Required: {requested}\n" "Provided: {provided}" - ).format( - requested=self.config.algorithm, - provided=self._encryption_materials.algorithm, - ) + ).format(requested=self.config.algorithm, provided=self._encryption_materials.algorithm) ) if self._encryption_materials.signing_key is None: self.signer = None else: self.signer = Signer.from_key_bytes( - algorithm=self._encryption_materials.algorithm, - key_bytes=self._encryption_materials.signing_key, + algorithm=self._encryption_materials.algorithm, key_bytes=self._encryption_materials.signing_key ) aws_encryption_sdk.internal.utils.validate_frame_length( - frame_length=self.config.frame_length, - algorithm=self._encryption_materials.algorithm, + frame_length=self.config.frame_length, algorithm=self._encryption_materials.algorithm ) self._derived_data_key = derive_data_encryption_key( @@ -609,20 +545,14 @@ def _read_bytes_to_non_framed_body(self, b): self.signer.update(ciphertext) if len(plaintext) < b: - _LOGGER.debug( - "Closing encryptor after receiving only %d bytes of %d bytes requested", - plaintext_length, - b, - ) + _LOGGER.debug("Closing encryptor after receiving only %d bytes of %d bytes requested", plaintext_length, b) closing = self.encryptor.finalize() if self.signer is not None: self.signer.update(closing) - closing += serialize_non_framed_close( - tag=self.encryptor.tag, signer=self.signer - ) + closing += serialize_non_framed_close(tag=self.encryptor.tag, signer=self.signer) if self.signer is not None: closing += serialize_footer(self.signer) @@ -644,11 +574,7 @@ def _read_bytes_to_framed_body(self, b): if b > 0: _frames_to_read = math.ceil(b / float(self.config.frame_length)) b = int(_frames_to_read * self.config.frame_length) - _LOGGER.debug( - "%d bytes requested; reading %d bytes after normalizing to frame length", - _b, - b, - ) + _LOGGER.debug("%d bytes requested; reading %d bytes after normalizing to frame length", _b, b) plaintext = self.source_stream.read(b) plaintext_length = len(plaintext) @@ -670,9 +596,7 @@ def _read_bytes_to_framed_body(self, b): or (finalize and not final_frame_written) ): current_plaintext_length = len(plaintext) - is_final_frame = ( - finalize and current_plaintext_length < self.config.frame_length - ) + is_final_frame = finalize and current_plaintext_length < self.config.frame_length bytes_in_frame = min(current_plaintext_length, self.config.frame_length) _LOGGER.debug( "Writing %d bytes into%s frame %d", @@ -708,9 +632,7 @@ def _read_bytes(self, b): :param int b: Number of bytes to read :raises NotSupportedError: if content type is not supported """ - _LOGGER.debug( - "%d bytes requested from stream with content type: %s", b, self.content_type - ) + _LOGGER.debug("%d bytes requested from stream with content type: %s", b, self.content_type) if 0 <= b <= len(self.output_buffer) or self.__message_complete: _LOGGER.debug("No need to read from source stream or source stream closed") return @@ -731,8 +653,7 @@ def _read_bytes(self, b): if self._bytes_encrypted > self.config.source_length: raise CustomMaximumValueExceeded( "Bytes encrypted has exceeded stated source length estimate:\n{actual:d} > {estimated:d}".format( - actual=self._bytes_encrypted, - estimated=self.config.source_length, + actual=self._bytes_encrypted, estimated=self.config.source_length ) ) @@ -765,17 +686,11 @@ class DecryptorConfig(_ClientConfig): """ max_body_length = attr.ib( - hash=True, - default=None, - validator=attr.validators.optional( - attr.validators.instance_of(six.integer_types) - ), + hash=True, default=None, validator=attr.validators.optional(attr.validators.instance_of(six.integer_types)) ) -class StreamDecryptor( - _EncryptionStream -): # pylint: disable=too-many-instance-attributes +class StreamDecryptor(_EncryptionStream): # pylint: disable=too-many-instance-attributes """Provides a streaming encryptor for encrypting a stream source. Behaves as a standard file-like object. @@ -808,9 +723,7 @@ class StreamDecryptor( _config_class = DecryptorConfig - def __init__( - self, **kwargs - ): # pylint: disable=unused-argument,super-init-not-called + def __init__(self, **kwargs): # pylint: disable=unused-argument,super-init-not-called """Prepares necessary initial values.""" self.last_sequence_number = 0 self.__unframed_bytes_read = 0 @@ -849,35 +762,23 @@ def _read_header(self): algorithm=header.algorithm, encryption_context=header.encryption_context, ) - decryption_materials = self.config.materials_manager.decrypt_materials( - request=decrypt_materials_request - ) + decryption_materials = self.config.materials_manager.decrypt_materials(request=decrypt_materials_request) if decryption_materials.verification_key is None: self.verifier = None else: self.verifier = Verifier.from_key_bytes( - algorithm=header.algorithm, - key_bytes=decryption_materials.verification_key, + algorithm=header.algorithm, key_bytes=decryption_materials.verification_key ) if self.verifier is not None: self.verifier.update(raw_header) header_auth = deserialize_header_auth( - stream=self.source_stream, - algorithm=header.algorithm, - verifier=self.verifier, + stream=self.source_stream, algorithm=header.algorithm, verifier=self.verifier ) self._derived_data_key = derive_data_encryption_key( - source_key=decryption_materials.data_key.data_key, - algorithm=header.algorithm, - message_id=header.message_id, - ) - validate_header( - header=header, - header_auth=header_auth, - raw_header=raw_header, - data_key=self._derived_data_key, + source_key=decryption_materials.data_key.data_key, algorithm=header.algorithm, message_id=header.message_id ) + validate_header(header=header, header_auth=header_auth, raw_header=raw_header, data_key=self._derived_data_key) return header, header_auth def _prep_non_framed(self): @@ -886,10 +787,7 @@ def _prep_non_framed(self): stream=self.source_stream, header=self._header, verifier=self.verifier ) - if ( - self.config.max_body_length is not None - and self.body_length > self.config.max_body_length - ): + if self.config.max_body_length is not None and self.body_length > self.config.max_body_length: raise CustomMaximumValueExceeded( "Non-framed message content length found larger than custom value: {found:d} > {custom:d}".format( found=self.body_length, custom=self.config.max_body_length @@ -916,16 +814,12 @@ def _read_bytes_from_non_framed_body(self, b): ciphertext = self.source_stream.read(bytes_to_read) if len(self.output_buffer) + len(ciphertext) < self.body_length: - raise SerializationError( - "Total message body contents less than specified in body description" - ) + raise SerializationError("Total message body contents less than specified in body description") if self.verifier is not None: self.verifier.update(ciphertext) - tag = deserialize_tag( - stream=self.source_stream, header=self._header, verifier=self.verifier - ) + tag = deserialize_tag(stream=self.source_stream, header=self._header, verifier=self.verifier) aad_content_string = aws_encryption_sdk.internal.utils.get_aad_content_string( content_type=self._header.content_type, is_final_frame=True @@ -947,9 +841,7 @@ def _read_bytes_from_non_framed_body(self, b): plaintext = self.decryptor.update(ciphertext) plaintext += self.decryptor.finalize() - self.footer = deserialize_footer( - stream=self.source_stream, verifier=self.verifier - ) + self.footer = deserialize_footer(stream=self.source_stream, verifier=self.verifier) return plaintext def _read_bytes_from_framed_body(self, b): @@ -972,8 +864,7 @@ def _read_bytes_from_framed_body(self, b): raise SerializationError("Malformed message: frames out of order") self.last_sequence_number += 1 aad_content_string = aws_encryption_sdk.internal.utils.get_aad_content_string( - content_type=self._header.content_type, - is_final_frame=frame_data.final_frame, + content_type=self._header.content_type, is_final_frame=frame_data.final_frame ) associated_data = assemble_content_aad( message_id=self._header.message_id, @@ -991,9 +882,7 @@ def _read_bytes_from_framed_body(self, b): _LOGGER.debug("bytes collected: %d", plaintext_length) if final_frame: _LOGGER.debug("Reading footer") - self.footer = deserialize_footer( - stream=self.source_stream, verifier=self.verifier - ) + self.footer = deserialize_footer(stream=self.source_stream, verifier=self.verifier) return plaintext @@ -1009,11 +898,7 @@ def _read_bytes(self, b): buffer_length = len(self.output_buffer) if 0 <= b <= buffer_length: - _LOGGER.debug( - "%d bytes requested less than or equal to current output buffer size %d", - b, - buffer_length, - ) + _LOGGER.debug("%d bytes requested less than or equal to current output buffer size %d", b, buffer_length) return if self._header.content_type == ContentType.FRAMED_DATA: diff --git a/src/aws_encryption_sdk/structures.py b/src/aws_encryption_sdk/structures.py index f26ad70ee..577e6fd9c 100644 --- a/src/aws_encryption_sdk/structures.py +++ b/src/aws_encryption_sdk/structures.py @@ -40,41 +40,17 @@ class MessageHeader(object): """ version = attr.ib( - hash=True, - validator=attr.validators.instance_of( - aws_encryption_sdk.identifiers.SerializationVersion - ), - ) - type = attr.ib( - hash=True, - validator=attr.validators.instance_of( - aws_encryption_sdk.identifiers.ObjectType - ), - ) - algorithm = attr.ib( - hash=True, - validator=attr.validators.instance_of( - aws_encryption_sdk.identifiers.AlgorithmSuite - ), + hash=True, validator=attr.validators.instance_of(aws_encryption_sdk.identifiers.SerializationVersion) ) + type = attr.ib(hash=True, validator=attr.validators.instance_of(aws_encryption_sdk.identifiers.ObjectType)) + algorithm = attr.ib(hash=True, validator=attr.validators.instance_of(aws_encryption_sdk.identifiers.AlgorithmSuite)) message_id = attr.ib(hash=True, validator=attr.validators.instance_of(bytes)) encryption_context = attr.ib(hash=True, validator=attr.validators.instance_of(dict)) encrypted_data_keys = attr.ib(hash=True, validator=attr.validators.instance_of(set)) - content_type = attr.ib( - hash=True, - validator=attr.validators.instance_of( - aws_encryption_sdk.identifiers.ContentType - ), - ) - content_aad_length = attr.ib( - hash=True, validator=attr.validators.instance_of(six.integer_types) - ) - header_iv_length = attr.ib( - hash=True, validator=attr.validators.instance_of(six.integer_types) - ) - frame_length = attr.ib( - hash=True, validator=attr.validators.instance_of(six.integer_types) - ) + content_type = attr.ib(hash=True, validator=attr.validators.instance_of(aws_encryption_sdk.identifiers.ContentType)) + content_aad_length = attr.ib(hash=True, validator=attr.validators.instance_of(six.integer_types)) + header_iv_length = attr.ib(hash=True, validator=attr.validators.instance_of(six.integer_types)) + frame_length = attr.ib(hash=True, validator=attr.validators.instance_of(six.integer_types)) @attr.s(hash=True) @@ -85,16 +61,8 @@ class MasterKeyInfo(object): :param bytes key_info: MasterKey key_info value """ - provider_id = attr.ib( - hash=True, - validator=attr.validators.instance_of((six.string_types, bytes)), - converter=to_str, - ) - key_info = attr.ib( - hash=True, - validator=attr.validators.instance_of((six.string_types, bytes)), - converter=to_bytes, - ) + provider_id = attr.ib(hash=True, validator=attr.validators.instance_of((six.string_types, bytes)), converter=to_str) + key_info = attr.ib(hash=True, validator=attr.validators.instance_of((six.string_types, bytes)), converter=to_bytes) @attr.s(hash=True) @@ -106,12 +74,8 @@ class RawDataKey(object): :param bytes data_key: Plaintext data key """ - key_provider = attr.ib( - hash=True, validator=attr.validators.instance_of(MasterKeyInfo) - ) - data_key = attr.ib( - hash=True, repr=False, validator=attr.validators.instance_of(bytes) - ) + key_provider = attr.ib(hash=True, validator=attr.validators.instance_of(MasterKeyInfo)) + data_key = attr.ib(hash=True, repr=False, validator=attr.validators.instance_of(bytes)) @attr.s(hash=True) @@ -124,15 +88,9 @@ class DataKey(object): :param bytes encrypted_data_key: Encrypted data key """ - key_provider = attr.ib( - hash=True, validator=attr.validators.instance_of(MasterKeyInfo) - ) - data_key = attr.ib( - hash=True, repr=False, validator=attr.validators.instance_of(bytes) - ) - encrypted_data_key = attr.ib( - hash=True, validator=attr.validators.instance_of(bytes) - ) + key_provider = attr.ib(hash=True, validator=attr.validators.instance_of(MasterKeyInfo)) + data_key = attr.ib(hash=True, repr=False, validator=attr.validators.instance_of(bytes)) + encrypted_data_key = attr.ib(hash=True, validator=attr.validators.instance_of(bytes)) @attr.s(hash=True) @@ -144,9 +102,5 @@ class EncryptedDataKey(object): :param bytes encrypted_data_key: Encrypted data key """ - key_provider = attr.ib( - hash=True, validator=attr.validators.instance_of(MasterKeyInfo) - ) - encrypted_data_key = attr.ib( - hash=True, validator=attr.validators.instance_of(bytes) - ) + key_provider = attr.ib(hash=True, validator=attr.validators.instance_of(MasterKeyInfo)) + encrypted_data_key = attr.ib(hash=True, validator=attr.validators.instance_of(bytes)) diff --git a/test/functional/test_f_aws_encryption_sdk_client.py b/test/functional/test_f_aws_encryption_sdk_client.py index b0a367253..6e05ec8ee 100644 --- a/test/functional/test_f_aws_encryption_sdk_client.py +++ b/test/functional/test_f_aws_encryption_sdk_client.py @@ -28,27 +28,15 @@ import aws_encryption_sdk from aws_encryption_sdk import KMSMasterKeyProvider -from aws_encryption_sdk.caches import ( - build_decryption_materials_cache_key, - build_encryption_materials_cache_key, -) +from aws_encryption_sdk.caches import build_decryption_materials_cache_key, build_encryption_materials_cache_key from aws_encryption_sdk.exceptions import CustomMaximumValueExceeded -from aws_encryption_sdk.identifiers import ( - AlgorithmSuite, - EncryptionKeyType, - WrappingAlgorithm, -) +from aws_encryption_sdk.identifiers import AlgorithmSuite, EncryptionKeyType, WrappingAlgorithm from aws_encryption_sdk.internal.crypto.wrapping_keys import WrappingKey from aws_encryption_sdk.internal.defaults import LINE_LENGTH -from aws_encryption_sdk.internal.formatting.encryption_context import ( - serialize_encryption_context, -) +from aws_encryption_sdk.internal.formatting.encryption_context import serialize_encryption_context from aws_encryption_sdk.key_providers.base import MasterKeyProviderConfig from aws_encryption_sdk.key_providers.raw import RawMasterKeyProvider -from aws_encryption_sdk.materials_managers import ( - DecryptionMaterialsRequest, - EncryptionMaterialsRequest, -) +from aws_encryption_sdk.materials_managers import DecryptionMaterialsRequest, EncryptionMaterialsRequest pytestmark = [pytest.mark.functional, pytest.mark.local] @@ -191,9 +179,7 @@ class FakeRawMasterKeyProvider(RawMasterKeyProvider): def _get_raw_key(self, key_id): wrapping_key = VALUES["raw"][key_id][self.config.encryption_key_type] if key_id == b"sym1": - wrapping_key = wrapping_key[ - : self.config.wrapping_algorithm.algorithm.data_key_len - ] + wrapping_key = wrapping_key[: self.config.wrapping_algorithm.algorithm.data_key_len] return WrappingKey( wrapping_algorithm=self.config.wrapping_algorithm, wrapping_key=wrapping_key, @@ -203,16 +189,12 @@ def _get_raw_key(self, key_id): def _mgf1_sha256_supported(): wk = serialization.load_pem_private_key( - data=VALUES["raw"][b"asym1"][EncryptionKeyType.PRIVATE], - password=None, - backend=default_backend(), + data=VALUES["raw"][b"asym1"][EncryptionKeyType.PRIVATE], password=None, backend=default_backend() ) try: wk.public_key().encrypt( plaintext=b"aosdjfoiajfoiaj;foijae;rogijaerg", - padding=padding.OAEP( - mgf=padding.MGF1(hashes.SHA256()), algorithm=hashes.SHA256(), label=None - ), + padding=padding.OAEP(mgf=padding.MGF1(hashes.SHA256()), algorithm=hashes.SHA256(), label=None), ) except cryptography.exceptions.UnsupportedAlgorithm: return False @@ -279,12 +261,7 @@ def test_encrypt_load_header(): header_length += 34 header_length += algorithm.iv_len header_length += algorithm.auth_len - header_length += ( - 6 - + 7 - + len(VALUES["arn"]) - + len(VALUES["data_keys"][algorithm.kdf_input_len]["encrypted"]) - ) + header_length += 6 + 7 + len(VALUES["arn"]) + len(VALUES["data_keys"][algorithm.kdf_input_len]["encrypted"]) with aws_encryption_sdk.stream( mode="e", source=VALUES["plaintext_128"], @@ -306,14 +283,11 @@ def test_encrypt_decrypt_header_only(): key_provider=fake_kms_key_provider(), encryption_context=VALUES["encryption_context"], ) - with aws_encryption_sdk.stream( - mode="d", source=ciphertext, key_provider=fake_kms_key_provider() - ) as decryptor: + with aws_encryption_sdk.stream(mode="d", source=ciphertext, key_provider=fake_kms_key_provider()) as decryptor: decryptor_header = decryptor.header assert decryptor.output_buffer == b"" assert all( - pair in decryptor_header.encryption_context.items() - for pair in encryptor_header.encryption_context.items() + pair in decryptor_header.encryption_context.items() for pair in encryptor_header.encryption_context.items() ) @@ -344,53 +318,25 @@ def test_encrypt_ciphertext_message(frame_length, algorithm, encryption_context) @pytest.mark.parametrize( "wrapping_algorithm, encryption_key_type, decryption_key_type", ( - ( - WrappingAlgorithm.AES_256_GCM_IV12_TAG16_NO_PADDING, - EncryptionKeyType.SYMMETRIC, - EncryptionKeyType.SYMMETRIC, - ), - ( - WrappingAlgorithm.RSA_PKCS1, - EncryptionKeyType.PRIVATE, - EncryptionKeyType.PRIVATE, - ), - ( - WrappingAlgorithm.RSA_PKCS1, - EncryptionKeyType.PUBLIC, - EncryptionKeyType.PRIVATE, - ), - ( - WrappingAlgorithm.RSA_OAEP_SHA1_MGF1, - EncryptionKeyType.PRIVATE, - EncryptionKeyType.PRIVATE, - ), - ( - WrappingAlgorithm.RSA_OAEP_SHA1_MGF1, - EncryptionKeyType.PUBLIC, - EncryptionKeyType.PRIVATE, - ), + (WrappingAlgorithm.AES_256_GCM_IV12_TAG16_NO_PADDING, EncryptionKeyType.SYMMETRIC, EncryptionKeyType.SYMMETRIC), + (WrappingAlgorithm.RSA_PKCS1, EncryptionKeyType.PRIVATE, EncryptionKeyType.PRIVATE), + (WrappingAlgorithm.RSA_PKCS1, EncryptionKeyType.PUBLIC, EncryptionKeyType.PRIVATE), + (WrappingAlgorithm.RSA_OAEP_SHA1_MGF1, EncryptionKeyType.PRIVATE, EncryptionKeyType.PRIVATE), + (WrappingAlgorithm.RSA_OAEP_SHA1_MGF1, EncryptionKeyType.PUBLIC, EncryptionKeyType.PRIVATE), ), ) -def test_encryption_cycle_raw_mkp( - caplog, wrapping_algorithm, encryption_key_type, decryption_key_type -): +def test_encryption_cycle_raw_mkp(caplog, wrapping_algorithm, encryption_key_type, decryption_key_type): caplog.set_level(logging.DEBUG) - encrypting_key_provider = build_fake_raw_key_provider( - wrapping_algorithm, encryption_key_type - ) - decrypting_key_provider = build_fake_raw_key_provider( - wrapping_algorithm, decryption_key_type - ) + encrypting_key_provider = build_fake_raw_key_provider(wrapping_algorithm, encryption_key_type) + decrypting_key_provider = build_fake_raw_key_provider(wrapping_algorithm, decryption_key_type) ciphertext, _ = aws_encryption_sdk.encrypt( source=VALUES["plaintext_128"], key_provider=encrypting_key_provider, encryption_context=VALUES["encryption_context"], frame_length=0, ) - plaintext, _ = aws_encryption_sdk.decrypt( - source=ciphertext, key_provider=decrypting_key_provider - ) + plaintext, _ = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=decrypting_key_provider) assert plaintext == VALUES["plaintext_128"] for member in encrypting_key_provider._members: @@ -398,8 +344,7 @@ def test_encryption_cycle_raw_mkp( @pytest.mark.skipif( - not _mgf1_sha256_supported(), - reason="MGF1-SHA2 not supported by this backend: OpenSSL required v1.0.2+", + not _mgf1_sha256_supported(), reason="MGF1-SHA2 not supported by this backend: OpenSSL required v1.0.2+" ) @pytest.mark.parametrize( "wrapping_algorithm", @@ -409,28 +354,18 @@ def test_encryption_cycle_raw_mkp( WrappingAlgorithm.RSA_OAEP_SHA512_MGF1, ), ) -@pytest.mark.parametrize( - "encryption_key_type", (EncryptionKeyType.PUBLIC, EncryptionKeyType.PRIVATE) -) -def test_encryption_cycle_raw_mkp_openssl_102_plus( - wrapping_algorithm, encryption_key_type -): +@pytest.mark.parametrize("encryption_key_type", (EncryptionKeyType.PUBLIC, EncryptionKeyType.PRIVATE)) +def test_encryption_cycle_raw_mkp_openssl_102_plus(wrapping_algorithm, encryption_key_type): decryption_key_type = EncryptionKeyType.PRIVATE - encrypting_key_provider = build_fake_raw_key_provider( - wrapping_algorithm, encryption_key_type - ) - decrypting_key_provider = build_fake_raw_key_provider( - wrapping_algorithm, decryption_key_type - ) + encrypting_key_provider = build_fake_raw_key_provider(wrapping_algorithm, encryption_key_type) + decrypting_key_provider = build_fake_raw_key_provider(wrapping_algorithm, decryption_key_type) ciphertext, _ = aws_encryption_sdk.encrypt( source=VALUES["plaintext_128"], key_provider=encrypting_key_provider, encryption_context=VALUES["encryption_context"], frame_length=0, ) - plaintext, _ = aws_encryption_sdk.decrypt( - source=ciphertext, key_provider=decrypting_key_provider - ) + plaintext, _ = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=decrypting_key_provider) assert plaintext == VALUES["plaintext_128"] @@ -454,9 +389,7 @@ def test_encryption_cycle_oneshot_kms(frame_length, algorithm, encryption_contex encryption_context=encryption_context, ) - plaintext, _ = aws_encryption_sdk.decrypt( - source=ciphertext, key_provider=key_provider - ) + plaintext, _ = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=key_provider) assert plaintext == VALUES["plaintext_128"] * 10 @@ -487,9 +420,7 @@ def test_encryption_cycle_stream_kms(frame_length, algorithm, encryption_context ciphertext = bytes(ciphertext) plaintext = bytearray() - with aws_encryption_sdk.stream( - mode="d", source=io.BytesIO(ciphertext), key_provider=key_provider - ) as decryptor: + with aws_encryption_sdk.stream(mode="d", source=io.BytesIO(ciphertext), key_provider=key_provider) as decryptor: for chunk in decryptor: plaintext.extend(chunk) plaintext = bytes(plaintext) @@ -502,9 +433,7 @@ def test_encryption_cycle_stream_kms(frame_length, algorithm, encryption_context def test_decrypt_legacy_provided_message(): """Tests backwards compatiblity against some legacy provided ciphertext.""" region = "us-west-2" - key_info = ( - "arn:aws:kms:us-west-2:249645522726:key/d1720f4e-953b-44bb-b9dd-fc8b9d0baa5f" - ) + key_info = "arn:aws:kms:us-west-2:249645522726:key/d1720f4e-953b-44bb-b9dd-fc8b9d0baa5f" mock_kms_client = fake_kms_client() mock_kms_client.decrypt.return_value = {"Plaintext": VALUES["provided"]["key"]} mock_kms_key_provider = fake_kms_key_provider() @@ -522,10 +451,7 @@ def test_encryption_cycle_with_caching(): key_provider = fake_kms_key_provider(algorithm.kdf_input_len) cache = aws_encryption_sdk.LocalCryptoMaterialsCache(capacity=10) ccmm = aws_encryption_sdk.CachingCryptoMaterialsManager( - master_key_provider=key_provider, - cache=cache, - max_age=3600.0, - max_messages_encrypted=5, + master_key_provider=key_provider, cache=cache, max_age=3600.0, max_messages_encrypted=5 ) encrypt_kwargs = dict( source=VALUES["plaintext_128"], @@ -585,9 +511,7 @@ def test_encrypt_source_length_enforcement(): plaintext = io.BytesIO(VALUES["plaintext_128"]) with pytest.raises(CustomMaximumValueExceeded) as excinfo: aws_encryption_sdk.encrypt( - source=plaintext, - materials_manager=cmm, - source_length=int(len(VALUES["plaintext_128"]) / 2), + source=plaintext, materials_manager=cmm, source_length=int(len(VALUES["plaintext_128"]) / 2) ) excinfo.match(r"Bytes encrypted has exceeded stated source length estimate:*") @@ -600,9 +524,7 @@ def test_encrypt_source_length_enforcement_legacy_support(): # provider is provided. key_provider = fake_kms_key_provider() aws_encryption_sdk.encrypt( - source=VALUES["plaintext_128"], - key_provider=key_provider, - source_length=int(len(VALUES["plaintext_128"]) / 2), + source=VALUES["plaintext_128"], key_provider=key_provider, source_length=int(len(VALUES["plaintext_128"]) / 2) ) @@ -625,15 +547,11 @@ def test_stream_encryptor_no_seek_input(): plaintext = NoSeekBytesIO(VALUES["plaintext_128"]) ciphertext = io.BytesIO() with aws_encryption_sdk.StreamEncryptor( - source=plaintext, - key_provider=key_provider, - encryption_context=VALUES["encryption_context"], + source=plaintext, key_provider=key_provider, encryption_context=VALUES["encryption_context"] ) as encryptor: for chunk in encryptor: ciphertext.write(chunk) - decrypted, _header = aws_encryption_sdk.decrypt( - source=ciphertext.getvalue(), key_provider=key_provider - ) + decrypted, _header = aws_encryption_sdk.decrypt(source=ciphertext.getvalue(), key_provider=key_provider) assert decrypted == VALUES["plaintext_128"] @@ -641,15 +559,11 @@ def test_stream_decryptor_no_seek_input(): """Test that StreamDecryptor can handle an input stream that is not seekable.""" key_provider = fake_kms_key_provider() ciphertext, _header = aws_encryption_sdk.encrypt( - source=VALUES["plaintext_128"], - key_provider=key_provider, - encryption_context=VALUES["encryption_context"], + source=VALUES["plaintext_128"], key_provider=key_provider, encryption_context=VALUES["encryption_context"] ) ciphertext_no_seek = NoSeekBytesIO(ciphertext) decrypted = io.BytesIO() - with aws_encryption_sdk.StreamDecryptor( - source=ciphertext_no_seek, key_provider=key_provider - ) as decryptor: + with aws_encryption_sdk.StreamDecryptor(source=ciphertext_no_seek, key_provider=key_provider) as decryptor: for chunk in decryptor: decrypted.write(chunk) assert decrypted.getvalue() == VALUES["plaintext_128"] @@ -660,13 +574,9 @@ def test_encrypt_oneshot_no_seek_input(): key_provider = fake_kms_key_provider() plaintext = NoSeekBytesIO(VALUES["plaintext_128"]) ciphertext, _header = aws_encryption_sdk.encrypt( - source=plaintext, - key_provider=key_provider, - encryption_context=VALUES["encryption_context"], - ) - decrypted, _header = aws_encryption_sdk.decrypt( - source=ciphertext, key_provider=key_provider + source=plaintext, key_provider=key_provider, encryption_context=VALUES["encryption_context"] ) + decrypted, _header = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=key_provider) assert decrypted == VALUES["plaintext_128"] @@ -674,14 +584,10 @@ def test_decrypt_oneshot_no_seek_input(): """Test that decrypt can handle an input stream that is not seekable.""" key_provider = fake_kms_key_provider() ciphertext, _header = aws_encryption_sdk.encrypt( - source=VALUES["plaintext_128"], - key_provider=key_provider, - encryption_context=VALUES["encryption_context"], + source=VALUES["plaintext_128"], key_provider=key_provider, encryption_context=VALUES["encryption_context"] ) ciphertext_no_seek = NoSeekBytesIO(ciphertext) - decrypted, _header = aws_encryption_sdk.decrypt( - source=ciphertext_no_seek, key_provider=key_provider - ) + decrypted, _header = aws_encryption_sdk.decrypt(source=ciphertext_no_seek, key_provider=key_provider) assert decrypted == VALUES["plaintext_128"] @@ -689,9 +595,7 @@ def test_stream_encryptor_readable(): """Verify that open StreamEncryptor instances report as readable.""" key_provider = fake_kms_key_provider() plaintext = io.BytesIO(VALUES["plaintext_128"]) - with aws_encryption_sdk.StreamEncryptor( - source=plaintext, key_provider=key_provider - ) as handler: + with aws_encryption_sdk.StreamEncryptor(source=plaintext, key_provider=key_provider) as handler: assert handler.readable() handler.read() assert not handler.readable() @@ -701,12 +605,8 @@ def test_stream_decryptor_readable(): """Verify that open StreamEncryptor instances report as readable.""" key_provider = fake_kms_key_provider() plaintext = io.BytesIO(VALUES["plaintext_128"]) - ciphertext, _header = aws_encryption_sdk.encrypt( - source=plaintext, key_provider=key_provider - ) - with aws_encryption_sdk.StreamDecryptor( - source=ciphertext, key_provider=key_provider - ) as handler: + ciphertext, _header = aws_encryption_sdk.encrypt(source=plaintext, key_provider=key_provider) + with aws_encryption_sdk.StreamDecryptor(source=ciphertext, key_provider=key_provider) as handler: assert handler.readable() handler.read() assert not handler.readable() @@ -767,9 +667,7 @@ def test_incomplete_read_stream_cycle(frame_length): decrypted = b"" cycle_count = 0 with aws_encryption_sdk.stream( - mode="decrypt", - source=SometimesIncompleteReaderIO(ciphertext), - key_provider=key_provider, + mode="decrypt", source=SometimesIncompleteReaderIO(ciphertext), key_provider=key_provider ) as decryptor: while True: cycle_count += 1 @@ -817,12 +715,9 @@ def _error_check(capsys_instance): assert "Call stack:" not in stderr +@pytest.mark.parametrize("frame_size", (0, LINE_LENGTH // 2, LINE_LENGTH, LINE_LENGTH * 2)) @pytest.mark.parametrize( - "frame_size", (0, LINE_LENGTH // 2, LINE_LENGTH, LINE_LENGTH * 2) -) -@pytest.mark.parametrize( - "plaintext_length", - (1, LINE_LENGTH // 2, LINE_LENGTH, int(LINE_LENGTH * 1.5), LINE_LENGTH * 2), + "plaintext_length", (1, LINE_LENGTH // 2, LINE_LENGTH, int(LINE_LENGTH * 1.5), LINE_LENGTH * 2) ) def test_plaintext_logs_oneshot(caplog, capsys, plaintext_length, frame_size): plaintext, key_provider = _prep_plaintext_and_logs(caplog, plaintext_length) @@ -835,22 +730,16 @@ def test_plaintext_logs_oneshot(caplog, capsys, plaintext_length, frame_size): _error_check(capsys) +@pytest.mark.parametrize("frame_size", (0, LINE_LENGTH // 2, LINE_LENGTH, LINE_LENGTH * 2)) @pytest.mark.parametrize( - "frame_size", (0, LINE_LENGTH // 2, LINE_LENGTH, LINE_LENGTH * 2) -) -@pytest.mark.parametrize( - "plaintext_length", - (1, LINE_LENGTH // 2, LINE_LENGTH, int(LINE_LENGTH * 1.5), LINE_LENGTH * 2), + "plaintext_length", (1, LINE_LENGTH // 2, LINE_LENGTH, int(LINE_LENGTH * 1.5), LINE_LENGTH * 2) ) def test_plaintext_logs_stream(caplog, capsys, plaintext_length, frame_size): plaintext, key_provider = _prep_plaintext_and_logs(caplog, plaintext_length) ciphertext = b"" with aws_encryption_sdk.stream( - mode="encrypt", - source=plaintext, - key_provider=key_provider, - frame_length=frame_size, + mode="encrypt", source=plaintext, key_provider=key_provider, frame_length=frame_size ) as encryptor: for line in encryptor: ciphertext += line @@ -891,9 +780,7 @@ def test_cycle_minimal_source_stream_api(frame_length, wrapping_class): source=plaintext, key_provider=key_provider, frame_length=frame_length ) ciphertext = wrapping_class(io.BytesIO(raw_ciphertext)) - decrypted, _decrypt_header = aws_encryption_sdk.decrypt( - source=ciphertext, key_provider=key_provider - ) + decrypted, _decrypt_header = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=key_provider) assert raw_plaintext == decrypted @@ -906,9 +793,7 @@ def test_encrypt_minimal_source_stream_api(frame_length, wrapping_class): ciphertext, _encrypt_header = aws_encryption_sdk.encrypt( source=plaintext, key_provider=key_provider, frame_length=frame_length ) - decrypted, _decrypt_header = aws_encryption_sdk.decrypt( - source=ciphertext, key_provider=key_provider - ) + decrypted, _decrypt_header = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=key_provider) assert raw_plaintext == decrypted @@ -921,15 +806,11 @@ def test_decrypt_minimal_source_stream_api(frame_length, wrapping_class): source=plaintext, key_provider=key_provider, frame_length=frame_length ) ciphertext = wrapping_class(io.BytesIO(raw_ciphertext)) - decrypted, _decrypt_header = aws_encryption_sdk.decrypt( - source=ciphertext, key_provider=key_provider - ) + decrypted, _decrypt_header = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=key_provider) assert plaintext == decrypted -def _assert_deprecated_but_not_yet_removed( - logcap, instance, attribute_name, error_message, no_later_than -): +def _assert_deprecated_but_not_yet_removed(logcap, instance, attribute_name, error_message, no_later_than): assert hasattr(instance, attribute_name) assert error_message in logcap.text assert aws_encryption_sdk.__version__ < no_later_than @@ -940,19 +821,13 @@ def _assert_decrypted_and_removed(instance, attribute_name, removed_in): assert aws_encryption_sdk.__version__ >= removed_in -@pytest.mark.parametrize( - "attribute, no_later_than", (("body_start", "1.4.0"), ("body_end", "1.4.0")) -) +@pytest.mark.parametrize("attribute, no_later_than", (("body_start", "1.4.0"), ("body_end", "1.4.0"))) def test_decryptor_deprecated_attributes(caplog, attribute, no_later_than): caplog.set_level(logging.WARNING) plaintext = exact_length_plaintext(100) key_provider = fake_kms_key_provider() - ciphertext, _header = aws_encryption_sdk.encrypt( - source=plaintext, key_provider=key_provider, frame_length=0 - ) - with aws_encryption_sdk.stream( - mode="decrypt", source=ciphertext, key_provider=key_provider - ) as decryptor: + ciphertext, _header = aws_encryption_sdk.encrypt(source=plaintext, key_provider=key_provider, frame_length=0) + with aws_encryption_sdk.stream(mode="decrypt", source=ciphertext, key_provider=key_provider) as decryptor: decrypted = decryptor.read() assert decrypted == plaintext @@ -967,6 +842,4 @@ def test_decryptor_deprecated_attributes(caplog, attribute, no_later_than): no_later_than=no_later_than, ) else: - _assert_decrypted_and_removed( - instance=decryptor, attribute_name=attribute, removed_in=no_later_than - ) + _assert_decrypted_and_removed(instance=decryptor, attribute_name=attribute, removed_in=no_later_than) diff --git a/test/functional/test_f_crypto.py b/test/functional/test_f_crypto.py index 3c90b957d..9d71c040a 100644 --- a/test/functional/test_f_crypto.py +++ b/test/functional/test_f_crypto.py @@ -18,9 +18,7 @@ import aws_encryption_sdk from aws_encryption_sdk.internal.crypto.authentication import Signer -from aws_encryption_sdk.internal.crypto.elliptic_curve import ( - _ecc_static_length_signature, -) +from aws_encryption_sdk.internal.crypto.elliptic_curve import _ecc_static_length_signature pytestmark = [pytest.mark.functional, pytest.mark.local] @@ -28,46 +26,27 @@ # Run several of each type to make get a high probability of forcing signature length correction @pytest.mark.parametrize( "algorithm", - [ - aws_encryption_sdk.AlgorithmSuite.AES_128_GCM_IV12_TAG16_HKDF_SHA256_ECDSA_P256 - for i in range(10) - ] - + [ - aws_encryption_sdk.AlgorithmSuite.AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384 - for i in range(10) - ], + [aws_encryption_sdk.AlgorithmSuite.AES_128_GCM_IV12_TAG16_HKDF_SHA256_ECDSA_P256 for i in range(10)] + + [aws_encryption_sdk.AlgorithmSuite.AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384 for i in range(10)], ) def test_ecc_static_length_signature(algorithm): - private_key = ec.generate_private_key( - curve=algorithm.signing_algorithm_info(), backend=default_backend() - ) + private_key = ec.generate_private_key(curve=algorithm.signing_algorithm_info(), backend=default_backend()) hasher = hashes.Hash(algorithm.signing_hash_type(), backend=default_backend()) data = b"aifuhaw9fe48haw9e8cnavwp9e8fhaw9438fnhjzsudfvhnsa89w74fhp90se8rhgfi" hasher.update(data) digest = hasher.finalize() - signature = _ecc_static_length_signature( - key=private_key, algorithm=algorithm, digest=digest - ) + signature = _ecc_static_length_signature(key=private_key, algorithm=algorithm, digest=digest) assert len(signature) == algorithm.signature_len private_key.public_key().verify( - signature=signature, - data=data, - signature_algorithm=ec.ECDSA(algorithm.signing_hash_type()), + signature=signature, data=data, signature_algorithm=ec.ECDSA(algorithm.signing_hash_type()) ) def test_signer_key_bytes_cycle(): key = ec.generate_private_key(curve=ec.SECP384R1, backend=default_backend()) - signer = Signer( - algorithm=aws_encryption_sdk.AlgorithmSuite.AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384, - key=key, - ) + signer = Signer(algorithm=aws_encryption_sdk.AlgorithmSuite.AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384, key=key) key_bytes = signer.key_bytes() new_signer = Signer.from_key_bytes( - algorithm=aws_encryption_sdk.AlgorithmSuite.AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384, - key_bytes=key_bytes, - ) - assert ( - new_signer.key.private_numbers().private_value - == signer.key.private_numbers().private_value + algorithm=aws_encryption_sdk.AlgorithmSuite.AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384, key_bytes=key_bytes ) + assert new_signer.key.private_numbers().private_value == signer.key.private_numbers().private_value diff --git a/test/functional/test_f_xcompat.py b/test/functional/test_f_xcompat.py index 231fe5d26..d97835b4c 100644 --- a/test/functional/test_f_xcompat.py +++ b/test/functional/test_f_xcompat.py @@ -63,10 +63,7 @@ def _file_root(): } ), } -_KEY_TYPES_MAP = { - b"AES": EncryptionKeyType.SYMMETRIC, - b"RSA": EncryptionKeyType.PRIVATE, -} +_KEY_TYPES_MAP = {b"AES": EncryptionKeyType.SYMMETRIC, b"RSA": EncryptionKeyType.PRIVATE} _STATIC_KEYS = defaultdict(dict) @@ -78,19 +75,13 @@ class StaticStoredMasterKeyProvider(RawMasterKeyProvider): def _get_raw_key(self, key_id): """Finds a loaded raw key.""" try: - algorithm, key_bits, padding_algorithm, padding_hash = key_id.upper().split( - b".", 3 - ) + algorithm, key_bits, padding_algorithm, padding_hash = key_id.upper().split(b".", 3) key_bits = int(key_bits) key_type = _KEY_TYPES_MAP[algorithm] - wrapping_algorithm = _WRAPPING_ALGORITHM_MAP[algorithm][key_bits][ - padding_algorithm - ][padding_hash] + wrapping_algorithm = _WRAPPING_ALGORITHM_MAP[algorithm][key_bits][padding_algorithm][padding_hash] static_key = _STATIC_KEYS[algorithm][key_bits] return WrappingKey( - wrapping_algorithm=wrapping_algorithm, - wrapping_key=static_key, - wrapping_key_type=key_type, + wrapping_algorithm=wrapping_algorithm, wrapping_key=static_key, wrapping_key_type=key_type ) except KeyError: _LOGGER.exception("Unknown Key ID: %s", key_id) @@ -101,9 +92,7 @@ def _get_raw_key(self, key_id): class RawKeyDescription(object): """Customer raw key descriptor used by StaticStoredMasterKeyProvider.""" - encryption_algorithm = attr.ib( - validator=attr.validators.instance_of(six.string_types) - ) + encryption_algorithm = attr.ib(validator=attr.validators.instance_of(six.string_types)) key_bits = attr.ib(validator=attr.validators.instance_of(int)) padding_algorithm = attr.ib(validator=attr.validators.instance_of(six.string_types)) padding_hash = attr.ib(validator=attr.validators.instance_of(six.string_types)) @@ -111,26 +100,15 @@ class RawKeyDescription(object): @property def key_id(self): """Build a key ID from instance parameters.""" - return ".".join( - [ - self.encryption_algorithm, - str(self.key_bits), - self.padding_algorithm, - self.padding_hash, - ] - ) + return ".".join([self.encryption_algorithm, str(self.key_bits), self.padding_algorithm, self.padding_hash]) @attr.s class Scenario(object): """Scenario details.""" - plaintext_filename = attr.ib( - validator=attr.validators.instance_of(six.string_types) - ) - ciphertext_filename = attr.ib( - validator=attr.validators.instance_of(six.string_types) - ) + plaintext_filename = attr.ib(validator=attr.validators.instance_of(six.string_types)) + ciphertext_filename = attr.ib(validator=attr.validators.instance_of(six.string_types)) key_ids = attr.ib(validator=attr.validators.instance_of(list)) @@ -142,9 +120,7 @@ def _generate_test_cases(): # noqa=C901 if not os.path.isdir(root_dir): root_dir = os.getcwd() base_dir = os.path.join(root_dir, "aws_encryption_sdk_resources") - ciphertext_manifest_path = os.path.join( - base_dir, "manifests", "ciphertext.manifest" - ) + ciphertext_manifest_path = os.path.join(base_dir, "manifests", "ciphertext.manifest") if not os.path.isfile(ciphertext_manifest_path): # Make no test cases if the ciphertext file is not found @@ -171,9 +147,7 @@ def _generate_test_cases(): # noqa=C901 # Collect test cases from ciphertext manifest for test_case in ciphertext_manifest["test_cases"]: key_ids = [] - algorithm = aws_encryption_sdk.AlgorithmSuite.get_by_id( - int(test_case["algorithm"], 16) - ) + algorithm = aws_encryption_sdk.AlgorithmSuite.get_by_id(int(test_case["algorithm"], 16)) for key in test_case["master_keys"]: sys.stderr.write("XC:: " + json.dumps(key) + "\n") if key["provider_id"] == StaticStoredMasterKeyProvider.provider_id: @@ -205,7 +179,5 @@ def test_decrypt_from_file(scenario): plaintext = infile.read() key_provider = StaticStoredMasterKeyProvider() key_provider.add_master_keys_from_list(scenario.key_ids) - decrypted_ciphertext, _header = aws_encryption_sdk.decrypt( - source=ciphertext, key_provider=key_provider - ) + decrypted_ciphertext, _header = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=key_provider) assert decrypted_ciphertext == plaintext diff --git a/test/integration/test_i_aws_encrytion_sdk_client.py b/test/integration/test_i_aws_encrytion_sdk_client.py index 2247c3041..2f880c00e 100644 --- a/test/integration/test_i_aws_encrytion_sdk_client.py +++ b/test/integration/test_i_aws_encrytion_sdk_client.py @@ -44,10 +44,7 @@ def test_encrypt_verify_user_agent_kms_master_key_provider(caplog): mkp = setup_kms_master_key_provider() mk = mkp.master_key(get_cmk_arn()) - mk.generate_data_key( - algorithm=AlgorithmSuite.AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384, - encryption_context={}, - ) + mk.generate_data_key(algorithm=AlgorithmSuite.AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384, encryption_context={}) assert USER_AGENT_SUFFIX in caplog.text @@ -56,10 +53,7 @@ def test_encrypt_verify_user_agent_kms_master_key(caplog): caplog.set_level(level=logging.DEBUG) mk = KMSMasterKey(key_id=get_cmk_arn()) - mk.generate_data_key( - algorithm=AlgorithmSuite.AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384, - encryption_context={}, - ) + mk.generate_data_key(algorithm=AlgorithmSuite.AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384, encryption_context={}) assert USER_AGENT_SUFFIX in caplog.text @@ -99,9 +93,7 @@ def test_encryption_cycle_default_algorithm_framed_stream(self): ciphertext = encryptor.read() header_1 = encryptor.header with aws_encryption_sdk.stream( - source=io.BytesIO(ciphertext), - key_provider=self.kms_master_key_provider, - mode="d", + source=io.BytesIO(ciphertext), key_provider=self.kms_master_key_provider, mode="d" ) as decryptor: plaintext = decryptor.read() header_2 = decryptor.header @@ -125,9 +117,7 @@ def test_encryption_cycle_default_algorithm_framed_stream_many_lines(self): header_1 = encryptor.header plaintext = b"" with aws_encryption_sdk.stream( - source=io.BytesIO(ciphertext), - key_provider=self.kms_master_key_provider, - mode="d", + source=io.BytesIO(ciphertext), key_provider=self.kms_master_key_provider, mode="d" ) as decryptor: for chunk in decryptor: plaintext += chunk @@ -145,9 +135,7 @@ def test_encryption_cycle_default_algorithm_non_framed(self): encryption_context=VALUES["encryption_context"], frame_length=0, ) - plaintext, _ = aws_encryption_sdk.decrypt( - source=ciphertext, key_provider=self.kms_master_key_provider - ) + plaintext, _ = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=self.kms_master_key_provider) assert plaintext == VALUES["plaintext_128"] def test_encryption_cycle_default_algorithm_non_framed_no_encryption_context(self): @@ -155,13 +143,9 @@ def test_encryption_cycle_default_algorithm_non_framed_no_encryption_context(sel for a non-framed message using the default algorithm. """ ciphertext, _ = aws_encryption_sdk.encrypt( - source=VALUES["plaintext_128"], - key_provider=self.kms_master_key_provider, - frame_length=0, - ) - plaintext, _ = aws_encryption_sdk.decrypt( - source=ciphertext, key_provider=self.kms_master_key_provider + source=VALUES["plaintext_128"], key_provider=self.kms_master_key_provider, frame_length=0 ) + plaintext, _ = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=self.kms_master_key_provider) assert plaintext == VALUES["plaintext_128"] def test_encryption_cycle_default_algorithm_single_frame(self): @@ -174,9 +158,7 @@ def test_encryption_cycle_default_algorithm_single_frame(self): encryption_context=VALUES["encryption_context"], frame_length=1024, ) - plaintext, _ = aws_encryption_sdk.decrypt( - source=ciphertext, key_provider=self.kms_master_key_provider - ) + plaintext, _ = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=self.kms_master_key_provider) assert plaintext == VALUES["plaintext_128"] def test_encryption_cycle_default_algorithm_multiple_frames(self): @@ -190,9 +172,7 @@ def test_encryption_cycle_default_algorithm_multiple_frames(self): encryption_context=VALUES["encryption_context"], frame_length=1024, ) - plaintext, _ = aws_encryption_sdk.decrypt( - source=ciphertext, key_provider=self.kms_master_key_provider - ) + plaintext, _ = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=self.kms_master_key_provider) assert plaintext == VALUES["plaintext_128"] * 100 def test_encryption_cycle_aes_128_gcm_iv12_tag16_single_frame(self): @@ -207,9 +187,7 @@ def test_encryption_cycle_aes_128_gcm_iv12_tag16_single_frame(self): frame_length=1024, algorithm=AlgorithmSuite.AES_128_GCM_IV12_TAG16, ) - plaintext, _ = aws_encryption_sdk.decrypt( - source=ciphertext, key_provider=self.kms_master_key_provider - ) + plaintext, _ = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=self.kms_master_key_provider) assert plaintext == VALUES["plaintext_128"] def test_encryption_cycle_aes_128_gcm_iv12_tag16_non_framed(self): @@ -224,9 +202,7 @@ def test_encryption_cycle_aes_128_gcm_iv12_tag16_non_framed(self): frame_length=0, algorithm=AlgorithmSuite.AES_128_GCM_IV12_TAG16, ) - plaintext, _ = aws_encryption_sdk.decrypt( - source=ciphertext, key_provider=self.kms_master_key_provider - ) + plaintext, _ = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=self.kms_master_key_provider) assert plaintext == VALUES["plaintext_128"] def test_encryption_cycle_aes_192_gcm_iv12_tag16_single_frame(self): @@ -241,9 +217,7 @@ def test_encryption_cycle_aes_192_gcm_iv12_tag16_single_frame(self): frame_length=1024, algorithm=AlgorithmSuite.AES_192_GCM_IV12_TAG16, ) - plaintext, _ = aws_encryption_sdk.decrypt( - source=ciphertext, key_provider=self.kms_master_key_provider - ) + plaintext, _ = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=self.kms_master_key_provider) assert plaintext == VALUES["plaintext_128"] def test_encryption_cycle_aes_192_gcm_iv12_tag16_non_framed(self): @@ -258,9 +232,7 @@ def test_encryption_cycle_aes_192_gcm_iv12_tag16_non_framed(self): frame_length=0, algorithm=AlgorithmSuite.AES_192_GCM_IV12_TAG16, ) - plaintext, _ = aws_encryption_sdk.decrypt( - source=ciphertext, key_provider=self.kms_master_key_provider - ) + plaintext, _ = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=self.kms_master_key_provider) assert plaintext == VALUES["plaintext_128"] def test_encryption_cycle_aes_256_gcm_iv12_tag16_single_frame(self): @@ -275,9 +247,7 @@ def test_encryption_cycle_aes_256_gcm_iv12_tag16_single_frame(self): frame_length=1024, algorithm=AlgorithmSuite.AES_256_GCM_IV12_TAG16, ) - plaintext, _ = aws_encryption_sdk.decrypt( - source=ciphertext, key_provider=self.kms_master_key_provider - ) + plaintext, _ = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=self.kms_master_key_provider) assert plaintext == VALUES["plaintext_128"] def test_encryption_cycle_aes_256_gcm_iv12_tag16_non_framed(self): @@ -292,9 +262,7 @@ def test_encryption_cycle_aes_256_gcm_iv12_tag16_non_framed(self): frame_length=0, algorithm=AlgorithmSuite.AES_256_GCM_IV12_TAG16, ) - plaintext, _ = aws_encryption_sdk.decrypt( - source=ciphertext, key_provider=self.kms_master_key_provider - ) + plaintext, _ = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=self.kms_master_key_provider) assert plaintext == VALUES["plaintext_128"] def test_encryption_cycle_aes_128_gcm_iv12_tag16_hkdf_sha256_single_frame(self): @@ -309,9 +277,7 @@ def test_encryption_cycle_aes_128_gcm_iv12_tag16_hkdf_sha256_single_frame(self): frame_length=1024, algorithm=AlgorithmSuite.AES_128_GCM_IV12_TAG16_HKDF_SHA256, ) - plaintext, _ = aws_encryption_sdk.decrypt( - source=ciphertext, key_provider=self.kms_master_key_provider - ) + plaintext, _ = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=self.kms_master_key_provider) assert plaintext == VALUES["plaintext_128"] def test_encryption_cycle_aes_128_gcm_iv12_tag16_hkdf_sha256_non_framed(self): @@ -326,9 +292,7 @@ def test_encryption_cycle_aes_128_gcm_iv12_tag16_hkdf_sha256_non_framed(self): frame_length=0, algorithm=AlgorithmSuite.AES_128_GCM_IV12_TAG16_HKDF_SHA256, ) - plaintext, _ = aws_encryption_sdk.decrypt( - source=ciphertext, key_provider=self.kms_master_key_provider - ) + plaintext, _ = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=self.kms_master_key_provider) assert plaintext == VALUES["plaintext_128"] def test_encryption_cycle_aes_192_gcm_iv12_tag16_hkdf_sha256_single_frame(self): @@ -343,9 +307,7 @@ def test_encryption_cycle_aes_192_gcm_iv12_tag16_hkdf_sha256_single_frame(self): frame_length=1024, algorithm=AlgorithmSuite.AES_192_GCM_IV12_TAG16_HKDF_SHA256, ) - plaintext, _ = aws_encryption_sdk.decrypt( - source=ciphertext, key_provider=self.kms_master_key_provider - ) + plaintext, _ = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=self.kms_master_key_provider) assert plaintext == VALUES["plaintext_128"] def test_encryption_cycle_aes_192_gcm_iv12_tag16_hkdf_sha256_non_framed(self): @@ -360,9 +322,7 @@ def test_encryption_cycle_aes_192_gcm_iv12_tag16_hkdf_sha256_non_framed(self): frame_length=0, algorithm=AlgorithmSuite.AES_192_GCM_IV12_TAG16_HKDF_SHA256, ) - plaintext, _ = aws_encryption_sdk.decrypt( - source=ciphertext, key_provider=self.kms_master_key_provider - ) + plaintext, _ = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=self.kms_master_key_provider) assert plaintext == VALUES["plaintext_128"] def test_encryption_cycle_aes_256_gcm_iv12_tag16_hkdf_sha256_single_frame(self): @@ -377,9 +337,7 @@ def test_encryption_cycle_aes_256_gcm_iv12_tag16_hkdf_sha256_single_frame(self): frame_length=1024, algorithm=AlgorithmSuite.AES_256_GCM_IV12_TAG16_HKDF_SHA256, ) - plaintext, _ = aws_encryption_sdk.decrypt( - source=ciphertext, key_provider=self.kms_master_key_provider - ) + plaintext, _ = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=self.kms_master_key_provider) assert plaintext == VALUES["plaintext_128"] def test_encryption_cycle_aes_256_gcm_iv12_tag16_hkdf_sha256_non_framed(self): @@ -394,14 +352,10 @@ def test_encryption_cycle_aes_256_gcm_iv12_tag16_hkdf_sha256_non_framed(self): frame_length=0, algorithm=AlgorithmSuite.AES_256_GCM_IV12_TAG16_HKDF_SHA256, ) - plaintext, _ = aws_encryption_sdk.decrypt( - source=ciphertext, key_provider=self.kms_master_key_provider - ) + plaintext, _ = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=self.kms_master_key_provider) assert plaintext == VALUES["plaintext_128"] - def test_encryption_cycle_aes_128_gcm_iv12_tag16_hkdf_sha256_ecdsa_p256_single_frame( - self - ): + def test_encryption_cycle_aes_128_gcm_iv12_tag16_hkdf_sha256_ecdsa_p256_single_frame(self): """Test that the enrypt/decrypt cycle completes successfully for a single frame message using the aes_128_gcm_iv12_tag16_hkdf_sha256_ecdsa_p256 algorithm. @@ -413,14 +367,10 @@ def test_encryption_cycle_aes_128_gcm_iv12_tag16_hkdf_sha256_ecdsa_p256_single_f frame_length=1024, algorithm=AlgorithmSuite.AES_128_GCM_IV12_TAG16_HKDF_SHA256_ECDSA_P256, ) - plaintext, _ = aws_encryption_sdk.decrypt( - source=ciphertext, key_provider=self.kms_master_key_provider - ) + plaintext, _ = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=self.kms_master_key_provider) assert plaintext == VALUES["plaintext_128"] - def test_encryption_cycle_aes_128_gcm_iv12_tag16_hkdf_sha256_ecdsa_p256_non_framed( - self - ): + def test_encryption_cycle_aes_128_gcm_iv12_tag16_hkdf_sha256_ecdsa_p256_non_framed(self): """Test that the enrypt/decrypt cycle completes successfully for a single block message using the aes_128_gcm_iv12_tag16_hkdf_sha256_ecdsa_p256 algorithm. @@ -432,14 +382,10 @@ def test_encryption_cycle_aes_128_gcm_iv12_tag16_hkdf_sha256_ecdsa_p256_non_fram frame_length=0, algorithm=AlgorithmSuite.AES_128_GCM_IV12_TAG16_HKDF_SHA256_ECDSA_P256, ) - plaintext, _ = aws_encryption_sdk.decrypt( - source=ciphertext, key_provider=self.kms_master_key_provider - ) + plaintext, _ = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=self.kms_master_key_provider) assert plaintext == VALUES["plaintext_128"] - def test_encryption_cycle_aes_192_gcm_iv12_tag16_hkdf_sha384_ecdsa_p384_single_frame( - self - ): + def test_encryption_cycle_aes_192_gcm_iv12_tag16_hkdf_sha384_ecdsa_p384_single_frame(self): """Test that the enrypt/decrypt cycle completes successfully for a single frame message using the aes_192_gcm_iv12_tag16_hkdf_sha384_ecdsa_p384 algorithm. @@ -451,14 +397,10 @@ def test_encryption_cycle_aes_192_gcm_iv12_tag16_hkdf_sha384_ecdsa_p384_single_f frame_length=1024, algorithm=AlgorithmSuite.AES_192_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384, ) - plaintext, _ = aws_encryption_sdk.decrypt( - source=ciphertext, key_provider=self.kms_master_key_provider - ) + plaintext, _ = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=self.kms_master_key_provider) assert plaintext == VALUES["plaintext_128"] - def test_encryption_cycle_aes_192_gcm_iv12_tag16_hkdf_sha384_ecdsa_p384_non_framed( - self - ): + def test_encryption_cycle_aes_192_gcm_iv12_tag16_hkdf_sha384_ecdsa_p384_non_framed(self): """Test that the enrypt/decrypt cycle completes successfully for a single block message using the aes_192_gcm_iv12_tag16_hkdf_sha384_ecdsa_p384 algorithm. @@ -470,14 +412,10 @@ def test_encryption_cycle_aes_192_gcm_iv12_tag16_hkdf_sha384_ecdsa_p384_non_fram frame_length=0, algorithm=AlgorithmSuite.AES_192_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384, ) - plaintext, _ = aws_encryption_sdk.decrypt( - source=ciphertext, key_provider=self.kms_master_key_provider - ) + plaintext, _ = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=self.kms_master_key_provider) assert plaintext == VALUES["plaintext_128"] - def test_encryption_cycle_aes_256_gcm_iv12_tag16_hkdf_sha384_ecdsa_p384_single_frame( - self - ): + def test_encryption_cycle_aes_256_gcm_iv12_tag16_hkdf_sha384_ecdsa_p384_single_frame(self): """Test that the enrypt/decrypt cycle completes successfully for a single frame message using the aes_256_gcm_iv12_tag16_hkdf_sha384_ecdsa_p384 algorithm. @@ -489,14 +427,10 @@ def test_encryption_cycle_aes_256_gcm_iv12_tag16_hkdf_sha384_ecdsa_p384_single_f frame_length=1024, algorithm=AlgorithmSuite.AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384, ) - plaintext, _ = aws_encryption_sdk.decrypt( - source=ciphertext, key_provider=self.kms_master_key_provider - ) + plaintext, _ = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=self.kms_master_key_provider) assert plaintext == VALUES["plaintext_128"] - def test_encryption_cycle_aes_256_gcm_iv12_tag16_hkdf_sha384_ecdsa_p384_non_framed( - self - ): + def test_encryption_cycle_aes_256_gcm_iv12_tag16_hkdf_sha384_ecdsa_p384_non_framed(self): """Test that the enrypt/decrypt cycle completes successfully for a single block message using the aes_256_gcm_iv12_tag16_hkdf_sha384_ecdsa_p384 algorithm. @@ -508,7 +442,5 @@ def test_encryption_cycle_aes_256_gcm_iv12_tag16_hkdf_sha384_ecdsa_p384_non_fram frame_length=0, algorithm=AlgorithmSuite.AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384, ) - plaintext, _ = aws_encryption_sdk.decrypt( - source=ciphertext, key_provider=self.kms_master_key_provider - ) + plaintext, _ = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=self.kms_master_key_provider) assert plaintext == VALUES["plaintext_128"] diff --git a/test/integration/test_i_xcompat_kms.py b/test/integration/test_i_xcompat_kms.py index aef83716d..b6f22a3ff 100644 --- a/test/integration/test_i_xcompat_kms.py +++ b/test/integration/test_i_xcompat_kms.py @@ -42,9 +42,7 @@ def _generate_test_cases(): if not os.path.isdir(root_dir): root_dir = os.getcwd() base_dir = os.path.join(root_dir, "aws_encryption_sdk_resources") - ciphertext_manifest_path = os.path.join( - base_dir, "manifests", "ciphertext.manifest" - ) + ciphertext_manifest_path = os.path.join(base_dir, "manifests", "ciphertext.manifest") if not os.path.isfile(ciphertext_manifest_path): # Make no test cases if the ciphertext file is not found @@ -68,9 +66,7 @@ def _generate_test_cases(): return _test_cases -@pytest.mark.parametrize( - "plaintext_filename, ciphertext_filename", _generate_test_cases() -) +@pytest.mark.parametrize("plaintext_filename, ciphertext_filename", _generate_test_cases()) def test_decrypt_from_file(plaintext_filename, ciphertext_filename): """Tests decrypt from known good files.""" with open(ciphertext_filename, "rb") as infile: diff --git a/test/unit/test_crypto_elliptic_curve.py b/test/unit/test_crypto_elliptic_curve.py index fadf9c35d..b030db5c2 100644 --- a/test/unit/test_crypto_elliptic_curve.py +++ b/test/unit/test_crypto_elliptic_curve.py @@ -38,9 +38,7 @@ @pytest.yield_fixture def patch_default_backend(mocker): - mocker.patch.object( - aws_encryption_sdk.internal.crypto.elliptic_curve, "default_backend" - ) + mocker.patch.object(aws_encryption_sdk.internal.crypto.elliptic_curve, "default_backend") yield aws_encryption_sdk.internal.crypto.elliptic_curve.default_backend @@ -58,42 +56,31 @@ def patch_pow(mocker): @pytest.yield_fixture def patch_encode_dss_signature(mocker): - mocker.patch.object( - aws_encryption_sdk.internal.crypto.elliptic_curve, "encode_dss_signature" - ) + mocker.patch.object(aws_encryption_sdk.internal.crypto.elliptic_curve, "encode_dss_signature") yield aws_encryption_sdk.internal.crypto.elliptic_curve.encode_dss_signature @pytest.yield_fixture def patch_decode_dss_signature(mocker): - mocker.patch.object( - aws_encryption_sdk.internal.crypto.elliptic_curve, "decode_dss_signature" - ) + mocker.patch.object(aws_encryption_sdk.internal.crypto.elliptic_curve, "decode_dss_signature") yield aws_encryption_sdk.internal.crypto.elliptic_curve.decode_dss_signature @pytest.yield_fixture def patch_ecc_decode_compressed_point(mocker): - mocker.patch.object( - aws_encryption_sdk.internal.crypto.elliptic_curve, - "_ecc_decode_compressed_point", - ) + mocker.patch.object(aws_encryption_sdk.internal.crypto.elliptic_curve, "_ecc_decode_compressed_point") yield aws_encryption_sdk.internal.crypto.elliptic_curve._ecc_decode_compressed_point @pytest.yield_fixture def patch_verify_interface(mocker): - mocker.patch.object( - aws_encryption_sdk.internal.crypto.elliptic_curve, "verify_interface" - ) + mocker.patch.object(aws_encryption_sdk.internal.crypto.elliptic_curve, "verify_interface") yield aws_encryption_sdk.internal.crypto.elliptic_curve.verify_interface @pytest.yield_fixture def patch_ecc_curve_parameters(mocker): - mocker.patch.object( - aws_encryption_sdk.internal.crypto.elliptic_curve, "_ECC_CURVE_PARAMETERS" - ) + mocker.patch.object(aws_encryption_sdk.internal.crypto.elliptic_curve, "_ECC_CURVE_PARAMETERS") yield aws_encryption_sdk.internal.crypto.elliptic_curve._ECC_CURVE_PARAMETERS @@ -115,45 +102,9 @@ def test_ecc_curve_not_in_cryptography(): def test_ecc_curve_parameters_secp256r1(): """Verify values from http://www.secg.org/sec2-v2.pdf""" p = pow(2, 224) * (pow(2, 32) - 1) + pow(2, 192) + pow(2, 96) - 1 - a = int( - ( - "FFFFFFFF" - "00000001" - "00000000" - "00000000" - "00000000" - "FFFFFFFF" - "FFFFFFFF" - "FFFFFFFC" - ), - 16, - ) - b = int( - ( - "5AC635D8" - "AA3A93E7" - "B3EBBD55" - "769886BC" - "651D06B0" - "CC53B0F6" - "3BCE3C3E" - "27D2604B" - ), - 16, - ) - order = int( - ( - "FFFFFFFF" - "00000000" - "FFFFFFFF" - "FFFFFFFF" - "BCE6FAAD" - "A7179E84" - "F3B9CAC2" - "FC632551" - ), - 16, - ) + a = int(("FFFFFFFF" "00000001" "00000000" "00000000" "00000000" "FFFFFFFF" "FFFFFFFF" "FFFFFFFC"), 16) + b = int(("5AC635D8" "AA3A93E7" "B3EBBD55" "769886BC" "651D06B0" "CC53B0F6" "3BCE3C3E" "27D2604B"), 16) + order = int(("FFFFFFFF" "00000000" "FFFFFFFF" "FFFFFFFF" "BCE6FAAD" "A7179E84" "F3B9CAC2" "FC632551"), 16) assert _ECC_CURVE_PARAMETERS["secp256r1"].p == p assert _ECC_CURVE_PARAMETERS["secp256r1"].a == a assert _ECC_CURVE_PARAMETERS["secp256r1"].b == b @@ -296,34 +247,22 @@ def test_ecc_curve_parameters_secp521r1(): def test_ecc_static_length_signature_first_try( - patch_default_backend, - patch_ec, - patch_encode_dss_signature, - patch_decode_dss_signature, - patch_prehashed, + patch_default_backend, patch_ec, patch_encode_dss_signature, patch_decode_dss_signature, patch_prehashed ): algorithm = MagicMock(signature_len=55) private_key = MagicMock() private_key.sign.return_value = b"a" * 55 - test_signature = _ecc_static_length_signature( - key=private_key, algorithm=algorithm, digest=sentinel.digest - ) + test_signature = _ecc_static_length_signature(key=private_key, algorithm=algorithm, digest=sentinel.digest) patch_prehashed.assert_called_once_with(algorithm.signing_hash_type.return_value) patch_ec.ECDSA.assert_called_once_with(patch_prehashed.return_value) - private_key.sign.assert_called_once_with( - sentinel.digest, patch_ec.ECDSA.return_value - ) + private_key.sign.assert_called_once_with(sentinel.digest, patch_ec.ECDSA.return_value) assert not patch_encode_dss_signature.called assert not patch_decode_dss_signature.called assert test_signature is private_key.sign.return_value def test_ecc_static_length_signature_single_negation( - patch_default_backend, - patch_ec, - patch_encode_dss_signature, - patch_decode_dss_signature, - patch_prehashed, + patch_default_backend, patch_ec, patch_encode_dss_signature, patch_decode_dss_signature, patch_prehashed ): algorithm = MagicMock(signature_len=55) algorithm.signing_algorithm_info.name = "secp256r1" @@ -331,23 +270,15 @@ def test_ecc_static_length_signature_single_negation( private_key.sign.return_value = b"a" patch_decode_dss_signature.return_value = sentinel.r, 100 patch_encode_dss_signature.return_value = "a" * 55 - test_signature = _ecc_static_length_signature( - key=private_key, algorithm=algorithm, digest=sentinel.digest - ) + test_signature = _ecc_static_length_signature(key=private_key, algorithm=algorithm, digest=sentinel.digest) assert len(private_key.sign.mock_calls) == 1 patch_decode_dss_signature.assert_called_once_with(b"a") - patch_encode_dss_signature.assert_called_once_with( - sentinel.r, _ECC_CURVE_PARAMETERS["secp256r1"].order - 100 - ) + patch_encode_dss_signature.assert_called_once_with(sentinel.r, _ECC_CURVE_PARAMETERS["secp256r1"].order - 100) assert test_signature is patch_encode_dss_signature.return_value def test_ecc_static_length_signature_recalculate( - patch_default_backend, - patch_ec, - patch_encode_dss_signature, - patch_decode_dss_signature, - patch_prehashed, + patch_default_backend, patch_ec, patch_encode_dss_signature, patch_decode_dss_signature, patch_prehashed ): algorithm = MagicMock(signature_len=55) algorithm.signing_algorithm_info.name = "secp256r1" @@ -355,9 +286,7 @@ def test_ecc_static_length_signature_recalculate( private_key.sign.side_effect = (b"a", b"b" * 55) patch_decode_dss_signature.return_value = sentinel.r, 100 patch_encode_dss_signature.return_value = "a" * 100 - test_signature = _ecc_static_length_signature( - key=private_key, algorithm=algorithm, digest=sentinel.digest - ) + test_signature = _ecc_static_length_signature(key=private_key, algorithm=algorithm, digest=sentinel.digest) assert len(private_key.sign.mock_calls) == 2 assert len(patch_decode_dss_signature.mock_calls) == 1 assert len(patch_encode_dss_signature.mock_calls) == 1 @@ -365,9 +294,7 @@ def test_ecc_static_length_signature_recalculate( def test_ecc_encode_compressed_point_prime(): - compressed_point = _ecc_encode_compressed_point( - private_key=VALUES["ecc_private_key_prime"] - ) + compressed_point = _ecc_encode_compressed_point(private_key=VALUES["ecc_private_key_prime"]) assert compressed_point == VALUES["ecc_compressed_point"] @@ -386,110 +313,82 @@ def test_ecc_decode_compressed_point_infinity(): def test_ecc_decode_compressed_point_prime(): - x, y = _ecc_decode_compressed_point( - curve=ec.SECP384R1(), compressed_point=VALUES["ecc_compressed_point"] - ) + x, y = _ecc_decode_compressed_point(curve=ec.SECP384R1(), compressed_point=VALUES["ecc_compressed_point"]) numbers = VALUES["ecc_private_key_prime"].public_key().public_numbers() assert x == numbers.x assert y == numbers.y @pytest.mark.skipif( - sys.version_info.major == 3 and sys.version_info.minor == 4, - reason='Patching builtin "pow" fails in Python3.4', + sys.version_info.major == 3 and sys.version_info.minor == 4, reason='Patching builtin "pow" fails in Python3.4' ) def test_ecc_decode_compressed_point_prime_characteristic_two(patch_pow): patch_pow.return_value = 1 - _, y = _ecc_decode_compressed_point( - curve=ec.SECP384R1(), compressed_point=VALUES["ecc_compressed_point"] - ) + _, y = _ecc_decode_compressed_point(curve=ec.SECP384R1(), compressed_point=VALUES["ecc_compressed_point"]) assert y == 1 @pytest.mark.skipif( - sys.version_info.major == 3 and sys.version_info.minor == 4, - reason='Patching builtin "pow" fails in Python3.4', + sys.version_info.major == 3 and sys.version_info.minor == 4, reason='Patching builtin "pow" fails in Python3.4' ) def test_ecc_decode_compressed_point_prime_not_characteristic_two(patch_pow): patch_pow.return_value = 0 - _, y = _ecc_decode_compressed_point( - curve=ec.SECP384R1(), compressed_point=VALUES["ecc_compressed_point"] - ) + _, y = _ecc_decode_compressed_point(curve=ec.SECP384R1(), compressed_point=VALUES["ecc_compressed_point"]) assert y == _ECC_CURVE_PARAMETERS["secp384r1"].p def test_ecc_decode_compressed_point_prime_unsupported(): with pytest.raises(NotSupportedError) as excinfo: - _ecc_decode_compressed_point( - curve=ec.SECP192R1(), - compressed_point="\x02skdgaiuhgijudflkjsdgfkjsdflgjhsd", - ) + _ecc_decode_compressed_point(curve=ec.SECP192R1(), compressed_point="\x02skdgaiuhgijudflkjsdgfkjsdflgjhsd") excinfo.match(r"Curve secp192r1 is not supported at this time") def test_ecc_decode_compressed_point_prime_complex(patch_ecc_curve_parameters): - patch_ecc_curve_parameters.__getitem__.return_value = _ECCCurveParameters( - p=5, a=5, b=5, order=5 - ) + patch_ecc_curve_parameters.__getitem__.return_value = _ECCCurveParameters(p=5, a=5, b=5, order=5) mock_curve = MagicMock() mock_curve.name = "secp_mock_curve" with pytest.raises(NotSupportedError) as excinfo: - _ecc_decode_compressed_point( - curve=mock_curve, compressed_point=VALUES["ecc_compressed_point"] - ) + _ecc_decode_compressed_point(curve=mock_curve, compressed_point=VALUES["ecc_compressed_point"]) excinfo.match(r"S not 1 :: Curve not supported at this time") def test_ecc_decode_compressed_point_nonprime_characteristic_two(): with pytest.raises(NotSupportedError) as excinfo: - _ecc_decode_compressed_point( - curve=ec.SECT409K1(), - compressed_point="\x02skdgaiuhgijudflkjsdgfkjsdflgjhsd", - ) + _ecc_decode_compressed_point(curve=ec.SECT409K1(), compressed_point="\x02skdgaiuhgijudflkjsdgfkjsdflgjhsd") excinfo.match(r"Non-prime curves are not supported at this time") -def test_ecc_public_numbers_from_compressed_point( - patch_ec, patch_ecc_decode_compressed_point -): +def test_ecc_public_numbers_from_compressed_point(patch_ec, patch_ecc_decode_compressed_point): patch_ecc_decode_compressed_point.return_value = sentinel.x, sentinel.y patch_ec.EllipticCurvePublicNumbers.return_value = sentinel.public_numbers_instance test = _ecc_public_numbers_from_compressed_point( curve=sentinel.curve_instance, compressed_point=sentinel.compressed_point ) - patch_ecc_decode_compressed_point.assert_called_once_with( - sentinel.curve_instance, sentinel.compressed_point - ) + patch_ecc_decode_compressed_point.assert_called_once_with(sentinel.curve_instance, sentinel.compressed_point) patch_ec.EllipticCurvePublicNumbers.assert_called_once_with( x=sentinel.x, y=sentinel.y, curve=sentinel.curve_instance ) assert test == sentinel.public_numbers_instance -def test_generate_ecc_signing_key_supported( - patch_default_backend, patch_ec, patch_verify_interface -): +def test_generate_ecc_signing_key_supported(patch_default_backend, patch_ec, patch_verify_interface): patch_ec.generate_private_key.return_value = sentinel.raw_signing_key mock_algorithm_info = MagicMock(return_value=sentinel.algorithm_info) mock_algorithm = MagicMock(signing_algorithm_info=mock_algorithm_info) test_signing_key = generate_ecc_signing_key(algorithm=mock_algorithm) - patch_verify_interface.assert_called_once_with( - patch_ec.EllipticCurve, mock_algorithm_info - ) + patch_verify_interface.assert_called_once_with(patch_ec.EllipticCurve, mock_algorithm_info) patch_ec.generate_private_key.assert_called_once_with( curve=sentinel.algorithm_info, backend=patch_default_backend.return_value ) assert test_signing_key is sentinel.raw_signing_key -def test_generate_ecc_signing_key_unsupported( - patch_default_backend, patch_ec, patch_verify_interface -): +def test_generate_ecc_signing_key_unsupported(patch_default_backend, patch_ec, patch_verify_interface): patch_verify_interface.side_effect = InterfaceNotImplemented mock_algorithm_info = MagicMock(return_value=sentinel.algorithm_info) mock_algorithm = MagicMock(signing_algorithm_info=mock_algorithm_info) From 7b130f268ff3c564fd000fa63542320ef7502330 Mon Sep 17 00:00:00 2001 From: Adriano Hernandez Date: Wed, 14 Aug 2019 20:12:48 -0700 Subject: [PATCH 4/6] Ragona Comments --- src/aws_encryption_sdk/identifiers.py | 4 +--- .../internal/crypto/authentication.py | 2 +- test/integration/test_i_aws_encrytion_sdk_client.py | 10 ++-------- 3 files changed, 4 insertions(+), 12 deletions(-) diff --git a/src/aws_encryption_sdk/identifiers.py b/src/aws_encryption_sdk/identifiers.py index b8ec35910..d9033a315 100644 --- a/src/aws_encryption_sdk/identifiers.py +++ b/src/aws_encryption_sdk/identifiers.py @@ -239,9 +239,7 @@ def safe_to_cache(self): """Determine whether encryption materials for this algorithm suite should be cached.""" return self.kdf is not KDFSuite.NONE - -# algorithm is just an alias for AlgorithmSuite ... but Sphinx does not recognize this fact -# so we need to go through and fix the references +# sphinx linking won't work if you reference Algorithm instead of AlgorithmSuite Algorithm = AlgorithmSuite diff --git a/src/aws_encryption_sdk/internal/crypto/authentication.py b/src/aws_encryption_sdk/internal/crypto/authentication.py index dc9929bf7..c0f14a8d2 100644 --- a/src/aws_encryption_sdk/internal/crypto/authentication.py +++ b/src/aws_encryption_sdk/internal/crypto/authentication.py @@ -46,7 +46,7 @@ def __init__(self, algorithm, key): self._hasher = self._build_hasher() def _set_signature_type(self): - """Ensures that the algorithm (suite) signature type is a known type and sets a reference value.""" + """Ensures that the algorithm suite signature type is a known type and sets a reference value.""" try: verify_interface(ec.EllipticCurve, self.algorithm.signing_algorithm_info) return ec.EllipticCurve diff --git a/test/integration/test_i_aws_encrytion_sdk_client.py b/test/integration/test_i_aws_encrytion_sdk_client.py index 2f880c00e..ec62dacfe 100644 --- a/test/integration/test_i_aws_encrytion_sdk_client.py +++ b/test/integration/test_i_aws_encrytion_sdk_client.py @@ -64,14 +64,8 @@ def test_remove_bad_client(): with pytest.raises(BotoCoreError): test._regional_clients["us-fakey-12"].list_keys() - # I believe that because KMSMasterKeyProvider() sets a default regional client - # we want to test that the fake key was properly removed, instead of the dict (of regional clients) - # being empty. That is to say, after the first line of this test function - # the dict is NOT EMPTY, and this default first value will stay with us, so - # if we test for emptiness of the dict then we will get a non-passing test, when really - # it might be passing. The old line is commented out in case it matters later. - - # assert not test._regional_clients + # instead of asserting emptiness we check for the specific bad key we added + # because there may be other keys depending on other factors assert "us-fakey-12" not in test._regional_clients From ea6def4c9db578a18903337cdee475e72558c532 Mon Sep 17 00:00:00 2001 From: Adriano Hernandez Date: Wed, 14 Aug 2019 20:22:01 -0700 Subject: [PATCH 5/6] Formatted for flake8 with autoformat. --- src/aws_encryption_sdk/identifiers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/aws_encryption_sdk/identifiers.py b/src/aws_encryption_sdk/identifiers.py index d9033a315..cb91da78f 100644 --- a/src/aws_encryption_sdk/identifiers.py +++ b/src/aws_encryption_sdk/identifiers.py @@ -239,6 +239,7 @@ def safe_to_cache(self): """Determine whether encryption materials for this algorithm suite should be cached.""" return self.kdf is not KDFSuite.NONE + # sphinx linking won't work if you reference Algorithm instead of AlgorithmSuite Algorithm = AlgorithmSuite From 63b35a83f74be051ff6694baf6a2edc64a5a95b9 Mon Sep 17 00:00:00 2001 From: Adriano Hernandez Date: Thu, 15 Aug 2019 14:44:47 -0700 Subject: [PATCH 6/6] formatted regona comment area to be standardized --- src/aws_encryption_sdk/identifiers.py | 40 +++++++++++++++++++++++---- 1 file changed, 35 insertions(+), 5 deletions(-) diff --git a/src/aws_encryption_sdk/identifiers.py b/src/aws_encryption_sdk/identifiers.py index cb91da78f..cc4d32a6a 100644 --- a/src/aws_encryption_sdk/identifiers.py +++ b/src/aws_encryption_sdk/identifiers.py @@ -293,11 +293,41 @@ class WrappingAlgorithm(Enum): None, None, ) - RSA_PKCS1 = (EncryptionType.ASYMMETRIC, rsa, padding.PKCS1v15, None, None) - RSA_OAEP_SHA1_MGF1 = (EncryptionType.ASYMMETRIC, rsa, padding.OAEP, hashes.SHA1, padding.MGF1) - RSA_OAEP_SHA256_MGF1 = (EncryptionType.ASYMMETRIC, rsa, padding.OAEP, hashes.SHA256, padding.MGF1) - RSA_OAEP_SHA384_MGF1 = (EncryptionType.ASYMMETRIC, rsa, padding.OAEP, hashes.SHA384, padding.MGF1) - RSA_OAEP_SHA512_MGF1 = (EncryptionType.ASYMMETRIC, rsa, padding.OAEP, hashes.SHA512, padding.MGF1) + RSA_PKCS1 = ( + EncryptionType.ASYMMETRIC, + rsa, + padding.PKCS1v15, + None, + None + ) + RSA_OAEP_SHA1_MGF1 = ( + EncryptionType.ASYMMETRIC, + rsa, + padding.OAEP, + hashes.SHA1, + padding.MGF1 + ) + RSA_OAEP_SHA256_MGF1 = ( + EncryptionType.ASYMMETRIC, + rsa, + padding.OAEP, + hashes.SHA256, + padding.MGF1 + ) + RSA_OAEP_SHA384_MGF1 = ( + EncryptionType.ASYMMETRIC, + rsa, + padding.OAEP, + hashes.SHA384, + padding.MGF1 + ) + RSA_OAEP_SHA512_MGF1 = ( + EncryptionType.ASYMMETRIC, + rsa, + padding.OAEP, + hashes.SHA512, + padding.MGF1 + ) def __init__(self, encryption_type, algorithm, padding_type, padding_algorithm, padding_mgf): """Prepares new WrappingAlgorithm."""