Skip to content

Commit 96b7b8f

Browse files
michaelbrewerTom McCarthy
and
Tom McCarthy
authored
fix(ssm): Make decrypt an explicit option and refactoring (#123)
* fix(ssm): Make decrypt an explicit option * chore: declare as self * fix: update get_parameter and get_parameters Changes: ssm.py - get_parameters - pass through the **sdk_options and merge in the recursive and decrypt params ssm.py - get_parameter - add explicit option for decrypt * chore: fix typos and type hinting * tests: verify that the default kwargs are set - `decrypt` should be false by default - `recursive` should be true by default * fix(capture_method): should yield inside with (#124) Changes: * capture_method should yield from within the "with" statement * Add missing test cases Closes #112 * chore: version bump to 1.3.1 * refactor: reduce get_multiple complexity Changes: - base.py - update get_multiple to reduce the overall complexity - base.py - `_has_not_expired` returns whether a key exists and has not expired - base.py - `transform_value` add `raise_on_transform_error` and default to True - test_utilities_parameters.py - Add a direct test of transform_value * refactor: revert to a regular for each Changes: * Add type hint to `values` as it can change later on in transform * Use a slightly faster and easier to read for each over dict comprehension Co-authored-by: Tom McCarthy <[email protected]>
1 parent d0aaad5 commit 96b7b8f

File tree

4 files changed

+107
-41
lines changed

4 files changed

+107
-41
lines changed

aws_lambda_powertools/utilities/parameters/base.py

+37-34
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from abc import ABC, abstractmethod
88
from collections import namedtuple
99
from datetime import datetime, timedelta
10-
from typing import Dict, Optional, Union
10+
from typing import Dict, Optional, Tuple, Union
1111

1212
from .exceptions import GetParameterError, TransformParameterError
1313

@@ -31,6 +31,9 @@ def __init__(self):
3131

3232
self.store = {}
3333

34+
def _has_not_expired(self, key: Tuple[str, Optional[str]]) -> bool:
35+
return key in self.store and self.store[key].ttl >= datetime.now()
36+
3437
def get(
3538
self, name: str, max_age: int = DEFAULT_MAX_AGE_SECS, transform: Optional[str] = None, **sdk_options
3639
) -> Union[str, list, dict, bytes]:
@@ -70,24 +73,26 @@ def get(
7073
# an acceptable tradeoff.
7174
key = (name, transform)
7275

73-
if key not in self.store or self.store[key].ttl < datetime.now():
74-
try:
75-
value = self._get(name, **sdk_options)
76-
# Encapsulate all errors into a generic GetParameterError
77-
except Exception as exc:
78-
raise GetParameterError(str(exc))
76+
if self._has_not_expired(key):
77+
return self.store[key].value
78+
79+
try:
80+
value = self._get(name, **sdk_options)
81+
# Encapsulate all errors into a generic GetParameterError
82+
except Exception as exc:
83+
raise GetParameterError(str(exc))
7984

80-
if transform is not None:
81-
value = transform_value(value, transform)
85+
if transform is not None:
86+
value = transform_value(value, transform)
8287

83-
self.store[key] = ExpirableValue(value, datetime.now() + timedelta(seconds=max_age),)
88+
self.store[key] = ExpirableValue(value, datetime.now() + timedelta(seconds=max_age),)
8489

85-
return self.store[key].value
90+
return value
8691

8792
@abstractmethod
8893
def _get(self, name: str, **sdk_options) -> str:
8994
"""
90-
Retrieve paramater value from the underlying parameter store
95+
Retrieve parameter value from the underlying parameter store
9196
"""
9297
raise NotImplementedError()
9398

@@ -129,29 +134,22 @@ def get_multiple(
129134

130135
key = (path, transform)
131136

132-
if key not in self.store or self.store[key].ttl < datetime.now():
133-
try:
134-
values = self._get_multiple(path, **sdk_options)
135-
# Encapsulate all errors into a generic GetParameterError
136-
except Exception as exc:
137-
raise GetParameterError(str(exc))
137+
if self._has_not_expired(key):
138+
return self.store[key].value
138139

139-
if transform is not None:
140-
new_values = {}
141-
for key, value in values.items():
142-
try:
143-
new_values[key] = transform_value(value, transform)
144-
except Exception as exc:
145-
if raise_on_transform_error:
146-
raise exc
147-
else:
148-
new_values[key] = None
140+
try:
141+
values: Dict[str, Union[str, bytes, dict, None]] = self._get_multiple(path, **sdk_options)
142+
# Encapsulate all errors into a generic GetParameterError
143+
except Exception as exc:
144+
raise GetParameterError(str(exc))
149145

150-
values = new_values
146+
if transform is not None:
147+
for (key, value) in values.items():
148+
values[key] = transform_value(value, transform, raise_on_transform_error)
151149

152-
self.store[key] = ExpirableValue(values, datetime.now() + timedelta(seconds=max_age),)
150+
self.store[key] = ExpirableValue(values, datetime.now() + timedelta(seconds=max_age),)
153151

154-
return self.store[key].value
152+
return values
155153

156154
@abstractmethod
157155
def _get_multiple(self, path: str, **sdk_options) -> Dict[str, str]:
@@ -161,16 +159,19 @@ def _get_multiple(self, path: str, **sdk_options) -> Dict[str, str]:
161159
raise NotImplementedError()
162160

163161

164-
def transform_value(value: str, transform: str) -> Union[dict, bytes]:
162+
def transform_value(value: str, transform: str, raise_on_transform_error: bool = True) -> Union[dict, bytes, None]:
165163
"""
166164
Apply a transform to a value
167165
168166
Parameters
169167
---------
170168
value: str
171-
Parameter alue to transform
169+
Parameter value to transform
172170
transform: str
173171
Type of transform, supported values are "json" and "binary"
172+
raise_on_transform_error: bool, optional
173+
Raises an exception if any transform fails, otherwise this will
174+
return a None value for each transform that failed
174175
175176
Raises
176177
------
@@ -187,4 +188,6 @@ def transform_value(value: str, transform: str) -> Union[dict, bytes]:
187188
raise ValueError(f"Invalid transform type '{transform}'")
188189

189190
except Exception as exc:
190-
raise TransformParameterError(str(exc))
191+
if raise_on_transform_error:
192+
raise TransformParameterError(str(exc))
193+
return None

aws_lambda_powertools/utilities/parameters/secrets.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def _get(self, name: str, **sdk_options) -> str:
7777
----------
7878
name: str
7979
Name of the parameter
80-
sdk_options: dict
80+
sdk_options: dict, optional
8181
Dictionary of options that will be passed to the Secrets Manager get_secret_value API call
8282
"""
8383

aws_lambda_powertools/utilities/parameters/ssm.py

+56-6
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import boto3
99
from botocore.config import Config
1010

11-
from .base import DEFAULT_PROVIDERS, BaseProvider
11+
from .base import DEFAULT_MAX_AGE_SECS, DEFAULT_PROVIDERS, BaseProvider
1212

1313

1414
class SSMProvider(BaseProvider):
@@ -86,6 +86,46 @@ def __init__(
8686

8787
super().__init__()
8888

89+
def get(
90+
self,
91+
name: str,
92+
max_age: int = DEFAULT_MAX_AGE_SECS,
93+
transform: Optional[str] = None,
94+
decrypt: bool = False,
95+
**sdk_options
96+
) -> Union[str, list, dict, bytes]:
97+
"""
98+
Retrieve a parameter value or return the cached value
99+
100+
Parameters
101+
----------
102+
name: str
103+
Parameter name
104+
max_age: int
105+
Maximum age of the cached value
106+
transform: str
107+
Optional transformation of the parameter value. Supported values
108+
are "json" for JSON strings and "binary" for base 64 encoded
109+
values.
110+
decrypt: bool, optional
111+
If the parameter value should be decrypted
112+
sdk_options: dict, optional
113+
Arguments that will be passed directly to the underlying API call
114+
115+
Raises
116+
------
117+
GetParameterError
118+
When the parameter provider fails to retrieve a parameter value for
119+
a given name.
120+
TransformParameterError
121+
When the parameter provider fails to transform a parameter value.
122+
"""
123+
124+
# Add to `decrypt` sdk_options to we can have an explicit option for this
125+
sdk_options["decrypt"] = decrypt
126+
127+
return super().get(name, max_age, transform, **sdk_options)
128+
89129
def _get(self, name: str, decrypt: bool = False, **sdk_options) -> str:
90130
"""
91131
Retrieve a parameter value from AWS Systems Manager Parameter Store
@@ -144,7 +184,9 @@ def _get_multiple(self, path: str, decrypt: bool = False, recursive: bool = Fals
144184
return parameters
145185

146186

147-
def get_parameter(name: str, transform: Optional[str] = None, **sdk_options) -> Union[str, list, dict, bytes]:
187+
def get_parameter(
188+
name: str, transform: Optional[str] = None, decrypt: bool = False, **sdk_options
189+
) -> Union[str, list, dict, bytes]:
148190
"""
149191
Retrieve a parameter value from AWS Systems Manager (SSM) Parameter Store
150192
@@ -154,6 +196,8 @@ def get_parameter(name: str, transform: Optional[str] = None, **sdk_options) ->
154196
Name of the parameter
155197
transform: str, optional
156198
Transforms the content from a JSON object ('json') or base64 binary string ('binary')
199+
decrypt: bool, optional
200+
If the parameter values should be decrypted
157201
sdk_options: dict, optional
158202
Dictionary of options that will be passed to the Parameter Store get_parameter API call
159203
@@ -190,7 +234,10 @@ def get_parameter(name: str, transform: Optional[str] = None, **sdk_options) ->
190234
if "ssm" not in DEFAULT_PROVIDERS:
191235
DEFAULT_PROVIDERS["ssm"] = SSMProvider()
192236

193-
return DEFAULT_PROVIDERS["ssm"].get(name, transform=transform)
237+
# Add to `decrypt` sdk_options to we can have an explicit option for this
238+
sdk_options["decrypt"] = decrypt
239+
240+
return DEFAULT_PROVIDERS["ssm"].get(name, transform=transform, **sdk_options)
194241

195242

196243
def get_parameters(
@@ -205,10 +252,10 @@ def get_parameters(
205252
Path to retrieve the parameters
206253
transform: str, optional
207254
Transforms the content from a JSON object ('json') or base64 binary string ('binary')
208-
decrypt: bool, optional
209-
If the parameter values should be decrypted
210255
recursive: bool, optional
211256
If this should retrieve the parameter values recursively or not, defaults to True
257+
decrypt: bool, optional
258+
If the parameter values should be decrypted
212259
sdk_options: dict, optional
213260
Dictionary of options that will be passed to the Parameter Store get_parameters_by_path API call
214261
@@ -245,4 +292,7 @@ def get_parameters(
245292
if "ssm" not in DEFAULT_PROVIDERS:
246293
DEFAULT_PROVIDERS["ssm"] = SSMProvider()
247294

248-
return DEFAULT_PROVIDERS["ssm"].get_multiple(path, transform=transform, recursive=recursive, decrypt=decrypt)
295+
sdk_options["recursive"] = recursive
296+
sdk_options["decrypt"] = decrypt
297+
298+
return DEFAULT_PROVIDERS["ssm"].get_multiple(path, transform=transform, **sdk_options)

tests/functional/test_utilities_parameters.py

+13
Original file line numberDiff line numberDiff line change
@@ -1310,6 +1310,7 @@ def test_get_parameter_new(monkeypatch, mock_name, mock_value):
13101310
class TestProvider(BaseProvider):
13111311
def _get(self, name: str, **kwargs) -> str:
13121312
assert name == mock_name
1313+
assert not kwargs["decrypt"]
13131314
return mock_value
13141315

13151316
def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]:
@@ -1355,6 +1356,8 @@ def _get(self, name: str, **kwargs) -> str:
13551356

13561357
def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]:
13571358
assert path == mock_name
1359+
assert kwargs["recursive"]
1360+
assert not kwargs["decrypt"]
13581361
return mock_value
13591362

13601363
monkeypatch.setattr(parameters.ssm, "DEFAULT_PROVIDERS", {})
@@ -1468,3 +1471,13 @@ def test_transform_value_wrong(mock_value):
14681471
parameters.base.transform_value(mock_value, "INCORRECT")
14691472

14701473
assert "Invalid transform type" in str(excinfo)
1474+
1475+
1476+
def test_transform_value_ignore_error(mock_value):
1477+
"""
1478+
Test transform_value() does not raise errors when raise_on_transform_error is False
1479+
"""
1480+
1481+
value = parameters.base.transform_value(mock_value, "INCORRECT", raise_on_transform_error=False)
1482+
1483+
assert value is None

0 commit comments

Comments
 (0)