diff --git a/aws_lambda_powertools/shared/functions.py b/aws_lambda_powertools/shared/functions.py index fb4eedb7f36..884edb37e35 100644 --- a/aws_lambda_powertools/shared/functions.py +++ b/aws_lambda_powertools/shared/functions.py @@ -1,9 +1,10 @@ import base64 +import itertools import logging import os import warnings from binascii import Error as BinAsciiError -from typing import Optional, Union, overload +from typing import Dict, Generator, Optional, Union, overload from aws_lambda_powertools.shared import constants @@ -115,3 +116,8 @@ def powertools_debug_is_set() -> bool: return True return False + + +def slice_dictionary(data: Dict, chunk_size: int) -> Generator[Dict, None, None]: + for _ in range(0, len(data), chunk_size): + yield {dict_key: data[dict_key] for dict_key in itertools.islice(data, chunk_size)} diff --git a/aws_lambda_powertools/utilities/feature_flags/appconfig.py b/aws_lambda_powertools/utilities/feature_flags/appconfig.py index 8c8dbacc6c5..8695c1fd8c9 100644 --- a/aws_lambda_powertools/utilities/feature_flags/appconfig.py +++ b/aws_lambda_powertools/utilities/feature_flags/appconfig.py @@ -15,8 +15,6 @@ from .base import StoreProvider from .exceptions import ConfigurationStoreError, StoreClientError -TRANSFORM_TYPE = "json" - class AppConfigStore(StoreProvider): def __init__( @@ -74,7 +72,7 @@ def get_raw_configuration(self) -> Dict[str, Any]: dict, self._conf_store.get( name=self.name, - transform=TRANSFORM_TYPE, + transform="json", max_age=self.cache_seconds, ), ) diff --git a/aws_lambda_powertools/utilities/parameters/__init__.py b/aws_lambda_powertools/utilities/parameters/__init__.py index 7dce2ac4c9a..9fcaa4fa701 100644 --- a/aws_lambda_powertools/utilities/parameters/__init__.py +++ b/aws_lambda_powertools/utilities/parameters/__init__.py @@ -9,7 +9,7 @@ from .dynamodb import DynamoDBProvider from .exceptions import GetParameterError, TransformParameterError from .secrets import SecretsProvider, get_secret -from .ssm import SSMProvider, get_parameter, get_parameters +from .ssm import SSMProvider, get_parameter, get_parameters, get_parameters_by_name __all__ = [ "AppConfigProvider", @@ -22,6 +22,7 @@ "get_app_config", "get_parameter", "get_parameters", + "get_parameters_by_name", "get_secret", "clear_caches", ] diff --git a/aws_lambda_powertools/utilities/parameters/appconfig.py b/aws_lambda_powertools/utilities/parameters/appconfig.py index a3a340a62be..7884728024e 100644 --- a/aws_lambda_powertools/utilities/parameters/appconfig.py +++ b/aws_lambda_powertools/utilities/parameters/appconfig.py @@ -9,6 +9,8 @@ import boto3 from botocore.config import Config +from aws_lambda_powertools.utilities.parameters.types import TransformOptions + if TYPE_CHECKING: from mypy_boto3_appconfigdata import AppConfigDataClient @@ -132,7 +134,7 @@ def get_app_config( name: str, environment: str, application: Optional[str] = None, - transform: Optional[str] = None, + transform: TransformOptions = None, force_fetch: bool = False, max_age: int = DEFAULT_MAX_AGE_SECS, **sdk_options diff --git a/aws_lambda_powertools/utilities/parameters/base.py b/aws_lambda_powertools/utilities/parameters/base.py index b76b16e1dd8..8587d3b5f3f 100644 --- a/aws_lambda_powertools/utilities/parameters/base.py +++ b/aws_lambda_powertools/utilities/parameters/base.py @@ -1,17 +1,31 @@ """ Base for Parameter providers """ +from __future__ import annotations import base64 import json from abc import ABC, abstractmethod -from collections import namedtuple from datetime import datetime, timedelta -from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Type, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + NamedTuple, + Optional, + Tuple, + Type, + Union, + cast, + overload, +) import boto3 from botocore.config import Config +from aws_lambda_powertools.utilities.parameters.types import TransformOptions + from .exceptions import GetParameterError, TransformParameterError if TYPE_CHECKING: @@ -22,7 +36,6 @@ DEFAULT_MAX_AGE_SECS = 5 -ExpirableValue = namedtuple("ExpirableValue", ["value", "ttl"]) # These providers will be dynamically initialized on first use of the helper functions DEFAULT_PROVIDERS: Dict[str, Any] = {} TRANSFORM_METHOD_JSON = "json" @@ -30,29 +43,42 @@ SUPPORTED_TRANSFORM_METHODS = [TRANSFORM_METHOD_JSON, TRANSFORM_METHOD_BINARY] ParameterClients = Union["AppConfigDataClient", "SecretsManagerClient", "SSMClient"] +TRANSFORM_METHOD_MAPPING = { + TRANSFORM_METHOD_JSON: json.loads, + TRANSFORM_METHOD_BINARY: base64.b64decode, + ".json": json.loads, + ".binary": base64.b64decode, + None: lambda x: x, +} + + +class ExpirableValue(NamedTuple): + value: str | bytes | Dict[str, Any] + ttl: datetime + class BaseProvider(ABC): """ Abstract Base Class for Parameter providers """ - store: Any = None + store: Dict[Tuple[str, TransformOptions], ExpirableValue] def __init__(self): """ Initialize the base provider """ - self.store = {} + self.store: Dict[Tuple[str, TransformOptions], ExpirableValue] = {} - def _has_not_expired(self, key: Tuple[str, Optional[str]]) -> bool: + def has_not_expired_in_cache(self, key: Tuple[str, TransformOptions]) -> bool: return key in self.store and self.store[key].ttl >= datetime.now() def get( self, name: str, max_age: int = DEFAULT_MAX_AGE_SECS, - transform: Optional[str] = None, + transform: TransformOptions = None, force_fetch: bool = False, **sdk_options, ) -> Optional[Union[str, dict, bytes]]: @@ -95,7 +121,7 @@ def get( value: Optional[Union[str, bytes, dict]] = None key = (name, transform) - if not force_fetch and self._has_not_expired(key): + if not force_fetch and self.has_not_expired_in_cache(key): return self.store[key].value try: @@ -105,11 +131,11 @@ def get( raise GetParameterError(str(exc)) if transform: - if isinstance(value, bytes): - value = value.decode("utf-8") - value = transform_value(value, transform) + value = transform_value(key=name, value=value, transform=transform, raise_on_transform_error=True) - self.store[key] = ExpirableValue(value, datetime.now() + timedelta(seconds=max_age)) + # NOTE: don't cache None, as they might've been failed transforms and may be corrected + if value is not None: + self.store[key] = ExpirableValue(value, datetime.now() + timedelta(seconds=max_age)) return value @@ -124,7 +150,7 @@ def get_multiple( self, path: str, max_age: int = DEFAULT_MAX_AGE_SECS, - transform: Optional[str] = None, + transform: TransformOptions = None, raise_on_transform_error: bool = False, force_fetch: bool = False, **sdk_options, @@ -160,8 +186,8 @@ def get_multiple( """ key = (path, transform) - if not force_fetch and self._has_not_expired(key): - return self.store[key].value + if not force_fetch and self.has_not_expired_in_cache(key): + return self.store[key].value # type: ignore # need to revisit entire typing here try: values = self._get_multiple(path, **sdk_options) @@ -170,13 +196,8 @@ def get_multiple( raise GetParameterError(str(exc)) if transform: - transformed_values: dict = {} - for (item, value) in values.items(): - _transform = get_transform_method(item, transform) - if not _transform: - continue - transformed_values[item] = transform_value(value, _transform, raise_on_transform_error) - values.update(transformed_values) + values.update(transform_value(values, transform, raise_on_transform_error)) + self.store[key] = ExpirableValue(values, datetime.now() + timedelta(seconds=max_age)) return values @@ -191,6 +212,12 @@ def _get_multiple(self, path: str, **sdk_options) -> Dict[str, str]: def clear_cache(self): self.store.clear() + def add_to_cache(self, key: Tuple[str, TransformOptions], value: Any, max_age: int): + if max_age <= 0: + return + + self.store[key] = ExpirableValue(value, datetime.now() + timedelta(seconds=max_age)) + @staticmethod def _build_boto3_client( service_name: str, @@ -258,57 +285,81 @@ def _build_boto3_resource_client( return session.resource(service_name=service_name, config=config, endpoint_url=endpoint_url) -def get_transform_method(key: str, transform: Optional[str] = None) -> Optional[str]: +def get_transform_method(value: str, transform: TransformOptions = None) -> Callable[..., Any]: """ Determine the transform method Examples ------- - >>> get_transform_method("key", "any_other_value") + >>> get_transform_method("key","any_other_value") 'any_other_value' - >>> get_transform_method("key.json", "auto") + >>> get_transform_method("key.json","auto") 'json' - >>> get_transform_method("key.binary", "auto") + >>> get_transform_method("key.binary","auto") 'binary' - >>> get_transform_method("key", "auto") + >>> get_transform_method("key","auto") None - >>> get_transform_method("key", None) + >>> get_transform_method("key",None) None Parameters --------- - key: str - Only used when the tranform is "auto". + value: str + Only used when the transform is "auto". transform: str, optional Original transform method, only "auto" will try to detect the transform method by the key Returns ------ - Optional[str]: - The transform method either when transform is "auto" then None, "json" or "binary" is returned - or the original transform method + Callable: + Transform function could be json.loads, base64.b64decode, or a lambda that echo the str value """ - if transform != "auto": - return transform + transform_method = TRANSFORM_METHOD_MAPPING.get(transform) + + if transform == "auto": + key_suffix = value.rsplit(".")[-1] + transform_method = TRANSFORM_METHOD_MAPPING.get(key_suffix, TRANSFORM_METHOD_MAPPING[None]) + + return cast(Callable, transform_method) # https://github.com/python/mypy/issues/10740 + + +@overload +def transform_value( + value: Dict[str, Any], + transform: TransformOptions, + raise_on_transform_error: bool = False, + key: str = "", +) -> Dict[str, Any]: + ... + - for transform_method in SUPPORTED_TRANSFORM_METHODS: - if key.endswith("." + transform_method): - return transform_method - return None +@overload +def transform_value( + value: Union[str, bytes, Dict[str, Any]], + transform: TransformOptions, + raise_on_transform_error: bool = False, + key: str = "", +) -> Optional[Union[str, bytes, Dict[str, Any]]]: + ... def transform_value( - value: str, transform: str, raise_on_transform_error: Optional[bool] = True -) -> Optional[Union[dict, bytes]]: + value: Union[str, bytes, Dict[str, Any]], + transform: TransformOptions, + raise_on_transform_error: bool = True, + key: str = "", +) -> Optional[Union[str, bytes, Dict[str, Any]]]: """ - Apply a transform to a value + Transform a value using one of the available options. Parameters --------- value: str Parameter value to transform transform: str - Type of transform, supported values are "json" and "binary" + Type of transform, supported values are "json", "binary", and "auto" based on suffix (.json, .binary) + key: str + Parameter key when transform is auto to infer its transform method raise_on_transform_error: bool, optional Raises an exception if any transform fails, otherwise this will return a None value for each transform that failed @@ -318,18 +369,41 @@ def transform_value( TransformParameterError: When the parameter value could not be transformed """ + # Maintenance: For v3, we should consider returning the original value for soft transform failures. + + err_msg = "Unable to transform value using '{transform}' transform: {exc}" + + if isinstance(value, bytes): + value = value.decode("utf-8") + + if isinstance(value, dict): + # NOTE: We must handle partial failures when receiving multiple values + # where one of the keys might fail during transform, e.g. `{"a": "valid", "b": "{"}` + # expected: `{"a": "valid", "b": None}` + + transformed_values: Dict[str, Any] = {} + for dict_key, dict_value in value.items(): + transform_method = get_transform_method(value=dict_key, transform=transform) + try: + transformed_values[dict_key] = transform_method(dict_value) + except Exception as exc: + if raise_on_transform_error: + raise TransformParameterError(err_msg.format(transform=transform, exc=exc)) from exc + transformed_values[dict_key] = None + return transformed_values + + if transform == "auto": + # key="a.json", value='{"a": "b"}', or key="a.binary", value="b64_encoded" + transform_method = get_transform_method(value=key, transform=transform) + else: + # value='{"key": "value"} + transform_method = get_transform_method(value=value, transform=transform) try: - if transform == TRANSFORM_METHOD_JSON: - return json.loads(value) - elif transform == TRANSFORM_METHOD_BINARY: - return base64.b64decode(value) - else: - raise ValueError(f"Invalid transform type '{transform}'") - + return transform_method(value) except Exception as exc: if raise_on_transform_error: - raise TransformParameterError(str(exc)) + raise TransformParameterError(err_msg.format(transform=transform, exc=exc)) from exc return None diff --git a/aws_lambda_powertools/utilities/parameters/ssm.py b/aws_lambda_powertools/utilities/parameters/ssm.py index 3b3e782fd45..ae4a76dac4a 100644 --- a/aws_lambda_powertools/utilities/parameters/ssm.py +++ b/aws_lambda_powertools/utilities/parameters/ssm.py @@ -1,17 +1,23 @@ """ AWS SSM Parameter retrieval and caching utility """ +from __future__ import annotations - -from typing import TYPE_CHECKING, Any, Dict, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, overload import boto3 from botocore.config import Config +from typing_extensions import Literal + +from aws_lambda_powertools.shared.functions import slice_dictionary -from .base import DEFAULT_MAX_AGE_SECS, DEFAULT_PROVIDERS, BaseProvider +from .base import DEFAULT_MAX_AGE_SECS, DEFAULT_PROVIDERS, BaseProvider, transform_value +from .exceptions import GetParameterError +from .types import TransformOptions if TYPE_CHECKING: from mypy_boto3_ssm import SSMClient + from mypy_boto3_ssm.type_defs import GetParametersResultTypeDef class SSMProvider(BaseProvider): @@ -80,6 +86,8 @@ class SSMProvider(BaseProvider): """ client: Any = None + _MAX_GET_PARAMETERS_ITEM = 10 + _ERRORS_KEY = "_errors" def __init__( self, @@ -103,10 +111,10 @@ def get( # type: ignore[override] self, name: str, max_age: int = DEFAULT_MAX_AGE_SECS, - transform: Optional[str] = None, + transform: TransformOptions = None, decrypt: bool = False, force_fetch: bool = False, - **sdk_options + **sdk_options, ) -> Optional[Union[str, dict, bytes]]: """ Retrieve a parameter value or return the cached value @@ -187,7 +195,7 @@ def _get_multiple(self, path: str, decrypt: bool = False, recursive: bool = Fals for page in self.client.get_paginator("get_parameters_by_path").paginate(**sdk_options): for parameter in page.get("Parameters", []): # Standardize the parameter name - # The parameter name returned by SSM will contained the full path. + # The parameter name returned by SSM will contain the full path. # However, for readability, we should return only the part after # the path. name = parameter["Name"] @@ -199,6 +207,282 @@ def _get_multiple(self, path: str, decrypt: bool = False, recursive: bool = Fals return parameters + # NOTE: When bandwidth permits, allocate a week to refactor to lower cognitive load + def get_parameters_by_name( + self, + parameters: Dict[str, Dict], + transform: TransformOptions = None, + decrypt: bool = False, + max_age: int = DEFAULT_MAX_AGE_SECS, + raise_on_error: bool = True, + ) -> Dict[str, str] | Dict[str, bytes] | Dict[str, dict]: + """ + Retrieve multiple parameter values by name from SSM or cache. + + Raise_on_error decides on error handling strategy: + + - A) Default to fail-fast. Raises GetParameterError upon any error + - B) Gracefully aggregate all parameters that failed under "_errors" key + + It transparently uses GetParameter and/or GetParameters depending on decryption requirements. + + ┌────────────────────────┐ + ┌───▶ Decrypt entire batch │─────┐ + │ └────────────────────────┘ │ ┌────────────────────┐ + │ ├─────▶ GetParameters API │ + ┌──────────────────┐ │ ┌────────────────────────┐ │ └────────────────────┘ + │ Split batch │─── ┼──▶│ No decryption required │─────┘ + └──────────────────┘ │ └────────────────────────┘ + │ ┌────────────────────┐ + │ ┌────────────────────────┐ │ GetParameter API │ + └──▶│Decrypt some but not all│───────────▶────────────────────┤ + └────────────────────────┘ │ GetParameters API │ + └────────────────────┘ + + Parameters + ---------- + parameters: List[Dict[str, Dict]] + List of parameter names, and any optional overrides + transform: str, optional + Transforms the content from a JSON object ('json') or base64 binary string ('binary') + decrypt: bool, optional + If the parameter values should be decrypted + max_age: int + Maximum age of the cached value + raise_on_error: bool + Whether to fail-fast or fail gracefully by including "_errors" key in the response, by default True + + Raises + ------ + GetParameterError + When the parameter provider fails to retrieve a parameter value for a given name. + + When "_errors" reserved key is in parameters to be fetched from SSM. + """ + # Init potential batch/decrypt batch responses and errors + batch_ret: Dict[str, Any] = {} + decrypt_ret: Dict[str, Any] = {} + batch_err: List[str] = [] + decrypt_err: List[str] = [] + response: Dict[str, Any] = {} + + # NOTE: We fail early to avoid unintended graceful errors being replaced with their '_errors' param values + self._raise_if_errors_key_is_present(parameters, self._ERRORS_KEY, raise_on_error) + + batch_params, decrypt_params = self._split_batch_and_decrypt_parameters(parameters, transform, max_age, decrypt) + + # NOTE: We need to find out whether all parameters must be decrypted or not to know which API to use + ## Logic: + ## + ## GetParameters API -> When decrypt is used for all parameters in the the batch + ## GetParameter API -> When decrypt is used for one or more in the batch + + if len(decrypt_params) != len(parameters): + decrypt_ret, decrypt_err = self._get_parameters_by_name_with_decrypt_option(decrypt_params, raise_on_error) + batch_ret, batch_err = self._get_parameters_batch_by_name(batch_params, raise_on_error, decrypt=False) + else: + batch_ret, batch_err = self._get_parameters_batch_by_name(decrypt_params, raise_on_error, decrypt=True) + + # Fail-fast disabled, let's aggregate errors under "_errors" key so they can handle gracefully + if not raise_on_error: + response[self._ERRORS_KEY] = [*decrypt_err, *batch_err] + + return {**response, **batch_ret, **decrypt_ret} + + def _get_parameters_by_name_with_decrypt_option( + self, batch: Dict[str, Dict], raise_on_error: bool + ) -> Tuple[Dict, List]: + response: Dict[str, Any] = {} + errors: List[str] = [] + + # Decided for single-thread as it outperforms in 128M and 1G + reduce timeout risk + # see: https://github.com/awslabs/aws-lambda-powertools-python/issues/1040#issuecomment-1299954613 + for parameter, options in batch.items(): + try: + response[parameter] = self.get(parameter, options["max_age"], options["transform"], options["decrypt"]) + except GetParameterError: + if raise_on_error: + raise + errors.append(parameter) + continue + + return response, errors + + def _get_parameters_batch_by_name( + self, batch: Dict[str, Dict], raise_on_error: bool = True, decrypt: bool = False + ) -> Tuple[Dict, List]: + """Slice batch and fetch parameters using GetParameters by max permitted""" + errors: List[str] = [] + + # Fetch each possible batch param from cache and return if entire batch is cached + cached_params = self._get_parameters_by_name_from_cache(batch) + if len(cached_params) == len(batch): + return cached_params, errors + + # Slice batch by max permitted GetParameters call + batch_ret, errors = self._get_parameters_by_name_in_chunks(batch, cached_params, raise_on_error, decrypt) + + return {**cached_params, **batch_ret}, errors + + def _get_parameters_by_name_from_cache(self, batch: Dict[str, Dict]) -> Dict[str, Any]: + """Fetch each parameter from batch that hasn't been expired""" + cache = {} + for name, options in batch.items(): + cache_key = (name, options["transform"]) + if self.has_not_expired_in_cache(cache_key): + cache[name] = self.store[cache_key].value + + return cache + + def _get_parameters_by_name_in_chunks( + self, batch: Dict[str, Dict], cache: Dict[str, Any], raise_on_error: bool, decrypt: bool = False + ) -> Tuple[Dict, List]: + """Take out differences from cache and batch, slice it and fetch from SSM""" + response: Dict[str, Any] = {} + errors: List[str] = [] + + diff = {key: value for key, value in batch.items() if key not in cache} + + for chunk in slice_dictionary(data=diff, chunk_size=self._MAX_GET_PARAMETERS_ITEM): + response, possible_errors = self._get_parameters_by_name( + parameters=chunk, raise_on_error=raise_on_error, decrypt=decrypt + ) + response.update(response) + errors.extend(possible_errors) + + return response, errors + + def _get_parameters_by_name( + self, parameters: Dict[str, Dict], raise_on_error: bool = True, decrypt: bool = False + ) -> Tuple[Dict[str, Any], List[str]]: + """Use SSM GetParameters to fetch parameters, hydrate cache, and handle partial failure + + Parameters + ---------- + parameters : Dict[str, Dict] + Parameters to fetch + raise_on_error : bool, optional + Whether to fail-fast or fail gracefully by including "_errors" key in the response, by default True + + Returns + ------- + Dict[str, Any] + Retrieved parameters as key names and their values + + Raises + ------ + GetParameterError + When one or more parameters failed on fetching, and raise_on_error is enabled + """ + ret: Dict[str, Any] = {} + batch_errors: List[str] = [] + parameter_names = list(parameters.keys()) + + # All params in the batch must be decrypted + # we return early if we hit an unrecoverable exception like InvalidKeyId/InternalServerError + # everything else should technically be recoverable as GetParameters is non-atomic + try: + if decrypt: + response = self.client.get_parameters(Names=parameter_names, WithDecryption=True) + else: + response = self.client.get_parameters(Names=parameter_names) + except (self.client.exceptions.InvalidKeyId, self.client.exceptions.InternalServerError): + return ret, parameter_names + + batch_errors = self._handle_any_invalid_get_parameter_errors(response, raise_on_error) + transformed_params = self._transform_and_cache_get_parameters_response(response, parameters, raise_on_error) + + return transformed_params, batch_errors + + def _transform_and_cache_get_parameters_response( + self, api_response: GetParametersResultTypeDef, parameters: Dict[str, Any], raise_on_error: bool = True + ) -> Dict[str, Any]: + response: Dict[str, Any] = {} + + for parameter in api_response["Parameters"]: + name = parameter["Name"] + value = parameter["Value"] + options = parameters[name] + transform = options.get("transform") + + # NOTE: If transform is set, we do it before caching to reduce number of operations + if transform: + value = transform_value(name, value, transform, raise_on_error) # type: ignore + + _cache_key = (name, options["transform"]) + self.add_to_cache(key=_cache_key, value=value, max_age=options["max_age"]) + + response[name] = value + + return response + + @staticmethod + def _handle_any_invalid_get_parameter_errors( + api_response: GetParametersResultTypeDef, raise_on_error: bool = True + ) -> List[str]: + """GetParameters is non-atomic. Failures don't always reflect in exceptions so we need to collect.""" + failed_parameters = api_response["InvalidParameters"] + if failed_parameters: + if raise_on_error: + raise GetParameterError(f"Failed to fetch parameters: {failed_parameters}") + + return failed_parameters + + return [] + + @staticmethod + def _split_batch_and_decrypt_parameters( + parameters: Dict[str, Dict], transform: TransformOptions, max_age: int, decrypt: bool + ) -> Tuple[Dict[str, Dict], Dict[str, Dict]]: + """Split parameters that can be fetched by GetParameters vs GetParameter + + Parameters + ---------- + parameters : Dict[str, Dict] + Parameters containing names as key and optional config override as value + transform : TransformOptions + Transform configuration + max_age : int + How long to cache a parameter for + decrypt : bool + Whether to use KMS to decrypt a parameter + + Returns + ------- + Tuple[Dict[str, Dict], Dict[str, Dict]] + GetParameters and GetParameter parameters dict along with their overrides/globals merged + """ + batch_parameters: Dict[str, Dict] = {} + decrypt_parameters: Dict[str, Any] = {} + + for parameter, options in parameters.items(): + # NOTE: TypeDict later + _overrides = options or {} + _overrides["transform"] = _overrides.get("transform") or transform + + # These values can be falsy (False, 0) + if "decrypt" not in _overrides: + _overrides["decrypt"] = decrypt + + if "max_age" not in _overrides: + _overrides["max_age"] = max_age + + # NOTE: Split parameters who have decrypt OR have it global + if _overrides["decrypt"]: + decrypt_parameters[parameter] = _overrides + else: + batch_parameters[parameter] = _overrides + + return batch_parameters, decrypt_parameters + + @staticmethod + def _raise_if_errors_key_is_present(parameters: Dict, reserved_parameter: str, raise_on_error: bool): + """Raise GetParameterError if fail-fast is disabled and '_errors' key is in parameters batch""" + if not raise_on_error and reserved_parameter in parameters: + raise GetParameterError( + f"You cannot fetch a parameter named '{reserved_parameter}' in graceful error mode." + ) + def get_parameter( name: str, @@ -206,8 +490,8 @@ def get_parameter( decrypt: bool = False, force_fetch: bool = False, max_age: int = DEFAULT_MAX_AGE_SECS, - **sdk_options -) -> Union[str, list, dict, bytes]: + **sdk_options, +) -> Union[str, dict, bytes]: """ Retrieve a parameter value from AWS Systems Manager (SSM) Parameter Store @@ -275,7 +559,7 @@ def get_parameters( force_fetch: bool = False, max_age: int = DEFAULT_MAX_AGE_SECS, raise_on_transform_error: bool = False, - **sdk_options + **sdk_options, ) -> Union[Dict[str, str], Dict[str, dict], Dict[str, bytes]]: """ Retrieve multiple parameter values from AWS Systems Manager (SSM) Parameter Store @@ -342,5 +626,116 @@ def get_parameters( transform=transform, raise_on_transform_error=raise_on_transform_error, force_fetch=force_fetch, - **sdk_options + **sdk_options, + ) + + +@overload +def get_parameters_by_name( + parameters: Dict[str, Dict], + transform: None = None, + decrypt: bool = False, + max_age: int = DEFAULT_MAX_AGE_SECS, + raise_on_error: bool = True, +) -> Dict[str, str]: + ... + + +@overload +def get_parameters_by_name( + parameters: Dict[str, Dict], + transform: Literal["binary"], + decrypt: bool = False, + max_age: int = DEFAULT_MAX_AGE_SECS, + raise_on_error: bool = True, +) -> Dict[str, bytes]: + ... + + +@overload +def get_parameters_by_name( + parameters: Dict[str, Dict], + transform: Literal["json"], + decrypt: bool = False, + max_age: int = DEFAULT_MAX_AGE_SECS, + raise_on_error: bool = True, +) -> Dict[str, Dict[str, Any]]: + ... + + +@overload +def get_parameters_by_name( + parameters: Dict[str, Dict], + transform: Literal["auto"], + decrypt: bool = False, + max_age: int = DEFAULT_MAX_AGE_SECS, + raise_on_error: bool = True, +) -> Union[Dict[str, str], Dict[str, dict]]: + ... + + +def get_parameters_by_name( + parameters: Dict[str, Any], + transform: TransformOptions = None, + decrypt: bool = False, + max_age: int = DEFAULT_MAX_AGE_SECS, + raise_on_error: bool = True, +) -> Union[Dict[str, str], Dict[str, bytes], Dict[str, dict]]: + """ + Retrieve multiple parameter values by name from AWS Systems Manager (SSM) Parameter Store + + Parameters + ---------- + parameters: List[Dict[str, Dict]] + List of parameter names, and any optional overrides + transform: str, optional + Transforms the content from a JSON object ('json') or base64 binary string ('binary') + decrypt: bool, optional + If the parameter values should be decrypted + max_age: int + Maximum age of the cached value + raise_on_error: bool, optional + Whether to fail-fast or fail gracefully by including "_errors" key in the response, by default True + + Example + ------- + + **Retrieves multiple parameters from distinct paths from Systems Manager Parameter Store** + + from aws_lambda_powertools.utilities.parameters import get_parameters_by_name + + params = { + "/param": {}, + "/json": {"transform": "json"}, + "/binary": {"transform": "binary"}, + "/no_cache": {"max_age": 0}, + "/api_key": {"decrypt": True}, + } + + values = get_parameters_by_name(parameters=params) + for param_name, value in values.items(): + print(f"{param_name}: {value}") + + # "/param": value + # "/json": value + # "/binary": value + # "/no_cache": value + # "/api_key": value + + Raises + ------ + GetParameterError + When the parameter provider fails to retrieve a parameter value for + a given name. + """ + + # NOTE: Decided against using multi-thread due to single-thread outperforming in 128M and 1G + timeout risk + # see: https://github.com/awslabs/aws-lambda-powertools-python/issues/1040#issuecomment-1299954613 + + # Only create the provider if this function is called at least once + if "ssm" not in DEFAULT_PROVIDERS: + DEFAULT_PROVIDERS["ssm"] = SSMProvider() + + return DEFAULT_PROVIDERS["ssm"].get_parameters_by_name( + parameters=parameters, max_age=max_age, transform=transform, decrypt=decrypt, raise_on_error=raise_on_error ) diff --git a/aws_lambda_powertools/utilities/parameters/types.py b/aws_lambda_powertools/utilities/parameters/types.py new file mode 100644 index 00000000000..6a15873c496 --- /dev/null +++ b/aws_lambda_powertools/utilities/parameters/types.py @@ -0,0 +1,3 @@ +from typing_extensions import Literal + +TransformOptions = Literal["json", "binary", "auto", None] diff --git a/docs/utilities/parameters.md b/docs/utilities/parameters.md index 6b7d64b66b9..9441d94fe12 100644 --- a/docs/utilities/parameters.md +++ b/docs/utilities/parameters.md @@ -24,34 +24,99 @@ This utility requires additional permissions to work as expected. ???+ note Different parameter providers require different permissions. -| Provider | Function/Method | IAM Permission | -| ------------------- | -----------------------------------------------------------------| -----------------------------------------------------------------------------| -| SSM Parameter Store | `get_parameter`, `SSMProvider.get` | `ssm:GetParameter` | -| SSM Parameter Store | `get_parameters`, `SSMProvider.get_multiple` | `ssm:GetParametersByPath` | -| SSM Parameter Store | If using `decrypt=True` | You must add an additional permission `kms:Decrypt` | -| Secrets Manager | `get_secret`, `SecretsManager.get` | `secretsmanager:GetSecretValue` | -| DynamoDB | `DynamoDBProvider.get` | `dynamodb:GetItem` | -| DynamoDB | `DynamoDBProvider.get_multiple` | `dynamodb:Query` | -| App Config | `get_app_config`, `AppConfigProvider.get_app_config` | `appconfig:GetLatestConfiguration` and `appconfig:StartConfigurationSession` | +| Provider | Function/Method | IAM Permission | +| --------- | ---------------------------------------------------------------------- | ------------------------------------------------------------------------------------ | +| SSM | **`get_parameter`**, **`SSMProvider.get`** | **`ssm:GetParameter`** | +| SSM | **`get_parameters`**, **`SSMProvider.get_multiple`** | **`ssm:GetParametersByPath`** | +| SSM | **`get_parameters_by_name`**, **`SSMProvider.get_parameters_by_name`** | **`ssm:GetParameter`** and **`ssm:GetParameters`** | +| SSM | If using **`decrypt=True`** | You must add an additional permission **`kms:Decrypt`** | +| Secrets | **`get_secret`**, **`SecretsManager.get`** | **`secretsmanager:GetSecretValue`** | +| DynamoDB | **`DynamoDBProvider.get`** | **`dynamodb:GetItem`** | +| DynamoDB | **`DynamoDBProvider.get_multiple`** | **`dynamodb:Query`** | +| AppConfig | **`get_app_config`**, **`AppConfigProvider.get_app_config`** | **`appconfig:GetLatestConfiguration`** and **`appconfig:StartConfigurationSession`** | ### Fetching parameters You can retrieve a single parameter using `get_parameter` high-level function. -For multiple parameters, you can use `get_parameters` and pass a path to retrieve them recursively. - -```python hl_lines="1 5 9" title="Fetching multiple parameters recursively" +```python hl_lines="5" title="Fetching a single parameter" from aws_lambda_powertools.utilities import parameters def handler(event, context): # Retrieve a single parameter value = parameters.get_parameter("/my/parameter") - # Retrieve multiple parameters from a path prefix recursively - # This returns a dict with the parameter name as key - values = parameters.get_parameters("/my/path/prefix") - for k, v in values.items(): - print(f"{k}: {v}") +``` + +For multiple parameters, you can use either: + +* `get_parameters` to recursively fetch all parameters by path. +* `get_parameters_by_name` to fetch distinct parameters by their full name. It also accepts custom caching, transform, decrypt per parameter. + +=== "get_parameters" + + ```python hl_lines="1 6" + from aws_lambda_powertools.utilities import parameters + + def handler(event, context): + # Retrieve multiple parameters from a path prefix recursively + # This returns a dict with the parameter name as key + values = parameters.get_parameters("/my/path/prefix") + for parameter, value in values.items(): + print(f"{parameter}: {value}") + ``` + +=== "get_parameters_by_name" + + ```python hl_lines="3 5 14" + from typing import Any + + from aws_lambda_powertools.utilities import get_parameters_by_name + + parameters = { + "/develop/service/commons/telemetry/config": {"max_age": 300, "transform": "json"}, + "/no_cache_param": {"max_age": 0}, + # inherit default values + "/develop/service/payment/api/capture/url": {}, + } + + def handler(event, context): + # This returns a dict with the parameter name as key + response: dict[str, Any] = parameters.get_parameters_by_name(parameters=parameters, max_age=60) + for parameter, value in response.items(): + print(f"{parameter}: {value}") + ``` + +???+ tip "`get_parameters_by_name` supports graceful error handling" + By default, we will raise `GetParameterError` when any parameter fails to be fetched. You can override it by setting `raise_on_error=False`. + + When disabled, we take the following actions: + + * Add failed parameter name in the `_errors` key, _e.g._, `{_errors: ["/param1", "/param2"]}` + * Keep only successful parameter names and their values in the response + * Raise `GetParameterError` if any of your parameters is named `_errors` + +```python hl_lines="3 5 12-13 15" title="Graceful error handling" +from typing import Any + +from aws_lambda_powertools.utilities import get_parameters_by_name + +parameters = { + "/develop/service/commons/telemetry/config": {"max_age": 300, "transform": "json"}, + # it would fail by default + "/this/param/does/not/exist" +} + +def handler(event, context): + values: dict[str, Any] = parameters.get_parameters_by_name(parameters=parameters, raise_on_error=False) + errors: list[str] = values.get("_errors", []) + + # Handle gracefully, since '/this/param/does/not/exist' will only be available in `_errors` + if errors: + ... + + for parameter, value in values.items(): + print(f"{parameter}: {value}") ``` ### Fetching secrets diff --git a/pyproject.toml b/pyproject.toml index f3616c322e9..9490e83e5d6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ aws-xray-sdk = { version = "^2.8.0", optional = true } fastjsonschema = { version = "^2.14.5", optional = true } pydantic = { version = "^1.8.2", optional = true } boto3 = { version = "^1.20.32", optional = true } +typing-extensions = "^4.4.0" [tool.poetry.dev-dependencies] coverage = {extras = ["toml"], version = "^6.2"} diff --git a/tests/e2e/parameters/handlers/parameter_ssm_get_parameters_by_name.py b/tests/e2e/parameters/handlers/parameter_ssm_get_parameters_by_name.py new file mode 100644 index 00000000000..948fad2aa12 --- /dev/null +++ b/tests/e2e/parameters/handlers/parameter_ssm_get_parameters_by_name.py @@ -0,0 +1,15 @@ +import json +import os +from typing import Any, Dict, List, cast + +from aws_lambda_powertools.utilities.parameters.ssm import get_parameters_by_name +from aws_lambda_powertools.utilities.typing import LambdaContext + +parameters_list: List[str] = cast(List, json.loads(os.getenv("parameters", ""))) + + +def lambda_handler(event: dict, context: LambdaContext) -> Dict[str, Any]: + parameters_to_fetch: Dict[str, Any] = {param: {} for param in parameters_list} + + # response`{parameter:value}` + return get_parameters_by_name(parameters=parameters_to_fetch, max_age=0) diff --git a/tests/e2e/parameters/infrastructure.py b/tests/e2e/parameters/infrastructure.py index d0fb1b6c60c..e2cd5101ba7 100644 --- a/tests/e2e/parameters/infrastructure.py +++ b/tests/e2e/parameters/infrastructure.py @@ -1,18 +1,38 @@ -from pyclbr import Function +import json +from typing import List -from aws_cdk import CfnOutput +from aws_cdk import CfnOutput, Duration from aws_cdk import aws_appconfig as appconfig from aws_cdk import aws_iam as iam +from aws_cdk import aws_ssm as ssm +from aws_cdk.aws_lambda import Function -from tests.e2e.utils.data_builder import build_service_name +from tests.e2e.utils.data_builder import build_random_value, build_service_name from tests.e2e.utils.infrastructure import BaseInfrastructure class ParametersStack(BaseInfrastructure): def create_resources(self): - functions = self.create_lambda_functions() + parameters = self._create_ssm_parameters() + + env_vars = {"parameters": json.dumps(parameters)} + functions = self.create_lambda_functions( + function_props={"environment": env_vars, "timeout": Duration.seconds(30)} + ) + self._create_app_config(function=functions["ParameterAppconfigFreeformHandler"]) + # NOTE: Enforce least-privilege for our param tests only + functions["ParameterSsmGetParametersByName"].add_to_role_policy( + iam.PolicyStatement( + effect=iam.Effect.ALLOW, + actions=[ + "ssm:GetParameter", + ], + resources=[f"arn:aws:ssm:{self.region}:{self.account_id}:parameter/powertools/e2e/parameters/*"], + ) + ) + def _create_app_config(self, function: Function): service_name = build_service_name() @@ -106,3 +126,16 @@ def _create_app_config_freeform( resources=["*"], ) ) + + def _create_ssm_parameters(self) -> List[str]: + parameters: List[str] = [] + + for _ in range(10): + param = f"/powertools/e2e/parameters/{build_random_value()}" + rand = build_random_value() + ssm.StringParameter(self.stack, f"param-{rand}", parameter_name=param, string_value=rand) + parameters.append(param) + + CfnOutput(self.stack, "ParametersNameList", value=json.dumps(parameters)) + + return parameters diff --git a/tests/e2e/parameters/test_ssm.py b/tests/e2e/parameters/test_ssm.py new file mode 100644 index 00000000000..7e9614f8ea0 --- /dev/null +++ b/tests/e2e/parameters/test_ssm.py @@ -0,0 +1,34 @@ +import json +from typing import Any, Dict, List + +import pytest + +from tests.e2e.utils import data_fetcher + + +@pytest.fixture +def ssm_get_parameters_by_name_fn_arn(infrastructure: dict) -> str: + return infrastructure.get("ParameterSsmGetParametersByNameArn", "") + + +@pytest.fixture +def parameters_list(infrastructure: dict) -> List[str]: + param_list = infrastructure.get("ParametersNameList", "[]") + return json.loads(param_list) + + +# +def test_get_parameters_by_name( + ssm_get_parameters_by_name_fn_arn: str, + parameters_list: str, +): + # GIVEN/WHEN + function_response, _ = data_fetcher.get_lambda_response(lambda_arn=ssm_get_parameters_by_name_fn_arn) + parameter_values: Dict[str, Any] = json.loads(function_response["Payload"].read().decode("utf-8")) + + # THEN + for param in parameters_list: + try: + assert parameter_values[param] is not None + except (KeyError, TypeError): + pytest.fail(f"Parameter {param} not found in response") diff --git a/tests/functional/test_utilities_parameters.py b/tests/functional/test_utilities_parameters.py index 123c2fdbcc2..c5e65c158be 100644 --- a/tests/functional/test_utilities_parameters.py +++ b/tests/functional/test_utilities_parameters.py @@ -1,10 +1,12 @@ +from __future__ import annotations + import base64 import json import random import string from datetime import datetime, timedelta from io import BytesIO -from typing import Dict +from typing import Any, Dict, List, Tuple import boto3 import pytest @@ -14,7 +16,12 @@ from botocore.response import StreamingBody from aws_lambda_powertools.utilities import parameters -from aws_lambda_powertools.utilities.parameters.base import BaseProvider, ExpirableValue +from aws_lambda_powertools.utilities.parameters.base import ( + TRANSFORM_METHOD_MAPPING, + BaseProvider, + ExpirableValue, +) +from aws_lambda_powertools.utilities.parameters.ssm import SSMProvider @pytest.fixture(scope="function") @@ -39,6 +46,29 @@ def config(): return Config(region_name="us-east-1") +def build_get_parameters_stub(params: Dict[str, Any], invalid_parameters: List[str] | None = None) -> Dict[str, List]: + invalid_parameters = invalid_parameters or [] + version = random.randrange(1, 1000) + return { + "Parameters": [ + { + "Name": param, + "Type": "String", + "Value": value, + "Version": version, + "Selector": f"{param}:{version}", + "SourceResult": "string", + "LastModifiedDate": datetime(2015, 1, 1), + "ARN": f"arn:aws:ssm:us-east-2:111122223333:parameter/{param.lstrip('/')}", + "DataType": "string", + } + for param, value in params.items() + if param not in invalid_parameters + ], + "InvalidParameters": invalid_parameters, # official SDK stub fails validation here, need to raise an issue + } + + def test_dynamodb_provider_get(mock_name, mock_value, config): """ Test DynamoDBProvider.get() with a non-cached value @@ -610,6 +640,169 @@ def test_ssm_provider_clear_cache(mock_name, mock_value, config): assert provider.store == {} +def test_ssm_provider_get_parameters_by_name_raise_on_failure(mock_name, mock_value, config): + # GIVEN two parameters are requested + provider = parameters.SSMProvider(config=config) + success = f"/dev/{mock_name}" + fail = f"/prod/{mock_name}" + + params = {success: {}, fail: {}} + param_names = list(params.keys()) + stub_params = {success: mock_value} + + expected_stub_response = build_get_parameters_stub(params=stub_params, invalid_parameters=[fail]) + expected_stub_params = {"Names": param_names} + + stubber = stub.Stubber(provider.client) + stubber.add_response("get_parameters", expected_stub_response, expected_stub_params) + stubber.activate() + + # WHEN one of them fails to be retrieved + # THEN raise GetParameterError + with pytest.raises(parameters.exceptions.GetParameterError, match=f"Failed to fetch parameters: .*{fail}.*"): + try: + provider.get_parameters_by_name(parameters=params) + stubber.assert_no_pending_responses() + finally: + stubber.deactivate() + + +def test_ssm_provider_get_parameters_by_name_do_not_raise_on_failure(mock_name, mock_value, config): + # GIVEN two parameters are requested + success = f"/dev/{mock_name}" + fail = f"/prod/{mock_name}" + params = {success: {}, fail: {}} + param_names = list(params.keys()) + stub_params = {success: mock_value} + + expected_stub_response = build_get_parameters_stub(params=stub_params, invalid_parameters=[fail]) + expected_stub_params = {"Names": param_names} + + provider = parameters.SSMProvider(config=config) + stubber = stub.Stubber(provider.client) + stubber.add_response("get_parameters", expected_stub_response, expected_stub_params) + stubber.activate() + + # WHEN one of them fails to be retrieved + try: + ret = provider.get_parameters_by_name(parameters=params, raise_on_error=False) + + # THEN there should be no error raised + # and failed ones available within "_errors" key + stubber.assert_no_pending_responses() + assert ret["_errors"] + assert len(ret["_errors"]) == 1 + assert fail not in ret + finally: + stubber.deactivate() + + +def test_ssm_provider_get_parameters_by_name_do_not_raise_on_failure_with_decrypt(mock_name, config): + # GIVEN one parameter requires decryption and an arbitrary SDK error occurs + param = f"/{mock_name}" + params = {param: {"decrypt": True}} + + provider = parameters.SSMProvider(config=config) + stubber = stub.Stubber(provider.client) + stubber.add_client_error("get_parameters", "InvalidKeyId") + stubber.activate() + + # WHEN fail-fast is disabled in get_parameters_by_name + try: + ret = provider.get_parameters_by_name(parameters=params, raise_on_error=False) + stubber.assert_no_pending_responses() + + # THEN there should be no error raised but added under `_errors` key + assert ret["_errors"] + assert len(ret["_errors"]) == 1 + assert param not in ret + finally: + stubber.deactivate() + + +def test_ssm_provider_get_parameters_by_name_do_not_raise_on_failure_batch_decrypt_combined( + mock_value, mock_version, config +): + # GIVEN three parameters are requested + # one requires decryption, two can be batched + # and an arbitrary SDK error is injected + fail = "/fail" + success = "/success" + decrypt_fail = "/fail/decrypt" + params = {decrypt_fail: {"decrypt": True}, success: {}, fail: {}} + + expected_stub_params = {"Names": [success, fail]} + expected_stub_response = build_get_parameters_stub( + params={fail: mock_value, success: mock_value}, invalid_parameters=[fail] + ) + + provider = parameters.SSMProvider(config=config) + stubber = stub.Stubber(provider.client) + stubber.add_client_error("get_parameter") + stubber.add_response("get_parameters", expected_stub_response, expected_stub_params) + stubber.activate() + + # WHEN fail-fast is disabled in get_parameters_by_name + # and only one parameter succeeds out of three + try: + ret = provider.get_parameters_by_name(parameters=params, raise_on_error=False) + + # THEN there should be no error raised + # successful params returned accordingly + # and failed ones available within "_errors" key + stubber.assert_no_pending_responses() + assert success in ret + assert ret["_errors"] + assert len(ret["_errors"]) == 2 + assert fail not in ret + assert decrypt_fail not in ret + finally: + stubber.deactivate() + + +def test_ssm_provider_get_parameters_by_name_raise_on_reserved_errors_key(mock_name, mock_value, config): + # GIVEN one of the parameters is named `_errors` + success = f"/dev/{mock_name}" + fail = "_errors" + + params = {success: {}, fail: {}} + provider = parameters.SSMProvider(config=config) + + # WHEN using get_parameters_by_name to fetch + # THEN raise GetParameterError + with pytest.raises(parameters.exceptions.GetParameterError, match="You cannot fetch a parameter named"): + provider.get_parameters_by_name(parameters=params, raise_on_error=False) + + +def test_ssm_provider_get_parameters_by_name_all_decrypt_should_use_get_parameters_api(mock_name, mock_value, config): + # GIVEN all parameters require decryption + param_a = f"/a/{mock_name}" + param_b = f"/b/{mock_name}" + fail = "/does_not_exist" # stub model doesn't support all-success yet + + all_params = {param_a: {}, param_b: {}, fail: {}} + all_params_names = list(all_params.keys()) + + expected_param_values = {param_a: mock_value, param_b: mock_value} + expected_stub_response = build_get_parameters_stub(params=expected_param_values, invalid_parameters=[fail]) + expected_stub_params = {"Names": all_params_names, "WithDecryption": True} + + provider = parameters.SSMProvider(config=config) + stubber = stub.Stubber(provider.client) + stubber.add_response("get_parameters", expected_stub_response, expected_stub_params) + stubber.activate() + + # WHEN get_parameters_by_name is called + # THEN we should only use GetParameters WithDecryption=true to prevent throttling + try: + ret = provider.get_parameters_by_name(parameters=all_params, decrypt=True, raise_on_error=False) + stubber.assert_no_pending_responses() + + assert ret is not None + finally: + stubber.deactivate() + + def test_dynamodb_provider_clear_cache(mock_name, mock_value, config): # GIVEN a provider is initialized with a cached value provider = parameters.DynamoDBProvider(table_name="test", config=config) @@ -1518,6 +1711,167 @@ def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]: assert value == mock_value +def test_get_parameters_by_name(monkeypatch, mock_name, mock_value, config): + params = {mock_name: {}} + + class TestProvider(SSMProvider): + def __init__(self, config: Config = config, **kwargs): + super().__init__(config, **kwargs) + + def get_parameters_by_name(self, *args, **kwargs) -> Dict[str, str] | Dict[str, bytes] | Dict[str, dict]: + return {mock_name: mock_value} + + monkeypatch.setitem(parameters.base.DEFAULT_PROVIDERS, "ssm", TestProvider()) + + values = parameters.get_parameters_by_name(parameters=params) + + assert len(values) == 1 + assert values[mock_name] == mock_value + + +def test_get_parameters_by_name_with_decrypt_override(monkeypatch, mock_name, mock_value, config): + # GIVEN 2 out of 3 parameters have decrypt override + decrypt_param = "/api_key" + decrypt_param_two = "/another/secret" + decrypt_params = {decrypt_param: {"decrypt": True}, decrypt_param_two: {"decrypt": True}} + decrypted_response = "decrypted" + params = {mock_name: {}, **decrypt_params} + + class TestProvider(SSMProvider): + def __init__(self, config: Config = config, **kwargs): + super().__init__(config, **kwargs) + + def _get(self, name: str, decrypt: bool = False, **sdk_options) -> str: + # THEN params with `decrypt` override should use GetParameter` (`_get`) + assert name in decrypt_params + assert decrypt + return decrypted_response + + def _get_parameters_by_name(self, *args, **kwargs) -> Tuple[Dict[str, Any], List[str]]: + return {mock_name: mock_value}, [] + + monkeypatch.setitem(parameters.base.DEFAULT_PROVIDERS, "ssm", TestProvider()) + + # WHEN get_parameters_by_name is called + values = parameters.get_parameters_by_name(parameters=params) + + # THEN all parameters should be merged in the response + assert len(values) == 3 + assert values[mock_name] == mock_value + assert values[decrypt_param] == decrypted_response + assert values[decrypt_param_two] == decrypted_response + + +def test_get_parameters_by_name_with_override_and_explicit_global(monkeypatch, mock_name, mock_value, config): + # GIVEN a parameter overrides a default setting + default_cache_period = 500 + params = {mock_name: {"max_age": 0}, "no-override": {}} + + class TestProvider(SSMProvider): + def __init__(self, config: Config = config, **kwargs): + super().__init__(config, **kwargs) + + # NOTE: By convention, we check at `_get_parameters_by_name` + # as that's right before we call SSM, and when options have been merged + # def _get_parameters_by_name(self, parameters: Dict[str, Dict], raise_on_error: bool = True) -> Dict[str, Any]: + def _get_parameters_by_name( + self, parameters: Dict[str, Dict], raise_on_error: bool = True, decrypt: bool = False + ) -> Tuple[Dict[str, Any], List[str]]: + # THEN max_age should use no_cache_param override + assert parameters[mock_name]["max_age"] == 0 + assert parameters["no-override"]["max_age"] == default_cache_period + + return {mock_name: mock_value}, [] + + monkeypatch.setitem(parameters.base.DEFAULT_PROVIDERS, "ssm", TestProvider()) + + # WHEN get_parameters_by_name is called with max_age set to 500 as the default + parameters.get_parameters_by_name(parameters=params, max_age=default_cache_period) + + +def test_get_parameters_by_name_with_max_batch(monkeypatch, config): + # GIVEN a batch of 20 parameters + params = {f"param_{i}": {} for i in range(20)} + + class TestProvider(SSMProvider): + def __init__(self, config: Config = config, **kwargs): + super().__init__(config, **kwargs) + + def _get_parameters_by_name( + self, parameters: Dict[str, Dict], raise_on_error: bool = True, decrypt: bool = False + ) -> Tuple[Dict[str, Any], List[str]]: + # THEN we should always split to respect GetParameters max + assert len(parameters) == self._MAX_GET_PARAMETERS_ITEM + return {}, [] + + monkeypatch.setitem(parameters.base.DEFAULT_PROVIDERS, "ssm", TestProvider()) + + # WHEN get_parameters_by_name is called + parameters.get_parameters_by_name(parameters=params) + + +def test_get_parameters_by_name_cache(monkeypatch, mock_name, mock_value, config): + # GIVEN we have a parameter to fetch but is already in cache + params = {mock_name: {}} + cache_key = (mock_name, None) + + class TestProvider(SSMProvider): + def __init__(self, config: Config = config, **kwargs): + super().__init__(config, **kwargs) + + def _get_parameters_by_name(self, *args, **kwargs) -> Tuple[Dict[str, Any], List[str]]: + raise RuntimeError("Should not be called if it's in cache") + + provider = TestProvider() + provider.add_to_cache(key=(mock_name, None), value=mock_value, max_age=10) + + monkeypatch.setitem(parameters.base.DEFAULT_PROVIDERS, "ssm", provider) + + # WHEN get_parameters_by_name is called + provider.get_parameters_by_name(parameters=params) + + # THEN the cache should be used and _get_parameters_by_name should not be called + assert provider.has_not_expired_in_cache(key=cache_key) + + +def test_get_parameters_by_name_empty_batch(monkeypatch, config): + # GIVEN we have an empty dictionary + params = {} + + class TestProvider(SSMProvider): + def __init__(self, config: Config = config, **kwargs): + super().__init__(config, **kwargs) + + monkeypatch.setitem(parameters.base.DEFAULT_PROVIDERS, "ssm", TestProvider()) + + # WHEN get_parameters_by_name is called + # THEN it should return an empty response + assert parameters.get_parameters_by_name(parameters=params) == {} + + +def test_get_parameters_by_name_cache_them_individually_not_batch(monkeypatch, mock_name, mock_version): + # GIVEN we have a parameter to fetch but is already in cache + dev_param = f"/dev/{mock_name}" + prod_param = f"/prod/{mock_name}" + params = {dev_param: {}, prod_param: {}} + + stub_params = {dev_param: mock_value, prod_param: mock_value} + stub_response = build_get_parameters_stub(params=stub_params) + + class FakeClient: + def get_parameters(self, *args, **kwargs): + return stub_response + + provider = SSMProvider(boto3_client=FakeClient()) + monkeypatch.setitem(parameters.base.DEFAULT_PROVIDERS, "ssm", provider) + + # WHEN get_parameters_by_name is called + provider.get_parameters_by_name(parameters=params) + + # THEN the cache should be populated with each parameter + assert len(provider.store) == len(params) + + def test_get_parameter_new(monkeypatch, mock_name, mock_value): """ Test get_parameter() without a default provider @@ -1584,6 +1938,27 @@ def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]: assert value == mock_value +def test_get_parameters_by_name_new(monkeypatch, mock_name, mock_value, config): + """ + Test get_parameters_by_name() without a default provider + """ + params = {mock_name: {}} + + class TestProvider(SSMProvider): + def __init__(self, config: Config = config, **kwargs): + super().__init__(config, **kwargs) + + def get_parameters_by_name(self, *args, **kwargs) -> Dict[str, str] | Dict[str, bytes] | Dict[str, dict]: + return {mock_name: mock_value} + + monkeypatch.setattr(parameters.ssm, "DEFAULT_PROVIDERS", {}) + monkeypatch.setattr(parameters.ssm, "SSMProvider", TestProvider) + + value = parameters.get_parameters_by_name(params) + + assert value[mock_name] == mock_value + + def test_get_secret(monkeypatch, mock_name, mock_value): """ Test get_secret() @@ -1810,6 +2185,50 @@ def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]: assert value == mock_value +def test_transform_value_auto(mock_value: str): + # GIVEN + json_data = json.dumps({"A": mock_value}) + mock_binary = mock_value.encode() + binary_data = base64.b64encode(mock_binary).decode() + + # WHEN + json_value = parameters.base.transform_value(key="/a.json", value=json_data, transform="auto") + binary_value = parameters.base.transform_value(key="/a.binary", value=binary_data, transform="auto") + + # THEN + assert isinstance(json_value, dict) + assert isinstance(binary_value, bytes) + assert json_value["A"] == mock_value + assert binary_value == mock_binary + + +def test_transform_value_auto_incorrect_key(mock_value: str): + # GIVEN + mock_key = "/missing/json/suffix" + json_data = json.dumps({"A": mock_value}) + + # WHEN + value = parameters.base.transform_value(key=mock_key, value=json_data, transform="auto") + + # THEN it should echo back its value + assert isinstance(value, str) + assert value == json_data + + +def test_transform_value_auto_unsupported_transform(mock_value: str): + # GIVEN + mock_key = "/a.does_not_exist" + mock_dict = {"hello": "world"} + + # WHEN + value = parameters.base.transform_value(key=mock_key, value=mock_value, transform="auto") + dict_value = parameters.base.transform_value(key=mock_key, value=mock_dict, transform="auto") + + # THEN it should echo back its value + assert value == mock_value + assert dict_value == mock_dict + + def test_transform_value_json(mock_value): """ Test transform_value() with a json transform @@ -1863,17 +2282,6 @@ def test_transform_value_binary_exception(): assert "Incorrect padding" in str(excinfo) -def test_transform_value_wrong(mock_value): - """ - Test transform_value() with an incorrect transform - """ - - with pytest.raises(parameters.TransformParameterError) as excinfo: - parameters.base.transform_value(mock_value, "INCORRECT") - - assert "Invalid transform type" in str(excinfo) - - def test_transform_value_ignore_error(mock_value): """ Test transform_value() does not raise errors when raise_on_transform_error is False @@ -1884,16 +2292,6 @@ def test_transform_value_ignore_error(mock_value): assert value is None -@pytest.mark.parametrize("original_transform", ["json", "binary", "other", "Auto", None]) -def test_get_transform_method_preserve_original(original_transform): - """ - Check if original transform method is returned for anything other than "auto" - """ - transform = parameters.base.get_transform_method("key", original_transform) - - assert transform == original_transform - - @pytest.mark.parametrize("extension", ["json", "binary"]) def test_get_transform_method_preserve_auto(extension, mock_name): """ @@ -1901,18 +2299,7 @@ def test_get_transform_method_preserve_auto(extension, mock_name): """ transform = parameters.base.get_transform_method(f"{mock_name}.{extension}", "auto") - assert transform == extension - - -@pytest.mark.parametrize("key", ["json", "binary", "example", "example.jsonp"]) -def test_get_transform_method_preserve_auto_unhandled(key): - """ - Check if any key that does not end with a supported extension returns None when - using the transform="auto" - """ - transform = parameters.base.get_transform_method(key, "auto") - - assert transform is None + assert transform == TRANSFORM_METHOD_MAPPING[extension] def test_base_provider_get_multiple_force_update(mock_name, mock_value): @@ -1958,3 +2345,18 @@ def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]: assert isinstance(value, str) assert value == mock_value + + +def test_cache_ignores_max_age_zero_or_negative(mock_value, config): + # GIVEN we have two parameters that shouldn't be cached + param = "/no_cache" + provider = SSMProvider(config=config) + cache_key = (param, None) + + # WHEN a provider adds them into the cache + provider.add_to_cache(key=cache_key, value=mock_value, max_age=0) + provider.add_to_cache(key=cache_key, value=mock_value, max_age=-10) + + # THEN they should not be added to the cache + assert len(provider.store) == 0 + assert provider.has_not_expired_in_cache(cache_key) is False