Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 7374fcb

Browse files
committedMar 26, 2024
unit tests
1 parent 2d26009 commit 7374fcb

7 files changed

+565
-72
lines changed
 

‎src/aws_encryption_sdk/streaming_client.py

Lines changed: 73 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -954,8 +954,72 @@ def _prep_message(self):
954954
self._prep_non_framed()
955955
self._message_prepped = True
956956

957-
# TODO-MPL: Refactor this function, remove linter disablers
958-
def _read_header(self): # noqa pylint: disable=too-many-branches
957+
def _create_decrypt_materials_request(self, header):
958+
"""
959+
Create a DecryptionMaterialsRequest based on whether
960+
the StreamDecryptor was provided encryption_context on decrypt
961+
(i.e. expects to use required encryption context CMM from the MPL).
962+
"""
963+
# If encryption_context is provided on decrypt,
964+
# pass it to the DecryptionMaterialsRequest as reproduced_encryption_context
965+
if hasattr(self.config, "encryption_context"):
966+
return DecryptionMaterialsRequest(
967+
encrypted_data_keys=header.encrypted_data_keys,
968+
algorithm=header.algorithm,
969+
encryption_context=header.encryption_context,
970+
commitment_policy=self.config.commitment_policy,
971+
reproduced_encryption_context=self.config.encryption_context
972+
)
973+
return DecryptionMaterialsRequest(
974+
encrypted_data_keys=header.encrypted_data_keys,
975+
algorithm=header.algorithm,
976+
encryption_context=header.encryption_context,
977+
commitment_policy=self.config.commitment_policy,
978+
)
979+
980+
def _validate_parsed_header(
981+
self,
982+
header,
983+
header_auth,
984+
raw_header,
985+
):
986+
"""
987+
Pass arguments from this StreamDecryptor to validate_header based on whether
988+
the StreamDecryptor has the _required_encryption_context attribute
989+
(i.e. is using the required encryption context CMM from the MPL).
990+
"""
991+
# If _required_encryption_context is present,
992+
# serialize it and pass it to validate_header.
993+
if hasattr(self, "_required_encryption_context") \
994+
and self._required_encryption_context is not None:
995+
# The authenticated only encryption context is all encryption context key-value pairs where the
996+
# key exists in Required Encryption Context Keys. It is then serialized according to the
997+
# message header Key Value Pairs.
998+
required_ec_serialized = \
999+
aws_encryption_sdk.internal.formatting.encryption_context.serialize_encryption_context(
1000+
self._required_encryption_context
1001+
)
1002+
1003+
validate_header(
1004+
header=header,
1005+
header_auth=header_auth,
1006+
# When verifying the header, the AAD input to the authenticated encryption algorithm
1007+
# specified by the algorithm suite is the message header body and the serialized
1008+
# authenticated only encryption context.
1009+
raw_header=raw_header + required_ec_serialized,
1010+
data_key=self._derived_data_key
1011+
)
1012+
else:
1013+
validate_header(
1014+
header=header,
1015+
header_auth=header_auth,
1016+
raw_header=raw_header,
1017+
data_key=self._derived_data_key
1018+
)
1019+
1020+
return header, header_auth
1021+
1022+
def _read_header(self):
9591023
"""Reads the message header from the input stream.
9601024
9611025
:returns: tuple containing deserialized header and header_auth objects
@@ -981,24 +1045,7 @@ def _read_header(self): # noqa pylint: disable=too-many-branches
9811045
)
9821046
)
9831047

984-
# If encryption_context is provided on decrypt,
985-
# pass it to the DecryptionMaterialsRequest
986-
if hasattr(self.config, "encryption_context"):
987-
decrypt_materials_request = DecryptionMaterialsRequest(
988-
encrypted_data_keys=header.encrypted_data_keys,
989-
algorithm=header.algorithm,
990-
encryption_context=header.encryption_context,
991-
commitment_policy=self.config.commitment_policy,
992-
reproduced_encryption_context=self.config.encryption_context
993-
)
994-
else:
995-
decrypt_materials_request = DecryptionMaterialsRequest(
996-
encrypted_data_keys=header.encrypted_data_keys,
997-
algorithm=header.algorithm,
998-
encryption_context=header.encryption_context,
999-
commitment_policy=self.config.commitment_policy,
1000-
)
1001-
1048+
decrypt_materials_request = self._create_decrypt_materials_request(header)
10021049
decryption_materials = self.config.materials_manager.decrypt_materials(request=decrypt_materials_request)
10031050

10041051
# If the materials_manager passed required_encryption_context_keys,
@@ -1049,36 +1096,12 @@ def _read_header(self): # noqa pylint: disable=too-many-branches
10491096
"Key commitment validation failed. Key identity does not match the identity asserted in the "
10501097
"message. Halting processing of this message."
10511098
)
1052-
1053-
# If _required_encryption_context is present,
1054-
# serialize it and pass it to validate_header.
1055-
if self._required_encryption_context is not None:
1056-
# The authenticated only encryption context is all encryption context key-value pairs where the
1057-
# key exists in Required Encryption Context Keys. It is then serialized according to the
1058-
# message header Key Value Pairs.
1059-
required_ec_serialized = \
1060-
aws_encryption_sdk.internal.formatting.encryption_context.serialize_encryption_context(
1061-
self._required_encryption_context
1062-
)
1063-
1064-
validate_header(
1065-
header=header,
1066-
header_auth=header_auth,
1067-
# When verifying the header, the AAD input to the authenticated encryption algorithm
1068-
# specified by the algorithm suite is the message header body and the serialized
1069-
# authenticated only encryption context.
1070-
raw_header=raw_header + required_ec_serialized,
1071-
data_key=self._derived_data_key
1072-
)
1073-
else:
1074-
validate_header(
1075-
header=header,
1076-
header_auth=header_auth,
1077-
raw_header=raw_header,
1078-
data_key=self._derived_data_key
1079-
)
1080-
1081-
return header, header_auth
1099+
1100+
return self._validate_parsed_header(
1101+
header=header,
1102+
header_auth=header_auth,
1103+
raw_header=raw_header,
1104+
)
10821105

10831106
def _prep_non_framed(self):
10841107
"""Prepare the opening data for a non-framed message."""

‎test/mpl/unit/test_material_managers_mpl_cmm.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
mock_mpl_cmm = MagicMock(__class__=MPL_ICryptographicMaterialsManager)
3939
mock_mpl_encryption_materials = MagicMock(__class__=MPL_EncryptionMaterials)
4040
mock_mpl_decrypt_materials = MagicMock(__class__=MPL_DecryptionMaterials)
41+
mock_reproduced_encryption_context = MagicMock(__class_=dict)
4142

4243

4344
mock_edk = MagicMock(__class__=Native_EncryptedDataKey)
@@ -259,6 +260,7 @@ def test_GIVEN_valid_request_WHEN_create_mpl_decrypt_materials_input_from_reques
259260
for mock_edks in [no_mock_edks, one_mock_edk, two_mock_edks]:
260261

261262
mock_decryption_materials_request.encrypted_data_keys = mock_edks
263+
mock_decryption_materials_request.reproduced_encryption_context = mock_reproduced_encryption_context
262264

263265
# When: _create_mpl_decrypt_materials_input_from_request
264266
output = CryptoMaterialsManagerFromMPL._create_mpl_decrypt_materials_input_from_request(
@@ -271,6 +273,7 @@ def test_GIVEN_valid_request_WHEN_create_mpl_decrypt_materials_input_from_reques
271273
assert output.algorithm_suite_id == mock_algorithm_id
272274
assert output.commitment_policy == mock_commitment_policy
273275
assert output.encryption_context == mock_decryption_materials_request.encryption_context
276+
assert output.reproduced_encryption_context == mock_reproduced_encryption_context
274277

275278
assert len(output.encrypted_data_keys) == len(mock_edks)
276279
for i in range(len(output.encrypted_data_keys)):

‎test/mpl/unit/test_material_managers_mpl_materials.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,19 @@ def test_GIVEN_valid_signing_key_WHEN_EncryptionMaterials_get_signing_key_THEN_r
160160
assert output == mock_signing_key
161161

162162

163+
def test_GIVEN_valid_required_encryption_context_keys_WHEN_EncryptionMaterials_get_required_encryption_context_keys_THEN_returns_required_encryption_context_keys():
164+
# Given: valid required encryption context keys
165+
mock_required_encryption_context_keys = MagicMock(__class__=bytes)
166+
mock_mpl_encryption_materials.required_encryption_context_keys = mock_required_encryption_context_keys
167+
168+
# When: get required encryption context keys
169+
mpl_encryption_materials = EncryptionMaterialsFromMPL(mpl_materials=mock_mpl_encryption_materials)
170+
output = mpl_encryption_materials.required_encryption_context_keys
171+
172+
# Then: returns required encryption context keys
173+
assert output == mock_required_encryption_context_keys
174+
175+
163176
def test_GIVEN_valid_data_key_WHEN_DecryptionMaterials_get_data_key_THEN_returns_data_key():
164177
# Given: valid MPL data key
165178
mock_data_key = MagicMock(__class__=bytes)
@@ -187,3 +200,29 @@ def test_GIVEN_valid_verification_key_WHEN_DecryptionMaterials_get_verification_
187200

188201
# Then: returns verification key
189202
assert output == mock_verification_key
203+
204+
205+
def test_GIVEN_valid_encryption_context_WHEN_DecryptionMaterials_get_encryption_context_THEN_returns_encryption_context():
206+
# Given: valid encryption context
207+
mock_encryption_context = MagicMock(__class__=Dict[str, str])
208+
mock_mpl_decrypt_materials.encryption_context = mock_encryption_context
209+
210+
# When: get encryption context
211+
mpl_decryption_materials = DecryptionMaterialsFromMPL(mpl_materials=mock_mpl_decrypt_materials)
212+
output = mpl_decryption_materials.encryption_context
213+
214+
# Then: returns valid encryption context
215+
assert output == mock_encryption_context
216+
217+
218+
def test_GIVEN_valid_required_encryption_context_keys_WHEN_DecryptionMaterials_get_required_encryption_context_keys_THEN_returns_required_encryption_context_keys():
219+
# Given: valid required encryption context keys
220+
mock_required_encryption_context_keys = MagicMock(__class__=bytes)
221+
mock_mpl_decrypt_materials.required_encryption_context_keys = mock_required_encryption_context_keys
222+
223+
# When: get required encryption context keys
224+
mpl_decryption_materials = DecryptionMaterialsFromMPL(mpl_materials=mock_mpl_decrypt_materials)
225+
output = mpl_decryption_materials.required_encryption_context_keys
226+
227+
# Then: returns required encryption context keys
228+
assert output == mock_required_encryption_context_keys

‎test/unit/test_serialize.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def apply_fixtures(self):
7979
"aws_encryption_sdk.internal.formatting.serialize.aws_encryption_sdk.internal.utils.validate_frame_length"
8080
)
8181
self.mock_valid_frame_length = self.mock_valid_frame_length_patcher.start()
82+
self.mock_required_ec_bytes = MagicMock()
8283
# Set up mock signer
8384
self.mock_signer = MagicMock()
8485
self.mock_signer.update.return_value = None
@@ -167,6 +168,31 @@ def test_serialize_header_auth_v1_no_signer(self):
167168
data_encryption_key=VALUES["data_key_obj"],
168169
)
169170

171+
@patch("aws_encryption_sdk.internal.formatting.serialize.header_auth_iv")
172+
def test_GIVEN_required_ec_bytes_WHEN_serialize_header_auth_v1_THEN_aad_has_required_ec_bytes(self, mock_header_auth_iv):
173+
"""Validate that the _create_header_auth function
174+
behaves as expected for SerializationVersion.V1
175+
when required_ec_bytes are provided.
176+
"""
177+
self.mock_encrypt.return_value = VALUES["header_auth_base"]
178+
test = aws_encryption_sdk.internal.formatting.serialize.serialize_header_auth(
179+
version=SerializationVersion.V1,
180+
algorithm=self.mock_algorithm,
181+
header=VALUES["serialized_header"],
182+
data_encryption_key=sentinel.encryption_key,
183+
signer=self.mock_signer,
184+
required_ec_bytes=self.mock_required_ec_bytes,
185+
)
186+
self.mock_encrypt.assert_called_once_with(
187+
algorithm=self.mock_algorithm,
188+
key=sentinel.encryption_key,
189+
plaintext=b"",
190+
associated_data=VALUES["serialized_header"] + self.mock_required_ec_bytes,
191+
iv=mock_header_auth_iv.return_value,
192+
)
193+
self.mock_signer.update.assert_called_once_with(VALUES["serialized_header_auth"])
194+
assert test == VALUES["serialized_header_auth"]
195+
170196
@patch("aws_encryption_sdk.internal.formatting.serialize.header_auth_iv")
171197
def test_serialize_header_auth_v2(self, mock_header_auth_iv):
172198
"""Validate that the _create_header_auth function
@@ -203,6 +229,30 @@ def test_serialize_header_auth_v2_no_signer(self):
203229
data_encryption_key=VALUES["data_key_obj"],
204230
)
205231

