Skip to content

Commit fce3268

Browse files
committed
feat: throw exception on failed transform for parameter utility
1 parent d53c373 commit fce3268

File tree

4 files changed

+164
-39
lines changed

4 files changed

+164
-39
lines changed

aws_lambda_powertools/utilities/parameters/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from .base import BaseProvider
88
from .dynamodb import DynamoDBProvider
9-
from .exceptions import GetParameterError
9+
from .exceptions import GetParameterError, TransformParameterError
1010
from .secrets import SecretsProvider, get_secret
1111
from .ssm import SSMProvider, get_parameter, get_parameters
1212

@@ -16,6 +16,7 @@
1616
"DynamoDBProvider",
1717
"SecretsProvider",
1818
"SSMProvider",
19+
"TransformParameterError",
1920
"get_parameter",
2021
"get_parameters",
2122
"get_secret",

aws_lambda_powertools/utilities/parameters/base.py

Lines changed: 44 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from datetime import datetime, timedelta
1010
from typing import Dict, Optional, Union
1111

12-
from .exceptions import GetParameterError
12+
from .exceptions import GetParameterError, TransformParameterError
1313

1414
DEFAULT_MAX_AGE_SECS = 5
1515
ExpirableValue = namedtuple("ExpirableValue", ["value", "ttl"])
@@ -32,14 +32,13 @@ def __init__(self):
3232
self.store = {}
3333

3434
def get(
35-
self, name: str, max_age: int = DEFAULT_MAX_AGE_SECS, transform: Optional[str] = None, **kwargs
35+
self, name: str, max_age: int = DEFAULT_MAX_AGE_SECS, transform: Optional[str] = None, **sdk_options
3636
) -> Union[str, list, dict, bytes]:
3737
"""
3838
Retrieve a parameter value or return the cached value
3939
4040
Parameters
4141
----------
42-
4342
name: str
4443
Parameter name
4544
max_age: int
@@ -51,10 +50,11 @@ def get(
5150
5251
Raises
5352
------
54-
5553
GetParameterError
5654
When the parameter provider fails to retrieve a parameter value for
5755
a given name.
56+
TransformParameterError
57+
When the parameter provider fails to transform a parameter value.
5858
"""
5959

6060
# If there are multiple calls to the same parameter but in a different
@@ -70,54 +70,81 @@ def get(
7070

7171
if key not in self.store or self.store[key].ttl < datetime.now():
7272
try:
73-
value = self._get(name, **kwargs)
73+
value = self._get(name, **sdk_options)
7474
# Encapsulate all errors into a generic GetParameterError
7575
except Exception as exc:
7676
raise GetParameterError(str(exc))
7777

78-
if transform == "json":
79-
value = json.loads(value)
80-
elif transform == "binary":
81-
value = base64.b64decode(value)
78+
try:
79+
if transform == "json":
80+
value = json.loads(value)
81+
elif transform == "binary":
82+
value = base64.b64decode(value)
83+
# Encapsulate transform exceptions into TransformParameterError
84+
except Exception as exc:
85+
raise TransformParameterError(str(exc))
8286

8387
self.store[key] = ExpirableValue(value, datetime.now() + timedelta(seconds=max_age),)
8488

8589
return self.store[key].value
8690

8791
@abstractmethod
88-
def _get(self, name: str, **kwargs) -> str:
92+
def _get(self, name: str, **sdk_options) -> str:
8993
"""
9094
Retrieve paramater value from the underlying parameter store
9195
"""
9296
raise NotImplementedError()
9397

9498
def get_multiple(
95-
self, path: str, max_age: int = DEFAULT_MAX_AGE_SECS, transform: Optional[str] = None, **kwargs
99+
self, path: str, max_age: int = DEFAULT_MAX_AGE_SECS, transform: Optional[str] = None, **sdk_options
96100
) -> Union[Dict[str, str], Dict[str, dict], Dict[str, bytes]]:
97101
"""
98102
Retrieve multiple parameters based on a path prefix
103+
104+
Parameters
105+
----------
106+
path: str
107+
Parameter path used to retrieve multiple parameters
108+
max_age: int
109+
Maximum age of the cached value
110+
transform: str
111+
Optional transformation of the parameter value. Supported values
112+
are "json" for JSON strings and "binary" for base 64 encoded
113+
values.
114+
115+
Raises
116+
------
117+
GetParameterError
118+
When the parameter provider fails to retrieve parameter values for
119+
a given path.
120+
TransformParameterError
121+
When the parameter provider fails to transform a parameter value.
99122
"""
100123

101124
key = (path, transform)
102125

103126
if key not in self.store or self.store[key].ttl < datetime.now():
104127
try:
105-
values = self._get_multiple(path, **kwargs)
128+
values = self._get_multiple(path, **sdk_options)
106129
# Encapsulate all errors into a generic GetParameterError
107130
except Exception as exc:
108131
raise GetParameterError(str(exc))
109132

110-
if transform == "json":
111-
values = {k: json.loads(v) for k, v in values.items()}
112-
elif transform == "binary":
113-
values = {k: base64.b64decode(v) for k, v in values.items()}
133+
try:
134+
if transform == "json":
135+
values = {k: json.loads(v) for k, v in values.items()}
136+
elif transform == "binary":
137+
values = {k: base64.b64decode(v) for k, v in values.items()}
138+
# Encapsulate transform exceptions into TransformParameterError
139+
except Exception as exc:
140+
raise TransformParameterError(str(exc))
114141

115142
self.store[key] = ExpirableValue(values, datetime.now() + timedelta(seconds=max_age),)
116143

117144
return self.store[key].value
118145

119146
@abstractmethod
120-
def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]:
147+
def _get_multiple(self, path: str, **sdk_options) -> Dict[str, str]:
121148
"""
122149
Retrieve multiple parameter values from the underlying parameter store
123150
"""

