Skip to content

feat(parameters): transform = "auto" #133

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Sep 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 51 additions & 5 deletions aws_lambda_powertools/utilities/parameters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
ExpirableValue = namedtuple("ExpirableValue", ["value", "ttl"])
# These providers will be dynamically initialized on first use of the helper functions
DEFAULT_PROVIDERS = {}
TRANSFORM_METHOD_JSON = "json"
TRANSFORM_METHOD_BINARY = "binary"
SUPPORTED_TRANSFORM_METHODS = [TRANSFORM_METHOD_JSON, TRANSFORM_METHOD_BINARY]


class BaseProvider(ABC):
Expand Down Expand Up @@ -115,8 +118,8 @@ def get_multiple(
Maximum age of the cached value
transform: str, optional
Optional transformation of the parameter value. Supported values
are "json" for JSON strings and "binary" for base 64 encoded
values.
are "json" for JSON strings, "binary" for base 64 encoded
values or "auto" which looks at the attribute key to determine the type.
raise_on_transform_error: bool, optional
Raises an exception if any transform fails, otherwise this will
return a None value for each transform that failed
Expand Down Expand Up @@ -145,7 +148,11 @@ def get_multiple(

if transform is not None:
for (key, value) in values.items():
values[key] = transform_value(value, transform, raise_on_transform_error)
_transform = get_transform_method(key, transform)
if _transform is None:
continue

values[key] = transform_value(value, _transform, raise_on_transform_error)

self.store[key] = ExpirableValue(values, datetime.now() + timedelta(seconds=max_age),)

Expand All @@ -159,6 +166,45 @@ def _get_multiple(self, path: str, **sdk_options) -> Dict[str, str]:
raise NotImplementedError()


def get_transform_method(key: str, transform: Optional[str] = None) -> Optional[str]:
"""
Determine the transform method

Examples
-------
>>> get_transform_method("key", "any_other_value")
'any_other_value'
>>> get_transform_method("key.json", "auto")
'json'
>>> get_transform_method("key.binary", "auto")
'binary'
>>> get_transform_method("key", "auto")
None
>>> get_transform_method("key", None)
None

Parameters
---------
key: str
Only used when the tranform is "auto".
transform: str, optional
Original transform method, only "auto" will try to detect the transform method by the key

Returns
------
Optional[str]:
The transform method either when transform is "auto" then None, "json" or "binary" is returned
or the original transform method
"""
if transform != "auto":
return transform

for transform_method in SUPPORTED_TRANSFORM_METHODS:
if key.endswith("." + transform_method):
return transform_method
return None


def transform_value(value: str, transform: str, raise_on_transform_error: bool = True) -> Union[dict, bytes, None]:
"""
Apply a transform to a value
Expand All @@ -180,9 +226,9 @@ def transform_value(value: str, transform: str, raise_on_transform_error: bool =
"""

try:
if transform == "json":
if transform == TRANSFORM_METHOD_JSON:
return json.loads(value)
elif transform == "binary":
elif transform == TRANSFORM_METHOD_BINARY:
return base64.b64decode(value)
else:
raise ValueError(f"Invalid transform type '{transform}'")
Expand Down
73 changes: 73 additions & 0 deletions tests/functional/test_utilities_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,48 @@ def test_dynamodb_provider_get_multiple(mock_name, mock_value, config):
stubber.deactivate()


def test_dynamodb_provider_get_multiple_auto(mock_name, mock_value, config):
"""
Test DynamoDBProvider.get_multiple() with transform = "auto"
"""
mock_binary = mock_value.encode()
mock_binary_data = base64.b64encode(mock_binary).decode()
mock_json_data = json.dumps({mock_name: mock_value})
mock_params = {"D.json": mock_json_data, "E.binary": mock_binary_data, "F": mock_value}
table_name = "TEST_TABLE_AUTO"

# Create a new provider
provider = parameters.DynamoDBProvider(table_name, config=config)

# Stub the boto3 client
stubber = stub.Stubber(provider.table.meta.client)
response = {
"Items": [
{"id": {"S": mock_name}, "sk": {"S": name}, "value": {"S": value}} for (name, value) in mock_params.items()
]
}
expected_params = {"TableName": table_name, "KeyConditionExpression": Key("id").eq(mock_name)}
stubber.add_response("query", response, expected_params)
stubber.activate()

try:
values = provider.get_multiple(mock_name, transform="auto")

stubber.assert_no_pending_responses()

assert len(values) == len(mock_params)
for key in mock_params.keys():
assert key in values
if key.endswith(".json"):
assert values[key][mock_name] == mock_value
elif key.endswith(".binary"):
assert values[key] == mock_binary
else:
assert values[key] == mock_value
finally:
stubber.deactivate()


def test_dynamodb_provider_get_multiple_next_token(mock_name, mock_value, config):
"""
Test DynamoDBProvider.get_multiple() with a non-cached path
Expand Down Expand Up @@ -1481,3 +1523,34 @@ def test_transform_value_ignore_error(mock_value):
value = parameters.base.transform_value(mock_value, "INCORRECT", raise_on_transform_error=False)

assert value is None


@pytest.mark.parametrize("original_transform", ["json", "binary", "other", "Auto", None])
def test_get_transform_method_preserve_original(original_transform):
"""
Check if original transform method is returned for anything other than "auto"
"""
transform = parameters.base.get_transform_method("key", original_transform)

assert transform == original_transform


@pytest.mark.parametrize("extension", ["json", "binary"])
def test_get_transform_method_preserve_auto(extension, mock_name):
"""
Check if we can auto detect the transform method by the support extensions json / binary
"""
transform = parameters.base.get_transform_method(f"{mock_name}.{extension}", "auto")

assert transform == extension


@pytest.mark.parametrize("key", ["json", "binary", "example", "example.jsonp"])
def test_get_transform_method_preserve_auto_unhandled(key):
"""
Check if any key that does not end with a supported extension returns None when
using the transform="auto"
"""
transform = parameters.base.get_transform_method(key, "auto")

assert transform is None