232+
@patch("aws_encryption_sdk.internal.formatting.serialize.header_auth_iv")
233+
def test_GIVEN_required_ec_bytes_WHEN_serialize_header_auth_v2_THEN_aad_has_required_ec_bytes(self, mock_header_auth_iv):
234+
"""Validate that the _create_header_auth function
235+
behaves as expected for SerializationVersion.V2.
236+
"""
237+
self.mock_encrypt.return_value = VALUES["header_auth_base"]
238+
test = aws_encryption_sdk.internal.formatting.serialize.serialize_header_auth(
239+
version=SerializationVersion.V2,
240+
algorithm=self.mock_algorithm,
241+
header=VALUES["serialized_header_v2_committing"],
242+
data_encryption_key=sentinel.encryption_key,
243+
signer=self.mock_signer,
244+
required_ec_bytes=self.mock_required_ec_bytes,
245+
)
246+
self.mock_encrypt.assert_called_once_with(
247+
algorithm=self.mock_algorithm,
248+
key=sentinel.encryption_key,
249+
plaintext=b"",
250+
associated_data=VALUES["serialized_header_v2_committing"] + self.mock_required_ec_bytes,
251+
iv=mock_header_auth_iv.return_value,
252+
)
253+
self.mock_signer.update.assert_called_once_with(VALUES["serialized_header_auth_v2"])
254+
assert test == VALUES["serialized_header_auth_v2"]
255+
206256
def test_serialize_non_framed_open(self):
207257
"""Validate that the serialize_non_framed_open
208258
function behaves as expected.

‎test/unit/test_streaming_client_configs.py

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import pytest
1717
import six
18-
from mock import patch
18+
from mock import MagicMock, patch
1919

2020
from aws_encryption_sdk import CommitmentPolicy
2121
from aws_encryption_sdk.internal.defaults import ALGORITHM, FRAME_LENGTH, LINE_LENGTH
@@ -33,7 +33,10 @@
3333
# Ideally, this logic would be based on mocking imports and testing logic,
3434
# but doing that introduces errors that cause other tests to fail.
3535
try:
36-
from aws_cryptographic_materialproviders.mpl.references import IKeyring
36+
from aws_cryptographic_materialproviders.mpl.references import (
37+
ICryptographicMaterialsManager,
38+
IKeyring,
39+
)
3740
HAS_MPL = True
3841

3942
from aws_encryption_sdk.materials_managers.mpl.cmm import CryptoMaterialsManagerFromMPL
@@ -236,24 +239,21 @@ def test_client_configs_with_mpl(
236239
assert test.materials_manager is not None
237240

238241
# If materials manager was provided, it should be directly used
239-
if hasattr(kwargs, "materials_manager"):
242+
if "materials_manager" in kwargs:
240243
assert kwargs["materials_manager"] == test.materials_manager
241244

242-
# If MPL keyring was provided, it should be wrapped in MPL materials manager
243-
if hasattr(kwargs, "keyring"):
244-
assert test.keyring is not None
245-
assert test.keyring == kwargs["keyring"]
246-
assert isinstance(test.keyring, IKeyring)
247-
assert isinstance(test.materials_manager, CryptoMaterialsManagerFromMPL)
248-
249245
# If native key_provider was provided, it should be wrapped in native materials manager
250-
if hasattr(kwargs, "key_provider"):
246+
elif "key_provider" in kwargs:
251247
assert test.key_provider is not None
252248
assert test.key_provider == kwargs["key_provider"]
253249
assert isinstance(test.materials_manager, DefaultCryptoMaterialsManager)
254250

251+
else:
252+
raise ValueError(f"Test did not find materials_manager or key_provider. {kwargs}")
253+
255254

256-
# This needs its own test; pytest parametrize cannot use a conditionally-loaded type
255+
# This is an addition to test_client_configs_with_mpl;
256+
# This needs its own test; pytest's parametrize cannot use a conditionally-loaded type (IKeyring)
257257
@pytest.mark.skipif(not HAS_MPL, reason="Test should only be executed with MPL in installation")
258258
def test_keyring_client_config_with_mpl(
259259
):
@@ -265,16 +265,30 @@ def test_keyring_client_config_with_mpl(
265265

266266
test = _ClientConfig(**kwargs)
267267

268-
# In all cases, config should have a materials manager
269268
assert test.materials_manager is not None
270269

271-
# If materials manager was provided, it should be directly used
272-
if hasattr(kwargs, "materials_manager"):
273-
assert kwargs["materials_manager"] == test.materials_manager
270+
assert test.keyring is not None
271+
assert test.keyring == kwargs["keyring"]
272+
assert isinstance(test.keyring, IKeyring)
273+
assert isinstance(test.materials_manager, CryptoMaterialsManagerFromMPL)
274+
275+
276+
# This is an addition to test_client_configs_with_mpl;
277+
# This needs its own test; pytest's parametrize cannot use a conditionally-loaded type (MPL CMM)
278+
@pytest.mark.skipif(not HAS_MPL, reason="Test should only be executed with MPL in installation")
279+
def test_mpl_cmm_client_config_with_mpl(
280+
):
281+
mock_mpl_cmm = MagicMock(__class__=ICryptographicMaterialsManager)
282+
kwargs = {
283+
"source": b"",
284+
"materials_manager": mock_mpl_cmm,
285+
"commitment_policy": CommitmentPolicy.REQUIRE_ENCRYPT_REQUIRE_DECRYPT
286+
}
287+
288+
test = _ClientConfig(**kwargs)
274289

275-
# If MPL keyring was provided, it should be wrapped in MPL materials manager
276-
if hasattr(kwargs, "keyring"):
277-
assert test.keyring is not None
278-
assert test.keyring == kwargs["keyring"]
279-
assert isinstance(test.keyring, IKeyring)
280-
assert isinstance(test.materials_manager, CryptoMaterialsManagerFromMPL)
290+
assert test.materials_manager is not None
291+
# Assert that the MPL CMM is wrapped in the native interface
292+
assert isinstance(test.materials_manager, CryptoMaterialsManagerFromMPL)
293+
# Assert the MPL CMM is used by the native interface
294+
assert test.materials_manager.mpl_cmm == mock_mpl_cmm

‎test/unit/test_streaming_client_stream_decryptor.py

Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -924,3 +924,213 @@ def test_close_no_footer(self, mock_close):
924924
with pytest.raises(SerializationError) as excinfo:
925925
test_decryptor.close()
926926
excinfo.match("Footer not read")
927+
928+
@patch("aws_encryption_sdk.streaming_client.validate_header")
929+
def test_GIVEN_does_not_have_required_EC_WHEN_validate_parsed_header_THEN_validate_header(
930+
self,
931+
mock_validate_header
932+
):
933+
self.mock_header.content_type = ContentType.FRAMED_DATA
934+
test_decryptor = StreamDecryptor(
935+
materials_manager=self.mock_materials_manager,
936+
source=self.mock_input_stream,
937+
commitment_policy=self.mock_commitment_policy,
938+
)
939+
test_decryptor._derived_data_key = sentinel.derived_data_key
940+
# Given: test_decryptor does not have _required_encryption_context attribute
941+
# When: _validate_parsed_header
942+
test_decryptor._validate_parsed_header(
943+
header=self.mock_header,
944+
header_auth=sentinel.header_auth,
945+
raw_header=self.mock_raw_header
946+
)
947+
# Then: validate_header
948+
mock_validate_header.assert_called_once_with(
949+
header=self.mock_header,
950+
header_auth=sentinel.header_auth,
951+
raw_header=self.mock_raw_header,
952+
data_key=sentinel.derived_data_key,
953+
)
954+
955+
@patch("aws_encryption_sdk.internal.formatting.encryption_context.serialize_encryption_context")
956+
@patch("aws_encryption_sdk.streaming_client.validate_header")
957+
def test_GIVEN_has_required_EC_WHEN_validate_parsed_header_THEN_validate_header_with_serialized_required_EC(
958+
self,
959+
mock_validate_header,
960+
mock_serialize_encryption_context,
961+
):
962+
self.mock_header.content_type = ContentType.FRAMED_DATA
963+
test_decryptor = StreamDecryptor(
964+
materials_manager=self.mock_materials_manager,
965+
source=self.mock_input_stream,
966+
commitment_policy=self.mock_commitment_policy,
967+
)
968+
test_decryptor._derived_data_key = sentinel.derived_data_key
969+
# Given: test_decryptor has _required_encryption_context attribute
970+
mock_required_ec = MagicMock(__class__=dict)
971+
test_decryptor._required_encryption_context = mock_required_ec
972+
mock_serialized_required_ec = MagicMock(__class__=bytes)
973+
mock_serialize_encryption_context.return_value = mock_serialized_required_ec
974+
# When: _validate_parsed_header
975+
test_decryptor._validate_parsed_header(
976+
header=self.mock_header,
977+
header_auth=sentinel.header_auth,
978+
raw_header=self.mock_raw_header
979+
)
980+
# Then: call validate_header with serialized required EC
981+
mock_validate_header.assert_called_once_with(
982+
header=self.mock_header,
983+
header_auth=sentinel.header_auth,
984+
raw_header=self.mock_raw_header + mock_serialized_required_ec,
985+
data_key=sentinel.derived_data_key,
986+
)
987+
988+
def test_GIVEN_config_has_EC_WHEN_create_decrypt_materials_request_THEN_provide_reproduced_EC(
989+
self,
990+
):
991+
self.mock_header.content_type = ContentType.FRAMED_DATA
992+
test_decryptor = StreamDecryptor(
993+
materials_manager=self.mock_materials_manager,
994+
source=self.mock_input_stream,
995+
commitment_policy=self.mock_commitment_policy,
996+
)
997+
998+
# Given: StreamDecryptor.config has encryption_context attribute
999+
mock_reproduced_encryption_context = MagicMock(__class__=dict)
1000+
test_decryptor.config.encryption_context = mock_reproduced_encryption_context
1001+
# Type checking on header encryption context seems to require concrete instance,
1002+
# neither MagicMock nor sentinel value work
1003+
self.mock_header.encryption_context = {"some_key_to_pass_type_validation": "some_value"}
1004+
1005+
# When: _create_decrypt_materials_request
1006+
output = test_decryptor._create_decrypt_materials_request(
1007+
header=self.mock_header,
1008+
)
1009+
1010+
# Then: decrypt_materials_request has reproduced_encryption_context attribute
1011+
assert hasattr(output, "reproduced_encryption_context")
1012+
assert output.reproduced_encryption_context == mock_reproduced_encryption_context
1013+
1014+
def test_GIVEN_config_does_not_have_EC_WHEN_create_decrypt_materials_request_THEN_request_does_not_have_reproduced_EC(
1015+
self,
1016+
):
1017+
self.mock_header.content_type = ContentType.FRAMED_DATA
1018+
test_decryptor = StreamDecryptor(
1019+
materials_manager=self.mock_materials_manager,
1020+
source=self.mock_input_stream,
1021+
commitment_policy=self.mock_commitment_policy,
1022+
)
1023+
1024+
# Given: StreamDecryptor.config does not have an encryption_context attribute
1025+
del test_decryptor.config.encryption_context
1026+
# Type checking on header encryption context seems to require concrete instance,
1027+
# neither MagicMock nor sentinel value work
1028+
self.mock_header.encryption_context = {"some_key_to_pass_type_validation": "some_value"}
1029+
1030+
# When: _create_decrypt_materials_request
1031+
output = test_decryptor._create_decrypt_materials_request(
1032+
header=self.mock_header,
1033+
)
1034+
1035+
# Then: decrypt_materials_request.reproduced_encryption_context is None
1036+
assert output.reproduced_encryption_context is None
1037+
1038+
@patch("aws_encryption_sdk.streaming_client.derive_data_encryption_key")
1039+
@patch("aws_encryption_sdk.streaming_client.DecryptionMaterialsRequest")
1040+
@patch("aws_encryption_sdk.streaming_client.Verifier")
1041+
def test_GIVEN_materials_has_no_required_encryption_context_keys_attr_WHEN_read_header_THEN_required_EC_is_None(
1042+
self,
1043+
mock_verifier,
1044+
*_
1045+
):
1046+
1047+
mock_verifier_instance = MagicMock()
1048+
mock_verifier.from_key_bytes.return_value = mock_verifier_instance
1049+
1050+
self.mock_header.content_type = ContentType.FRAMED_DATA
1051+
test_decryptor = StreamDecryptor(
1052+
materials_manager=self.mock_materials_manager,
1053+
source=self.mock_input_stream,
1054+
commitment_policy=self.mock_commitment_policy,
1055+
)
1056+
1057+
# Given: decryption_materials does not have a required_encryption_context_keys attribute
1058+
del self.mock_decrypt_materials.required_encryption_context_keys
1059+
1060+
# When: _read_header
1061+
test_decryptor._read_header()
1062+
1063+
# Then: StreamDecryptor._required_encryption_context is None
1064+
assert test_decryptor._required_encryption_context is None
1065+
1066+
@patch("aws_encryption_sdk.streaming_client.derive_data_encryption_key")
1067+
@patch("aws_encryption_sdk.streaming_client.DecryptionMaterialsRequest")
1068+
@patch("aws_encryption_sdk.streaming_client.Verifier")
1069+
def test_GIVEN_materials_has_required_encryption_context_keys_attr_WHEN_read_header_THEN_creates_correct_required_EC(
1070+
self,
1071+
mock_verifier,
1072+
*_
1073+
):
1074+
required_encryption_context_keys_values = [
1075+
# Case of empty encryption context list is not allowed;
1076+
# if a list is provided, it must be non-empty.
1077+
# The MPL enforces this behavior on construction.
1078+
["one_key"],
1079+
["one_key", "two_key"],
1080+
["one_key", "two_key", "red_key"],
1081+
["one_key", "two_key", "red_key", "blue_key"],
1082+
]
1083+
1084+
encryption_context_values = [
1085+
{},
1086+
{"one_key": "some_value"},
1087+
{
1088+
"one_key": "some_value",
1089+
"two_key": "some_other_value",
1090+
},
1091+
{
1092+
"one_key": "some_value",
1093+
"two_key": "some_other_value",
1094+
"red_key": "some_red_value",
1095+
},
1096+
{
1097+
"one_key": "some_value",
1098+
"two_key": "some_other_value",
1099+
"red_key": "some_red_value",
1100+
"blue_key": "some_blue_value",
1101+
}
1102+
]
1103+
1104+
for required_encryption_context_keys in required_encryption_context_keys_values:
1105+
1106+
# Given: decryption_materials has required_encryption_context_keys
1107+
self.mock_decrypt_materials.required_encryption_context_keys = \
1108+
required_encryption_context_keys
1109+
1110+
for encryption_context in encryption_context_values:
1111+
1112+
self.mock_decrypt_materials.encryption_context = encryption_context
1113+
1114+
mock_verifier_instance = MagicMock()
1115+
mock_verifier.from_key_bytes.return_value = mock_verifier_instance
1116+
1117+
self.mock_header.content_type = ContentType.FRAMED_DATA
1118+
test_decryptor = StreamDecryptor(
1119+
materials_manager=self.mock_materials_manager,
1120+
source=self.mock_input_stream,
1121+
commitment_policy=self.mock_commitment_policy,
1122+
)
1123+
1124+
# When: _read_header
1125+
test_decryptor._read_header()
1126+
1127+
# Then: Assert correctness of partitioned EC
1128+
for k in encryption_context:
1129+
# If a key is in required_encryption_context_keys, then ...
1130+
if k in required_encryption_context_keys:
1131+
# ... its EC is in the StreamEncryptor._required_encryption_context
1132+
assert k in test_decryptor._required_encryption_context
1133+
# If a key is NOT in required_encryption_context_keys, then ...
1134+
else:
1135+
# ... its EC is NOT in the StreamEncryptor._required_encryption_context
1136+
assert k not in test_decryptor._required_encryption_context

‎test/unit/test_streaming_client_stream_encryptor.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,113 @@ def test_GIVEN_has_mpl_AND_has_MPLCMM_AND_uses_signer_WHEN_prep_message_THEN_sig
451451
encoding=serialization.Encoding.PEM
452452
)
453453

454+
# Given: has MPL
455+
@pytest.mark.skipif(not HAS_MPL, reason="Test should only be executed with MPL in installation")
456+
def test_GIVEN_has_mpl_AND_encryption_materials_has_required_EC_keys_WHEN_prep_message_THEN_paritions_stored_and_required_EC(self):
457+
# Create explicit values to explicitly test logic in smaller cases
458+
required_encryption_context_keys_values = [
459+
# Case of empty encryption context list is not allowed;
460+
# if a list is provided, it must be non-empty.
461+
# The MPL enforces this behavior on construction.
462+
["one_key"],
463+
["one_key", "two_key"],
464+
["one_key", "two_key", "red_key"],
465+
["one_key", "two_key", "red_key", "blue_key"],
466+
]
467+
468+
encryption_context_values = [
469+
{},
470+
{"one_key": "some_value"},
471+
{
472+
"one_key": "some_value",
473+
"two_key": "some_other_value",
474+
},
475+
{
476+
"one_key": "some_value",
477+
"two_key": "some_other_value",
478+
"red_key": "some_red_value",
479+
},
480+
{
481+
"one_key": "some_value",
482+
"two_key": "some_other_value",
483+
"red_key": "some_red_value",
484+
"blue_key": "some_blue_value",
485+
}
486+
]
487+
488+
self.mock_encryption_materials.algorithm = Algorithm.AES_128_GCM_IV12_TAG16
489+
490+
for required_encryption_context_keys in required_encryption_context_keys_values:
491+
492+
# Given: encryption context has required_encryption_context_keys
493+
self.mock_encryption_materials.required_encryption_context_keys = \
494+
required_encryption_context_keys
495+
496+
for encryption_context in encryption_context_values:
497+
self.mock_encryption_materials.encryption_context = encryption_context
498+
499+
test_encryptor = StreamEncryptor(
500+
source=VALUES["data_128"],
501+
materials_manager=self.mock_mpl_materials_manager,
502+
frame_length=self.mock_frame_length,
503+
algorithm=Algorithm.AES_128_GCM_IV12_TAG16,
504+
commitment_policy=self.mock_commitment_policy,
505+
signature_policy=self.mock_signature_policy,
506+
)
507+
test_encryptor.content_type = ContentType.FRAMED_DATA
508+
# When: prep_message
509+
test_encryptor._prep_message()
510+
511+
# Then: Assert correctness of partitioned EC
512+
for k in encryption_context:
513+
# If a key is in required_encryption_context_keys, then
514+
if k in required_encryption_context_keys:
515+
# 1) Its EC is in the StreamEncryptor._required_encryption_context
516+
assert k in test_encryptor._required_encryption_context
517+
# 2) Its EC is NOT in the StreamEncryptor._stored_encryption_context
518+
assert k not in test_encryptor._stored_encryption_context
519+
# If a key is NOT in required_encryption_context_keys, then
520+
else:
521+
# 1) Its EC is NOT in the StreamEncryptor._required_encryption_context
522+
assert k not in test_encryptor._required_encryption_context
523+
# 2) Its EC is in the StreamEncryptor._stored_encryption_context
524+
assert k in test_encryptor._stored_encryption_context
525+
526+
# Assert size(stored_EC) + size(required_EC) == size(EC)
527+
# (i.e. every EC was sorted into one or the other)
528+
assert len(test_encryptor._required_encryption_context) \
529+
+ len(test_encryptor._stored_encryption_context) \
530+
== len(encryption_context)
531+
532+
# Given: has MPL
533+
@pytest.mark.skipif(not HAS_MPL, reason="Test should only be executed with MPL in installation")
534+
def test_GIVEN_has_mpl_AND_encryption_materials_does_not_have_required_EC_keys_WHEN_prep_message_THEN_stored_EC_is_EC(self):
535+
536+
self.mock_encryption_materials.algorithm = Algorithm.AES_128_GCM_IV12_TAG16
537+
538+
mock_encryption_context = MagicMock(__class__=dict)
539+
self.mock_encryption_materials.encryption_context = mock_encryption_context
540+
# Given: encryption materials does not have required encryption context keys
541+
# (MagicMock default is to "make up" "Some" value here; this deletes that value)
542+
del self.mock_encryption_materials.required_encryption_context_keys
543+
544+
test_encryptor = StreamEncryptor(
545+
source=VALUES["data_128"],
546+
materials_manager=self.mock_mpl_materials_manager,
547+
frame_length=self.mock_frame_length,
548+
algorithm=Algorithm.AES_128_GCM_IV12_TAG16,
549+
commitment_policy=self.mock_commitment_policy,
550+
signature_policy=self.mock_signature_policy,
551+
)
552+
test_encryptor.content_type = ContentType.FRAMED_DATA
553+
# When: prep_message
554+
test_encryptor._prep_message()
555+
556+
# Then: _stored_encryption_context is the provided encryption_context
557+
assert test_encryptor._stored_encryption_context == mock_encryption_context
558+
# Then: _required_encryption_context is None
559+
assert test_encryptor._required_encryption_context is None
560+
454561
def test_prep_message_no_signer(self):
455562
self.mock_encryption_materials.algorithm = Algorithm.AES_128_GCM_IV12_TAG16
456563
test_encryptor = StreamEncryptor(
@@ -575,6 +682,53 @@ def test_write_header(self):
575682
)
576683
assert test_encryptor.output_buffer == b"1234567890"
577684

685+
@patch("aws_encryption_sdk.internal.formatting.encryption_context.serialize_encryption_context")
686+
# Given: has MPL
687+
@pytest.mark.skipif(not HAS_MPL, reason="Test should only be executed with MPL in installation")
688+
def test_GIVEN_has_mpl_AND_has_required_EC_WHEN_write_header_THEN_adds_serialized_required_ec_to_header_auth(
689+
self,
690+
serialize_encryption_context
691+
):
692+
self.mock_serialize_header.return_value = b"12345"
693+
self.mock_serialize_header_auth.return_value = b"67890"
694+
pt_stream = io.BytesIO(self.plaintext)
695+
test_encryptor = StreamEncryptor(
696+
source=pt_stream,
697+
materials_manager=self.mock_materials_manager,
698+
algorithm=aws_encryption_sdk.internal.defaults.ALGORITHM,
699+
frame_length=self.mock_frame_length,
700+
commitment_policy=self.mock_commitment_policy,
701+
signature_policy=self.mock_signature_policy,
702+
)
703+
test_encryptor.signer = sentinel.signer
704+
test_encryptor.content_type = sentinel.content_type
705+
test_encryptor._header = sentinel.header
706+
sentinel.header.version = SerializationVersion.V1
707+
test_encryptor.output_buffer = b""
708+
test_encryptor._encryption_materials = self.mock_encryption_materials
709+
test_encryptor._derived_data_key = sentinel.derived_data_key
710+
711+
# Given: StreamEncryptor has _required_encryption_context
712+
mock_required_ec = MagicMock(__class__=dict)
713+
test_encryptor._required_encryption_context = mock_required_ec
714+
mock_serialized_required_ec = MagicMock(__class__=bytes)
715+
serialize_encryption_context.return_value = mock_serialized_required_ec
716+
717+
# When: _write_header()
718+
test_encryptor._write_header()
719+
720+
self.mock_serialize_header.assert_called_once_with(header=test_encryptor._header, signer=sentinel.signer)
721+
self.mock_serialize_header_auth.assert_called_once_with(
722+
version=sentinel.header.version,
723+
algorithm=self.mock_encryption_materials.algorithm,
724+
header=b"12345",
725+
data_encryption_key=sentinel.derived_data_key,
726+
signer=sentinel.signer,
727+
# Then: Pass serialized required EC to serialize_header_auth
728+
required_ec_bytes=mock_serialized_required_ec,
729+
)
730+
assert test_encryptor.output_buffer == b"1234567890"
731+
578732
@patch("aws_encryption_sdk.streaming_client.non_framed_body_iv")
579733
def test_prep_non_framed(self, mock_non_framed_iv):
580734
self.mock_serialize_non_framed_open.return_value = b"1234567890"

0 commit comments

Comments
 (0)
Please sign in to comment.