Skip to content

Commit 8299039

Browse files
committed
Removed args and ItsDangerous and commented on tests
1 parent 0193ee6 commit 8299039

File tree

4 files changed

+124
-122
lines changed

4 files changed

+124
-122
lines changed

aws_lambda_powertools/utilities/data_masking/base.py

+38-11
Original file line numberDiff line numberDiff line change
@@ -11,22 +11,49 @@ def __init__(self, provider=None):
1111
else:
1212
self.provider = provider
1313

14-
def encrypt(self, data, fields=None, **kwargs):
15-
return self._apply_action(data, fields, self.provider.encrypt, **kwargs)
14+
def encrypt(self, data, fields=None, **provider_options):
15+
return self._apply_action(data, fields, self.provider.encrypt, **provider_options)
1616

17-
def decrypt(self, data, fields=None, **kwargs):
18-
return self._apply_action(data, fields, self.provider.decrypt, **kwargs)
17+
def decrypt(self, data, fields=None, **provider_options):
18+
return self._apply_action(data, fields, self.provider.decrypt, **provider_options)
1919

20-
def mask(self, data, fields=None, **kwargs):
21-
return self._apply_action(data, fields, self.provider.mask, **kwargs)
20+
def mask(self, data, fields=None, **provider_options):
21+
return self._apply_action(data, fields, self.provider.mask, **provider_options)
2222

23-
def _apply_action(self, data, fields, action, *args, **kwargs):
23+
def _apply_action(self, data, fields, action, **provider_options):
2424
if fields is not None:
25-
return self._apply_action_to_fields(data, fields, action, *args, **kwargs)
25+
return self._apply_action_to_fields(data, fields, action, **provider_options)
2626
else:
27-
return action(data, *args, **kwargs)
27+
return action(data, **provider_options)
28+
29+
def _apply_action_to_fields(self, data: Union[dict, str], fields, action, **provider_options) -> str:
30+
"""
31+
Apply the specified action to the specified fields in the input data.
32+
33+
This method is takes the input data, which can be either a dictionary or a JSON string representation
34+
of a dictionary, and applies a mask, an encryption, or a decryption to the specified fields.
35+
36+
Parameters:
37+
data (Union[dict, str]): The input data to process. It can be either a dictionary or a JSON string
38+
representation of a dictionary.
39+
fields (list): A list of fields to apply the action to. Each field can be specified as a string or
40+
a list of strings representing nested keys in the dictionary.
41+
action (callable): The action to apply to the fields. It should be a callable that takes the current
42+
value of the field as the first argument and any additional arguments that might be required
43+
for the action. It performs an operation on the current value using the provided arguments and
44+
returns the modified value.
45+
**provider_options: Additional keyword arguments to pass to the 'action' function.
46+
47+
Returns:
48+
str: A JSON string representation of the modified dictionary after applying the action to the
49+
specified fields.
50+
51+
Raises:
52+
ValueError: If 'fields' parameter is None.
53+
TypeError: If the 'data' parameter is not a dictionary or a JSON string representation of a dictionary.
54+
KeyError: If specified 'fields' do not exist in input data
55+
"""
2856

29-
def _apply_action_to_fields(self, data: Union[dict, str], fields, action, *args, **kwargs) -> str:
3057
if fields is None:
3158
raise ValueError("No fields specified.")
3259

