From fff76dbfe68602e79179106f914d4eee25a0f282 Mon Sep 17 00:00:00 2001 From: heitorlessa Date: Tue, 1 Nov 2022 18:12:07 +0100 Subject: [PATCH 01/28] chore(deps): add typing_extensions dep --- poetry.lock | 2 +- pyproject.toml | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/poetry.lock b/poetry.lock index 1ef3728253a..c5b06506d58 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1511,7 +1511,7 @@ validation = ["fastjsonschema"] [metadata] lock-version = "1.1" python-versions = "^3.7.4" -content-hash = "48a6c11b4ef71716e88efa7ffa474aa73fd7fcb02553ffd49c0d03fe72c1f838" +content-hash = "e5d33eff22737a00153da107e050d151e6405684d8c4c71f6b8c618152731eb7" [metadata.files] attrs = [ diff --git a/pyproject.toml b/pyproject.toml index cddceb2388d..f8e7b91325d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ aws-xray-sdk = { version = "^2.8.0", optional = true } fastjsonschema = { version = "^2.14.5", optional = true } pydantic = { version = "^1.8.2", optional = true } boto3 = { version = "^1.20.32", optional = true } +typing-extensions = "^4.4.0" [tool.poetry.dev-dependencies] coverage = {extras = ["toml"], version = "^6.2"} From d238286ed72e057b5d0bfac3f2939b068f1a8081 Mon Sep 17 00:00:00 2001 From: heitorlessa Date: Tue, 1 Nov 2022 18:12:29 +0100 Subject: [PATCH 02/28] feat(parameters): initial prototype for get_parameters_by_name --- .../utilities/parameters/ssm.py | 111 +++++++++++++++++- .../utilities/parameters/types.py | 3 + 2 files changed, 111 insertions(+), 3 deletions(-) create mode 100644 aws_lambda_powertools/utilities/parameters/types.py diff --git a/aws_lambda_powertools/utilities/parameters/ssm.py b/aws_lambda_powertools/utilities/parameters/ssm.py index 3b3e782fd45..84f417555e2 100644 --- a/aws_lambda_powertools/utilities/parameters/ssm.py +++ b/aws_lambda_powertools/utilities/parameters/ssm.py @@ -2,13 +2,14 @@ AWS SSM Parameter retrieval and caching utility """ - -from typing import TYPE_CHECKING, Any, Dict, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Union, overload import boto3 from botocore.config import Config +from typing_extensions import Literal from .base import DEFAULT_MAX_AGE_SECS, DEFAULT_PROVIDERS, BaseProvider +from .types import TransformOptions if TYPE_CHECKING: from mypy_boto3_ssm import SSMClient @@ -207,7 +208,7 @@ def get_parameter( force_fetch: bool = False, max_age: int = DEFAULT_MAX_AGE_SECS, **sdk_options -) -> Union[str, list, dict, bytes]: +) -> Union[str, dict, bytes]: """ Retrieve a parameter value from AWS Systems Manager (SSM) Parameter Store @@ -344,3 +345,107 @@ def get_parameters( force_fetch=force_fetch, **sdk_options ) + + +@overload +def get_parameters_by_name( + parameters: Dict[str, Dict], + transform: None = None, + decrypt: bool = False, + force_fetch: bool = False, + max_age: int = DEFAULT_MAX_AGE_SECS, +) -> Dict[str, str]: + ... + + +@overload +def get_parameters_by_name( + parameters: Dict[str, Dict], + transform: Literal["binary"], + decrypt: bool = False, + force_fetch: bool = False, + max_age: int = DEFAULT_MAX_AGE_SECS, +) -> Dict[str, bytes]: + ... + + +@overload +def get_parameters_by_name( + parameters: Dict[str, Dict], + transform: Literal["json"], + decrypt: bool = False, + force_fetch: bool = False, + max_age: int = DEFAULT_MAX_AGE_SECS, +) -> Dict[str, Dict[str, Any]]: + ... + + +@overload +def get_parameters_by_name( + parameters: Dict[str, Dict], + transform: Literal["auto"], + decrypt: bool = False, + force_fetch: bool = False, + max_age: int = DEFAULT_MAX_AGE_SECS, +) -> Union[Dict[str, str], Dict[str, dict]]: + ... + + +def get_parameters_by_name( + parameters: Dict[str, Any], + transform: TransformOptions = None, + decrypt: bool = False, + force_fetch: bool = False, + max_age: int = DEFAULT_MAX_AGE_SECS, +) -> Union[Dict[str, str], Dict[str, bytes], Dict[str, dict]]: + """ + Retrieve multiple parameter values by name from AWS Systems Manager (SSM) Parameter Store + + Parameters + ---------- + parameters: List[Dict[str, Dict]] + List of parameter names, and any optional overrides + transform: str, optional + Transforms the content from a JSON object ('json') or base64 binary string ('binary') + decrypt: bool, optional + If the parameter values should be decrypted + force_fetch: bool, optional + Force update even before a cached item has expired, defaults to False + max_age: int + Maximum age of the cached value + sdk_options: dict, optional + Dictionary of options that will be passed to the Parameter Store get_parameter API call + + Raises + ------ + GetParameterError + When the parameter provider fails to retrieve a parameter value for + a given name. + TransformParameterError + When the parameter provider fails to transform a parameter value. + """ + + # NOTE: Need a param for hard failure mode on parameter retrieval + # by default, we should return an empty string on failure (ask customer for desired behaviour) + + # NOTE: Check costs of threads to assess when it's worth the overhead. + # for threads, assess failure mode to absorb OR raise/cancel futures + + ret: Dict[str, Any] = {} + + for parameter, options in parameters.items(): + if isinstance(options, dict): + transform = options.get("transform") or transform + decrypt = options.get("decrypt") or decrypt + max_age = options.get("max_age") or max_age + force_fetch = options.get("force_fetch") or force_fetch + + ret[parameter] = get_parameter( + name=parameter, + transform=transform, + decrypt=decrypt, + max_age=max_age, + force_fetch=force_fetch, + ) + + return ret diff --git a/aws_lambda_powertools/utilities/parameters/types.py b/aws_lambda_powertools/utilities/parameters/types.py new file mode 100644 index 00000000000..6a15873c496 --- /dev/null +++ b/aws_lambda_powertools/utilities/parameters/types.py @@ -0,0 +1,3 @@ +from typing_extensions import Literal + +TransformOptions = Literal["json", "binary", "auto", None] From a22f7db87e7d5e44b1f60f45665462247904d3ae Mon Sep 17 00:00:00 2001 From: heitorlessa Date: Wed, 2 Nov 2022 10:12:25 +0100 Subject: [PATCH 03/28] feat(parameters): add multi-thread option --- .../utilities/parameters/ssm.py | 80 ++++++++++++++----- 1 file changed, 58 insertions(+), 22 deletions(-) diff --git a/aws_lambda_powertools/utilities/parameters/ssm.py b/aws_lambda_powertools/utilities/parameters/ssm.py index 84f417555e2..d1f2e5278e4 100644 --- a/aws_lambda_powertools/utilities/parameters/ssm.py +++ b/aws_lambda_powertools/utilities/parameters/ssm.py @@ -1,7 +1,9 @@ """ AWS SSM Parameter retrieval and caching utility """ - +import concurrent.futures +import functools +from concurrent.futures import Future from typing import TYPE_CHECKING, Any, Dict, Optional, Union, overload import boto3 @@ -107,7 +109,7 @@ def get( # type: ignore[override] transform: Optional[str] = None, decrypt: bool = False, force_fetch: bool = False, - **sdk_options + **sdk_options, ) -> Optional[Union[str, dict, bytes]]: """ Retrieve a parameter value or return the cached value @@ -207,7 +209,7 @@ def get_parameter( decrypt: bool = False, force_fetch: bool = False, max_age: int = DEFAULT_MAX_AGE_SECS, - **sdk_options + **sdk_options, ) -> Union[str, dict, bytes]: """ Retrieve a parameter value from AWS Systems Manager (SSM) Parameter Store @@ -276,7 +278,7 @@ def get_parameters( force_fetch: bool = False, max_age: int = DEFAULT_MAX_AGE_SECS, raise_on_transform_error: bool = False, - **sdk_options + **sdk_options, ) -> Union[Dict[str, str], Dict[str, dict], Dict[str, bytes]]: """ Retrieve multiple parameter values from AWS Systems Manager (SSM) Parameter Store @@ -343,7 +345,7 @@ def get_parameters( transform=transform, raise_on_transform_error=raise_on_transform_error, force_fetch=force_fetch, - **sdk_options + **sdk_options, ) @@ -354,6 +356,7 @@ def get_parameters_by_name( decrypt: bool = False, force_fetch: bool = False, max_age: int = DEFAULT_MAX_AGE_SECS, + parallel: bool = False, ) -> Dict[str, str]: ... @@ -365,6 +368,7 @@ def get_parameters_by_name( decrypt: bool = False, force_fetch: bool = False, max_age: int = DEFAULT_MAX_AGE_SECS, + parallel: bool = False, ) -> Dict[str, bytes]: ... @@ -376,6 +380,7 @@ def get_parameters_by_name( decrypt: bool = False, force_fetch: bool = False, max_age: int = DEFAULT_MAX_AGE_SECS, + parallel: bool = False, ) -> Dict[str, Dict[str, Any]]: ... @@ -387,6 +392,7 @@ def get_parameters_by_name( decrypt: bool = False, force_fetch: bool = False, max_age: int = DEFAULT_MAX_AGE_SECS, + parallel: bool = False, ) -> Union[Dict[str, str], Dict[str, dict]]: ... @@ -397,6 +403,7 @@ def get_parameters_by_name( decrypt: bool = False, force_fetch: bool = False, max_age: int = DEFAULT_MAX_AGE_SECS, + parallel: bool = True, ) -> Union[Dict[str, str], Dict[str, bytes], Dict[str, dict]]: """ Retrieve multiple parameter values by name from AWS Systems Manager (SSM) Parameter Store @@ -428,24 +435,53 @@ def get_parameters_by_name( # NOTE: Need a param for hard failure mode on parameter retrieval # by default, we should return an empty string on failure (ask customer for desired behaviour) - # NOTE: Check costs of threads to assess when it's worth the overhead. - # for threads, assess failure mode to absorb OR raise/cancel futures + # NOTE: Decide whether to leave multi-threaded option or not due to slower results (throttling+fork cost) ret: Dict[str, Any] = {} - - for parameter, options in parameters.items(): - if isinstance(options, dict): - transform = options.get("transform") or transform - decrypt = options.get("decrypt") or decrypt - max_age = options.get("max_age") or max_age - force_fetch = options.get("force_fetch") or force_fetch - - ret[parameter] = get_parameter( - name=parameter, - transform=transform, - decrypt=decrypt, - max_age=max_age, - force_fetch=force_fetch, - ) + future_to_param: Dict[Future, str] = {} + + if parallel: + with concurrent.futures.ThreadPoolExecutor(max_workers=len(parameters)) as pool: + for parameter, options in parameters.items(): + if isinstance(options, dict): + transform = options.get("transform") or transform + decrypt = options.get("decrypt") or decrypt + max_age = options.get("max_age") or max_age + force_fetch = options.get("force_fetch") or force_fetch + + fetch_parameter_callable = functools.partial( + get_parameter, + name=parameter, + transform=transform, + decrypt=decrypt, + max_age=max_age, + force_fetch=force_fetch, + ) + + future = pool.submit(fetch_parameter_callable) + future_to_param[future] = parameter + + for future in concurrent.futures.as_completed(future_to_param): + try: + # "parameter": "future result" + ret[future_to_param[future]] = future.result() + except Exception as exc: + print(f"Uh oh, failed to fetch '{future_to_param[future]}': {exc}") + + else: + for parameter, options in parameters.items(): + if isinstance(options, dict): + transform = options.get("transform") or transform + decrypt = options.get("decrypt") or decrypt + max_age = options.get("max_age") or max_age + force_fetch = options.get("force_fetch") or force_fetch + + ret[parameter] = get_parameter( + name=parameter, + transform=transform, + decrypt=decrypt, + max_age=max_age, + force_fetch=force_fetch, + ) return ret From e13bbeff270b422e4d7c488e27dc9438122f47a1 Mon Sep 17 00:00:00 2001 From: heitorlessa Date: Wed, 2 Nov 2022 14:22:40 +0100 Subject: [PATCH 04/28] docs(parameters): add get_parameters_by_name Signed-off-by: heitorlessa --- docs/utilities/parameters.md | 65 ++++++++++++++++++++++++++---------- 1 file changed, 48 insertions(+), 17 deletions(-) diff --git a/docs/utilities/parameters.md b/docs/utilities/parameters.md index 6b7d64b66b9..9bc0c71e2c8 100644 --- a/docs/utilities/parameters.md +++ b/docs/utilities/parameters.md @@ -24,36 +24,67 @@ This utility requires additional permissions to work as expected. ???+ note Different parameter providers require different permissions. -| Provider | Function/Method | IAM Permission | -| ------------------- | -----------------------------------------------------------------| -----------------------------------------------------------------------------| -| SSM Parameter Store | `get_parameter`, `SSMProvider.get` | `ssm:GetParameter` | -| SSM Parameter Store | `get_parameters`, `SSMProvider.get_multiple` | `ssm:GetParametersByPath` | -| SSM Parameter Store | If using `decrypt=True` | You must add an additional permission `kms:Decrypt` | -| Secrets Manager | `get_secret`, `SecretsManager.get` | `secretsmanager:GetSecretValue` | -| DynamoDB | `DynamoDBProvider.get` | `dynamodb:GetItem` | -| DynamoDB | `DynamoDBProvider.get_multiple` | `dynamodb:Query` | -| App Config | `get_app_config`, `AppConfigProvider.get_app_config` | `appconfig:GetLatestConfiguration` and `appconfig:StartConfigurationSession` | +| Provider | Function/Method | IAM Permission | +| ------------------- | ---------------------------------------------------- | ---------------------------------------------------------------------------- | +| SSM Parameter Store | `get_parameter`, `SSMProvider.get` | `ssm:GetParameter` | +| SSM Parameter Store | `get_parameters`, `SSMProvider.get_multiple` | `ssm:GetParametersByPath` | +| SSM Parameter Store | If using `decrypt=True` | You must add an additional permission `kms:Decrypt` | +| Secrets Manager | `get_secret`, `SecretsManager.get` | `secretsmanager:GetSecretValue` | +| DynamoDB | `DynamoDBProvider.get` | `dynamodb:GetItem` | +| DynamoDB | `DynamoDBProvider.get_multiple` | `dynamodb:Query` | +| App Config | `get_app_config`, `AppConfigProvider.get_app_config` | `appconfig:GetLatestConfiguration` and `appconfig:StartConfigurationSession` | ### Fetching parameters You can retrieve a single parameter using `get_parameter` high-level function. -For multiple parameters, you can use `get_parameters` and pass a path to retrieve them recursively. - -```python hl_lines="1 5 9" title="Fetching multiple parameters recursively" +```python hl_lines="5" title="Fetching a single parameter" from aws_lambda_powertools.utilities import parameters def handler(event, context): # Retrieve a single parameter value = parameters.get_parameter("/my/parameter") - # Retrieve multiple parameters from a path prefix recursively - # This returns a dict with the parameter name as key - values = parameters.get_parameters("/my/path/prefix") - for k, v in values.items(): - print(f"{k}: {v}") ``` +For multiple parameters, you can use either: + +* `get_parameters` to recursively fetch all parameters by path +* `get_parameters_by_name` to fetch distinct parameters by their full name. It also accepts custom caching, transform, and decrypt per parameter. + +=== "get_parameters" + + ```python hl_lines="1 6" + from aws_lambda_powertools.utilities import parameters + + def handler(event, context): + # Retrieve multiple parameters from a path prefix recursively + # This returns a dict with the parameter name as key + values = parameters.get_parameters("/my/path/prefix") + for parameter, value in values.items(): + print(f"{parameter}: {value}") + ``` + +=== "get_parameters_by_name" + + ```python hl_lines="1 3 13" + from aws_lambda_powertools.utilities import get_parameters_by_name + + parameters = { + "/develop/service/commons/telemetry/config": {"max_age": 300, "transform": "json"}, + "/develop/service/amplify/auth/userpool/arn": {"max_age": 300}, + # inherit default values + "/develop/service/payment/api/capture/url": {}, + "/develop/service/payment/api/charge/url": {}, + } + + def handler(event, context): + # This returns a dict with the parameter name as key + values = parameters.get_parameters_by_name(parameters=parameters, max_age=60) + for parameter, value in values.items(): + print(f"{parameter}: {value}") + ``` + ### Fetching secrets You can fetch secrets stored in Secrets Manager using `get_secrets`. From 9dc4ce645cc9f03e74632e4bbd4539d223475277 Mon Sep 17 00:00:00 2001 From: heitorlessa Date: Wed, 2 Nov 2022 17:16:39 +0100 Subject: [PATCH 05/28] chore(tests): add end-to-end test --- .../utilities/parameters/ssm.py | 8 ++-- .../parameter_ssm_get_parameters_by_name.py | 15 +++++++ tests/e2e/parameters/infrastructure.py | 41 +++++++++++++++++-- tests/e2e/parameters/test_ssm.py | 33 +++++++++++++++ 4 files changed, 88 insertions(+), 9 deletions(-) create mode 100644 tests/e2e/parameters/handlers/parameter_ssm_get_parameters_by_name.py create mode 100644 tests/e2e/parameters/test_ssm.py diff --git a/aws_lambda_powertools/utilities/parameters/ssm.py b/aws_lambda_powertools/utilities/parameters/ssm.py index d1f2e5278e4..8804a1d053b 100644 --- a/aws_lambda_powertools/utilities/parameters/ssm.py +++ b/aws_lambda_powertools/utilities/parameters/ssm.py @@ -403,7 +403,7 @@ def get_parameters_by_name( decrypt: bool = False, force_fetch: bool = False, max_age: int = DEFAULT_MAX_AGE_SECS, - parallel: bool = True, + parallel: bool = False, ) -> Union[Dict[str, str], Dict[str, bytes], Dict[str, dict]]: """ Retrieve multiple parameter values by name from AWS Systems Manager (SSM) Parameter Store @@ -432,10 +432,8 @@ def get_parameters_by_name( When the parameter provider fails to transform a parameter value. """ - # NOTE: Need a param for hard failure mode on parameter retrieval - # by default, we should return an empty string on failure (ask customer for desired behaviour) - - # NOTE: Decide whether to leave multi-threaded option or not due to slower results (throttling+fork cost) + # NOTE: Need a param for hard failure mode on parameter retrieval (asked feature request author) + # NOTE: Decide whether to leave multi-threaded option or not due to slower results (throttling+LWP cost) ret: Dict[str, Any] = {} future_to_param: Dict[Future, str] = {} diff --git a/tests/e2e/parameters/handlers/parameter_ssm_get_parameters_by_name.py b/tests/e2e/parameters/handlers/parameter_ssm_get_parameters_by_name.py new file mode 100644 index 00000000000..4d941bf93d3 --- /dev/null +++ b/tests/e2e/parameters/handlers/parameter_ssm_get_parameters_by_name.py @@ -0,0 +1,15 @@ +import json +import os +from typing import Any, Dict, List, cast + +from aws_lambda_powertools.utilities.parameters.ssm import get_parameters_by_name +from aws_lambda_powertools.utilities.typing import LambdaContext + +parameters_list: List[str] = cast(List, json.loads(os.getenv("parameters", ""))) + + +def lambda_handler(event: dict, context: LambdaContext) -> Dict[str, Any]: + parameters_to_fetch: Dict[str, Any] = {param: {} for param in parameters_list} + + # response`{parameter:value}` + return get_parameters_by_name(parameters=parameters_to_fetch, max_age=0, parallel=True) diff --git a/tests/e2e/parameters/infrastructure.py b/tests/e2e/parameters/infrastructure.py index d0fb1b6c60c..e2cd5101ba7 100644 --- a/tests/e2e/parameters/infrastructure.py +++ b/tests/e2e/parameters/infrastructure.py @@ -1,18 +1,38 @@ -from pyclbr import Function +import json +from typing import List -from aws_cdk import CfnOutput +from aws_cdk import CfnOutput, Duration from aws_cdk import aws_appconfig as appconfig from aws_cdk import aws_iam as iam +from aws_cdk import aws_ssm as ssm +from aws_cdk.aws_lambda import Function -from tests.e2e.utils.data_builder import build_service_name +from tests.e2e.utils.data_builder import build_random_value, build_service_name from tests.e2e.utils.infrastructure import BaseInfrastructure class ParametersStack(BaseInfrastructure): def create_resources(self): - functions = self.create_lambda_functions() + parameters = self._create_ssm_parameters() + + env_vars = {"parameters": json.dumps(parameters)} + functions = self.create_lambda_functions( + function_props={"environment": env_vars, "timeout": Duration.seconds(30)} + ) + self._create_app_config(function=functions["ParameterAppconfigFreeformHandler"]) + # NOTE: Enforce least-privilege for our param tests only + functions["ParameterSsmGetParametersByName"].add_to_role_policy( + iam.PolicyStatement( + effect=iam.Effect.ALLOW, + actions=[ + "ssm:GetParameter", + ], + resources=[f"arn:aws:ssm:{self.region}:{self.account_id}:parameter/powertools/e2e/parameters/*"], + ) + ) + def _create_app_config(self, function: Function): service_name = build_service_name() @@ -106,3 +126,16 @@ def _create_app_config_freeform( resources=["*"], ) ) + + def _create_ssm_parameters(self) -> List[str]: + parameters: List[str] = [] + + for _ in range(10): + param = f"/powertools/e2e/parameters/{build_random_value()}" + rand = build_random_value() + ssm.StringParameter(self.stack, f"param-{rand}", parameter_name=param, string_value=rand) + parameters.append(param) + + CfnOutput(self.stack, "ParametersNameList", value=json.dumps(parameters)) + + return parameters diff --git a/tests/e2e/parameters/test_ssm.py b/tests/e2e/parameters/test_ssm.py new file mode 100644 index 00000000000..fce3a4b458b --- /dev/null +++ b/tests/e2e/parameters/test_ssm.py @@ -0,0 +1,33 @@ +import json +from typing import Any, Dict, List + +import pytest + +from tests.e2e.utils import data_fetcher + + +@pytest.fixture +def ssm_get_parameters_by_name_fn_arn(infrastructure: dict) -> str: + return infrastructure.get("ParameterSsmGetParametersByNameArn", "") + + +@pytest.fixture +def parameters_list(infrastructure: dict) -> List[str]: + param_list = infrastructure.get("ParametersNameList", "[]") + return json.loads(param_list) + + +def test_get_parameter_appconfig_freeform( + ssm_get_parameters_by_name_fn_arn: str, + parameters_list: str, +): + # GIVEN/WHEN + function_response, _ = data_fetcher.get_lambda_response(lambda_arn=ssm_get_parameters_by_name_fn_arn) + parameter_values: Dict[str, Any] = json.loads(function_response["Payload"].read().decode("utf-8")) + + # THEN + for param in parameters_list: + try: + assert parameter_values[param] is not None + except (KeyError, TypeError): + pytest.fail(f"Parameter {param} not found in response") From 1c6c71c170d06ca35307be5af00c196f63eae499 Mon Sep 17 00:00:00 2001 From: heitorlessa Date: Wed, 2 Nov 2022 17:59:28 +0100 Subject: [PATCH 06/28] chore(parameters): remove parallel option due to timeout risk --- .../utilities/parameters/ssm.py | 74 +++++-------------- .../parameter_ssm_get_parameters_by_name.py | 2 +- 2 files changed, 18 insertions(+), 58 deletions(-) diff --git a/aws_lambda_powertools/utilities/parameters/ssm.py b/aws_lambda_powertools/utilities/parameters/ssm.py index 8804a1d053b..8f3a2bb6030 100644 --- a/aws_lambda_powertools/utilities/parameters/ssm.py +++ b/aws_lambda_powertools/utilities/parameters/ssm.py @@ -1,9 +1,6 @@ """ AWS SSM Parameter retrieval and caching utility """ -import concurrent.futures -import functools -from concurrent.futures import Future from typing import TYPE_CHECKING, Any, Dict, Optional, Union, overload import boto3 @@ -356,7 +353,6 @@ def get_parameters_by_name( decrypt: bool = False, force_fetch: bool = False, max_age: int = DEFAULT_MAX_AGE_SECS, - parallel: bool = False, ) -> Dict[str, str]: ... @@ -368,7 +364,6 @@ def get_parameters_by_name( decrypt: bool = False, force_fetch: bool = False, max_age: int = DEFAULT_MAX_AGE_SECS, - parallel: bool = False, ) -> Dict[str, bytes]: ... @@ -380,7 +375,6 @@ def get_parameters_by_name( decrypt: bool = False, force_fetch: bool = False, max_age: int = DEFAULT_MAX_AGE_SECS, - parallel: bool = False, ) -> Dict[str, Dict[str, Any]]: ... @@ -392,7 +386,6 @@ def get_parameters_by_name( decrypt: bool = False, force_fetch: bool = False, max_age: int = DEFAULT_MAX_AGE_SECS, - parallel: bool = False, ) -> Union[Dict[str, str], Dict[str, dict]]: ... @@ -403,7 +396,6 @@ def get_parameters_by_name( decrypt: bool = False, force_fetch: bool = False, max_age: int = DEFAULT_MAX_AGE_SECS, - parallel: bool = False, ) -> Union[Dict[str, str], Dict[str, bytes], Dict[str, dict]]: """ Retrieve multiple parameter values by name from AWS Systems Manager (SSM) Parameter Store @@ -420,8 +412,6 @@ def get_parameters_by_name( Force update even before a cached item has expired, defaults to False max_age: int Maximum age of the cached value - sdk_options: dict, optional - Dictionary of options that will be passed to the Parameter Store get_parameter API call Raises ------ @@ -432,54 +422,24 @@ def get_parameters_by_name( When the parameter provider fails to transform a parameter value. """ - # NOTE: Need a param for hard failure mode on parameter retrieval (asked feature request author) - # NOTE: Decide whether to leave multi-threaded option or not due to slower results (throttling+LWP cost) + # NOTE: Decided against using multi-thread due to single-thread outperforming in 128M and 1G + timeout risk + # see: https://github.com/awslabs/aws-lambda-powertools-python/issues/1040#issuecomment-1299954613 ret: Dict[str, Any] = {} - future_to_param: Dict[Future, str] = {} - - if parallel: - with concurrent.futures.ThreadPoolExecutor(max_workers=len(parameters)) as pool: - for parameter, options in parameters.items(): - if isinstance(options, dict): - transform = options.get("transform") or transform - decrypt = options.get("decrypt") or decrypt - max_age = options.get("max_age") or max_age - force_fetch = options.get("force_fetch") or force_fetch - - fetch_parameter_callable = functools.partial( - get_parameter, - name=parameter, - transform=transform, - decrypt=decrypt, - max_age=max_age, - force_fetch=force_fetch, - ) - - future = pool.submit(fetch_parameter_callable) - future_to_param[future] = parameter - - for future in concurrent.futures.as_completed(future_to_param): - try: - # "parameter": "future result" - ret[future_to_param[future]] = future.result() - except Exception as exc: - print(f"Uh oh, failed to fetch '{future_to_param[future]}': {exc}") - - else: - for parameter, options in parameters.items(): - if isinstance(options, dict): - transform = options.get("transform") or transform - decrypt = options.get("decrypt") or decrypt - max_age = options.get("max_age") or max_age - force_fetch = options.get("force_fetch") or force_fetch - - ret[parameter] = get_parameter( - name=parameter, - transform=transform, - decrypt=decrypt, - max_age=max_age, - force_fetch=force_fetch, - ) + + for parameter, options in parameters.items(): + if isinstance(options, dict): + transform = options.get("transform") or transform + decrypt = options.get("decrypt") or decrypt + max_age = options.get("max_age") or max_age + force_fetch = options.get("force_fetch") or force_fetch + + ret[parameter] = get_parameter( + name=parameter, + transform=transform, + decrypt=decrypt, + max_age=max_age, + force_fetch=force_fetch, + ) return ret diff --git a/tests/e2e/parameters/handlers/parameter_ssm_get_parameters_by_name.py b/tests/e2e/parameters/handlers/parameter_ssm_get_parameters_by_name.py index 4d941bf93d3..948fad2aa12 100644 --- a/tests/e2e/parameters/handlers/parameter_ssm_get_parameters_by_name.py +++ b/tests/e2e/parameters/handlers/parameter_ssm_get_parameters_by_name.py @@ -12,4 +12,4 @@ def lambda_handler(event: dict, context: LambdaContext) -> Dict[str, Any]: parameters_to_fetch: Dict[str, Any] = {param: {} for param in parameters_list} # response`{parameter:value}` - return get_parameters_by_name(parameters=parameters_to_fetch, max_age=0, parallel=True) + return get_parameters_by_name(parameters=parameters_to_fetch, max_age=0) From 215b2d72b6c1f4b910b4d28e5c675edf7473ec30 Mon Sep 17 00:00:00 2001 From: heitorlessa Date: Thu, 3 Nov 2022 10:41:30 +0100 Subject: [PATCH 07/28] refactor: strict typing transform_value/method --- .../utilities/feature_flags/appconfig.py | 4 +- .../utilities/parameters/appconfig.py | 4 +- .../utilities/parameters/base.py | 114 ++++++++++++------ .../utilities/parameters/ssm.py | 2 +- tests/functional/test_utilities_parameters.py | 40 +----- 5 files changed, 91 insertions(+), 73 deletions(-) diff --git a/aws_lambda_powertools/utilities/feature_flags/appconfig.py b/aws_lambda_powertools/utilities/feature_flags/appconfig.py index 8c8dbacc6c5..8695c1fd8c9 100644 --- a/aws_lambda_powertools/utilities/feature_flags/appconfig.py +++ b/aws_lambda_powertools/utilities/feature_flags/appconfig.py @@ -15,8 +15,6 @@ from .base import StoreProvider from .exceptions import ConfigurationStoreError, StoreClientError -TRANSFORM_TYPE = "json" - class AppConfigStore(StoreProvider): def __init__( @@ -74,7 +72,7 @@ def get_raw_configuration(self) -> Dict[str, Any]: dict, self._conf_store.get( name=self.name, - transform=TRANSFORM_TYPE, + transform="json", max_age=self.cache_seconds, ), ) diff --git a/aws_lambda_powertools/utilities/parameters/appconfig.py b/aws_lambda_powertools/utilities/parameters/appconfig.py index a3a340a62be..7884728024e 100644 --- a/aws_lambda_powertools/utilities/parameters/appconfig.py +++ b/aws_lambda_powertools/utilities/parameters/appconfig.py @@ -9,6 +9,8 @@ import boto3 from botocore.config import Config +from aws_lambda_powertools.utilities.parameters.types import TransformOptions + if TYPE_CHECKING: from mypy_boto3_appconfigdata import AppConfigDataClient @@ -132,7 +134,7 @@ def get_app_config( name: str, environment: str, application: Optional[str] = None, - transform: Optional[str] = None, + transform: TransformOptions = None, force_fetch: bool = False, max_age: int = DEFAULT_MAX_AGE_SECS, **sdk_options diff --git a/aws_lambda_powertools/utilities/parameters/base.py b/aws_lambda_powertools/utilities/parameters/base.py index b76b16e1dd8..44d4be1d88e 100644 --- a/aws_lambda_powertools/utilities/parameters/base.py +++ b/aws_lambda_powertools/utilities/parameters/base.py @@ -7,11 +7,24 @@ from abc import ABC, abstractmethod from collections import namedtuple from datetime import datetime, timedelta -from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Type, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Optional, + Tuple, + Type, + Union, + cast, + overload, +) import boto3 from botocore.config import Config +from aws_lambda_powertools.utilities.parameters.types import TransformOptions + from .exceptions import GetParameterError, TransformParameterError if TYPE_CHECKING: @@ -30,6 +43,14 @@ SUPPORTED_TRANSFORM_METHODS = [TRANSFORM_METHOD_JSON, TRANSFORM_METHOD_BINARY] ParameterClients = Union["AppConfigDataClient", "SecretsManagerClient", "SSMClient"] +TRANSFORM_METHOD_MAPPING = { + TRANSFORM_METHOD_JSON: json.loads, + TRANSFORM_METHOD_BINARY: base64.b64decode, + ".json": json.loads, + ".binary": base64.b64decode, + None: lambda x: x, +} + class BaseProvider(ABC): """ @@ -52,7 +73,7 @@ def get( self, name: str, max_age: int = DEFAULT_MAX_AGE_SECS, - transform: Optional[str] = None, + transform: TransformOptions = None, force_fetch: bool = False, **sdk_options, ) -> Optional[Union[str, dict, bytes]]: @@ -107,7 +128,7 @@ def get( if transform: if isinstance(value, bytes): value = value.decode("utf-8") - value = transform_value(value, transform) + value = transform_value(value, transform, raise_on_transform_error=True) self.store[key] = ExpirableValue(value, datetime.now() + timedelta(seconds=max_age)) @@ -124,7 +145,7 @@ def get_multiple( self, path: str, max_age: int = DEFAULT_MAX_AGE_SECS, - transform: Optional[str] = None, + transform: TransformOptions = None, raise_on_transform_error: bool = False, force_fetch: bool = False, **sdk_options, @@ -170,13 +191,8 @@ def get_multiple( raise GetParameterError(str(exc)) if transform: - transformed_values: dict = {} - for (item, value) in values.items(): - _transform = get_transform_method(item, transform) - if not _transform: - continue - transformed_values[item] = transform_value(value, _transform, raise_on_transform_error) - values.update(transformed_values) + values.update(transform_value(values, transform, raise_on_transform_error)) + self.store[key] = ExpirableValue(values, datetime.now() + timedelta(seconds=max_age)) return values @@ -258,7 +274,7 @@ def _build_boto3_resource_client( return session.resource(service_name=service_name, config=config, endpoint_url=endpoint_url) -def get_transform_method(key: str, transform: Optional[str] = None) -> Optional[str]: +def get_transform_method(key: str, transform: TransformOptions = None) -> Callable[..., Any]: """ Determine the transform method @@ -278,37 +294,50 @@ def get_transform_method(key: str, transform: Optional[str] = None) -> Optional[ Parameters --------- key: str - Only used when the tranform is "auto". + Only used when the transform is "auto". transform: str, optional Original transform method, only "auto" will try to detect the transform method by the key Returns ------ - Optional[str]: - The transform method either when transform is "auto" then None, "json" or "binary" is returned - or the original transform method + Callable: + Transform function could be json.loads, base64.b64decode, or a lambda that echo the str value """ - if transform != "auto": - return transform + transform_method = TRANSFORM_METHOD_MAPPING.get(transform) + + if transform == "auto": + key_suffix = key.rsplit(".")[-1] + transform_method = TRANSFORM_METHOD_MAPPING.get(key_suffix, TRANSFORM_METHOD_MAPPING[None]) - for transform_method in SUPPORTED_TRANSFORM_METHODS: - if key.endswith("." + transform_method): - return transform_method - return None + return cast(Callable, transform_method) # https://github.com/python/mypy/issues/10740 +@overload def transform_value( - value: str, transform: str, raise_on_transform_error: Optional[bool] = True -) -> Optional[Union[dict, bytes]]: + value: Dict[str, Any], transform: TransformOptions, raise_on_transform_error: bool = False +) -> Dict[str, Any]: + ... + + +@overload +def transform_value( + value: Union[str, bytes, Dict[str, Any]], transform: TransformOptions, raise_on_transform_error: bool = False +) -> Optional[Union[str, bytes, Dict[str, Any]]]: + ... + + +def transform_value( + value: Union[str, bytes, Dict[str, Any]], transform: TransformOptions, raise_on_transform_error: bool = True +) -> Optional[Union[str, bytes, Dict[str, Any]]]: """ - Apply a transform to a value + Transform a value using one of the available options. Parameters --------- value: str Parameter value to transform transform: str - Type of transform, supported values are "json" and "binary" + Type of transform, supported values are "json", "binary", and "auto" based on suffix (.json, .binary) raise_on_transform_error: bool, optional Raises an exception if any transform fails, otherwise this will return a None value for each transform that failed @@ -318,18 +347,35 @@ def transform_value( TransformParameterError: When the parameter value could not be transformed """ + # Maintenance: For v3, we should consider returning the original value for soft transform failures. - try: - if transform == TRANSFORM_METHOD_JSON: - return json.loads(value) - elif transform == TRANSFORM_METHOD_BINARY: - return base64.b64decode(value) - else: - raise ValueError(f"Invalid transform type '{transform}'") + err_msg = "Unable to transform value using '{transform}' transform: {exc}" + + if isinstance(value, bytes): + value = value.decode("utf-8") + if isinstance(value, dict): + # NOTE: We must handle partial failures when receiving multiple values + # where one of the keys might fail during transform, e.g. `{"a": "valid", "b": "{"}` + # expected: `{"a": "valid", "b": None}` + + transformed_values: Dict[str, Any] = {} + for dict_key, dict_value in value.items(): + transform_method = get_transform_method(key=dict_key, transform=transform) + try: + transformed_values[dict_key] = transform_method(dict_value) + except Exception as exc: + if raise_on_transform_error: + raise TransformParameterError(err_msg.format(transform=transform, exc=exc)) from exc + transformed_values[dict_key] = None + return transformed_values + + try: + transform_method = get_transform_method(key=value, transform=transform) + return transform_method(value) except Exception as exc: if raise_on_transform_error: - raise TransformParameterError(str(exc)) + raise TransformParameterError(err_msg.format(transform=transform, exc=exc)) from exc return None diff --git a/aws_lambda_powertools/utilities/parameters/ssm.py b/aws_lambda_powertools/utilities/parameters/ssm.py index 8f3a2bb6030..567d470f3c5 100644 --- a/aws_lambda_powertools/utilities/parameters/ssm.py +++ b/aws_lambda_powertools/utilities/parameters/ssm.py @@ -103,7 +103,7 @@ def get( # type: ignore[override] self, name: str, max_age: int = DEFAULT_MAX_AGE_SECS, - transform: Optional[str] = None, + transform: TransformOptions = None, decrypt: bool = False, force_fetch: bool = False, **sdk_options, diff --git a/tests/functional/test_utilities_parameters.py b/tests/functional/test_utilities_parameters.py index 123c2fdbcc2..ff9a2767183 100644 --- a/tests/functional/test_utilities_parameters.py +++ b/tests/functional/test_utilities_parameters.py @@ -14,7 +14,11 @@ from botocore.response import StreamingBody from aws_lambda_powertools.utilities import parameters -from aws_lambda_powertools.utilities.parameters.base import BaseProvider, ExpirableValue +from aws_lambda_powertools.utilities.parameters.base import ( + TRANSFORM_METHOD_MAPPING, + BaseProvider, + ExpirableValue, +) @pytest.fixture(scope="function") @@ -1863,17 +1867,6 @@ def test_transform_value_binary_exception(): assert "Incorrect padding" in str(excinfo) -def test_transform_value_wrong(mock_value): - """ - Test transform_value() with an incorrect transform - """ - - with pytest.raises(parameters.TransformParameterError) as excinfo: - parameters.base.transform_value(mock_value, "INCORRECT") - - assert "Invalid transform type" in str(excinfo) - - def test_transform_value_ignore_error(mock_value): """ Test transform_value() does not raise errors when raise_on_transform_error is False @@ -1884,16 +1877,6 @@ def test_transform_value_ignore_error(mock_value): assert value is None -@pytest.mark.parametrize("original_transform", ["json", "binary", "other", "Auto", None]) -def test_get_transform_method_preserve_original(original_transform): - """ - Check if original transform method is returned for anything other than "auto" - """ - transform = parameters.base.get_transform_method("key", original_transform) - - assert transform == original_transform - - @pytest.mark.parametrize("extension", ["json", "binary"]) def test_get_transform_method_preserve_auto(extension, mock_name): """ @@ -1901,18 +1884,7 @@ def test_get_transform_method_preserve_auto(extension, mock_name): """ transform = parameters.base.get_transform_method(f"{mock_name}.{extension}", "auto") - assert transform == extension - - -@pytest.mark.parametrize("key", ["json", "binary", "example", "example.jsonp"]) -def test_get_transform_method_preserve_auto_unhandled(key): - """ - Check if any key that does not end with a supported extension returns None when - using the transform="auto" - """ - transform = parameters.base.get_transform_method(key, "auto") - - assert transform is None + assert transform == TRANSFORM_METHOD_MAPPING[extension] def test_base_provider_get_multiple_force_update(mock_name, mock_value): From 56df26f5d65d91fcb27ee01311f09ccc1c0c180e Mon Sep 17 00:00:00 2001 From: heitorlessa Date: Thu, 3 Nov 2022 18:00:37 +0100 Subject: [PATCH 08/28] refactor: move to GetParameters, use GetParameter upon decrypt --- aws_lambda_powertools/shared/functions.py | 11 + .../utilities/parameters/base.py | 23 +- .../utilities/parameters/ssm.py | 197 +++++++++++++++--- 3 files changed, 199 insertions(+), 32 deletions(-) diff --git a/aws_lambda_powertools/shared/functions.py b/aws_lambda_powertools/shared/functions.py index fb4eedb7f36..6ab9d1442e5 100644 --- a/aws_lambda_powertools/shared/functions.py +++ b/aws_lambda_powertools/shared/functions.py @@ -1,4 +1,5 @@ import base64 +import itertools import logging import os import warnings @@ -115,3 +116,13 @@ def powertools_debug_is_set() -> bool: return True return False + + +def slice_dictionary(data, chunk_size: int): + # save CPU cycles if input is already small than chunk_size + if len(data) <= chunk_size: + yield data + + data_iterator = iter(data) # we don't know how big this is + for _ in range(0, len(data), chunk_size): + yield {dict_key: data[dict_key] for dict_key in itertools.islice(data_iterator, chunk_size)} diff --git a/aws_lambda_powertools/utilities/parameters/base.py b/aws_lambda_powertools/utilities/parameters/base.py index 44d4be1d88e..a277be772cd 100644 --- a/aws_lambda_powertools/utilities/parameters/base.py +++ b/aws_lambda_powertools/utilities/parameters/base.py @@ -5,13 +5,13 @@ import base64 import json from abc import ABC, abstractmethod -from collections import namedtuple from datetime import datetime, timedelta from typing import ( TYPE_CHECKING, Any, Callable, Dict, + NamedTuple, Optional, Tuple, Type, @@ -35,7 +35,6 @@ DEFAULT_MAX_AGE_SECS = 5 -ExpirableValue = namedtuple("ExpirableValue", ["value", "ttl"]) # These providers will be dynamically initialized on first use of the helper functions DEFAULT_PROVIDERS: Dict[str, Any] = {} TRANSFORM_METHOD_JSON = "json" @@ -52,21 +51,26 @@ } +class ExpirableValue(NamedTuple): + value: Union[str, bytes, Dict[str, Any]] + ttl: datetime + + class BaseProvider(ABC): """ Abstract Base Class for Parameter providers """ - store: Any = None + store: Dict[Tuple[str, TransformOptions], ExpirableValue] def __init__(self): """ Initialize the base provider """ - self.store = {} + self.store: Dict[Tuple[str, TransformOptions], ExpirableValue] = {} - def _has_not_expired(self, key: Tuple[str, Optional[str]]) -> bool: + def _has_not_expired(self, key: Tuple[str, TransformOptions]) -> bool: return key in self.store and self.store[key].ttl >= datetime.now() def get( @@ -130,7 +134,9 @@ def get( value = value.decode("utf-8") value = transform_value(value, transform, raise_on_transform_error=True) - self.store[key] = ExpirableValue(value, datetime.now() + timedelta(seconds=max_age)) + # NOTE: don't cache None, as they might've been failed transforms and may be corrected + if value is not None: + self.store[key] = ExpirableValue(value, datetime.now() + timedelta(seconds=max_age)) return value @@ -182,7 +188,7 @@ def get_multiple( key = (path, transform) if not force_fetch and self._has_not_expired(key): - return self.store[key].value + return self.store[key].value # type: ignore # need to revisit entire typing here try: values = self._get_multiple(path, **sdk_options) @@ -207,6 +213,9 @@ def _get_multiple(self, path: str, **sdk_options) -> Dict[str, str]: def clear_cache(self): self.store.clear() + def _add_to_cache(self, key: Tuple[str, TransformOptions], value: Any, max_age: int): + self.store[key] = ExpirableValue(value, datetime.now() + timedelta(seconds=max_age)) + @staticmethod def _build_boto3_client( service_name: str, diff --git a/aws_lambda_powertools/utilities/parameters/ssm.py b/aws_lambda_powertools/utilities/parameters/ssm.py index 567d470f3c5..99375ef2710 100644 --- a/aws_lambda_powertools/utilities/parameters/ssm.py +++ b/aws_lambda_powertools/utilities/parameters/ssm.py @@ -1,13 +1,16 @@ """ AWS SSM Parameter retrieval and caching utility """ -from typing import TYPE_CHECKING, Any, Dict, Optional, Union, overload +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union, overload import boto3 from botocore.config import Config from typing_extensions import Literal +from aws_lambda_powertools.shared.functions import slice_dictionary + from .base import DEFAULT_MAX_AGE_SECS, DEFAULT_PROVIDERS, BaseProvider +from .exceptions import GetParameterError from .types import TransformOptions if TYPE_CHECKING: @@ -80,6 +83,7 @@ class SSMProvider(BaseProvider): """ client: Any = None + _MAX_GET_PARAMETERS_ITEM = 10 def __init__( self, @@ -199,6 +203,162 @@ def _get_multiple(self, path: str, decrypt: bool = False, recursive: bool = Fals return parameters + def get_parameters_by_name( + self, + parameters: Dict[str, Dict], + transform: TransformOptions = None, + decrypt: bool = False, + max_age: int = DEFAULT_MAX_AGE_SECS, + raise_on_failure: bool = True, + ) -> Union[Dict[str, str], Dict[str, bytes], Dict[str, dict]]: + """ + Retrieve multiple parameter values by name from SSM or cache. + + Parameters + ---------- + parameters: List[Dict[str, Dict]] + List of parameter names, and any optional overrides + transform: str, optional + Transforms the content from a JSON object ('json') or base64 binary string ('binary') + decrypt: bool, optional + If the parameter values should be decrypted + max_age: int + Maximum age of the cached value + + Raises + ------ + GetParameterError + When the parameter provider fails to retrieve a parameter value for + a given name. + """ + + ret: Dict[str, Any] = {} + + # Tasks: + # 1. [DONE] Move to GetParameters + # 2. [DONE] Slice parameters in 10 if more than 10 + # 3. [DONE] Split batch and decrypt parameters + # 4. [DONE] Use GetParameters for batch parameters + # 5. [DONE] Cache successful ones individually as they might have overrides + # 6. [DONE] Use GetParameter for those using `decrypt` + # 7. [DONE] Introduce raise_on_error + # 8. [DONE] Return from cache + # 9. [DONE] Migrate high-level function get_parameters_by_name to use new class get_parameters_by_name + # 10. Handle soft error with "_errors" key upon raise_on_error being False + # 11. Include raise_on_transform in inner functions too + + batch_params, decrypt_params = self._split_batch_and_decrypt_parameters(parameters, transform, max_age, decrypt) + + # Decided for single-thread as it outperforms in 128M and 1G + reduce timeout risk + # see: https://github.com/awslabs/aws-lambda-powertools-python/issues/1040#issuecomment-1299954613 + for parameter, options in decrypt_params.items(): + ret[parameter] = self.get( + parameter, max_age=options["max_age"], transform=options["transform"], decrypt=options["decrypt"] + ) + + # Merge both batched parameters and those that required decryption + return {**self._get_parameters_from_batch(batch=batch_params, raise_on_failure=raise_on_failure), **ret} + + def _get_parameters_from_batch(self, batch: Dict[str, Dict], raise_on_failure: bool = True) -> Dict[str, Any]: + ret: Dict[str, Any] = {} + + # Check if it's in cache first to prevent unnecessary calls + # also confirm whether the incoming batch matches our cached + for name, options in batch.items(): + cache_key = (name, options["transform"]) + if self._has_not_expired(cache_key): + ret[name] = self.store[cache_key].value + + if len(ret) == len(batch): + return ret + + # Take out the differences to prevent over-fetching + # since there could be parameters with distinct max_age override + batch_diff = {key: value for key, value in batch.items() if key not in ret} + + for chunk in slice_dictionary(data=batch_diff, chunk_size=self._MAX_GET_PARAMETERS_ITEM): + ret.update(**self._get_parameters_by_name(parameters=chunk, raise_on_failure=raise_on_failure)) + + return ret + + def _get_parameters_by_name(self, parameters: Dict[str, Dict], raise_on_failure: bool = True) -> Dict[str, Any]: + """Use SSM GetParameters to fetch parameters, hydrate cache, and handle partial failure + + Parameters + ---------- + parameters : Dict[str, Dict] + Parameters to fetch + raise_on_failure : bool, optional + Whether to fail-fast or fail gracefully by including "_errors" key in the response, by default True + + Returns + ------- + Dict[str, Any] + Retrieved parameters as key names and their values + + Raises + ------ + GetParameterError + When one or more parameters failed on fetching, and raise_on_failure is enabled + """ + ret = {} + response = self.client.get_parameters(Names=list(parameters.keys())) + if response["InvalidParameters"] and raise_on_failure: + raise GetParameterError(f"Failed to fetch parameters: {response['InvalidParameters']}") + + # Built up cache_key, hydrate cache, and return `{name:value}` + for parameter in response["Parameters"]: + name = parameter["Name"] + value = parameter["Value"] + options = parameters[name] + + _cache_key = (name, options["transform"]) + self._add_to_cache(key=_cache_key, value=value, max_age=options["max_age"]) + + ret[name] = value + + return ret + + @staticmethod + def _split_batch_and_decrypt_parameters( + parameters: Dict[str, Dict], transform: TransformOptions, max_age: int, decrypt: bool + ) -> Tuple[Dict[str, Dict], Dict[str, Dict]]: + """Split parameters that can be fetched by GetParameters vs GetParameter + + Parameters + ---------- + parameters : Dict[str, Dict] + Parameters containing names as key and optional config override as value + transform : TransformOptions + Transform configuration + max_age : int + How long to cache a parameter for + decrypt : bool + Whether to use KMS to decrypt a parameter + + Returns + ------- + Tuple[Dict[str, Dict], Dict[str, Dict]] + GetParameters and GetParameter parameters dict along with their overrides/globals merged + """ + batch_parameters: Dict[str, Dict] = {} + decrypt_parameters: Dict[str, Any] = {} + + for parameter, options in parameters.items(): + # NOTE: TypeDict later + _overrides = options or {} + _overrides["transform"] = _overrides.get("transform") or transform + _overrides["decrypt"] = _overrides.get("decrypt") or decrypt + _overrides["max_age"] = _overrides.get("max_age") or max_age + + # NOTE: Split parameters who have decrypt OR have it global + if _overrides["decrypt"]: + decrypt_parameters[parameter] = _overrides + else: + batch_parameters[parameter] = _overrides + + return batch_parameters, decrypt_parameters + def get_parameter( name: str, @@ -351,8 +511,8 @@ def get_parameters_by_name( parameters: Dict[str, Dict], transform: None = None, decrypt: bool = False, - force_fetch: bool = False, max_age: int = DEFAULT_MAX_AGE_SECS, + raise_on_failure: bool = True, ) -> Dict[str, str]: ... @@ -362,8 +522,8 @@ def get_parameters_by_name( parameters: Dict[str, Dict], transform: Literal["binary"], decrypt: bool = False, - force_fetch: bool = False, max_age: int = DEFAULT_MAX_AGE_SECS, + raise_on_failure: bool = True, ) -> Dict[str, bytes]: ... @@ -373,8 +533,8 @@ def get_parameters_by_name( parameters: Dict[str, Dict], transform: Literal["json"], decrypt: bool = False, - force_fetch: bool = False, max_age: int = DEFAULT_MAX_AGE_SECS, + raise_on_failure: bool = True, ) -> Dict[str, Dict[str, Any]]: ... @@ -384,8 +544,8 @@ def get_parameters_by_name( parameters: Dict[str, Dict], transform: Literal["auto"], decrypt: bool = False, - force_fetch: bool = False, max_age: int = DEFAULT_MAX_AGE_SECS, + raise_on_failure: bool = True, ) -> Union[Dict[str, str], Dict[str, dict]]: ... @@ -394,8 +554,8 @@ def get_parameters_by_name( parameters: Dict[str, Any], transform: TransformOptions = None, decrypt: bool = False, - force_fetch: bool = False, max_age: int = DEFAULT_MAX_AGE_SECS, + raise_on_failure: bool = True, ) -> Union[Dict[str, str], Dict[str, bytes], Dict[str, dict]]: """ Retrieve multiple parameter values by name from AWS Systems Manager (SSM) Parameter Store @@ -408,8 +568,6 @@ def get_parameters_by_name( Transforms the content from a JSON object ('json') or base64 binary string ('binary') decrypt: bool, optional If the parameter values should be decrypted - force_fetch: bool, optional - Force update even before a cached item has expired, defaults to False max_age: int Maximum age of the cached value @@ -425,21 +583,10 @@ def get_parameters_by_name( # NOTE: Decided against using multi-thread due to single-thread outperforming in 128M and 1G + timeout risk # see: https://github.com/awslabs/aws-lambda-powertools-python/issues/1040#issuecomment-1299954613 - ret: Dict[str, Any] = {} - - for parameter, options in parameters.items(): - if isinstance(options, dict): - transform = options.get("transform") or transform - decrypt = options.get("decrypt") or decrypt - max_age = options.get("max_age") or max_age - force_fetch = options.get("force_fetch") or force_fetch - - ret[parameter] = get_parameter( - name=parameter, - transform=transform, - decrypt=decrypt, - max_age=max_age, - force_fetch=force_fetch, - ) + # Only create the provider if this function is called at least once + if "ssm" not in DEFAULT_PROVIDERS: + DEFAULT_PROVIDERS["ssm"] = SSMProvider() - return ret + return DEFAULT_PROVIDERS["ssm"].get_parameters_by_name( + parameters=parameters, max_age=max_age, transform=transform, decrypt=decrypt, raise_on_failure=raise_on_failure + ) From 843ee9eedf8a6f81d0c2026ee5da6df63c7e3e0c Mon Sep 17 00:00:00 2001 From: heitorlessa Date: Fri, 4 Nov 2022 09:42:17 +0100 Subject: [PATCH 09/28] fix(parameters): transform_value auto should work for both single and multiple params --- .../utilities/parameters/base.py | 47 ++++++++++++------- .../utilities/parameters/ssm.py | 40 +++++++++------- tests/e2e/parameters/test_ssm.py | 3 +- tests/functional/test_utilities_parameters.py | 44 +++++++++++++++++ 4 files changed, 101 insertions(+), 33 deletions(-) diff --git a/aws_lambda_powertools/utilities/parameters/base.py b/aws_lambda_powertools/utilities/parameters/base.py index a277be772cd..dbd19d9861f 100644 --- a/aws_lambda_powertools/utilities/parameters/base.py +++ b/aws_lambda_powertools/utilities/parameters/base.py @@ -130,9 +130,7 @@ def get( raise GetParameterError(str(exc)) if transform: - if isinstance(value, bytes): - value = value.decode("utf-8") - value = transform_value(value, transform, raise_on_transform_error=True) + value = transform_value(key=name, value=value, transform=transform, raise_on_transform_error=True) # NOTE: don't cache None, as they might've been failed transforms and may be corrected if value is not None: @@ -283,26 +281,26 @@ def _build_boto3_resource_client( return session.resource(service_name=service_name, config=config, endpoint_url=endpoint_url) -def get_transform_method(key: str, transform: TransformOptions = None) -> Callable[..., Any]: +def get_transform_method(value: str, transform: TransformOptions = None) -> Callable[..., Any]: """ Determine the transform method Examples ------- - >>> get_transform_method("key", "any_other_value") + >>> get_transform_method("key","any_other_value") 'any_other_value' - >>> get_transform_method("key.json", "auto") + >>> get_transform_method("key.json","auto") 'json' - >>> get_transform_method("key.binary", "auto") + >>> get_transform_method("key.binary","auto") 'binary' - >>> get_transform_method("key", "auto") + >>> get_transform_method("key","auto") None - >>> get_transform_method("key", None) + >>> get_transform_method("key",None) None Parameters --------- - key: str + value: str Only used when the transform is "auto". transform: str, optional Original transform method, only "auto" will try to detect the transform method by the key @@ -315,7 +313,7 @@ def get_transform_method(key: str, transform: TransformOptions = None) -> Callab transform_method = TRANSFORM_METHOD_MAPPING.get(transform) if transform == "auto": - key_suffix = key.rsplit(".")[-1] + key_suffix = value.rsplit(".")[-1] transform_method = TRANSFORM_METHOD_MAPPING.get(key_suffix, TRANSFORM_METHOD_MAPPING[None]) return cast(Callable, transform_method) # https://github.com/python/mypy/issues/10740 @@ -323,20 +321,29 @@ def get_transform_method(key: str, transform: TransformOptions = None) -> Callab @overload def transform_value( - value: Dict[str, Any], transform: TransformOptions, raise_on_transform_error: bool = False + value: Dict[str, Any], + transform: TransformOptions, + raise_on_transform_error: bool = False, + key: str = "", ) -> Dict[str, Any]: ... @overload def transform_value( - value: Union[str, bytes, Dict[str, Any]], transform: TransformOptions, raise_on_transform_error: bool = False + value: Union[str, bytes, Dict[str, Any]], + transform: TransformOptions, + raise_on_transform_error: bool = False, + key: str = "", ) -> Optional[Union[str, bytes, Dict[str, Any]]]: ... def transform_value( - value: Union[str, bytes, Dict[str, Any]], transform: TransformOptions, raise_on_transform_error: bool = True + value: Union[str, bytes, Dict[str, Any]], + transform: TransformOptions, + raise_on_transform_error: bool = True, + key: str = "", ) -> Optional[Union[str, bytes, Dict[str, Any]]]: """ Transform a value using one of the available options. @@ -347,6 +354,8 @@ def transform_value( Parameter value to transform transform: str Type of transform, supported values are "json", "binary", and "auto" based on suffix (.json, .binary) + key: str + Parameter key when transform is auto to infer its transform method raise_on_transform_error: bool, optional Raises an exception if any transform fails, otherwise this will return a None value for each transform that failed @@ -370,7 +379,7 @@ def transform_value( transformed_values: Dict[str, Any] = {} for dict_key, dict_value in value.items(): - transform_method = get_transform_method(key=dict_key, transform=transform) + transform_method = get_transform_method(value=dict_key, transform=transform) try: transformed_values[dict_key] = transform_method(dict_value) except Exception as exc: @@ -379,8 +388,14 @@ def transform_value( transformed_values[dict_key] = None return transformed_values + if transform == "auto": + # key="a.json", value='{"a": "b"}', or key="a.binary", value="b64_encoded" + transform_method = get_transform_method(value=key, transform=transform) + else: + # value='{"key": "value"} + transform_method = get_transform_method(value=value, transform=transform) + try: - transform_method = get_transform_method(key=value, transform=transform) return transform_method(value) except Exception as exc: if raise_on_transform_error: diff --git a/aws_lambda_powertools/utilities/parameters/ssm.py b/aws_lambda_powertools/utilities/parameters/ssm.py index 99375ef2710..7400a87aebf 100644 --- a/aws_lambda_powertools/utilities/parameters/ssm.py +++ b/aws_lambda_powertools/utilities/parameters/ssm.py @@ -9,7 +9,7 @@ from aws_lambda_powertools.shared.functions import slice_dictionary -from .base import DEFAULT_MAX_AGE_SECS, DEFAULT_PROVIDERS, BaseProvider +from .base import DEFAULT_MAX_AGE_SECS, DEFAULT_PROVIDERS, BaseProvider, transform_value from .exceptions import GetParameterError from .types import TransformOptions @@ -209,7 +209,7 @@ def get_parameters_by_name( transform: TransformOptions = None, decrypt: bool = False, max_age: int = DEFAULT_MAX_AGE_SECS, - raise_on_failure: bool = True, + raise_on_error: bool = True, ) -> Union[Dict[str, str], Dict[str, bytes], Dict[str, dict]]: """ Retrieve multiple parameter values by name from SSM or cache. @@ -224,6 +224,8 @@ def get_parameters_by_name( If the parameter values should be decrypted max_age: int Maximum age of the cached value + raise_on_error: bool + Whether to raise GetParameterError if a parameter fails to be fetched or not Raises ------ @@ -245,7 +247,6 @@ def get_parameters_by_name( # 8. [DONE] Return from cache # 9. [DONE] Migrate high-level function get_parameters_by_name to use new class get_parameters_by_name # 10. Handle soft error with "_errors" key upon raise_on_error being False - # 11. Include raise_on_transform in inner functions too batch_params, decrypt_params = self._split_batch_and_decrypt_parameters(parameters, transform, max_age, decrypt) @@ -257,9 +258,9 @@ def get_parameters_by_name( ) # Merge both batched parameters and those that required decryption - return {**self._get_parameters_from_batch(batch=batch_params, raise_on_failure=raise_on_failure), **ret} + return {**self._get_parameters_from_batch(batch=batch_params, raise_on_error=raise_on_error), **ret} - def _get_parameters_from_batch(self, batch: Dict[str, Dict], raise_on_failure: bool = True) -> Dict[str, Any]: + def _get_parameters_from_batch(self, batch: Dict[str, Dict], raise_on_error: bool = True) -> Dict[str, Any]: ret: Dict[str, Any] = {} # Check if it's in cache first to prevent unnecessary calls @@ -277,18 +278,18 @@ def _get_parameters_from_batch(self, batch: Dict[str, Dict], raise_on_failure: b batch_diff = {key: value for key, value in batch.items() if key not in ret} for chunk in slice_dictionary(data=batch_diff, chunk_size=self._MAX_GET_PARAMETERS_ITEM): - ret.update(**self._get_parameters_by_name(parameters=chunk, raise_on_failure=raise_on_failure)) + ret.update(**self._get_parameters_by_name(parameters=chunk, raise_on_error=raise_on_error)) return ret - def _get_parameters_by_name(self, parameters: Dict[str, Dict], raise_on_failure: bool = True) -> Dict[str, Any]: + def _get_parameters_by_name(self, parameters: Dict[str, Dict], raise_on_error: bool = True) -> Dict[str, Any]: """Use SSM GetParameters to fetch parameters, hydrate cache, and handle partial failure Parameters ---------- parameters : Dict[str, Dict] Parameters to fetch - raise_on_failure : bool, optional + raise_on_error : bool, optional Whether to fail-fast or fail gracefully by including "_errors" key in the response, by default True Returns @@ -299,11 +300,11 @@ def _get_parameters_by_name(self, parameters: Dict[str, Dict], raise_on_failure: Raises ------ GetParameterError - When one or more parameters failed on fetching, and raise_on_failure is enabled + When one or more parameters failed on fetching, and raise_on_error is enabled """ ret = {} response = self.client.get_parameters(Names=list(parameters.keys())) - if response["InvalidParameters"] and raise_on_failure: + if response["InvalidParameters"] and raise_on_error: raise GetParameterError(f"Failed to fetch parameters: {response['InvalidParameters']}") # Built up cache_key, hydrate cache, and return `{name:value}` @@ -311,6 +312,13 @@ def _get_parameters_by_name(self, parameters: Dict[str, Dict], raise_on_failure: name = parameter["Name"] value = parameter["Value"] options = parameters[name] + transform = options.get("transform", "") + + # NOTE: If transform is set, we do it before caching to reduce number of operations + if transform: + value = transform_value( + key=name, value=value, transform=transform, raise_on_transform_error=raise_on_error + ) _cache_key = (name, options["transform"]) self._add_to_cache(key=_cache_key, value=value, max_age=options["max_age"]) @@ -512,7 +520,7 @@ def get_parameters_by_name( transform: None = None, decrypt: bool = False, max_age: int = DEFAULT_MAX_AGE_SECS, - raise_on_failure: bool = True, + raise_on_error: bool = True, ) -> Dict[str, str]: ... @@ -523,7 +531,7 @@ def get_parameters_by_name( transform: Literal["binary"], decrypt: bool = False, max_age: int = DEFAULT_MAX_AGE_SECS, - raise_on_failure: bool = True, + raise_on_error: bool = True, ) -> Dict[str, bytes]: ... @@ -534,7 +542,7 @@ def get_parameters_by_name( transform: Literal["json"], decrypt: bool = False, max_age: int = DEFAULT_MAX_AGE_SECS, - raise_on_failure: bool = True, + raise_on_error: bool = True, ) -> Dict[str, Dict[str, Any]]: ... @@ -545,7 +553,7 @@ def get_parameters_by_name( transform: Literal["auto"], decrypt: bool = False, max_age: int = DEFAULT_MAX_AGE_SECS, - raise_on_failure: bool = True, + raise_on_error: bool = True, ) -> Union[Dict[str, str], Dict[str, dict]]: ... @@ -555,7 +563,7 @@ def get_parameters_by_name( transform: TransformOptions = None, decrypt: bool = False, max_age: int = DEFAULT_MAX_AGE_SECS, - raise_on_failure: bool = True, + raise_on_error: bool = True, ) -> Union[Dict[str, str], Dict[str, bytes], Dict[str, dict]]: """ Retrieve multiple parameter values by name from AWS Systems Manager (SSM) Parameter Store @@ -588,5 +596,5 @@ def get_parameters_by_name( DEFAULT_PROVIDERS["ssm"] = SSMProvider() return DEFAULT_PROVIDERS["ssm"].get_parameters_by_name( - parameters=parameters, max_age=max_age, transform=transform, decrypt=decrypt, raise_on_failure=raise_on_failure + parameters=parameters, max_age=max_age, transform=transform, decrypt=decrypt, raise_on_error=raise_on_error ) diff --git a/tests/e2e/parameters/test_ssm.py b/tests/e2e/parameters/test_ssm.py index fce3a4b458b..7e9614f8ea0 100644 --- a/tests/e2e/parameters/test_ssm.py +++ b/tests/e2e/parameters/test_ssm.py @@ -17,7 +17,8 @@ def parameters_list(infrastructure: dict) -> List[str]: return json.loads(param_list) -def test_get_parameter_appconfig_freeform( +# +def test_get_parameters_by_name( ssm_get_parameters_by_name_fn_arn: str, parameters_list: str, ): diff --git a/tests/functional/test_utilities_parameters.py b/tests/functional/test_utilities_parameters.py index ff9a2767183..3ce827880ec 100644 --- a/tests/functional/test_utilities_parameters.py +++ b/tests/functional/test_utilities_parameters.py @@ -1814,6 +1814,50 @@ def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]: assert value == mock_value +def test_transform_value_auto(mock_value: str): + # GIVEN + json_data = json.dumps({"A": mock_value}) + mock_binary = mock_value.encode() + binary_data = base64.b64encode(mock_binary).decode() + + # WHEN + json_value = parameters.base.transform_value(key="/a.json", value=json_data, transform="auto") + binary_value = parameters.base.transform_value(key="/a.binary", value=binary_data, transform="auto") + + # THEN + assert isinstance(json_value, dict) + assert isinstance(binary_value, bytes) + assert json_value["A"] == mock_value + assert binary_value == mock_binary + + +def test_transform_value_auto_incorrect_key(mock_value: str): + # GIVEN + mock_key = "/missing/json/suffix" + json_data = json.dumps({"A": mock_value}) + + # WHEN + value = parameters.base.transform_value(key=mock_key, value=json_data, transform="auto") + + # THEN it should echo back its value + assert isinstance(value, str) + assert value == json_data + + +def test_transform_value_auto_unsupported_transform(mock_value: str): + # GIVEN + mock_key = "/a.does_not_exist" + mock_dict = {"hello": "world"} + + # WHEN + value = parameters.base.transform_value(key=mock_key, value=mock_value, transform="auto") + dict_value = parameters.base.transform_value(key=mock_key, value=mock_dict, transform="auto") + + # THEN it should echo back its value + assert value == mock_value + assert dict_value == mock_dict + + def test_transform_value_json(mock_value): """ Test transform_value() with a json transform From 0a294f5d1f8f2bbf403bca6ebe3f6e9d56b05a16 Mon Sep 17 00:00:00 2001 From: heitorlessa Date: Fri, 4 Nov 2022 16:18:25 +0100 Subject: [PATCH 10/28] chore(tests): add functional test for decrypt, batch split, and overrides --- .../utilities/parameters/__init__.py | 3 +- .../utilities/parameters/ssm.py | 9 +- tests/functional/test_utilities_parameters.py | 164 +++++++++++++++++- 3 files changed, 172 insertions(+), 4 deletions(-) diff --git a/aws_lambda_powertools/utilities/parameters/__init__.py b/aws_lambda_powertools/utilities/parameters/__init__.py index 7dce2ac4c9a..9fcaa4fa701 100644 --- a/aws_lambda_powertools/utilities/parameters/__init__.py +++ b/aws_lambda_powertools/utilities/parameters/__init__.py @@ -9,7 +9,7 @@ from .dynamodb import DynamoDBProvider from .exceptions import GetParameterError, TransformParameterError from .secrets import SecretsProvider, get_secret -from .ssm import SSMProvider, get_parameter, get_parameters +from .ssm import SSMProvider, get_parameter, get_parameters, get_parameters_by_name __all__ = [ "AppConfigProvider", @@ -22,6 +22,7 @@ "get_app_config", "get_parameter", "get_parameters", + "get_parameters_by_name", "get_secret", "clear_caches", ] diff --git a/aws_lambda_powertools/utilities/parameters/ssm.py b/aws_lambda_powertools/utilities/parameters/ssm.py index 7400a87aebf..12efa8a519a 100644 --- a/aws_lambda_powertools/utilities/parameters/ssm.py +++ b/aws_lambda_powertools/utilities/parameters/ssm.py @@ -356,8 +356,13 @@ def _split_batch_and_decrypt_parameters( # NOTE: TypeDict later _overrides = options or {} _overrides["transform"] = _overrides.get("transform") or transform - _overrides["decrypt"] = _overrides.get("decrypt") or decrypt - _overrides["max_age"] = _overrides.get("max_age") or max_age + + # These values can be falsy (False, 0) + if "decrypt" not in _overrides: + _overrides["decrypt"] = decrypt + + if "max_age" not in _overrides: + _overrides["max_age"] = max_age # NOTE: Split parameters who have decrypt OR have it global if _overrides["decrypt"]: diff --git a/tests/functional/test_utilities_parameters.py b/tests/functional/test_utilities_parameters.py index 3ce827880ec..d049bb7d05f 100644 --- a/tests/functional/test_utilities_parameters.py +++ b/tests/functional/test_utilities_parameters.py @@ -1,10 +1,12 @@ +from __future__ import annotations + import base64 import json import random import string from datetime import datetime, timedelta from io import BytesIO -from typing import Dict +from typing import Any, Dict import boto3 import pytest @@ -19,6 +21,11 @@ BaseProvider, ExpirableValue, ) +from aws_lambda_powertools.utilities.parameters.ssm import ( + DEFAULT_MAX_AGE_SECS, + SSMProvider, +) +from aws_lambda_powertools.utilities.parameters.types import TransformOptions @pytest.fixture(scope="function") @@ -614,6 +621,47 @@ def test_ssm_provider_clear_cache(mock_name, mock_value, config): assert provider.store == {} +def test_ssm_provider_get_parameters_by_name_raise_on_failure(mock_name, mock_value, mock_version, config): + # GIVEN two parameters are requested + provider = parameters.SSMProvider(config=config) + dev_param = f"/dev/{mock_name}" + prod_param = f"/prod/{mock_name}" + + params = {dev_param: {}, prod_param: {}} + param_names = list(params.keys()) + + stubber = stub.Stubber(provider.client) + response = { + "Parameters": [ + { + "Name": dev_param, + "Type": "String", + "Value": "string", + "Version": mock_version, + "Selector": f"{dev_param}:{mock_version}", + "SourceResult": "string", + "LastModifiedDate": datetime(2015, 1, 1), + "ARN": f"arn:aws:ssm:us-east-2:111122223333:parameter/{dev_param.removeprefix('/')}", + "DataType": "string", + }, + ], + "InvalidParameters": [prod_param], + } + + expected_params = {"Names": param_names} + stubber.add_response("get_parameters", response, expected_params) + stubber.activate() + + # WHEN one of them fails to be retrieved + # THEN raise GetParameterError + with pytest.raises(parameters.exceptions.GetParameterError, match=f"Failed to fetch parameters: .*{prod_param}.*"): + try: + provider.get_parameters_by_name(parameters=params) + stubber.assert_no_pending_responses() + finally: + stubber.deactivate() + + def test_dynamodb_provider_clear_cache(mock_name, mock_value, config): # GIVEN a provider is initialized with a cached value provider = parameters.DynamoDBProvider(table_name="test", config=config) @@ -1522,6 +1570,120 @@ def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]: assert value == mock_value +def test_get_parameters_by_name(monkeypatch, mock_name, mock_value): + params = {mock_name: {}} + + class TestProvider(SSMProvider): + def _get(self, name: str, decrypt: bool = False, **sdk_options) -> str: + return mock_value + + def get_parameters_by_name( + self, + parameters: Dict[str, Dict], + transform: TransformOptions = None, + decrypt: bool = False, + max_age: int = DEFAULT_MAX_AGE_SECS, + raise_on_error: bool = True, + ) -> Dict[str, str] | Dict[str, bytes] | Dict[str, dict]: + return {mock_name: mock_value} + + monkeypatch.setitem(parameters.base.DEFAULT_PROVIDERS, "ssm", TestProvider()) + + values = parameters.get_parameters_by_name(parameters=params) + + assert len(values) == 1 + assert values[mock_name] == mock_value + + +def test_get_parameters_by_name_with_decrypt_override(monkeypatch, mock_name, mock_value): + # GIVEN 2 out of 3 parameters have decrypt override + decrypt_param = "/api_key" + decrypt_param_two = "/another/secret" + decrypt_params = {decrypt_param: {"decrypt": True}, decrypt_param_two: {"decrypt": True}} + decrypted_response = "decrypted" + params = {mock_name: {}, **decrypt_params} + + class TestProvider(SSMProvider): + def _get(self, name: str, decrypt: bool = False, **sdk_options) -> str: + # THEN params with `decrypt` override should use GetParameter` (`_get`) + assert name in decrypt_params + assert decrypt + return decrypted_response + + def _get_parameters_by_name(self, *args, **kwargs) -> Dict[str, Any]: + return {mock_name: mock_value} + + monkeypatch.setitem(parameters.base.DEFAULT_PROVIDERS, "ssm", TestProvider()) + + # WHEN get_parameters_by_name is called + values = parameters.get_parameters_by_name(parameters=params) + + # THEN all parameters should be merged in the response + assert len(values) == 3 + assert values[mock_name] == mock_value + assert values[decrypt_param] == decrypted_response + assert values[decrypt_param_two] == decrypted_response + + +def test_get_parameters_by_name_with_max_age_override(monkeypatch, mock_name, mock_value): + # GIVEN 1 out of 2 parameters overrides max_age to 0 + no_cache_param = "/no_cache" + params = {mock_name: {}, no_cache_param: {"max_age": 0}} + + class TestProvider(SSMProvider): + # NOTE: By convention, we check at `_get_parameters_by_name` + # as that's right before we call SSM, and when options have been merged + def _get_parameters_by_name(self, parameters: Dict[str, Dict], raise_on_error: bool = True) -> Dict[str, Any]: + # THEN max_age should use no_cache_param override + assert parameters[no_cache_param]["max_age"] == 0 + + return {mock_name: mock_value, no_cache_param: mock_value} + + monkeypatch.setitem(parameters.base.DEFAULT_PROVIDERS, "ssm", TestProvider()) + + # WHEN get_parameters_by_name is called + parameters.get_parameters_by_name(parameters=params) + + +def test_get_parameters_by_name_with_override_and_explicit_global(monkeypatch, mock_name, mock_value): + # GIVEN a parameter overrides a default setting + default_cache_period = 500 + params = {mock_name: {"max_age": 0}, "no-override": {}} + + class TestProvider(SSMProvider): + # NOTE: By convention, we check at `_get_parameters_by_name` + # as that's right before we call SSM, and when options have been merged + def _get_parameters_by_name(self, parameters: Dict[str, Dict], raise_on_error: bool = True) -> Dict[str, Any]: + # THEN max_age should use no_cache_param override + assert parameters[mock_name]["max_age"] == 0 + assert parameters["no-override"]["max_age"] == default_cache_period + + return {mock_name: mock_value} + + monkeypatch.setitem(parameters.base.DEFAULT_PROVIDERS, "ssm", TestProvider()) + + # WHEN get_parameters_by_name is called with max_age set to 500 as the default + parameters.get_parameters_by_name(parameters=params, max_age=default_cache_period) + + +def test_get_parameters_by_name_with_max_batch(monkeypatch, mock_value): + # GIVEN a batch of 20 parameters + params = {f"param_{i}": {} for i in range(20)} + + class TestProvider(SSMProvider): + # NOTE: By convention, we check at `_get_parameters_by_name` + # as that's right before we call SSM, and when options have been merged + def _get_parameters_by_name(self, parameters: Dict[str, Dict], raise_on_error: bool = True) -> Dict[str, Any]: + # THEN we should always split to respect GetParameters max + assert len(parameters) == self._MAX_GET_PARAMETERS_ITEM + return {} + + monkeypatch.setitem(parameters.base.DEFAULT_PROVIDERS, "ssm", TestProvider()) + + # WHEN get_parameters_by_name is called + parameters.get_parameters_by_name(parameters=params) + + def test_get_parameter_new(monkeypatch, mock_name, mock_value): """ Test get_parameter() without a default provider From dd3b2b3bbe5241f91661f2f54f0d3c2035e5d298 Mon Sep 17 00:00:00 2001 From: heitorlessa Date: Fri, 4 Nov 2022 19:17:53 +0100 Subject: [PATCH 11/28] chore(tests): add functional test for cache batch, transform override --- .../utilities/parameters/base.py | 5 +- .../utilities/parameters/ssm.py | 3 +- tests/functional/test_utilities_parameters.py | 148 ++++++++++++++---- 3 files changed, 123 insertions(+), 33 deletions(-) diff --git a/aws_lambda_powertools/utilities/parameters/base.py b/aws_lambda_powertools/utilities/parameters/base.py index dbd19d9861f..309117f5212 100644 --- a/aws_lambda_powertools/utilities/parameters/base.py +++ b/aws_lambda_powertools/utilities/parameters/base.py @@ -1,6 +1,7 @@ """ Base for Parameter providers """ +from __future__ import annotations import base64 import json @@ -52,7 +53,7 @@ class ExpirableValue(NamedTuple): - value: Union[str, bytes, Dict[str, Any]] + value: str | bytes | Dict[str, Any] ttl: datetime @@ -211,7 +212,7 @@ def _get_multiple(self, path: str, **sdk_options) -> Dict[str, str]: def clear_cache(self): self.store.clear() - def _add_to_cache(self, key: Tuple[str, TransformOptions], value: Any, max_age: int): + def add_to_cache(self, key: Tuple[str, TransformOptions], value: Any, max_age: int): self.store[key] = ExpirableValue(value, datetime.now() + timedelta(seconds=max_age)) @staticmethod diff --git a/aws_lambda_powertools/utilities/parameters/ssm.py b/aws_lambda_powertools/utilities/parameters/ssm.py index 12efa8a519a..29dc51c74eb 100644 --- a/aws_lambda_powertools/utilities/parameters/ssm.py +++ b/aws_lambda_powertools/utilities/parameters/ssm.py @@ -270,6 +270,7 @@ def _get_parameters_from_batch(self, batch: Dict[str, Dict], raise_on_error: boo if self._has_not_expired(cache_key): ret[name] = self.store[cache_key].value + # Return early if all parameters were in cache OR batch was empty if len(ret) == len(batch): return ret @@ -321,7 +322,7 @@ def _get_parameters_by_name(self, parameters: Dict[str, Dict], raise_on_error: b ) _cache_key = (name, options["transform"]) - self._add_to_cache(key=_cache_key, value=value, max_age=options["max_age"]) + self.add_to_cache(key=_cache_key, value=value, max_age=options["max_age"]) ret[name] = value diff --git a/tests/functional/test_utilities_parameters.py b/tests/functional/test_utilities_parameters.py index d049bb7d05f..5acb5768186 100644 --- a/tests/functional/test_utilities_parameters.py +++ b/tests/functional/test_utilities_parameters.py @@ -6,7 +6,7 @@ import string from datetime import datetime, timedelta from io import BytesIO -from typing import Any, Dict +from typing import Any, Dict, List import boto3 import pytest @@ -50,6 +50,29 @@ def config(): return Config(region_name="us-east-1") +def build_get_parameters_stub(params: Dict[str, Any], invalid_parameters: List[str] | None = None) -> Dict[str, List]: + invalid_parameters = invalid_parameters or [] + version = random.randrange(1, 1000) + return { + "Parameters": [ + { + "Name": param, + "Type": "String", + "Value": value, + "Version": version, + "Selector": f"{param}:{version}", + "SourceResult": "string", + "LastModifiedDate": datetime(2015, 1, 1), + "ARN": f"arn:aws:ssm:us-east-2:111122223333:parameter/{param.removeprefix('/')}", + "DataType": "string", + } + for param, value in params.items() + if param not in invalid_parameters + ], + "InvalidParameters": invalid_parameters, # official SDK stub fails validation here, need to raise an issue + } + + def test_dynamodb_provider_get(mock_name, mock_value, config): """ Test DynamoDBProvider.get() with a non-cached value @@ -629,24 +652,10 @@ def test_ssm_provider_get_parameters_by_name_raise_on_failure(mock_name, mock_va params = {dev_param: {}, prod_param: {}} param_names = list(params.keys()) + stub_params = {dev_param: mock_value} stubber = stub.Stubber(provider.client) - response = { - "Parameters": [ - { - "Name": dev_param, - "Type": "String", - "Value": "string", - "Version": mock_version, - "Selector": f"{dev_param}:{mock_version}", - "SourceResult": "string", - "LastModifiedDate": datetime(2015, 1, 1), - "ARN": f"arn:aws:ssm:us-east-2:111122223333:parameter/{dev_param.removeprefix('/')}", - "DataType": "string", - }, - ], - "InvalidParameters": [prod_param], - } + response = build_get_parameters_stub(params=stub_params, invalid_parameters=[prod_param]) expected_params = {"Names": param_names} stubber.add_response("get_parameters", response, expected_params) @@ -1574,9 +1583,6 @@ def test_get_parameters_by_name(monkeypatch, mock_name, mock_value): params = {mock_name: {}} class TestProvider(SSMProvider): - def _get(self, name: str, decrypt: bool = False, **sdk_options) -> str: - return mock_value - def get_parameters_by_name( self, parameters: Dict[str, Dict], @@ -1625,24 +1631,32 @@ def _get_parameters_by_name(self, *args, **kwargs) -> Dict[str, Any]: assert values[decrypt_param_two] == decrypted_response -def test_get_parameters_by_name_with_max_age_override(monkeypatch, mock_name, mock_value): +def test_get_parameters_by_name_with_overrides(monkeypatch, mock_value): # GIVEN 1 out of 2 parameters overrides max_age to 0 no_cache_param = "/no_cache" - params = {mock_name: {}, no_cache_param: {"max_age": 0}} + json_param = "/json" + params = {no_cache_param: {"max_age": 0}, json_param: {"transform": "json"}} + + stub_params = {no_cache_param: mock_value, json_param: '{"a":"b"}'} + stub_response = build_get_parameters_stub(params=stub_params) + + class FakeClient: + def get_parameters(self, *args, **kwargs): + return stub_response class TestProvider(SSMProvider): - # NOTE: By convention, we check at `_get_parameters_by_name` - # as that's right before we call SSM, and when options have been merged def _get_parameters_by_name(self, parameters: Dict[str, Dict], raise_on_error: bool = True) -> Dict[str, Any]: # THEN max_age should use no_cache_param override assert parameters[no_cache_param]["max_age"] == 0 + return super()._get_parameters_by_name(parameters, raise_on_error) - return {mock_name: mock_value, no_cache_param: mock_value} - - monkeypatch.setitem(parameters.base.DEFAULT_PROVIDERS, "ssm", TestProvider()) + provider = TestProvider(boto3_client=FakeClient()) + monkeypatch.setitem(parameters.base.DEFAULT_PROVIDERS, "ssm", TestProvider(boto3_client=FakeClient())) # WHEN get_parameters_by_name is called - parameters.get_parameters_by_name(parameters=params) + ret = provider.get_parameters_by_name(parameters=params) + # THEN json_param should be transformed + assert isinstance(ret[json_param], dict) def test_get_parameters_by_name_with_override_and_explicit_global(monkeypatch, mock_name, mock_value): @@ -1671,8 +1685,6 @@ def test_get_parameters_by_name_with_max_batch(monkeypatch, mock_value): params = {f"param_{i}": {} for i in range(20)} class TestProvider(SSMProvider): - # NOTE: By convention, we check at `_get_parameters_by_name` - # as that's right before we call SSM, and when options have been merged def _get_parameters_by_name(self, parameters: Dict[str, Dict], raise_on_error: bool = True) -> Dict[str, Any]: # THEN we should always split to respect GetParameters max assert len(parameters) == self._MAX_GET_PARAMETERS_ITEM @@ -1684,6 +1696,64 @@ def _get_parameters_by_name(self, parameters: Dict[str, Dict], raise_on_error: b parameters.get_parameters_by_name(parameters=params) +def test_get_parameters_by_name_cache(monkeypatch, mock_name, mock_value): + # GIVEN we have a parameter to fetch but is already in cache + params = {mock_name: {}} + + class TestProvider(SSMProvider): + def _get_parameters_by_name( + self, parameters: Dict[str, Dict], raise_on_error: bool = True, **kwargs + ) -> Dict[str, Any]: + # THEN this should never be called + raise RuntimeError("Should not be called if it's in cache") + + provider = TestProvider() + provider.add_to_cache(key=(mock_name, None), value=mock_value, max_age=10) + + monkeypatch.setitem(parameters.base.DEFAULT_PROVIDERS, "ssm", provider) + + # WHEN get_parameters_by_name is called + provider.get_parameters_by_name(parameters=params) + + +def test_get_parameters_by_name_empty_batch(monkeypatch, mock_name, mock_value): + # GIVEN we have an empty dictionary + params = {} + + class TestProvider(SSMProvider): + ... + + monkeypatch.setitem(parameters.base.DEFAULT_PROVIDERS, "ssm", TestProvider()) + + # WHEN get_parameters_by_name is called + assert parameters.get_parameters_by_name(parameters=params) == {} + + +def test_get_parameters_by_name_cache_them_individually_not_batch(monkeypatch, mock_name, mock_version): + # GIVEN we have a parameter to fetch but is already in cache + dev_param = f"/dev/{mock_name}" + prod_param = f"/prod/{mock_name}" + params = {dev_param: {}, prod_param: {}} + + stub_params = {dev_param: mock_value, prod_param: mock_value} + stub_response = build_get_parameters_stub(params=stub_params) + + class FakeClient: + def get_parameters(self, *args, **kwargs): + return stub_response + + class TestProvider(SSMProvider): + def get_parameters_by_name(self, *args, **kwargs) -> Dict[str, str] | Dict[str, bytes] | Dict[str, dict]: + return super().get_parameters_by_name(*args, **kwargs) + + provider = TestProvider(boto3_client=FakeClient()) + monkeypatch.setitem(parameters.base.DEFAULT_PROVIDERS, "ssm", provider) + + # WHEN get_parameters_by_name is called + provider.get_parameters_by_name(parameters=params) + assert len(provider.store) == len(params) + + def test_get_parameter_new(monkeypatch, mock_name, mock_value): """ Test get_parameter() without a default provider @@ -1750,6 +1820,24 @@ def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]: assert value == mock_value +def test_get_parameters_by_name_new(monkeypatch, mock_name, mock_value): + """ + Test get_parameters_by_name() without a default provider + """ + params = {mock_name: {}} + + class TestProvider(SSMProvider): + def get_parameters_by_name(self, *args, **kwargs) -> Dict[str, str] | Dict[str, bytes] | Dict[str, dict]: + return {mock_name: mock_value} + + monkeypatch.setattr(parameters.ssm, "DEFAULT_PROVIDERS", {}) + monkeypatch.setattr(parameters.ssm, "SSMProvider", TestProvider) + + value = parameters.get_parameters_by_name(params) + + assert value[mock_name] == mock_value + + def test_get_secret(monkeypatch, mock_name, mock_value): """ Test get_secret() From 0faad699b4bd360e1e3d4ce1248b5210c05b94b6 Mon Sep 17 00:00:00 2001 From: heitorlessa Date: Sun, 6 Nov 2022 20:49:03 +0100 Subject: [PATCH 12/28] feat(parameters): graceful error handling for raise_on_failure; cleanup --- aws_lambda_powertools/shared/functions.py | 11 +- .../utilities/parameters/ssm.py | 76 ++++++----- tests/functional/test_utilities_parameters.py | 123 ++++++++++++++++-- 3 files changed, 161 insertions(+), 49 deletions(-) diff --git a/aws_lambda_powertools/shared/functions.py b/aws_lambda_powertools/shared/functions.py index 6ab9d1442e5..884edb37e35 100644 --- a/aws_lambda_powertools/shared/functions.py +++ b/aws_lambda_powertools/shared/functions.py @@ -4,7 +4,7 @@ import os import warnings from binascii import Error as BinAsciiError -from typing import Optional, Union, overload +from typing import Dict, Generator, Optional, Union, overload from aws_lambda_powertools.shared import constants @@ -118,11 +118,6 @@ def powertools_debug_is_set() -> bool: return False -def slice_dictionary(data, chunk_size: int): - # save CPU cycles if input is already small than chunk_size - if len(data) <= chunk_size: - yield data - - data_iterator = iter(data) # we don't know how big this is +def slice_dictionary(data: Dict, chunk_size: int) -> Generator[Dict, None, None]: for _ in range(0, len(data), chunk_size): - yield {dict_key: data[dict_key] for dict_key in itertools.islice(data_iterator, chunk_size)} + yield {dict_key: data[dict_key] for dict_key in itertools.islice(data, chunk_size)} diff --git a/aws_lambda_powertools/utilities/parameters/ssm.py b/aws_lambda_powertools/utilities/parameters/ssm.py index 29dc51c74eb..789f59fcd3e 100644 --- a/aws_lambda_powertools/utilities/parameters/ssm.py +++ b/aws_lambda_powertools/utilities/parameters/ssm.py @@ -1,7 +1,7 @@ """ AWS SSM Parameter retrieval and caching utility """ -from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union, overload +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, overload import boto3 from botocore.config import Config @@ -191,7 +191,7 @@ def _get_multiple(self, path: str, decrypt: bool = False, recursive: bool = Fals for page in self.client.get_paginator("get_parameters_by_path").paginate(**sdk_options): for parameter in page.get("Parameters", []): # Standardize the parameter name - # The parameter name returned by SSM will contained the full path. + # The parameter name returned by SSM will contain the full path. # However, for readability, we should return only the part after # the path. name = parameter["Name"] @@ -225,7 +225,7 @@ def get_parameters_by_name( max_age: int Maximum age of the cached value raise_on_error: bool - Whether to raise GetParameterError if a parameter fails to be fetched or not + Whether to fail-fast or fail gracefully by including "_errors" key in the response, by default True Raises ------ @@ -235,36 +235,38 @@ def get_parameters_by_name( """ ret: Dict[str, Any] = {} - - # Tasks: - # 1. [DONE] Move to GetParameters - # 2. [DONE] Slice parameters in 10 if more than 10 - # 3. [DONE] Split batch and decrypt parameters - # 4. [DONE] Use GetParameters for batch parameters - # 5. [DONE] Cache successful ones individually as they might have overrides - # 6. [DONE] Use GetParameter for those using `decrypt` - # 7. [DONE] Introduce raise_on_error - # 8. [DONE] Return from cache - # 9. [DONE] Migrate high-level function get_parameters_by_name to use new class get_parameters_by_name - # 10. Handle soft error with "_errors" key upon raise_on_error being False + decrypt_errors: List[str] = [] batch_params, decrypt_params = self._split_batch_and_decrypt_parameters(parameters, transform, max_age, decrypt) # Decided for single-thread as it outperforms in 128M and 1G + reduce timeout risk # see: https://github.com/awslabs/aws-lambda-powertools-python/issues/1040#issuecomment-1299954613 for parameter, options in decrypt_params.items(): - ret[parameter] = self.get( - parameter, max_age=options["max_age"], transform=options["transform"], decrypt=options["decrypt"] - ) + try: + ret[parameter] = self.get( + parameter, max_age=options["max_age"], transform=options["transform"], decrypt=options["decrypt"] + ) + except GetParameterError: + if raise_on_error: + raise + decrypt_errors.append(parameter) + continue + + batch_ret, batch_err = self._get_parameters_from_batch(batch=batch_params, raise_on_error=raise_on_error) + if not raise_on_error: + # merge batch and decrypt errors if instructed + ret["_errors"] = [*batch_err, *decrypt_errors] # Merge both batched parameters and those that required decryption - return {**self._get_parameters_from_batch(batch=batch_params, raise_on_error=raise_on_error), **ret} + return {**batch_ret, **ret} - def _get_parameters_from_batch(self, batch: Dict[str, Dict], raise_on_error: bool = True) -> Dict[str, Any]: + def _get_parameters_from_batch( + self, batch: Dict[str, Dict], raise_on_error: bool = True + ) -> Tuple[Dict[str, Any], List[str]]: ret: Dict[str, Any] = {} + errors: List[str] = [] - # Check if it's in cache first to prevent unnecessary calls - # also confirm whether the incoming batch matches our cached + # Check if it's in cache to prevent unnecessary calls for name, options in batch.items(): cache_key = (name, options["transform"]) if self._has_not_expired(cache_key): @@ -272,18 +274,24 @@ def _get_parameters_from_batch(self, batch: Dict[str, Dict], raise_on_error: boo # Return early if all parameters were in cache OR batch was empty if len(ret) == len(batch): - return ret + return ret, errors # Take out the differences to prevent over-fetching - # since there could be parameters with distinct max_age override + # there could be parameters with cache expired batch_diff = {key: value for key, value in batch.items() if key not in ret} for chunk in slice_dictionary(data=batch_diff, chunk_size=self._MAX_GET_PARAMETERS_ITEM): - ret.update(**self._get_parameters_by_name(parameters=chunk, raise_on_error=raise_on_error)) + response, possible_errors = self._get_parameters_by_name(parameters=chunk, raise_on_error=raise_on_error) + ret.update(response) - return ret + if not raise_on_error: + errors = possible_errors - def _get_parameters_by_name(self, parameters: Dict[str, Dict], raise_on_error: bool = True) -> Dict[str, Any]: + return ret, errors + + def _get_parameters_by_name( + self, parameters: Dict[str, Dict], raise_on_error: bool = True + ) -> Tuple[Dict[str, Any], List[str]]: """Use SSM GetParameters to fetch parameters, hydrate cache, and handle partial failure Parameters @@ -304,9 +312,15 @@ def _get_parameters_by_name(self, parameters: Dict[str, Dict], raise_on_error: b When one or more parameters failed on fetching, and raise_on_error is enabled """ ret = {} + batch_errors: List[str] = [] + response = self.client.get_parameters(Names=list(parameters.keys())) - if response["InvalidParameters"] and raise_on_error: - raise GetParameterError(f"Failed to fetch parameters: {response['InvalidParameters']}") + failed_parameters = response["InvalidParameters"] + if failed_parameters: + if raise_on_error: + raise GetParameterError(f"Failed to fetch parameters: {failed_parameters}") + else: + batch_errors = failed_parameters # Built up cache_key, hydrate cache, and return `{name:value}` for parameter in response["Parameters"]: @@ -326,7 +340,7 @@ def _get_parameters_by_name(self, parameters: Dict[str, Dict], raise_on_error: b ret[name] = value - return ret + return ret, batch_errors @staticmethod def _split_batch_and_decrypt_parameters( @@ -584,6 +598,8 @@ def get_parameters_by_name( If the parameter values should be decrypted max_age: int Maximum age of the cached value + raise_on_error: bool, optional + Whether to fail-fast or fail gracefully by including "_errors" key in the response, by default True Raises ------ diff --git a/tests/functional/test_utilities_parameters.py b/tests/functional/test_utilities_parameters.py index 5acb5768186..1d58375f0a9 100644 --- a/tests/functional/test_utilities_parameters.py +++ b/tests/functional/test_utilities_parameters.py @@ -6,7 +6,7 @@ import string from datetime import datetime, timedelta from io import BytesIO -from typing import Any, Dict, List +from typing import Any, Dict, List, Tuple import boto3 import pytest @@ -648,14 +648,14 @@ def test_ssm_provider_get_parameters_by_name_raise_on_failure(mock_name, mock_va # GIVEN two parameters are requested provider = parameters.SSMProvider(config=config) dev_param = f"/dev/{mock_name}" - prod_param = f"/prod/{mock_name}" + fail_param = f"/prod/{mock_name}" - params = {dev_param: {}, prod_param: {}} + params = {dev_param: {}, fail_param: {}} param_names = list(params.keys()) stub_params = {dev_param: mock_value} stubber = stub.Stubber(provider.client) - response = build_get_parameters_stub(params=stub_params, invalid_parameters=[prod_param]) + response = build_get_parameters_stub(params=stub_params, invalid_parameters=[fail_param]) expected_params = {"Names": param_names} stubber.add_response("get_parameters", response, expected_params) @@ -663,7 +663,7 @@ def test_ssm_provider_get_parameters_by_name_raise_on_failure(mock_name, mock_va # WHEN one of them fails to be retrieved # THEN raise GetParameterError - with pytest.raises(parameters.exceptions.GetParameterError, match=f"Failed to fetch parameters: .*{prod_param}.*"): + with pytest.raises(parameters.exceptions.GetParameterError, match=f"Failed to fetch parameters: .*{fail_param}.*"): try: provider.get_parameters_by_name(parameters=params) stubber.assert_no_pending_responses() @@ -671,6 +671,99 @@ def test_ssm_provider_get_parameters_by_name_raise_on_failure(mock_name, mock_va stubber.deactivate() +def test_ssm_provider_get_parameters_by_name_do_not_raise_on_failure(mock_name, mock_value, mock_version, config): + # GIVEN two parameters are requested + success = f"/dev/{mock_name}" + fail = f"/prod/{mock_name}" + params = {success: {}, fail: {}} + param_names = list(params.keys()) + stub_params = {success: mock_value} + + expected_stub_response = build_get_parameters_stub(params=stub_params, invalid_parameters=[fail]) + expected_stub_params = {"Names": param_names} + + provider = parameters.SSMProvider(config=config) + stubber = stub.Stubber(provider.client) + stubber.add_response("get_parameters", expected_stub_response, expected_stub_params) + stubber.activate() + + # WHEN one of them fails to be retrieved + try: + ret = provider.get_parameters_by_name(parameters=params, raise_on_error=False) + + # THEN there should be no error raised + # and failed ones available within "_errors" key + stubber.assert_no_pending_responses() + assert ret["_errors"] + assert len(ret["_errors"]) == 1 + assert fail not in ret + finally: + stubber.deactivate() + + +def test_ssm_provider_get_parameters_by_name_do_not_raise_on_failure_with_decrypt(mock_name, mock_version, config): + # GIVEN one parameter requires decryption and an arbitrary SDK error occurs + param = f"/{mock_name}" + params = {param: {"decrypt": True}} + + provider = parameters.SSMProvider(config=config) + stubber = stub.Stubber(provider.client) + stubber.add_client_error("get_parameter") + stubber.activate() + + # WHEN fail-fast is disabled in get_parameters_by_name + try: + ret = provider.get_parameters_by_name(parameters=params, raise_on_error=False) + stubber.assert_no_pending_responses() + + # THEN there should be no error raised but added under `_errors` key + assert ret["_errors"] + assert len(ret["_errors"]) == 1 + assert param not in ret + finally: + stubber.deactivate() + + +def test_ssm_provider_get_parameters_by_name_do_not_raise_on_failure_batch_decrypt_combined( + mock_value, mock_version, config +): + # GIVEN three parameters are requested + # one requires decryption, two can be batched + fail = "/fail" + success = "/success" + decrypt_fail = "/fail/decrypt" + params = {decrypt_fail: {"decrypt": True}, success: {}, fail: {}} + + expected_stub_params = {"Names": [success, fail]} + expected_stub_response = build_get_parameters_stub( + params={fail: mock_value, success: mock_value}, invalid_parameters=[fail] + ) + + # GIVEN an arbitrary SDK error is injected + provider = parameters.SSMProvider(config=config) + stubber = stub.Stubber(provider.client) + stubber.add_client_error("get_parameter") + stubber.add_response("get_parameters", expected_stub_response, expected_stub_params) + stubber.activate() + + # WHEN fail-fast is disabled in get_parameters_by_name + # and only one parameter succeeds out of three + try: + ret = provider.get_parameters_by_name(parameters=params, raise_on_error=False) + + # THEN there should be no error raised + # successful params returned accordingly + # and failed ones available within "_errors" key + stubber.assert_no_pending_responses() + assert success in ret + assert ret["_errors"] + assert len(ret["_errors"]) == 2 + assert fail not in ret + assert decrypt_fail not in ret + finally: + stubber.deactivate() + + def test_dynamodb_provider_clear_cache(mock_name, mock_value, config): # GIVEN a provider is initialized with a cached value provider = parameters.DynamoDBProvider(table_name="test", config=config) @@ -1616,8 +1709,11 @@ def _get(self, name: str, decrypt: bool = False, **sdk_options) -> str: assert decrypt return decrypted_response - def _get_parameters_by_name(self, *args, **kwargs) -> Dict[str, Any]: - return {mock_name: mock_value} + # def _get_parameters_by_name(self, *args, **kwargs) -> Dict[str, Any]: + def _get_parameters_by_name( + self, parameters: Dict[str, Dict], raise_on_error: bool = True + ) -> Tuple[Dict[str, Any], List[str]]: + return {mock_name: mock_value}, [] monkeypatch.setitem(parameters.base.DEFAULT_PROVIDERS, "ssm", TestProvider()) @@ -1667,12 +1763,15 @@ def test_get_parameters_by_name_with_override_and_explicit_global(monkeypatch, m class TestProvider(SSMProvider): # NOTE: By convention, we check at `_get_parameters_by_name` # as that's right before we call SSM, and when options have been merged - def _get_parameters_by_name(self, parameters: Dict[str, Dict], raise_on_error: bool = True) -> Dict[str, Any]: + # def _get_parameters_by_name(self, parameters: Dict[str, Dict], raise_on_error: bool = True) -> Dict[str, Any]: + def _get_parameters_by_name( + self, parameters: Dict[str, Dict], raise_on_error: bool = True + ) -> Tuple[Dict[str, Any], List[str]]: # THEN max_age should use no_cache_param override assert parameters[mock_name]["max_age"] == 0 assert parameters["no-override"]["max_age"] == default_cache_period - return {mock_name: mock_value} + return {mock_name: mock_value}, [] monkeypatch.setitem(parameters.base.DEFAULT_PROVIDERS, "ssm", TestProvider()) @@ -1685,10 +1784,12 @@ def test_get_parameters_by_name_with_max_batch(monkeypatch, mock_value): params = {f"param_{i}": {} for i in range(20)} class TestProvider(SSMProvider): - def _get_parameters_by_name(self, parameters: Dict[str, Dict], raise_on_error: bool = True) -> Dict[str, Any]: + def _get_parameters_by_name( + self, parameters: Dict[str, Dict], raise_on_error: bool = True + ) -> Tuple[Dict[str, Any], List[str]]: # THEN we should always split to respect GetParameters max assert len(parameters) == self._MAX_GET_PARAMETERS_ITEM - return {} + return {}, [] monkeypatch.setitem(parameters.base.DEFAULT_PROVIDERS, "ssm", TestProvider()) From 8e3a7a9bb2f6490972947c7097d0ae82546187cc Mon Sep 17 00:00:00 2001 From: heitorlessa Date: Mon, 7 Nov 2022 09:12:00 +0100 Subject: [PATCH 13/28] feat(parameters): expose has_not_expired_in_cache method to ease tests --- .../utilities/parameters/base.py | 6 +-- .../utilities/parameters/ssm.py | 2 +- tests/functional/test_utilities_parameters.py | 39 ++++++++++--------- 3 files changed, 25 insertions(+), 22 deletions(-) diff --git a/aws_lambda_powertools/utilities/parameters/base.py b/aws_lambda_powertools/utilities/parameters/base.py index 309117f5212..c6a483baf42 100644 --- a/aws_lambda_powertools/utilities/parameters/base.py +++ b/aws_lambda_powertools/utilities/parameters/base.py @@ -71,7 +71,7 @@ def __init__(self): self.store: Dict[Tuple[str, TransformOptions], ExpirableValue] = {} - def _has_not_expired(self, key: Tuple[str, TransformOptions]) -> bool: + def has_not_expired_in_cache(self, key: Tuple[str, TransformOptions]) -> bool: return key in self.store and self.store[key].ttl >= datetime.now() def get( @@ -121,7 +121,7 @@ def get( value: Optional[Union[str, bytes, dict]] = None key = (name, transform) - if not force_fetch and self._has_not_expired(key): + if not force_fetch and self.has_not_expired_in_cache(key): return self.store[key].value try: @@ -186,7 +186,7 @@ def get_multiple( """ key = (path, transform) - if not force_fetch and self._has_not_expired(key): + if not force_fetch and self.has_not_expired_in_cache(key): return self.store[key].value # type: ignore # need to revisit entire typing here try: diff --git a/aws_lambda_powertools/utilities/parameters/ssm.py b/aws_lambda_powertools/utilities/parameters/ssm.py index 789f59fcd3e..9561f3f62ca 100644 --- a/aws_lambda_powertools/utilities/parameters/ssm.py +++ b/aws_lambda_powertools/utilities/parameters/ssm.py @@ -269,7 +269,7 @@ def _get_parameters_from_batch( # Check if it's in cache to prevent unnecessary calls for name, options in batch.items(): cache_key = (name, options["transform"]) - if self._has_not_expired(cache_key): + if self.has_not_expired_in_cache(cache_key): ret[name] = self.store[cache_key].value # Return early if all parameters were in cache OR batch was empty diff --git a/tests/functional/test_utilities_parameters.py b/tests/functional/test_utilities_parameters.py index 1d58375f0a9..f0a3756babf 100644 --- a/tests/functional/test_utilities_parameters.py +++ b/tests/functional/test_utilities_parameters.py @@ -647,23 +647,23 @@ def test_ssm_provider_clear_cache(mock_name, mock_value, config): def test_ssm_provider_get_parameters_by_name_raise_on_failure(mock_name, mock_value, mock_version, config): # GIVEN two parameters are requested provider = parameters.SSMProvider(config=config) - dev_param = f"/dev/{mock_name}" - fail_param = f"/prod/{mock_name}" + success = f"/dev/{mock_name}" + fail = f"/prod/{mock_name}" - params = {dev_param: {}, fail_param: {}} + params = {success: {}, fail: {}} param_names = list(params.keys()) - stub_params = {dev_param: mock_value} + stub_params = {success: mock_value} - stubber = stub.Stubber(provider.client) - response = build_get_parameters_stub(params=stub_params, invalid_parameters=[fail_param]) + expected_stub_response = build_get_parameters_stub(params=stub_params, invalid_parameters=[fail]) + expected_stub_params = {"Names": param_names} - expected_params = {"Names": param_names} - stubber.add_response("get_parameters", response, expected_params) + stubber = stub.Stubber(provider.client) + stubber.add_response("get_parameters", expected_stub_response, expected_stub_params) stubber.activate() # WHEN one of them fails to be retrieved # THEN raise GetParameterError - with pytest.raises(parameters.exceptions.GetParameterError, match=f"Failed to fetch parameters: .*{fail_param}.*"): + with pytest.raises(parameters.exceptions.GetParameterError, match=f"Failed to fetch parameters: .*{fail}.*"): try: provider.get_parameters_by_name(parameters=params) stubber.assert_no_pending_responses() @@ -729,6 +729,7 @@ def test_ssm_provider_get_parameters_by_name_do_not_raise_on_failure_batch_decry ): # GIVEN three parameters are requested # one requires decryption, two can be batched + # and an arbitrary SDK error is injected fail = "/fail" success = "/success" decrypt_fail = "/fail/decrypt" @@ -739,7 +740,6 @@ def test_ssm_provider_get_parameters_by_name_do_not_raise_on_failure_batch_decry params={fail: mock_value, success: mock_value}, invalid_parameters=[fail] ) - # GIVEN an arbitrary SDK error is injected provider = parameters.SSMProvider(config=config) stubber = stub.Stubber(provider.client) stubber.add_client_error("get_parameter") @@ -1709,7 +1709,6 @@ def _get(self, name: str, decrypt: bool = False, **sdk_options) -> str: assert decrypt return decrypted_response - # def _get_parameters_by_name(self, *args, **kwargs) -> Dict[str, Any]: def _get_parameters_by_name( self, parameters: Dict[str, Dict], raise_on_error: bool = True ) -> Tuple[Dict[str, Any], List[str]]: @@ -1741,7 +1740,9 @@ def get_parameters(self, *args, **kwargs): return stub_response class TestProvider(SSMProvider): - def _get_parameters_by_name(self, parameters: Dict[str, Dict], raise_on_error: bool = True) -> Dict[str, Any]: + def _get_parameters_by_name( + self, parameters: Dict[str, Dict], raise_on_error: bool = True + ) -> Tuple[Dict[str, Any], List[str]]: # THEN max_age should use no_cache_param override assert parameters[no_cache_param]["max_age"] == 0 return super()._get_parameters_by_name(parameters, raise_on_error) @@ -1800,12 +1801,12 @@ def _get_parameters_by_name( def test_get_parameters_by_name_cache(monkeypatch, mock_name, mock_value): # GIVEN we have a parameter to fetch but is already in cache params = {mock_name: {}} + cache_key = (mock_name, None) class TestProvider(SSMProvider): def _get_parameters_by_name( self, parameters: Dict[str, Dict], raise_on_error: bool = True, **kwargs ) -> Dict[str, Any]: - # THEN this should never be called raise RuntimeError("Should not be called if it's in cache") provider = TestProvider() @@ -1816,6 +1817,9 @@ def _get_parameters_by_name( # WHEN get_parameters_by_name is called provider.get_parameters_by_name(parameters=params) + # THEN the cache should be used and _get_parameters_by_name should not be called + assert provider.has_not_expired_in_cache(key=cache_key) + def test_get_parameters_by_name_empty_batch(monkeypatch, mock_name, mock_value): # GIVEN we have an empty dictionary @@ -1827,6 +1831,7 @@ class TestProvider(SSMProvider): monkeypatch.setitem(parameters.base.DEFAULT_PROVIDERS, "ssm", TestProvider()) # WHEN get_parameters_by_name is called + # THEN it should return an empty response assert parameters.get_parameters_by_name(parameters=params) == {} @@ -1843,15 +1848,13 @@ class FakeClient: def get_parameters(self, *args, **kwargs): return stub_response - class TestProvider(SSMProvider): - def get_parameters_by_name(self, *args, **kwargs) -> Dict[str, str] | Dict[str, bytes] | Dict[str, dict]: - return super().get_parameters_by_name(*args, **kwargs) - - provider = TestProvider(boto3_client=FakeClient()) + provider = SSMProvider(boto3_client=FakeClient()) monkeypatch.setitem(parameters.base.DEFAULT_PROVIDERS, "ssm", provider) # WHEN get_parameters_by_name is called provider.get_parameters_by_name(parameters=params) + + # THEN the cache should be populated with each parameter assert len(provider.store) == len(params) From 4d63a22caef3d973890b77f25e21ce293fa75dd1 Mon Sep 17 00:00:00 2001 From: heitorlessa Date: Mon, 7 Nov 2022 09:20:47 +0100 Subject: [PATCH 14/28] chore(tests): ensure null or negative max_age params aren't cached --- .../utilities/parameters/base.py | 3 +++ tests/functional/test_utilities_parameters.py | 15 +++++++++++++++ 2 files changed, 18 insertions(+) diff --git a/aws_lambda_powertools/utilities/parameters/base.py b/aws_lambda_powertools/utilities/parameters/base.py index c6a483baf42..8587d3b5f3f 100644 --- a/aws_lambda_powertools/utilities/parameters/base.py +++ b/aws_lambda_powertools/utilities/parameters/base.py @@ -213,6 +213,9 @@ def clear_cache(self): self.store.clear() def add_to_cache(self, key: Tuple[str, TransformOptions], value: Any, max_age: int): + if max_age <= 0: + return + self.store[key] = ExpirableValue(value, datetime.now() + timedelta(seconds=max_age)) @staticmethod diff --git a/tests/functional/test_utilities_parameters.py b/tests/functional/test_utilities_parameters.py index f0a3756babf..f194d3c5c96 100644 --- a/tests/functional/test_utilities_parameters.py +++ b/tests/functional/test_utilities_parameters.py @@ -2328,3 +2328,18 @@ def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]: assert isinstance(value, str) assert value == mock_value + + +def test_cache_ignores_max_age_zero_or_negative(mock_value): + # GIVEN we have two parameters that shouldn't be cached + param = "/no_cache" + provider = SSMProvider() + cache_key = (param, None) + + # WHEN a provider adds them into the cache + provider.add_to_cache(key=cache_key, value=mock_value, max_age=0) + provider.add_to_cache(key=cache_key, value=mock_value, max_age=-10) + + # THEN they should not be added to the cache + assert len(provider.store) == 0 + assert provider.has_not_expired_in_cache(cache_key) is False From 6613a23ab52dc831a994ae651c64c85eac30ac95 Mon Sep 17 00:00:00 2001 From: heitorlessa Date: Mon, 7 Nov 2022 11:19:15 +0100 Subject: [PATCH 15/28] refactor: break logic in multiple methods to ease maintenance --- .../utilities/parameters/ssm.py | 80 ++++++++++++------- 1 file changed, 49 insertions(+), 31 deletions(-) diff --git a/aws_lambda_powertools/utilities/parameters/ssm.py b/aws_lambda_powertools/utilities/parameters/ssm.py index 9561f3f62ca..327f70c4aca 100644 --- a/aws_lambda_powertools/utilities/parameters/ssm.py +++ b/aws_lambda_powertools/utilities/parameters/ssm.py @@ -233,61 +233,79 @@ def get_parameters_by_name( When the parameter provider fails to retrieve a parameter value for a given name. """ - - ret: Dict[str, Any] = {} - decrypt_errors: List[str] = [] + response: Dict[str, Any] = {} batch_params, decrypt_params = self._split_batch_and_decrypt_parameters(parameters, transform, max_age, decrypt) + decrypt_ret, decrypt_err = self._get_parameters_by_name_with_decrypt_option(decrypt_params, raise_on_error) + batch_ret, batch_err = self._get_parameters_by_name_batch(batch=batch_params, raise_on_error=raise_on_error) + + response.update(**batch_ret, **decrypt_ret) + if not raise_on_error: + response["_errors"] = [*decrypt_err, *batch_err] + + return response + + def _get_parameters_by_name_with_decrypt_option( + self, batch: Dict[str, Dict], raise_on_error: bool + ) -> Tuple[Dict, List]: + response: Dict[str, Any] = {} + errors: List[str] = [] # Decided for single-thread as it outperforms in 128M and 1G + reduce timeout risk # see: https://github.com/awslabs/aws-lambda-powertools-python/issues/1040#issuecomment-1299954613 - for parameter, options in decrypt_params.items(): + for parameter, options in batch.items(): try: - ret[parameter] = self.get( + response[parameter] = self.get( parameter, max_age=options["max_age"], transform=options["transform"], decrypt=options["decrypt"] ) except GetParameterError: if raise_on_error: raise - decrypt_errors.append(parameter) + errors.append(parameter) continue - batch_ret, batch_err = self._get_parameters_from_batch(batch=batch_params, raise_on_error=raise_on_error) - if not raise_on_error: - # merge batch and decrypt errors if instructed - ret["_errors"] = [*batch_err, *decrypt_errors] - - # Merge both batched parameters and those that required decryption - return {**batch_ret, **ret} + return response, errors - def _get_parameters_from_batch( - self, batch: Dict[str, Dict], raise_on_error: bool = True - ) -> Tuple[Dict[str, Any], List[str]]: - ret: Dict[str, Any] = {} + def _get_parameters_by_name_batch(self, batch: Dict[str, Dict], raise_on_error: bool = True) -> Tuple[Dict, List]: errors: List[str] = [] - # Check if it's in cache to prevent unnecessary calls + # Fetch each possible batch param from cache and return if entire batch is cached + cached_params = self._get_parameters_by_name_from_cache(batch) + if len(cached_params) == len(batch): + return cached_params, errors + + # Slice batch by max permitted GetParameters call + batch_ret, errors = self._get_parameters_by_name_in_chunks( + batch=batch, cache=cached_params, raise_on_error=raise_on_error + ) + + return {**cached_params, **batch_ret}, errors + + def _get_parameters_by_name_from_cache(self, batch: Dict[str, Dict]) -> Dict[str, Any]: + """Fetch each parameter from batch that hasn't been expired""" + cache = {} for name, options in batch.items(): cache_key = (name, options["transform"]) if self.has_not_expired_in_cache(cache_key): - ret[name] = self.store[cache_key].value + cache[name] = self.store[cache_key].value + + return cache - # Return early if all parameters were in cache OR batch was empty - if len(ret) == len(batch): - return ret, errors + def _get_parameters_by_name_in_chunks( + self, batch: Dict[str, Dict], cache: Dict[str, Any], raise_on_error: bool + ) -> Tuple[Dict, List]: + """Take out differences from cache and batch, slice it and fetch from SSM""" + response: Dict[str, Any] = {} + errors: List[str] = [] - # Take out the differences to prevent over-fetching - # there could be parameters with cache expired - batch_diff = {key: value for key, value in batch.items() if key not in ret} + diff = {key: value for key, value in batch.items() if key not in cache} - for chunk in slice_dictionary(data=batch_diff, chunk_size=self._MAX_GET_PARAMETERS_ITEM): + for chunk in slice_dictionary(data=diff, chunk_size=self._MAX_GET_PARAMETERS_ITEM): response, possible_errors = self._get_parameters_by_name(parameters=chunk, raise_on_error=raise_on_error) - ret.update(response) - - if not raise_on_error: - errors = possible_errors + response.update(response) + errors.extend(possible_errors) - return ret, errors + return response, errors def _get_parameters_by_name( self, parameters: Dict[str, Dict], raise_on_error: bool = True From f213dea44cbf5c44053d648c4173c4bf79e441c5 Mon Sep 17 00:00:00 2001 From: heitorlessa Date: Mon, 7 Nov 2022 11:37:00 +0100 Subject: [PATCH 16/28] chore: add docstring with example --- .../utilities/parameters/ssm.py | 35 ++++++++++++++++--- 1 file changed, 31 insertions(+), 4 deletions(-) diff --git a/aws_lambda_powertools/utilities/parameters/ssm.py b/aws_lambda_powertools/utilities/parameters/ssm.py index 327f70c4aca..6d431656fc5 100644 --- a/aws_lambda_powertools/utilities/parameters/ssm.py +++ b/aws_lambda_powertools/utilities/parameters/ssm.py @@ -1,6 +1,8 @@ """ AWS SSM Parameter retrieval and caching utility """ +from __future__ import annotations + from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, overload import boto3 @@ -210,10 +212,12 @@ def get_parameters_by_name( decrypt: bool = False, max_age: int = DEFAULT_MAX_AGE_SECS, raise_on_error: bool = True, - ) -> Union[Dict[str, str], Dict[str, bytes], Dict[str, dict]]: + ) -> Dict[str, str] | Dict[str, bytes] | Dict[str, dict]: """ Retrieve multiple parameter values by name from SSM or cache. + It uses GetParameter if a param requires decryption, otherwise GetParameters. + Parameters ---------- parameters: List[Dict[str, Dict]] @@ -237,7 +241,7 @@ def get_parameters_by_name( batch_params, decrypt_params = self._split_batch_and_decrypt_parameters(parameters, transform, max_age, decrypt) decrypt_ret, decrypt_err = self._get_parameters_by_name_with_decrypt_option(decrypt_params, raise_on_error) - batch_ret, batch_err = self._get_parameters_by_name_batch(batch=batch_params, raise_on_error=raise_on_error) + batch_ret, batch_err = self._get_parameters_by_name_batch(batch_params, raise_on_error) response.update(**batch_ret, **decrypt_ret) if not raise_on_error: @@ -619,13 +623,36 @@ def get_parameters_by_name( raise_on_error: bool, optional Whether to fail-fast or fail gracefully by including "_errors" key in the response, by default True + Example + ------- + + **Retrieves multiple parameters from distinct paths from Systems Manager Parameter Store** + + from aws_lambda_powertools.utilities.parameters import get_parameters_by_name + + params = { + "/param": {}, + "/json": {"transform": "json"}, + "/binary": {"transform": "binary"}, + "/no_cache": {"max_age": 0}, + "/api_key": {"decrypt": True}, + } + + values = get_parameters_by_name(parameters=params) + for param_name, value in values.items(): + print(f"{param_name}: {value}") + + # "/param": value + # "/json": value + # "/binary": value + # "/no_cache": value + # "/api_key": value + Raises ------ GetParameterError When the parameter provider fails to retrieve a parameter value for a given name. - TransformParameterError - When the parameter provider fails to transform a parameter value. """ # NOTE: Decided against using multi-thread due to single-thread outperforming in 128M and 1G + timeout risk From 859171b86adb93517b2410a37882cea731ed1f80 Mon Sep 17 00:00:00 2001 From: heitorlessa Date: Mon, 7 Nov 2022 12:11:40 +0100 Subject: [PATCH 17/28] docs(parameters): document graceful error handling for get_parameters_by_name --- docs/utilities/parameters.md | 47 ++++++++++++++++++++++++++++++------ 1 file changed, 40 insertions(+), 7 deletions(-) diff --git a/docs/utilities/parameters.md b/docs/utilities/parameters.md index 9bc0c71e2c8..48aec3697f4 100644 --- a/docs/utilities/parameters.md +++ b/docs/utilities/parameters.md @@ -49,8 +49,8 @@ def handler(event, context): For multiple parameters, you can use either: -* `get_parameters` to recursively fetch all parameters by path -* `get_parameters_by_name` to fetch distinct parameters by their full name. It also accepts custom caching, transform, and decrypt per parameter. +* `get_parameters` to recursively fetch all parameters by path. +* `get_parameters_by_name` to fetch distinct parameters by their full name. It also accepts custom caching, transform, decrypt per parameter. === "get_parameters" @@ -67,24 +67,57 @@ For multiple parameters, you can use either: === "get_parameters_by_name" - ```python hl_lines="1 3 13" + ```python hl_lines="3 5 14" + from typing import Any + from aws_lambda_powertools.utilities import get_parameters_by_name parameters = { "/develop/service/commons/telemetry/config": {"max_age": 300, "transform": "json"}, - "/develop/service/amplify/auth/userpool/arn": {"max_age": 300}, + "/no_cache_param": {"max_age": 0}, # inherit default values "/develop/service/payment/api/capture/url": {}, - "/develop/service/payment/api/charge/url": {}, } def handler(event, context): # This returns a dict with the parameter name as key - values = parameters.get_parameters_by_name(parameters=parameters, max_age=60) - for parameter, value in values.items(): + response: dict[str, Any] = parameters.get_parameters_by_name(parameters=parameters, max_age=60) + for parameter, value in response.items(): print(f"{parameter}: {value}") ``` +???+ tip "`get_parameters_by_name` supports graceful error handling" + By default, we will raise `GetParameterError` when any parameter fails to be fetched. You can override it by setting `raise_on_error=False`. + + When disabled, we take the following actions: + + * Add failed parameter name in the `_errors` key, _e.g._, `{_errors: ["/param1", "/param2"]}` + * Keep only successful parameter names and their values in the response + * Raise `GetParameterError` if any of your parameters is named `_errors` + +```python hl_lines="3 5 12-13 15" title="Graceful error handling" +from typing import Any + +from aws_lambda_powertools.utilities import get_parameters_by_name + +parameters = { + "/develop/service/commons/telemetry/config": {"max_age": 300, "transform": "json"}, + # it would fail by default + "/this/param/does/not/exist" +} + +def handler(event, context): + values: dict[str, Any] = parameters.get_parameters_by_name(parameters=parameters, raise_on_error=False) + errors: list[str] = values.get("_errors", []) + + # Handle gracefully, since '/this/param/does/not/exist' will only be available in `_errors` + if errors: + ... + + for parameter, value in values.items(): + print(f"{parameter}: {value}") +``` + ### Fetching secrets You can fetch secrets stored in Secrets Manager using `get_secrets`. From e0a63ae83d668e3fd1b7ef9db32c7361cf0ddabe Mon Sep 17 00:00:00 2001 From: heitorlessa Date: Mon, 7 Nov 2022 12:21:26 +0100 Subject: [PATCH 18/28] feat(parameters): add guardrail for param also named _errors in graceful mode --- .../utilities/parameters/ssm.py | 4 ++++ tests/functional/test_utilities_parameters.py | 20 ++++++++++++++++--- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/aws_lambda_powertools/utilities/parameters/ssm.py b/aws_lambda_powertools/utilities/parameters/ssm.py index 6d431656fc5..25f1ecd32a1 100644 --- a/aws_lambda_powertools/utilities/parameters/ssm.py +++ b/aws_lambda_powertools/utilities/parameters/ssm.py @@ -239,6 +239,10 @@ def get_parameters_by_name( """ response: Dict[str, Any] = {} + # NOTE: We fail early to avoid unintended graceful errors being replaced with their param values + if "_errors" in parameters and not raise_on_error: + raise GetParameterError("You cannot fetch a parameter named '_errors' in graceful error mode.") + batch_params, decrypt_params = self._split_batch_and_decrypt_parameters(parameters, transform, max_age, decrypt) decrypt_ret, decrypt_err = self._get_parameters_by_name_with_decrypt_option(decrypt_params, raise_on_error) batch_ret, batch_err = self._get_parameters_by_name_batch(batch_params, raise_on_error) diff --git a/tests/functional/test_utilities_parameters.py b/tests/functional/test_utilities_parameters.py index f194d3c5c96..c95a1798856 100644 --- a/tests/functional/test_utilities_parameters.py +++ b/tests/functional/test_utilities_parameters.py @@ -644,7 +644,7 @@ def test_ssm_provider_clear_cache(mock_name, mock_value, config): assert provider.store == {} -def test_ssm_provider_get_parameters_by_name_raise_on_failure(mock_name, mock_value, mock_version, config): +def test_ssm_provider_get_parameters_by_name_raise_on_failure(mock_name, mock_value, config): # GIVEN two parameters are requested provider = parameters.SSMProvider(config=config) success = f"/dev/{mock_name}" @@ -671,7 +671,7 @@ def test_ssm_provider_get_parameters_by_name_raise_on_failure(mock_name, mock_va stubber.deactivate() -def test_ssm_provider_get_parameters_by_name_do_not_raise_on_failure(mock_name, mock_value, mock_version, config): +def test_ssm_provider_get_parameters_by_name_do_not_raise_on_failure(mock_name, mock_value, config): # GIVEN two parameters are requested success = f"/dev/{mock_name}" fail = f"/prod/{mock_name}" @@ -701,7 +701,7 @@ def test_ssm_provider_get_parameters_by_name_do_not_raise_on_failure(mock_name, stubber.deactivate() -def test_ssm_provider_get_parameters_by_name_do_not_raise_on_failure_with_decrypt(mock_name, mock_version, config): +def test_ssm_provider_get_parameters_by_name_do_not_raise_on_failure_with_decrypt(mock_name, config): # GIVEN one parameter requires decryption and an arbitrary SDK error occurs param = f"/{mock_name}" params = {param: {"decrypt": True}} @@ -764,6 +764,20 @@ def test_ssm_provider_get_parameters_by_name_do_not_raise_on_failure_batch_decry stubber.deactivate() +def test_ssm_provider_get_parameters_by_name_raise_on_reserved_errors_key(mock_name, mock_value, config): + # GIVEN one of the parameters is named `_errors` + success = f"/dev/{mock_name}" + fail = "_errors" + + params = {success: {}, fail: {}} + provider = parameters.SSMProvider(config=config) + + # WHEN using get_parameters_by_name to fetch + # THEN raise GetParameterError + with pytest.raises(parameters.exceptions.GetParameterError, match="You cannot fetch a parameter named"): + provider.get_parameters_by_name(parameters=params, raise_on_error=False) + + def test_dynamodb_provider_clear_cache(mock_name, mock_value, config): # GIVEN a provider is initialized with a cached value provider = parameters.DynamoDBProvider(table_name="test", config=config) From 386e813b59a78a4c39415dbb2043a4b854be35da Mon Sep 17 00:00:00 2001 From: heitorlessa Date: Mon, 7 Nov 2022 13:35:46 +0100 Subject: [PATCH 19/28] docs: add IAM permission for get_parameters_by_name --- docs/utilities/parameters.md | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/docs/utilities/parameters.md b/docs/utilities/parameters.md index 48aec3697f4..9441d94fe12 100644 --- a/docs/utilities/parameters.md +++ b/docs/utilities/parameters.md @@ -24,15 +24,16 @@ This utility requires additional permissions to work as expected. ???+ note Different parameter providers require different permissions. -| Provider | Function/Method | IAM Permission | -| ------------------- | ---------------------------------------------------- | ---------------------------------------------------------------------------- | -| SSM Parameter Store | `get_parameter`, `SSMProvider.get` | `ssm:GetParameter` | -| SSM Parameter Store | `get_parameters`, `SSMProvider.get_multiple` | `ssm:GetParametersByPath` | -| SSM Parameter Store | If using `decrypt=True` | You must add an additional permission `kms:Decrypt` | -| Secrets Manager | `get_secret`, `SecretsManager.get` | `secretsmanager:GetSecretValue` | -| DynamoDB | `DynamoDBProvider.get` | `dynamodb:GetItem` | -| DynamoDB | `DynamoDBProvider.get_multiple` | `dynamodb:Query` | -| App Config | `get_app_config`, `AppConfigProvider.get_app_config` | `appconfig:GetLatestConfiguration` and `appconfig:StartConfigurationSession` | +| Provider | Function/Method | IAM Permission | +| --------- | ---------------------------------------------------------------------- | ------------------------------------------------------------------------------------ | +| SSM | **`get_parameter`**, **`SSMProvider.get`** | **`ssm:GetParameter`** | +| SSM | **`get_parameters`**, **`SSMProvider.get_multiple`** | **`ssm:GetParametersByPath`** | +| SSM | **`get_parameters_by_name`**, **`SSMProvider.get_parameters_by_name`** | **`ssm:GetParameter`** and **`ssm:GetParameters`** | +| SSM | If using **`decrypt=True`** | You must add an additional permission **`kms:Decrypt`** | +| Secrets | **`get_secret`**, **`SecretsManager.get`** | **`secretsmanager:GetSecretValue`** | +| DynamoDB | **`DynamoDBProvider.get`** | **`dynamodb:GetItem`** | +| DynamoDB | **`DynamoDBProvider.get_multiple`** | **`dynamodb:Query`** | +| AppConfig | **`get_app_config`**, **`AppConfigProvider.get_app_config`** | **`appconfig:GetLatestConfiguration`** and **`appconfig:StartConfigurationSession`** | ### Fetching parameters From 27d05c4161c8466dabe38ec57563d6abab9ada0a Mon Sep 17 00:00:00 2001 From: heitorlessa Date: Mon, 7 Nov 2022 15:33:10 +0100 Subject: [PATCH 20/28] feat: use GetParameters if entire batch needs decryption --- .../utilities/parameters/ssm.py | 117 +++++++++++++----- tests/functional/test_utilities_parameters.py | 41 +++++- 2 files changed, 120 insertions(+), 38 deletions(-) diff --git a/aws_lambda_powertools/utilities/parameters/ssm.py b/aws_lambda_powertools/utilities/parameters/ssm.py index 25f1ecd32a1..fee90093ac9 100644 --- a/aws_lambda_powertools/utilities/parameters/ssm.py +++ b/aws_lambda_powertools/utilities/parameters/ssm.py @@ -17,6 +17,7 @@ if TYPE_CHECKING: from mypy_boto3_ssm import SSMClient + from mypy_boto3_ssm.type_defs import GetParametersResultTypeDef class SSMProvider(BaseProvider): @@ -86,6 +87,7 @@ class SSMProvider(BaseProvider): client: Any = None _MAX_GET_PARAMETERS_ITEM = 10 + _ERRORS_KEY = "_errors" def __init__( self, @@ -205,6 +207,7 @@ def _get_multiple(self, path: str, decrypt: bool = False, recursive: bool = Fals return parameters + # NOTE: When bandwidth permits, allocate a week to refactor to lower cognitive load def get_parameters_by_name( self, parameters: Dict[str, Dict], @@ -234,24 +237,39 @@ def get_parameters_by_name( Raises ------ GetParameterError - When the parameter provider fails to retrieve a parameter value for - a given name. + When the parameter provider fails to retrieve a parameter value for a given name. + + When "_errors" reserved key is in parameters to be fetched from SSM. """ + # Init potential batch/decrypt batch responses and errors + batch_ret: Dict[str, Any] = {} + decrypt_ret: Dict[str, Any] = {} + batch_err: List[str] = [] + decrypt_err: List[str] = [] response: Dict[str, Any] = {} - # NOTE: We fail early to avoid unintended graceful errors being replaced with their param values - if "_errors" in parameters and not raise_on_error: - raise GetParameterError("You cannot fetch a parameter named '_errors' in graceful error mode.") + # NOTE: We fail early to avoid unintended graceful errors being replaced with their '_errors' param values + self._raise_if_errors_key_is_present(parameters, self._ERRORS_KEY, raise_on_error) batch_params, decrypt_params = self._split_batch_and_decrypt_parameters(parameters, transform, max_age, decrypt) - decrypt_ret, decrypt_err = self._get_parameters_by_name_with_decrypt_option(decrypt_params, raise_on_error) - batch_ret, batch_err = self._get_parameters_by_name_batch(batch_params, raise_on_error) - response.update(**batch_ret, **decrypt_ret) + # NOTE: We need to find out whether all parameters must be decrypted or not to know which API to use + ## Logic: + ## + ## GetParameters API -> When decrypt is used for all parameters in the the batch + ## GetParameter API -> When decrypt is used for one or more in the batch + + if len(decrypt_params) != len(parameters): + decrypt_ret, decrypt_err = self._get_parameters_by_name_with_decrypt_option(decrypt_params, raise_on_error) + batch_ret, batch_err = self._get_parameters_batch_by_name(batch_params, raise_on_error, decrypt=False) + else: + batch_ret, batch_err = self._get_parameters_batch_by_name(decrypt_params, raise_on_error, decrypt=True) + + # Fail-fast disabled, let's aggregate errors under "_errors" key so they can handle gracefully if not raise_on_error: - response["_errors"] = [*decrypt_err, *batch_err] + response[self._ERRORS_KEY] = [*decrypt_err, *batch_err] - return response + return {**response, **batch_ret, **decrypt_ret} def _get_parameters_by_name_with_decrypt_option( self, batch: Dict[str, Dict], raise_on_error: bool @@ -263,9 +281,7 @@ def _get_parameters_by_name_with_decrypt_option( # see: https://github.com/awslabs/aws-lambda-powertools-python/issues/1040#issuecomment-1299954613 for parameter, options in batch.items(): try: - response[parameter] = self.get( - parameter, max_age=options["max_age"], transform=options["transform"], decrypt=options["decrypt"] - ) + response[parameter] = self.get(parameter, options["max_age"], options["transform"], options["decrypt"]) except GetParameterError: if raise_on_error: raise @@ -274,7 +290,10 @@ def _get_parameters_by_name_with_decrypt_option( return response, errors - def _get_parameters_by_name_batch(self, batch: Dict[str, Dict], raise_on_error: bool = True) -> Tuple[Dict, List]: + def _get_parameters_batch_by_name( + self, batch: Dict[str, Dict], raise_on_error: bool = True, decrypt: bool = False + ) -> Tuple[Dict, List]: + """Slice batch and fetch parameters using GetParameters by max permitted""" errors: List[str] = [] # Fetch each possible batch param from cache and return if entire batch is cached @@ -283,9 +302,7 @@ def _get_parameters_by_name_batch(self, batch: Dict[str, Dict], raise_on_error: return cached_params, errors # Slice batch by max permitted GetParameters call - batch_ret, errors = self._get_parameters_by_name_in_chunks( - batch=batch, cache=cached_params, raise_on_error=raise_on_error - ) + batch_ret, errors = self._get_parameters_by_name_in_chunks(batch, cached_params, raise_on_error, decrypt) return {**cached_params, **batch_ret}, errors @@ -300,7 +317,7 @@ def _get_parameters_by_name_from_cache(self, batch: Dict[str, Dict]) -> Dict[str return cache def _get_parameters_by_name_in_chunks( - self, batch: Dict[str, Dict], cache: Dict[str, Any], raise_on_error: bool + self, batch: Dict[str, Dict], cache: Dict[str, Any], raise_on_error: bool, decrypt: bool = False ) -> Tuple[Dict, List]: """Take out differences from cache and batch, slice it and fetch from SSM""" response: Dict[str, Any] = {} @@ -309,14 +326,16 @@ def _get_parameters_by_name_in_chunks( diff = {key: value for key, value in batch.items() if key not in cache} for chunk in slice_dictionary(data=diff, chunk_size=self._MAX_GET_PARAMETERS_ITEM): - response, possible_errors = self._get_parameters_by_name(parameters=chunk, raise_on_error=raise_on_error) + response, possible_errors = self._get_parameters_by_name( + parameters=chunk, raise_on_error=raise_on_error, decrypt=decrypt + ) response.update(response) errors.extend(possible_errors) return response, errors def _get_parameters_by_name( - self, parameters: Dict[str, Dict], raise_on_error: bool = True + self, parameters: Dict[str, Dict], raise_on_error: bool = True, decrypt: bool = False ) -> Tuple[Dict[str, Any], List[str]]: """Use SSM GetParameters to fetch parameters, hydrate cache, and handle partial failure @@ -337,19 +356,31 @@ def _get_parameters_by_name( GetParameterError When one or more parameters failed on fetching, and raise_on_error is enabled """ - ret = {} + ret: Dict[str, Any] = {} batch_errors: List[str] = [] - - response = self.client.get_parameters(Names=list(parameters.keys())) - failed_parameters = response["InvalidParameters"] - if failed_parameters: - if raise_on_error: - raise GetParameterError(f"Failed to fetch parameters: {failed_parameters}") + parameter_names = list(parameters.keys()) + + # All params in the batch must be decrypted + # we return early if we hit an unrecoverable exception like InvalidKeyId/InternalServerError + # everything else should technically be recoverable as GetParameters is non-atomic + try: + if decrypt: + response = self.client.get_parameters(Names=parameter_names, WithDecryption=True) else: - batch_errors = failed_parameters + response = self.client.get_parameters(Names=parameter_names) + except (self.client.exceptions.InvalidKeyId, self.client.exceptions.InternalServerError): + return ret, parameter_names - # Built up cache_key, hydrate cache, and return `{name:value}` - for parameter in response["Parameters"]: + batch_errors = self._handle_any_invalid_get_parameter_errors(response, raise_on_error) + transformed_params = self._transform_and_cache_get_parameters_response(response, parameters, raise_on_error) + + return transformed_params, batch_errors + + def _transform_and_cache_get_parameters_response( + self, api_response: GetParametersResultTypeDef, parameters: Dict[str, Any], raise_on_error: bool = True + ) -> Dict[str, Any]: + response: Dict[str, Any] = {} + for parameter in api_response["Parameters"]: name = parameter["Name"] value = parameter["Value"] options = parameters[name] @@ -364,9 +395,23 @@ def _get_parameters_by_name( _cache_key = (name, options["transform"]) self.add_to_cache(key=_cache_key, value=value, max_age=options["max_age"]) - ret[name] = value + response[name] = value - return ret, batch_errors + return response + + @staticmethod + def _handle_any_invalid_get_parameter_errors( + api_response: GetParametersResultTypeDef, raise_on_error: bool = True + ) -> List[str]: + """GetParameters is non-atomic. Failures don't always reflect in exceptions so we need to collect.""" + failed_parameters = api_response["InvalidParameters"] + if failed_parameters: + if raise_on_error: + raise GetParameterError(f"Failed to fetch parameters: {failed_parameters}") + + return failed_parameters + + return [] @staticmethod def _split_batch_and_decrypt_parameters( @@ -413,6 +458,14 @@ def _split_batch_and_decrypt_parameters( return batch_parameters, decrypt_parameters + @staticmethod + def _raise_if_errors_key_is_present(parameters: Dict, reserved_parameter: str, raise_on_error: bool): + """Raise GetParameterError if fail-fast is disabled and '_errors' key is in parameters batch""" + if not raise_on_error and reserved_parameter in parameters: + raise GetParameterError( + f"You cannot fetch a parameter named '{reserved_parameter}' in graceful error mode." + ) + def get_parameter( name: str, diff --git a/tests/functional/test_utilities_parameters.py b/tests/functional/test_utilities_parameters.py index c95a1798856..9528664d5c2 100644 --- a/tests/functional/test_utilities_parameters.py +++ b/tests/functional/test_utilities_parameters.py @@ -708,7 +708,7 @@ def test_ssm_provider_get_parameters_by_name_do_not_raise_on_failure_with_decryp provider = parameters.SSMProvider(config=config) stubber = stub.Stubber(provider.client) - stubber.add_client_error("get_parameter") + stubber.add_client_error("get_parameters", "InvalidKeyId") stubber.activate() # WHEN fail-fast is disabled in get_parameters_by_name @@ -778,6 +778,35 @@ def test_ssm_provider_get_parameters_by_name_raise_on_reserved_errors_key(mock_n provider.get_parameters_by_name(parameters=params, raise_on_error=False) +def test_ssm_provider_get_parameters_by_name_all_decrypt_should_use_get_parameters_api(mock_name, mock_value, config): + # GIVEN all parameters require decryption + param_a = f"/a/{mock_name}" + param_b = f"/b/{mock_name}" + fail = "/does_not_exist" # stub model doesn't support all-success yet + + all_params = {param_a: {}, param_b: {}, fail: {}} + all_params_names = list(all_params.keys()) + + expected_param_values = {param_a: mock_value, param_b: mock_value} + expected_stub_response = build_get_parameters_stub(params=expected_param_values, invalid_parameters=[fail]) + expected_stub_params = {"Names": all_params_names, "WithDecryption": True} + + provider = parameters.SSMProvider(config=config) + stubber = stub.Stubber(provider.client) + stubber.add_response("get_parameters", expected_stub_response, expected_stub_params) + stubber.activate() + + # WHEN get_parameters_by_name is called + # THEN we should only use GetParameters WithDecryption=true to prevent throttling + try: + ret = provider.get_parameters_by_name(parameters=all_params, decrypt=True, raise_on_error=False) + stubber.assert_no_pending_responses() + + assert ret is not None + finally: + stubber.deactivate() + + def test_dynamodb_provider_clear_cache(mock_name, mock_value, config): # GIVEN a provider is initialized with a cached value provider = parameters.DynamoDBProvider(table_name="test", config=config) @@ -1724,7 +1753,7 @@ def _get(self, name: str, decrypt: bool = False, **sdk_options) -> str: return decrypted_response def _get_parameters_by_name( - self, parameters: Dict[str, Dict], raise_on_error: bool = True + self, parameters: Dict[str, Dict], raise_on_error: bool = True, decrypt: bool = False ) -> Tuple[Dict[str, Any], List[str]]: return {mock_name: mock_value}, [] @@ -1755,11 +1784,11 @@ def get_parameters(self, *args, **kwargs): class TestProvider(SSMProvider): def _get_parameters_by_name( - self, parameters: Dict[str, Dict], raise_on_error: bool = True + self, parameters: Dict[str, Dict], raise_on_error: bool = True, decrypt: bool = False ) -> Tuple[Dict[str, Any], List[str]]: # THEN max_age should use no_cache_param override assert parameters[no_cache_param]["max_age"] == 0 - return super()._get_parameters_by_name(parameters, raise_on_error) + return super()._get_parameters_by_name(parameters, raise_on_error, decrypt) provider = TestProvider(boto3_client=FakeClient()) monkeypatch.setitem(parameters.base.DEFAULT_PROVIDERS, "ssm", TestProvider(boto3_client=FakeClient())) @@ -1780,7 +1809,7 @@ class TestProvider(SSMProvider): # as that's right before we call SSM, and when options have been merged # def _get_parameters_by_name(self, parameters: Dict[str, Dict], raise_on_error: bool = True) -> Dict[str, Any]: def _get_parameters_by_name( - self, parameters: Dict[str, Dict], raise_on_error: bool = True + self, parameters: Dict[str, Dict], raise_on_error: bool = True, decrypt: bool = False ) -> Tuple[Dict[str, Any], List[str]]: # THEN max_age should use no_cache_param override assert parameters[mock_name]["max_age"] == 0 @@ -1800,7 +1829,7 @@ def test_get_parameters_by_name_with_max_batch(monkeypatch, mock_value): class TestProvider(SSMProvider): def _get_parameters_by_name( - self, parameters: Dict[str, Dict], raise_on_error: bool = True + self, parameters: Dict[str, Dict], raise_on_error: bool = True, decrypt: bool = False ) -> Tuple[Dict[str, Any], List[str]]: # THEN we should always split to respect GetParameters max assert len(parameters) == self._MAX_GET_PARAMETERS_ITEM From aa5670cf5ad30d352192ebe6a5bba511e561a44f Mon Sep 17 00:00:00 2001 From: heitorlessa Date: Mon, 7 Nov 2022 16:00:34 +0100 Subject: [PATCH 21/28] docs: add ascii diagram to ease understanding of batch split API --- .../utilities/parameters/ssm.py | 20 ++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/aws_lambda_powertools/utilities/parameters/ssm.py b/aws_lambda_powertools/utilities/parameters/ssm.py index fee90093ac9..91f64ded739 100644 --- a/aws_lambda_powertools/utilities/parameters/ssm.py +++ b/aws_lambda_powertools/utilities/parameters/ssm.py @@ -219,7 +219,25 @@ def get_parameters_by_name( """ Retrieve multiple parameter values by name from SSM or cache. - It uses GetParameter if a param requires decryption, otherwise GetParameters. + Raise_on_error decides on error handling strategy: + + - A) Default to fail-fast. Raises GetParameterError upon any error + - B) Gracefully aggregate all parameters that failed under "_errors" key + + It transparently uses GetParameter and/or GetParameters depending on decryption requirements. + + ┌────────────────────────┐ + ┌───▶ Decrypt entire batch │─────┐ + │ └────────────────────────┘ │ ┌────────────────────┐ + │ ├─────▶ GetParameters API │ + ┌──────────────────┐ │ ┌────────────────────────┐ │ └────────────────────┘ + │ Split batch │─── ┼──▶│ No decryption required │─────┘ + └──────────────────┘ │ └────────────────────────┘ + │ ┌────────────────────┐ + │ ┌────────────────────────┐ │ GetParameter API │ + └──▶│Decrypt some but not all│───────────▶────────────────────┤ + └────────────────────────┘ │ GetParameters API │ + └────────────────────┘ Parameters ---------- From fbffe4803f0395bff1f5f81c5a5261792553a3ed Mon Sep 17 00:00:00 2001 From: heitorlessa Date: Mon, 7 Nov 2022 16:13:22 +0100 Subject: [PATCH 22/28] chore: ignore assignment mypy due to dinamism of transform --- aws_lambda_powertools/utilities/parameters/ssm.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/aws_lambda_powertools/utilities/parameters/ssm.py b/aws_lambda_powertools/utilities/parameters/ssm.py index 91f64ded739..d28dbac5ee2 100644 --- a/aws_lambda_powertools/utilities/parameters/ssm.py +++ b/aws_lambda_powertools/utilities/parameters/ssm.py @@ -398,17 +398,18 @@ def _transform_and_cache_get_parameters_response( self, api_response: GetParametersResultTypeDef, parameters: Dict[str, Any], raise_on_error: bool = True ) -> Dict[str, Any]: response: Dict[str, Any] = {} + for parameter in api_response["Parameters"]: name = parameter["Name"] value = parameter["Value"] options = parameters[name] - transform = options.get("transform", "") + transform = options.get("transform") # NOTE: If transform is set, we do it before caching to reduce number of operations if transform: value = transform_value( key=name, value=value, transform=transform, raise_on_transform_error=raise_on_error - ) + ) # type: ignore[assignment] # transform dynamism challenge _cache_key = (name, options["transform"]) self.add_to_cache(key=_cache_key, value=value, max_age=options["max_age"]) From d205ca48d1d0b8f51b19b48558135435ad015313 Mon Sep 17 00:00:00 2001 From: heitorlessa Date: Mon, 7 Nov 2022 16:23:50 +0100 Subject: [PATCH 23/28] chore(tests): remove redundant override test --- tests/functional/test_utilities_parameters.py | 30 ------------------- 1 file changed, 30 deletions(-) diff --git a/tests/functional/test_utilities_parameters.py b/tests/functional/test_utilities_parameters.py index 9528664d5c2..aa95934a3b6 100644 --- a/tests/functional/test_utilities_parameters.py +++ b/tests/functional/test_utilities_parameters.py @@ -1769,36 +1769,6 @@ def _get_parameters_by_name( assert values[decrypt_param_two] == decrypted_response -def test_get_parameters_by_name_with_overrides(monkeypatch, mock_value): - # GIVEN 1 out of 2 parameters overrides max_age to 0 - no_cache_param = "/no_cache" - json_param = "/json" - params = {no_cache_param: {"max_age": 0}, json_param: {"transform": "json"}} - - stub_params = {no_cache_param: mock_value, json_param: '{"a":"b"}'} - stub_response = build_get_parameters_stub(params=stub_params) - - class FakeClient: - def get_parameters(self, *args, **kwargs): - return stub_response - - class TestProvider(SSMProvider): - def _get_parameters_by_name( - self, parameters: Dict[str, Dict], raise_on_error: bool = True, decrypt: bool = False - ) -> Tuple[Dict[str, Any], List[str]]: - # THEN max_age should use no_cache_param override - assert parameters[no_cache_param]["max_age"] == 0 - return super()._get_parameters_by_name(parameters, raise_on_error, decrypt) - - provider = TestProvider(boto3_client=FakeClient()) - monkeypatch.setitem(parameters.base.DEFAULT_PROVIDERS, "ssm", TestProvider(boto3_client=FakeClient())) - - # WHEN get_parameters_by_name is called - ret = provider.get_parameters_by_name(parameters=params) - # THEN json_param should be transformed - assert isinstance(ret[json_param], dict) - - def test_get_parameters_by_name_with_override_and_explicit_global(monkeypatch, mock_name, mock_value): # GIVEN a parameter overrides a default setting default_cache_period = 500 From b9521526355f20dd3bd14016007be05c05fa5a08 Mon Sep 17 00:00:00 2001 From: heitorlessa Date: Mon, 7 Nov 2022 16:41:49 +0100 Subject: [PATCH 24/28] fix(tests): boto3 client side effect on super class --- tests/functional/test_utilities_parameters.py | 58 ++++++++++--------- 1 file changed, 31 insertions(+), 27 deletions(-) diff --git a/tests/functional/test_utilities_parameters.py b/tests/functional/test_utilities_parameters.py index aa95934a3b6..f0f7c36fc94 100644 --- a/tests/functional/test_utilities_parameters.py +++ b/tests/functional/test_utilities_parameters.py @@ -21,11 +21,7 @@ BaseProvider, ExpirableValue, ) -from aws_lambda_powertools.utilities.parameters.ssm import ( - DEFAULT_MAX_AGE_SECS, - SSMProvider, -) -from aws_lambda_powertools.utilities.parameters.types import TransformOptions +from aws_lambda_powertools.utilities.parameters.ssm import SSMProvider @pytest.fixture(scope="function") @@ -1715,18 +1711,14 @@ def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]: assert value == mock_value -def test_get_parameters_by_name(monkeypatch, mock_name, mock_value): +def test_get_parameters_by_name(monkeypatch, mock_name, mock_value, config): params = {mock_name: {}} class TestProvider(SSMProvider): - def get_parameters_by_name( - self, - parameters: Dict[str, Dict], - transform: TransformOptions = None, - decrypt: bool = False, - max_age: int = DEFAULT_MAX_AGE_SECS, - raise_on_error: bool = True, - ) -> Dict[str, str] | Dict[str, bytes] | Dict[str, dict]: + def __init__(self, config: Config = config, **kwargs): + super().__init__(config, **kwargs) + + def get_parameters_by_name(self, *args, **kwargs) -> Dict[str, str] | Dict[str, bytes] | Dict[str, dict]: return {mock_name: mock_value} monkeypatch.setitem(parameters.base.DEFAULT_PROVIDERS, "ssm", TestProvider()) @@ -1737,7 +1729,7 @@ def get_parameters_by_name( assert values[mock_name] == mock_value -def test_get_parameters_by_name_with_decrypt_override(monkeypatch, mock_name, mock_value): +def test_get_parameters_by_name_with_decrypt_override(monkeypatch, mock_name, mock_value, config): # GIVEN 2 out of 3 parameters have decrypt override decrypt_param = "/api_key" decrypt_param_two = "/another/secret" @@ -1746,15 +1738,16 @@ def test_get_parameters_by_name_with_decrypt_override(monkeypatch, mock_name, mo params = {mock_name: {}, **decrypt_params} class TestProvider(SSMProvider): + def __init__(self, config: Config = config, **kwargs): + super().__init__(config, **kwargs) + def _get(self, name: str, decrypt: bool = False, **sdk_options) -> str: # THEN params with `decrypt` override should use GetParameter` (`_get`) assert name in decrypt_params assert decrypt return decrypted_response - def _get_parameters_by_name( - self, parameters: Dict[str, Dict], raise_on_error: bool = True, decrypt: bool = False - ) -> Tuple[Dict[str, Any], List[str]]: + def _get_parameters_by_name(self, *args, **kwargs) -> Tuple[Dict[str, Any], List[str]]: return {mock_name: mock_value}, [] monkeypatch.setitem(parameters.base.DEFAULT_PROVIDERS, "ssm", TestProvider()) @@ -1769,12 +1762,15 @@ def _get_parameters_by_name( assert values[decrypt_param_two] == decrypted_response -def test_get_parameters_by_name_with_override_and_explicit_global(monkeypatch, mock_name, mock_value): +def test_get_parameters_by_name_with_override_and_explicit_global(monkeypatch, mock_name, mock_value, config): # GIVEN a parameter overrides a default setting default_cache_period = 500 params = {mock_name: {"max_age": 0}, "no-override": {}} class TestProvider(SSMProvider): + def __init__(self, config: Config = config, **kwargs): + super().__init__(config, **kwargs) + # NOTE: By convention, we check at `_get_parameters_by_name` # as that's right before we call SSM, and when options have been merged # def _get_parameters_by_name(self, parameters: Dict[str, Dict], raise_on_error: bool = True) -> Dict[str, Any]: @@ -1793,11 +1789,14 @@ def _get_parameters_by_name( parameters.get_parameters_by_name(parameters=params, max_age=default_cache_period) -def test_get_parameters_by_name_with_max_batch(monkeypatch, mock_value): +def test_get_parameters_by_name_with_max_batch(monkeypatch, config): # GIVEN a batch of 20 parameters params = {f"param_{i}": {} for i in range(20)} class TestProvider(SSMProvider): + def __init__(self, config: Config = config, **kwargs): + super().__init__(config, **kwargs) + def _get_parameters_by_name( self, parameters: Dict[str, Dict], raise_on_error: bool = True, decrypt: bool = False ) -> Tuple[Dict[str, Any], List[str]]: @@ -1811,15 +1810,16 @@ def _get_parameters_by_name( parameters.get_parameters_by_name(parameters=params) -def test_get_parameters_by_name_cache(monkeypatch, mock_name, mock_value): +def test_get_parameters_by_name_cache(monkeypatch, mock_name, mock_value, config): # GIVEN we have a parameter to fetch but is already in cache params = {mock_name: {}} cache_key = (mock_name, None) class TestProvider(SSMProvider): - def _get_parameters_by_name( - self, parameters: Dict[str, Dict], raise_on_error: bool = True, **kwargs - ) -> Dict[str, Any]: + def __init__(self, config: Config = config, **kwargs): + super().__init__(config, **kwargs) + + def _get_parameters_by_name(self, *args, **kwargs) -> Tuple[Dict[str, Any], List[str]]: raise RuntimeError("Should not be called if it's in cache") provider = TestProvider() @@ -1834,12 +1834,13 @@ def _get_parameters_by_name( assert provider.has_not_expired_in_cache(key=cache_key) -def test_get_parameters_by_name_empty_batch(monkeypatch, mock_name, mock_value): +def test_get_parameters_by_name_empty_batch(monkeypatch, config): # GIVEN we have an empty dictionary params = {} class TestProvider(SSMProvider): - ... + def __init__(self, config: Config = config, **kwargs): + super().__init__(config, **kwargs) monkeypatch.setitem(parameters.base.DEFAULT_PROVIDERS, "ssm", TestProvider()) @@ -1937,13 +1938,16 @@ def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]: assert value == mock_value -def test_get_parameters_by_name_new(monkeypatch, mock_name, mock_value): +def test_get_parameters_by_name_new(monkeypatch, mock_name, mock_value, config): """ Test get_parameters_by_name() without a default provider """ params = {mock_name: {}} class TestProvider(SSMProvider): + def __init__(self, config: Config = config, **kwargs): + super().__init__(config, **kwargs) + def get_parameters_by_name(self, *args, **kwargs) -> Dict[str, str] | Dict[str, bytes] | Dict[str, dict]: return {mock_name: mock_value} From 9eaf0f4a6bcbfa644c4920b894323e9395a39d4d Mon Sep 17 00:00:00 2001 From: heitorlessa Date: Mon, 7 Nov 2022 16:43:45 +0100 Subject: [PATCH 25/28] fix(tests): compat remove_prefix for 3.7+ --- tests/functional/test_utilities_parameters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/functional/test_utilities_parameters.py b/tests/functional/test_utilities_parameters.py index f0f7c36fc94..27f32dbf585 100644 --- a/tests/functional/test_utilities_parameters.py +++ b/tests/functional/test_utilities_parameters.py @@ -59,7 +59,7 @@ def build_get_parameters_stub(params: Dict[str, Any], invalid_parameters: List[s "Selector": f"{param}:{version}", "SourceResult": "string", "LastModifiedDate": datetime(2015, 1, 1), - "ARN": f"arn:aws:ssm:us-east-2:111122223333:parameter/{param.removeprefix('/')}", + "ARN": f"arn:aws:ssm:us-east-2:111122223333:parameter/{param.lstrip('/')}", "DataType": "string", } for param, value in params.items() From 9fc46a7b9ae4b1e12f85c9050d63731cc560c464 Mon Sep 17 00:00:00 2001 From: heitorlessa Date: Mon, 7 Nov 2022 16:52:21 +0100 Subject: [PATCH 26/28] chore(mypy): ignore assignment overload value --- aws_lambda_powertools/utilities/parameters/ssm.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/aws_lambda_powertools/utilities/parameters/ssm.py b/aws_lambda_powertools/utilities/parameters/ssm.py index d28dbac5ee2..7eabb02cf35 100644 --- a/aws_lambda_powertools/utilities/parameters/ssm.py +++ b/aws_lambda_powertools/utilities/parameters/ssm.py @@ -407,9 +407,7 @@ def _transform_and_cache_get_parameters_response( # NOTE: If transform is set, we do it before caching to reduce number of operations if transform: - value = transform_value( - key=name, value=value, transform=transform, raise_on_transform_error=raise_on_error - ) # type: ignore[assignment] # transform dynamism challenge + value = transform_value(name, value, transform, raise_on_error) # type: ignore[assignment] _cache_key = (name, options["transform"]) self.add_to_cache(key=_cache_key, value=value, max_age=options["max_age"]) From 285c4312fc970c08fad7a6ad4d192aa5b06a16dd Mon Sep 17 00:00:00 2001 From: heitorlessa Date: Mon, 7 Nov 2022 16:55:27 +0100 Subject: [PATCH 27/28] chore(mypy): ignore call-overload and assignment overload value --- aws_lambda_powertools/utilities/parameters/ssm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aws_lambda_powertools/utilities/parameters/ssm.py b/aws_lambda_powertools/utilities/parameters/ssm.py index 7eabb02cf35..ae4a76dac4a 100644 --- a/aws_lambda_powertools/utilities/parameters/ssm.py +++ b/aws_lambda_powertools/utilities/parameters/ssm.py @@ -407,7 +407,7 @@ def _transform_and_cache_get_parameters_response( # NOTE: If transform is set, we do it before caching to reduce number of operations if transform: - value = transform_value(name, value, transform, raise_on_error) # type: ignore[assignment] + value = transform_value(name, value, transform, raise_on_error) # type: ignore _cache_key = (name, options["transform"]) self.add_to_cache(key=_cache_key, value=value, max_age=options["max_age"]) From 9e84a07f966dfd0402f9224ed85da484a6a25305 Mon Sep 17 00:00:00 2001 From: heitorlessa Date: Mon, 7 Nov 2022 17:06:07 +0100 Subject: [PATCH 28/28] fix(tests): boto3 client side effect on super class --- tests/functional/test_utilities_parameters.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/functional/test_utilities_parameters.py b/tests/functional/test_utilities_parameters.py index 27f32dbf585..c5e65c158be 100644 --- a/tests/functional/test_utilities_parameters.py +++ b/tests/functional/test_utilities_parameters.py @@ -2347,10 +2347,10 @@ def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]: assert value == mock_value -def test_cache_ignores_max_age_zero_or_negative(mock_value): +def test_cache_ignores_max_age_zero_or_negative(mock_value, config): # GIVEN we have two parameters that shouldn't be cached param = "/no_cache" - provider = SSMProvider() + provider = SSMProvider(config=config) cache_key = (param, None) # WHEN a provider adds them into the cache