diff --git a/src/aws_encryption_sdk/__init__.py b/src/aws_encryption_sdk/__init__.py index aba067601..88da93d25 100644 --- a/src/aws_encryption_sdk/__init__.py +++ b/src/aws_encryption_sdk/__init__.py @@ -1,6 +1,8 @@ # Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 """High level AWS Encryption SDK client functions.""" +import copy + # Below are imported for ease of use by implementors from aws_encryption_sdk.caches.local import LocalCryptoMaterialsCache # noqa from aws_encryption_sdk.caches.null import NullCryptoMaterialsCache # noqa @@ -84,7 +86,10 @@ def encrypt(**kwargs): with StreamEncryptor(**kwargs) as encryptor: ciphertext = encryptor.read() - return CryptoResult(result=ciphertext, header=encryptor.header, keyring_trace=encryptor.keyring_trace) + header_copy = copy.deepcopy(encryptor.header) + keyring_trace_copy = copy.deepcopy(encryptor.keyring_trace) + + return CryptoResult(result=ciphertext, header=header_copy, keyring_trace=keyring_trace_copy) def decrypt(**kwargs): @@ -143,7 +148,10 @@ def decrypt(**kwargs): with StreamDecryptor(**kwargs) as decryptor: plaintext = decryptor.read() - return CryptoResult(result=plaintext, header=decryptor.header, keyring_trace=decryptor.keyring_trace) + header_copy = copy.deepcopy(decryptor.header) + keyring_trace_copy = copy.deepcopy(decryptor.keyring_trace) + + return CryptoResult(result=plaintext, header=header_copy, keyring_trace=keyring_trace_copy) def stream(**kwargs): diff --git a/test/unit/test_client.py b/test/unit/test_client.py index d6b763b49..bd40eb03a 100644 --- a/test/unit/test_client.py +++ b/test/unit/test_client.py @@ -53,13 +53,15 @@ def test_encrypt(self): test_ciphertext, test_header = aws_encryption_sdk.encrypt(a=sentinel.a, b=sentinel.b, c=sentinel.b) self.mock_stream_encryptor.called_once_with(a=sentinel.a, b=sentinel.b, c=sentinel.b) assert test_ciphertext is _CIPHERTEXT - assert test_header is _HEADER + assert test_header == _HEADER + assert test_header is not _HEADER def test_decrypt(self): test_plaintext, test_header = aws_encryption_sdk.decrypt(a=sentinel.a, b=sentinel.b, c=sentinel.b) self.mock_stream_encryptor.called_once_with(a=sentinel.a, b=sentinel.b, c=sentinel.b) assert test_plaintext is _PLAINTEXT - assert test_header is _HEADER + assert test_header == _HEADER + assert test_header is not _HEADER def test_stream_encryptor_e(self): test = aws_encryption_sdk.stream(mode="e", a=sentinel.a, b=sentinel.b, c=sentinel.b)