@@ -53,6 +80,6 @@ def _apply_action_to_fields(self, data: Union[dict, str], fields, action, *args,
5380
for key in keys[:-1]:
5481
curr_dict = curr_dict[key]
5582
valtochange = curr_dict[(keys[-1])]
56-
curr_dict[keys[-1]] = action(valtochange, *args, **kwargs)
83+
curr_dict[keys[-1]] = action(valtochange, **provider_options)
5784

5885
return my_dict_parsed

aws_lambda_powertools/utilities/data_masking/providers/aws_encryption_sdk.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ class SingletonMeta(type):
1717

1818
_instances: Dict["AwsEncryptionSdkProvider", Any] = {}
1919

20-
def __call__(cls, *args, **kwargs):
20+
def __call__(cls, *args, **provider_options):
2121
if cls not in cls._instances:
22-
instance = super().__call__(*args, **kwargs)
22+
instance = super().__call__(*args, **provider_options)
2323
cls._instances[cls] = instance
2424
return cls._instances[cls]
2525

@@ -45,12 +45,14 @@ def __init__(self, keys: List[str], client: Optional[EncryptionSDKClient] = None
4545
max_messages_encrypted=MAX_MESSAGES,
4646
)
4747

48-
def encrypt(self, data: Union[bytes, str], *args, **kwargs) -> str:
49-
ciphertext, _ = self.client.encrypt(source=data, materials_manager=self.cache_cmm, *args, **kwargs)
48+
def encrypt(self, data: Union[bytes, str], **provider_options) -> str:
49+
ciphertext, _ = self.client.encrypt(source=data, materials_manager=self.cache_cmm, **provider_options)
5050
ciphertext = base64.b64encode(ciphertext).decode()
5151
return ciphertext
5252

53-
def decrypt(self, data: str, *args, **kwargs) -> bytes:
53+
def decrypt(self, data: str, **provider_options) -> bytes:
5454
ciphertext_decoded = base64.b64decode(data)
55-
ciphertext, _ = self.client.decrypt(source=ciphertext_decoded, key_provider=self.key_provider, *args, **kwargs)
55+
ciphertext, _ = self.client.decrypt(
56+
source=ciphertext_decoded, key_provider=self.key_provider, **provider_options
57+
)
5658
return ciphertext

aws_lambda_powertools/utilities/data_masking/providers/itsdangerous.py

-53
This file was deleted.

tests/unit/data_masking/test_data_masking.py

+78-52
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,6 @@
77
from aws_lambda_powertools.shared.constants import DATA_MASKING_STRING
88
from aws_lambda_powertools.utilities.data_masking.base import DataMasking
99
from aws_lambda_powertools.utilities.data_masking.provider import Provider
10-
from aws_lambda_powertools.utilities.data_masking.providers.aws_encryption_sdk import (
11-
AwsEncryptionSdkProvider,
12-
)
13-
from aws_lambda_powertools.utilities.data_masking.providers.itsdangerous import (
14-
ItsDangerousProvider,
15-
)
16-
17-
AWS_SDK_KEY = "arn:aws:kms:us-west-2:683517028648:key/269301eb-81eb-4067-ac72-98e8e49bf2b3"
1810

1911

2012
class MyEncryptionProvider(Provider):
@@ -39,8 +31,6 @@ def decrypt(self, data: str) -> str:
3931

4032
data_maskers = [
4133
DataMasking(),
42-
DataMasking(provider=ItsDangerousProvider("mykey")),
43-
DataMasking(provider=AwsEncryptionSdkProvider(keys=[AWS_SDK_KEY])),
4434
DataMasking(provider=MyEncryptionProvider(keys="secret-key")),
4535
]
4636

@@ -121,101 +111,137 @@ def decrypt(self, data: str) -> str:
121111
@pytest.mark.parametrize("data_masker", data_maskers)
122112
@pytest.mark.parametrize("value, value_masked", data_types_and_masks)
123113
def test_mask_types(data_masker, value, value_masked):
114+
# GIVEN any data type
115+
116+
# WHEN mask is called with no fields argument
124117
masked_string = data_masker.mask(value)
118+
119+
# THEN the result is the full input data masked
125120
assert masked_string == value_masked
126121

127122

128123
@pytest.mark.parametrize("data_masker", data_maskers)
129124
def test_mask_with_fields(data_masker):
125+
# GIVEN the data type is a dictionary, or a json representation of a dictionary
126+
127+
# WHEN mask is called with a list of fields specified
130128
masked_string = data_masker.mask(python_dict, dict_fields)
129+
masked_json_string = data_masker.mask(json_dict, dict_fields)
130+
131+
# THEN the result is only the specified fields are masked
131132
assert masked_string == masked_with_fields
132-
masked_string = data_masker.mask(json_dict, dict_fields)
133-
assert masked_string == masked_with_fields
133+
assert masked_json_string == masked_with_fields
134134

135135

136-
@pytest.mark.parametrize("data_masker", data_maskers)
137136
@pytest.mark.parametrize("value", data_types)
138-
def test_encrypt_decrypt(data_masker, value):
139-
if data_masker == data_maskers[0]:
140-
with pytest.raises(NotImplementedError):
141-
encrypted_data = data_masker.encrypt(value)
137+
def test_encrypt_decrypt(value):
138+
# GIVEN an instantiation of DataMasking with a Provider
139+
data_masker = DataMasking(provider=MyEncryptionProvider(keys="secret-key"))
142140

143-
else:
144-
if data_masker == data_maskers[2]:
145-
# AWS Encryption SDK encrypt method only takes in bytes or strings
146-
value = bytes(str(value), "utf-8")
141+
# WHEN encrypting and then decrypting the encrypted data
142+
encrypted_data = data_masker.encrypt(value)
143+
decrypted_data = data_masker.decrypt(encrypted_data)
147144

148-
encrypted_data = data_masker.encrypt(value)
149-
decrypted_data = data_masker.decrypt(encrypted_data)
150-
assert decrypted_data == value
145+
# THEN the result is the original input data
146+
assert decrypted_data == value
151147

152148

153-
@pytest.mark.parametrize("data_masker", data_maskers)
154149
@pytest.mark.parametrize("value, fields", zip(dictionaries, fields_to_mask))
155-
def test_encrypt_decrypt_with_fields(data_masker, value, fields):
156-
if data_masker == data_maskers[0]:
157-
with pytest.raises(NotImplementedError):
158-
encrypted_data = data_masker.encrypt(value)
150+
def test_encrypt_decrypt_with_fields(value, fields):
151+
# GIVEN an instantiation of DataMasking with a Provider
152+
data_masker = DataMasking(provider=MyEncryptionProvider(keys="secret-key"))
153+
154+
# WHEN encrypting and then decrypting the encrypted data with a list of fields
155+
encrypted_data = data_masker.encrypt(value, fields)
156+
decrypted_data = data_masker.decrypt(encrypted_data, fields)
159157

158+
# THEN the result is the original input data
159+
if value == json_dict:
160+
assert decrypted_data == json.loads(value)
160161
else:
161-
encrypted_data = data_masker.encrypt(value, fields)
162-
decrypted_data = data_masker.decrypt(encrypted_data, fields)
162+
assert decrypted_data == value
163+
164+
165+
def test_encrypt_not_implemented():
166+
# GIVEN DataMasking is not initialized with a Provider
167+
data_masker = DataMasking()
163168

164-
if data_masker == data_maskers[2]:
165-
# AWS Encryption SDK decrypt method only returns bytes
166-
if value == json_blob:
167-
assert decrypted_data == aws_encrypted_json_blob
168-
else:
169-
assert decrypted_data == aws_encrypted_with_fields
169+
# WHEN attempting to call the encrypt method on the data
170170

171-
else:
172-
if value == json_dict:
173-
assert decrypted_data == json.loads(value)
174-
else:
175-
assert decrypted_data == value
171+
# THEN the result is a NotImplementedError
172+
with pytest.raises(NotImplementedError):
173+
data_masker.encrypt("hello world")
176174

177175

178176
def test_decrypt_not_implemented():
179-
"""Test decrypting with no Provider"""
177+
# GIVEN DataMasking is not initialized with a Provider
180178
data_masker = DataMasking()
181-
with pytest.raises(NotImplementedError):
182-
data_masker.decrypt("hello world")
183179

180+
# WHEN attempting to call the decrypt method on the data
184181

185-
def test_aws_encryption_sdk_with_context():
186-
data_masker = DataMasking(provider=AwsEncryptionSdkProvider(keys=[AWS_SDK_KEY]))
187-
encrypted_data = data_masker.encrypt(
188-
str(python_dict), encryption_context={"not really": "a secret", "but adds": "some auth"}
189-
)
190-
decrypted_data = data_masker.decrypt(encrypted_data)
191-
assert decrypted_data == bytes(str(python_dict), "utf-8")
182+
# THEN the result is a NotImplementedError
183+
with pytest.raises(NotImplementedError):
184+
data_masker.decrypt("hello world")
192185

193186

194187
def test_parsing_unsupported_data_type():
188+
# GIVEN an initialization of the DataMasking class
195189
data_masker = DataMasking()
190+
191+
# WHEN attempting to pass in a list of fields with input data that is not a dict
192+
193+
# THEN the result is a TypeError
196194
with pytest.raises(TypeError):
197195
data_masker.mask(42, ["this.field"])
198196

199197

198+
def test_parsing_nonexistent_fields():
199+
# GIVEN an initialization of the DataMasking class
200+
data_masker = DataMasking()
201+
_python_dict = {
202+
"3": {
203+
"1": {"None": "hello", "four": "world"},
204+
"4": {"33": {"5": "goodbye", "e": "world"}},
205+
}
206+
}
207+
208+
# WHEN attempting to pass in fields that do not exist in the input data
209+
210+
# THEN the result is a KeyError
211+
with pytest.raises(KeyError):
212+
data_masker.mask(_python_dict, ["3.1.True"])
213+
214+
200215
def test_parsing_nonstring_fields():
216+
# GIVEN an initialization of the DataMasking class
201217
data_masker = DataMasking()
202218
_python_dict = {
203219
"3": {
204220
"1": {"None": "hello", "four": "world"},
205221
"4": {"33": {"5": "goodbye", "e": "world"}},
206222
}
207223
}
224+
225+
# WHEN attempting to pass in a list of fields that are not strings
208226
masked = data_masker.mask(_python_dict, fields=[3.4])
227+
228+
# THEN the result is the value of the nested field should be masked as normal
209229
assert masked == {"3": {"1": {"None": "hello", "four": "world"}, "4": DATA_MASKING_STRING}}
210230

211231

212232
def test_parsing_nonstring_keys_and_fields():
233+
# GIVEN an initialization of the DataMasking class
213234
data_masker = DataMasking()
235+
236+
# WHEN the input data is a dictionary with integer keys
214237
_python_dict = {
215238
3: {
216239
"1": {"None": "hello", "four": "world"},
217240
4: {"33": {"5": "goodbye", "e": "world"}},
218241
}
219242
}
243+
220244
masked = data_masker.mask(_python_dict, fields=[3.4])
245+
246+
# THEN the result is the value of the nested field should be masked as normal
221247
assert masked == {"3": {"1": {"None": "hello", "four": "world"}, "4": DATA_MASKING_STRING}}

0 commit comments

Comments
 (0)