Skip to content

Commit a0c2da1

Browse files
committed
refactor: strict typing transform_value/method
1 parent 1c6c71c commit a0c2da1

File tree

5 files changed

+89
-72
lines changed

5 files changed

+89
-72
lines changed

Diff for: aws_lambda_powertools/utilities/feature_flags/appconfig.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
from .base import StoreProvider
1616
from .exceptions import ConfigurationStoreError, StoreClientError
1717

18-
TRANSFORM_TYPE = "json"
19-
2018

2119
class AppConfigStore(StoreProvider):
2220
def __init__(
@@ -74,7 +72,7 @@ def get_raw_configuration(self) -> Dict[str, Any]:
7472
dict,
7573
self._conf_store.get(
7674
name=self.name,
77-
transform=TRANSFORM_TYPE,
75+
transform="json",
7876
max_age=self.cache_seconds,
7977
),
8078
)

Diff for: aws_lambda_powertools/utilities/parameters/appconfig.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
import boto3
1010
from botocore.config import Config
1111

12+
from aws_lambda_powertools.utilities.parameters.types import TransformOptions
13+
1214
if TYPE_CHECKING:
1315
from mypy_boto3_appconfigdata import AppConfigDataClient
1416

@@ -132,7 +134,7 @@ def get_app_config(
132134
name: str,
133135
environment: str,
134136
application: Optional[str] = None,
135-
transform: Optional[str] = None,
137+
transform: TransformOptions = None,
136138
force_fetch: bool = False,
137139
max_age: int = DEFAULT_MAX_AGE_SECS,
138140
**sdk_options

Diff for: aws_lambda_powertools/utilities/parameters/base.py

+78-33
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,24 @@
77
from abc import ABC, abstractmethod
88
from collections import namedtuple
99
from datetime import datetime, timedelta
10-
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Type, Union
10+
from typing import (
11+
TYPE_CHECKING,
12+
Any,
13+
Callable,
14+
Dict,
15+
Optional,
16+
Tuple,
17+
Type,
18+
Union,
19+
cast,
20+
overload,
21+
)
1122

1223
import boto3
1324
from botocore.config import Config
1425

26+
from aws_lambda_powertools.utilities.parameters.types import TransformOptions
27+
1528
from .exceptions import GetParameterError, TransformParameterError
1629

1730
if TYPE_CHECKING:
@@ -30,6 +43,14 @@
3043
SUPPORTED_TRANSFORM_METHODS = [TRANSFORM_METHOD_JSON, TRANSFORM_METHOD_BINARY]
3144
ParameterClients = Union["AppConfigDataClient", "SecretsManagerClient", "SSMClient"]
3245

46+
TRANSFORM_METHOD_MAPPING = {
47+
TRANSFORM_METHOD_JSON: json.loads,
48+
TRANSFORM_METHOD_BINARY: base64.b64decode,
49+
".json": json.loads,
50+
".binary": base64.b64decode,
51+
None: lambda x: x,
52+
}
53+
3354

3455
class BaseProvider(ABC):
3556
"""
@@ -52,7 +73,7 @@ def get(
5273
self,
5374
name: str,
5475
max_age: int = DEFAULT_MAX_AGE_SECS,
55-
transform: Optional[str] = None,
76+
transform: TransformOptions = None,
5677
force_fetch: bool = False,
5778
**sdk_options,
5879
) -> Optional[Union[str, dict, bytes]]:
@@ -124,7 +145,7 @@ def get_multiple(
124145
self,
125146
path: str,
126147
max_age: int = DEFAULT_MAX_AGE_SECS,
127-
transform: Optional[str] = None,
148+
transform: TransformOptions = None,
128149
raise_on_transform_error: bool = False,
129150
force_fetch: bool = False,
130151
**sdk_options,
@@ -170,13 +191,7 @@ def get_multiple(
170191
raise GetParameterError(str(exc))
171192

172193
if transform:
173-
transformed_values: dict = {}
174-
for (item, value) in values.items():
175-
_transform = get_transform_method(item, transform)
176-
if not _transform:
177-
continue
178-
transformed_values[item] = transform_value(value, _transform, raise_on_transform_error)
179-
values.update(transformed_values)
194+
values.update(transform_value(values, transform, raise_on_transform_error))
180195
self.store[key] = ExpirableValue(values, datetime.now() + timedelta(seconds=max_age))
181196

182197
return values
@@ -258,7 +273,7 @@ def _build_boto3_resource_client(
258273
return session.resource(service_name=service_name, config=config, endpoint_url=endpoint_url)
259274

260275

261-
def get_transform_method(key: str, transform: Optional[str] = None) -> Optional[str]:
276+
def get_transform_method(key: str, transform: TransformOptions = None) -> Callable[..., Any]:
262277
"""
263278
Determine the transform method
264279
@@ -278,37 +293,50 @@ def get_transform_method(key: str, transform: Optional[str] = None) -> Optional[
278293
Parameters
279294
---------
280295
key: str
281-
Only used when the tranform is "auto".
296+
Only used when the transform is "auto".
282297
transform: str, optional
283298
Original transform method, only "auto" will try to detect the transform method by the key
284299
285300
Returns
286301
------
287-
Optional[str]:
288-
The transform method either when transform is "auto" then None, "json" or "binary" is returned
289-
or the original transform method
302+
Callable:
303+
Transform function could be json.loads, base64.b64decode, or a lambda that echo the str value
290304
"""
291-
if transform != "auto":
292-
return transform
305+
transform_method = TRANSFORM_METHOD_MAPPING.get(transform)
306+
307+
if transform == "auto":
308+
key_suffix = key.rsplit(".")[-1]
309+
transform_method = TRANSFORM_METHOD_MAPPING.get(key_suffix, TRANSFORM_METHOD_MAPPING[None])
310+
311+
return cast(Callable, transform_method) # https://github.com/python/mypy/issues/10740
312+
313+
314+
@overload
315+
def transform_value(
316+
value: Dict[str, Any], transform: TransformOptions, raise_on_transform_error: bool = False
317+
) -> Dict[str, Any]:
318+
...
293319

294-
for transform_method in SUPPORTED_TRANSFORM_METHODS:
295-
if key.endswith("." + transform_method):
296-
return transform_method
297-
return None
320+
321+
@overload
322+
def transform_value(
323+
value: Union[str, bytes, Dict[str, Any]], transform: TransformOptions, raise_on_transform_error: bool = False
324+
) -> Optional[Union[str, bytes, Dict[str, Any]]]:
325+
...
298326

299327

300328
def transform_value(
301-
value: str, transform: str, raise_on_transform_error: Optional[bool] = True
302-
) -> Optional[Union[dict, bytes]]:
329+
value: Union[str, bytes, Dict[str, Any]], transform: TransformOptions, raise_on_transform_error: bool = False
330+
) -> Optional[Union[str, bytes, Dict[str, Any]]]:
303331
"""
304-
Apply a transform to a value
332+
Transform a value using one of the available options.
305333
306334
Parameters
307335
---------
308336
value: str
309337
Parameter value to transform
310338
transform: str
311-
Type of transform, supported values are "json" and "binary"
339+
Type of transform, supported values are "json", "binary", and "auto" based on suffix (.json, .binary)
312340
raise_on_transform_error: bool, optional
313341
Raises an exception if any transform fails, otherwise this will
314342
return a None value for each transform that failed
@@ -318,18 +346,35 @@ def transform_value(
318346
TransformParameterError:
319347
When the parameter value could not be transformed
320348
"""
349+
# Maintenance: For v3, we should consider returning the original value for soft transform failures.
321350

322-
try:
323-
if transform == TRANSFORM_METHOD_JSON:
324-
return json.loads(value)
325-
elif transform == TRANSFORM_METHOD_BINARY:
326-
return base64.b64decode(value)
327-
else:
328-
raise ValueError(f"Invalid transform type '{transform}'")
351+
err_msg = "Unable to transform value using '{transform}' transform: {exc}"
352+
353+
if isinstance(value, bytes):
354+
value = value.decode("utf-8")
329355

356+
if isinstance(value, dict):
357+
# NOTE: We must handle partial failures when receiving multiple values
358+
# where one of the keys might fail during transform, e.g. `{"a": "valid", "b": "{"}`
359+
# expected: `{"a": "valid", "b": None}`
360+
361+
transformed_values: Dict[str, Any] = {}
362+
for dict_key, dict_value in value.items():
363+
transform_method = get_transform_method(key=dict_key, transform=transform)
364+
try:
365+
transformed_values[dict_key] = transform_method(dict_value)
366+
except Exception as exc:
367+
if raise_on_transform_error:
368+
raise TransformParameterError(err_msg.format(transform=transform, exc=exc)) from exc
369+
transformed_values[dict_key] = None
370+
return transformed_values
371+
372+
try:
373+
transform_method = get_transform_method(key=value, transform=transform)
374+
return transform_method(value)
330375
except Exception as exc:
331376
if raise_on_transform_error:
332-
raise TransformParameterError(str(exc))
377+
raise TransformParameterError(err_msg.format(transform=transform, exc=exc)) from exc
333378
return None
334379

335380

Diff for: aws_lambda_powertools/utilities/parameters/ssm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def get( # type: ignore[override]
103103
self,
104104
name: str,
105105
max_age: int = DEFAULT_MAX_AGE_SECS,
106-
transform: Optional[str] = None,
106+
transform: TransformOptions = None,
107107
decrypt: bool = False,
108108
force_fetch: bool = False,
109109
**sdk_options,

Diff for: tests/functional/test_utilities_parameters.py

+6-34
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@
1414
from botocore.response import StreamingBody
1515

1616
from aws_lambda_powertools.utilities import parameters
17-
from aws_lambda_powertools.utilities.parameters.base import BaseProvider, ExpirableValue
17+
from aws_lambda_powertools.utilities.parameters.base import (
18+
TRANSFORM_METHOD_MAPPING,
19+
BaseProvider,
20+
ExpirableValue,
21+
)
1822

1923

2024
@pytest.fixture(scope="function")
@@ -1863,17 +1867,6 @@ def test_transform_value_binary_exception():
18631867
assert "Incorrect padding" in str(excinfo)
18641868

18651869

1866-
def test_transform_value_wrong(mock_value):
1867-
"""
1868-
Test transform_value() with an incorrect transform
1869-
"""
1870-
1871-
with pytest.raises(parameters.TransformParameterError) as excinfo:
1872-
parameters.base.transform_value(mock_value, "INCORRECT")
1873-
1874-
assert "Invalid transform type" in str(excinfo)
1875-
1876-
18771870
def test_transform_value_ignore_error(mock_value):
18781871
"""
18791872
Test transform_value() does not raise errors when raise_on_transform_error is False
@@ -1884,35 +1877,14 @@ def test_transform_value_ignore_error(mock_value):
18841877
assert value is None
18851878

18861879

1887-
@pytest.mark.parametrize("original_transform", ["json", "binary", "other", "Auto", None])
1888-
def test_get_transform_method_preserve_original(original_transform):
1889-
"""
1890-
Check if original transform method is returned for anything other than "auto"
1891-
"""
1892-
transform = parameters.base.get_transform_method("key", original_transform)
1893-
1894-
assert transform == original_transform
1895-
1896-
18971880
@pytest.mark.parametrize("extension", ["json", "binary"])
18981881
def test_get_transform_method_preserve_auto(extension, mock_name):
18991882
"""
19001883
Check if we can auto detect the transform method by the support extensions json / binary
19011884
"""
19021885
transform = parameters.base.get_transform_method(f"{mock_name}.{extension}", "auto")
19031886

1904-
assert transform == extension
1905-
1906-
1907-
@pytest.mark.parametrize("key", ["json", "binary", "example", "example.jsonp"])
1908-
def test_get_transform_method_preserve_auto_unhandled(key):
1909-
"""
1910-
Check if any key that does not end with a supported extension returns None when
1911-
using the transform="auto"
1912-
"""
1913-
transform = parameters.base.get_transform_method(key, "auto")
1914-
1915-
assert transform is None
1887+
assert transform == TRANSFORM_METHOD_MAPPING[extension]
19161888

19171889

19181890
def test_base_provider_get_multiple_force_update(mock_name, mock_value):

0 commit comments

Comments
 (0)