diff --git a/tests/e2e/data_masking/test_e2e_data_masking.py b/tests/e2e/data_masking/test_e2e_data_masking.py index a720a265d83..3ee2400b5cc 100644 --- a/tests/e2e/data_masking/test_e2e_data_masking.py +++ b/tests/e2e/data_masking/test_e2e_data_masking.py @@ -2,16 +2,23 @@ from uuid import uuid4 import pytest -from aws_encryption_sdk.exceptions import DecryptKeyError from aws_lambda_powertools.utilities.data_masking import DataMasking -from aws_lambda_powertools.utilities.data_masking.exceptions import DataMaskingContextMismatchError +from aws_lambda_powertools.utilities.data_masking.exceptions import ( + DataMaskingContextMismatchError, + DataMaskingDecryptKeyError, +) from aws_lambda_powertools.utilities.data_masking.provider.kms.aws_encryption_sdk import ( AWSEncryptionSDKProvider, ) from tests.e2e.utils import data_fetcher +@pytest.fixture +def security_context(): + return {"this": "is_secure"} + + @pytest.fixture def basic_handler_fn(infrastructure: dict) -> str: return infrastructure.get("BasicHandler", "") @@ -53,36 +60,35 @@ def test_encryption(data_masker): @pytest.mark.xdist_group(name="data_masking") -def test_encryption_context(data_masker): +def test_encryption_context(data_masker, security_context): # GIVEN an instantiation of DataMasking with the AWS encryption provider value = [1, 2, "string", 4.5] - context = {"this": "is_secure"} # WHEN encrypting and then decrypting the encrypted data with an encryption_context - encrypted_data = data_masker.encrypt(value, encryption_context=context) - decrypted_data = data_masker.decrypt(encrypted_data, encryption_context=context) + encrypted_data = data_masker.encrypt(value, **security_context) + decrypted_data = data_masker.decrypt(encrypted_data, **security_context) # THEN the result is the original input data assert decrypted_data == value @pytest.mark.xdist_group(name="data_masking") -def test_encryption_context_mismatch(data_masker): +def test_encryption_context_mismatch(data_masker, security_context): # GIVEN an instantiation of DataMasking with the AWS encryption provider value = [1, 2, "string", 4.5] # WHEN encrypting with a encryption_context - encrypted_data = data_masker.encrypt(value, encryption_context={"this": "is_secure"}) + encrypted_data = data_masker.encrypt(value, **security_context) # THEN decrypting with a different encryption_context should raise a ContextMismatchError with pytest.raises(DataMaskingContextMismatchError): - data_masker.decrypt(encrypted_data, encryption_context={"not": "same_context"}) + data_masker.decrypt(encrypted_data, this="different_context") @pytest.mark.xdist_group(name="data_masking") -def test_encryption_no_context_fail(data_masker): +def test_encryption_no_context_fail(data_masker, security_context): # GIVEN an instantiation of DataMasking with the AWS encryption provider value = [1, 2, "string", 4.5] @@ -92,7 +98,7 @@ def test_encryption_no_context_fail(data_masker): # THEN decrypting with an encryption_context should raise a ContextMismatchError with pytest.raises(DataMaskingContextMismatchError): - data_masker.decrypt(encrypted_data, encryption_context={"this": "is_secure"}) + data_masker.decrypt(encrypted_data, **security_context) @pytest.mark.xdist_group(name="data_masking") @@ -106,7 +112,7 @@ def test_encryption_decryption_key_mismatch(data_masker, kms_key2_arn): # THEN when decrypting with a different key it should fail data_masker_key2 = DataMasking(provider=AWSEncryptionSDKProvider(keys=[kms_key2_arn])) - with pytest.raises(DecryptKeyError): + with pytest.raises(DataMaskingDecryptKeyError): data_masker_key2.decrypt(encrypted_data)