diff --git a/aws_lambda_powertools/utilities/parameters/base.py b/aws_lambda_powertools/utilities/parameters/base.py index 274cd96aace..7ce0c9e4d2e 100644 --- a/aws_lambda_powertools/utilities/parameters/base.py +++ b/aws_lambda_powertools/utilities/parameters/base.py @@ -15,6 +15,9 @@ ExpirableValue = namedtuple("ExpirableValue", ["value", "ttl"]) # These providers will be dynamically initialized on first use of the helper functions DEFAULT_PROVIDERS = {} +TRANSFORM_METHOD_JSON = "json" +TRANSFORM_METHOD_BINARY = "binary" +SUPPORTED_TRANSFORM_METHODS = [TRANSFORM_METHOD_JSON, TRANSFORM_METHOD_BINARY] class BaseProvider(ABC): @@ -115,8 +118,8 @@ def get_multiple( Maximum age of the cached value transform: str, optional Optional transformation of the parameter value. Supported values - are "json" for JSON strings and "binary" for base 64 encoded - values. + are "json" for JSON strings, "binary" for base 64 encoded + values or "auto" which looks at the attribute key to determine the type. raise_on_transform_error: bool, optional Raises an exception if any transform fails, otherwise this will return a None value for each transform that failed @@ -145,7 +148,11 @@ def get_multiple( if transform is not None: for (key, value) in values.items(): - values[key] = transform_value(value, transform, raise_on_transform_error) + _transform = get_transform_method(key, transform) + if _transform is None: + continue + + values[key] = transform_value(value, _transform, raise_on_transform_error) self.store[key] = ExpirableValue(values, datetime.now() + timedelta(seconds=max_age),) @@ -159,6 +166,45 @@ def _get_multiple(self, path: str, **sdk_options) -> Dict[str, str]: raise NotImplementedError() +def get_transform_method(key: str, transform: Optional[str] = None) -> Optional[str]: + """ + Determine the transform method + + Examples + ------- + >>> get_transform_method("key", "any_other_value") + 'any_other_value' + >>> get_transform_method("key.json", "auto") + 'json' + >>> get_transform_method("key.binary", "auto") + 'binary' + >>> get_transform_method("key", "auto") + None + >>> get_transform_method("key", None) + None + + Parameters + --------- + key: str + Only used when the tranform is "auto". + transform: str, optional + Original transform method, only "auto" will try to detect the transform method by the key + + Returns + ------ + Optional[str]: + The transform method either when transform is "auto" then None, "json" or "binary" is returned + or the original transform method + """ + if transform != "auto": + return transform + + for transform_method in SUPPORTED_TRANSFORM_METHODS: + if key.endswith("." + transform_method): + return transform_method + return None + + def transform_value(value: str, transform: str, raise_on_transform_error: bool = True) -> Union[dict, bytes, None]: """ Apply a transform to a value @@ -180,9 +226,9 @@ def transform_value(value: str, transform: str, raise_on_transform_error: bool = """ try: - if transform == "json": + if transform == TRANSFORM_METHOD_JSON: return json.loads(value) - elif transform == "binary": + elif transform == TRANSFORM_METHOD_BINARY: return base64.b64decode(value) else: raise ValueError(f"Invalid transform type '{transform}'") diff --git a/tests/functional/test_utilities_parameters.py b/tests/functional/test_utilities_parameters.py index abd121540a6..55f643924ad 100644 --- a/tests/functional/test_utilities_parameters.py +++ b/tests/functional/test_utilities_parameters.py @@ -233,6 +233,48 @@ def test_dynamodb_provider_get_multiple(mock_name, mock_value, config): stubber.deactivate() +def test_dynamodb_provider_get_multiple_auto(mock_name, mock_value, config): + """ + Test DynamoDBProvider.get_multiple() with transform = "auto" + """ + mock_binary = mock_value.encode() + mock_binary_data = base64.b64encode(mock_binary).decode() + mock_json_data = json.dumps({mock_name: mock_value}) + mock_params = {"D.json": mock_json_data, "E.binary": mock_binary_data, "F": mock_value} + table_name = "TEST_TABLE_AUTO" + + # Create a new provider + provider = parameters.DynamoDBProvider(table_name, config=config) + + # Stub the boto3 client + stubber = stub.Stubber(provider.table.meta.client) + response = { + "Items": [ + {"id": {"S": mock_name}, "sk": {"S": name}, "value": {"S": value}} for (name, value) in mock_params.items() + ] + } + expected_params = {"TableName": table_name, "KeyConditionExpression": Key("id").eq(mock_name)} + stubber.add_response("query", response, expected_params) + stubber.activate() + + try: + values = provider.get_multiple(mock_name, transform="auto") + + stubber.assert_no_pending_responses() + + assert len(values) == len(mock_params) + for key in mock_params.keys(): + assert key in values + if key.endswith(".json"): + assert values[key][mock_name] == mock_value + elif key.endswith(".binary"): + assert values[key] == mock_binary + else: + assert values[key] == mock_value + finally: + stubber.deactivate() + + def test_dynamodb_provider_get_multiple_next_token(mock_name, mock_value, config): """ Test DynamoDBProvider.get_multiple() with a non-cached path @@ -1481,3 +1523,34 @@ def test_transform_value_ignore_error(mock_value): value = parameters.base.transform_value(mock_value, "INCORRECT", raise_on_transform_error=False) assert value is None + + +@pytest.mark.parametrize("original_transform", ["json", "binary", "other", "Auto", None]) +def test_get_transform_method_preserve_original(original_transform): + """ + Check if original transform method is returned for anything other than "auto" + """ + transform = parameters.base.get_transform_method("key", original_transform) + + assert transform == original_transform + + +@pytest.mark.parametrize("extension", ["json", "binary"]) +def test_get_transform_method_preserve_auto(extension, mock_name): + """ + Check if we can auto detect the transform method by the support extensions json / binary + """ + transform = parameters.base.get_transform_method(f"{mock_name}.{extension}", "auto") + + assert transform == extension + + +@pytest.mark.parametrize("key", ["json", "binary", "example", "example.jsonp"]) +def test_get_transform_method_preserve_auto_unhandled(key): + """ + Check if any key that does not end with a supported extension returns None when + using the transform="auto" + """ + transform = parameters.base.get_transform_method(key, "auto") + + assert transform is None