aws_lambda_powertools/utilities/parameters/exceptions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,7 @@
55

66
class GetParameterError(Exception):
77
"""When a provider raises an exception on parameter retrieval"""
8+
9+
10+
class TransformParameterError(Exception):
11+
"""When a provider fails to transform a parameter value"""

tests/functional/test_utilities_parameters.py

Lines changed: 114 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -957,6 +957,48 @@ def test_secrets_provider_get_sdk_options_overwrite(mock_name, mock_value, confi
957957
stubber.deactivate()
958958

959959

960+
def test_base_provider_get_exception(mock_name):
961+
"""
962+
Test BaseProvider.get() that raises an exception
963+
"""
964+
965+
class TestProvider(BaseProvider):
966+
def _get(self, name: str, **kwargs) -> str:
967+
assert name == mock_name
968+
raise Exception("test exception raised")
969+
970+
def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]:
971+
raise NotImplementedError()
972+
973+
provider = TestProvider()
974+
975+
with pytest.raises(parameters.GetParameterError) as excinfo:
976+
provider.get(mock_name)
977+
978+
assert "test exception raised" in str(excinfo)
979+
980+
981+
def test_base_provider_get_multiple_exception(mock_name):
982+
"""
983+
Test BaseProvider.get_multiple() that raises an exception
984+
"""
985+
986+
class TestProvider(BaseProvider):
987+
def _get(self, name: str, **kwargs) -> str:
988+
raise NotImplementedError()
989+
990+
def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]:
991+
assert path == mock_name
992+
raise Exception("test exception raised")
993+
994+
provider = TestProvider()
995+
996+
with pytest.raises(parameters.GetParameterError) as excinfo:
997+
provider.get_multiple(mock_name)
998+
999+
assert "test exception raised" in str(excinfo)
1000+
1001+
9601002
def test_base_provider_get_transform_json(mock_name, mock_value):
9611003
"""
9621004
Test BaseProvider.get() with a json transform
@@ -981,55 +1023,60 @@ def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]:
9811023
assert value[mock_name] == mock_value
9821024

9831025

984-
def test_base_provider_get_exception(mock_name):
1026+
def test_base_provider_get_transform_json_exception(mock_name, mock_value):
9851027
"""
986-
Test BaseProvider.get() that raises an exception
1028+
Test BaseProvider.get() with a json transform that raises an exception
9871029
"""
9881030

1031+
mock_data = json.dumps({mock_name: mock_value}) + "{"
1032+
9891033
class TestProvider(BaseProvider):
9901034
def _get(self, name: str, **kwargs) -> str:
9911035
assert name == mock_name
992-
raise Exception("test exception raised")
1036+
return mock_data
9931037

9941038
def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]:
9951039
raise NotImplementedError()
9961040

9971041
provider = TestProvider()
9981042

999-
with pytest.raises(parameters.GetParameterError) as excinfo:
1000-
provider.get(mock_name)
1043+
with pytest.raises(parameters.TransformParameterError) as excinfo:
1044+
provider.get(mock_name, transform="json")
10011045

1002-
assert "test exception raised" in str(excinfo)
1046+
assert "Extra data" in str(excinfo)
10031047

10041048

1005-
def test_base_provider_get_multiple_exception(mock_name):
1049+
def test_base_provider_get_transform_binary(mock_name, mock_value):
10061050
"""
1007-
Test BaseProvider.get_multiple() that raises an exception
1051+
Test BaseProvider.get() with a binary transform
10081052
"""
10091053

1054+
mock_binary = mock_value.encode()
1055+
mock_data = base64.b64encode(mock_binary).decode()
1056+
10101057
class TestProvider(BaseProvider):
10111058
def _get(self, name: str, **kwargs) -> str:
1012-
raise NotImplementedError()
1059+
assert name == mock_name
1060+
return mock_data
10131061

10141062
def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]:
1015-
assert path == mock_name
1016-
raise Exception("test exception raised")
1063+
raise NotImplementedError()
10171064

10181065
provider = TestProvider()
10191066

1020-
with pytest.raises(parameters.GetParameterError) as excinfo:
1021-
provider.get_multiple(mock_name)
1067+
value = provider.get(mock_name, transform="binary")
10221068

1023-
assert "test exception raised" in str(excinfo)
1069+
assert isinstance(value, bytes)
1070+
assert value == mock_binary
10241071

10251072

1026-
def test_base_provider_get_transform_binary(mock_name, mock_value):
1073+
def test_base_provider_get_transform_binary_exception(mock_name):
10271074
"""
1028-
Test BaseProvider.get() with a binary transform
1075+
Test BaseProvider.get() with a binary transform that raises an exception
10291076
"""
10301077

1031-
mock_binary = mock_value.encode()
1032-
mock_data = base64.b64encode(mock_binary).decode()
1078+
mock_data = "qw"
1079+
print(mock_data)
10331080

10341081
class TestProvider(BaseProvider):
10351082
def _get(self, name: str, **kwargs) -> str:
@@ -1041,10 +1088,10 @@ def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]:
10411088

10421089
provider = TestProvider()
10431090

1044-
value = provider.get(mock_name, transform="binary")
1091+
with pytest.raises(parameters.TransformParameterError) as excinfo:
1092+
provider.get(mock_name, transform="binary")
10451093

1046-
assert isinstance(value, bytes)
1047-
assert value == mock_binary
1094+
assert "Incorrect padding" in str(excinfo)
10481095

10491096

10501097
def test_base_provider_get_multiple_transform_json(mock_name, mock_value):
@@ -1070,6 +1117,29 @@ def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]:
10701117
assert value["A"][mock_name] == mock_value
10711118

10721119

1120+
def test_base_provider_get_multiple_transform_json_exception(mock_name, mock_value):
1121+
"""
1122+
Test BaseProvider.get_multiple() with a json transform that raises an exception
1123+
"""
1124+
1125+
mock_data = json.dumps({mock_name: mock_value}) + "{"
1126+
1127+
class TestProvider(BaseProvider):
1128+
def _get(self, name: str, **kwargs) -> str:
1129+
raise NotImplementedError()
1130+
1131+
def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]:
1132+
assert path == mock_name
1133+
return {"A": mock_data}
1134+
1135+
provider = TestProvider()
1136+
1137+
with pytest.raises(parameters.TransformParameterError) as excinfo:
1138+
provider.get_multiple(mock_name, transform="json")
1139+
1140+
assert "Extra data" in str(excinfo)
1141+
1142+
10731143
def test_base_provider_get_multiple_transform_binary(mock_name, mock_value):
10741144
"""
10751145
Test BaseProvider.get_multiple() with a binary transform
@@ -1094,6 +1164,29 @@ def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]:
10941164
assert value["A"] == mock_binary
10951165

10961166

1167+
def test_base_provider_get_multiple_transform_binary_exception(mock_name):
1168+
"""
1169+
Test BaseProvider.get_multiple() with a binary transform that raises an exception
1170+
"""
1171+
1172+
mock_data = "qw"
1173+
1174+
class TestProvider(BaseProvider):
1175+
def _get(self, name: str, **kwargs) -> str:
1176+
raise NotImplementedError()
1177+
1178+
def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]:
1179+
assert path == mock_name
1180+
return {"A": mock_data}
1181+
1182+
provider = TestProvider()
1183+
1184+
with pytest.raises(parameters.TransformParameterError) as excinfo:
1185+
provider.get_multiple(mock_name, transform="binary")
1186+
1187+
assert "Incorrect padding" in str(excinfo)
1188+
1189+
10971190
def test_base_provider_get_multiple_cached(mock_name, mock_value):
10981191
"""
10991192
Test BaseProvider.get_multiple() with cached values

0 commit comments

Comments
 (0)