From f6eb2b148526446bd1a00a31bbabe248e2ccf464 Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Sun, 18 Jun 2023 16:17:52 +0100 Subject: [PATCH 1/2] tech-debt: centralizing functions --- .../utilities/data_classes/active_mq_event.py | 9 +-- .../api_gateway_authorizer_event.py | 2 + .../data_classes/appsync_resolver_event.py | 4 +- .../data_classes/aws_config_rule_event.py | 9 +-- .../data_classes/cloud_watch_logs_event.py | 4 +- .../data_classes/code_pipeline_job_event.py | 3 +- .../utilities/data_classes/common.py | 27 ++----- .../utilities/data_classes/kafka_event.py | 16 ++-- .../utilities/data_classes/rabbit_mq_event.py | 7 +- .../utilities/data_classes/s3_object_event.py | 4 +- .../data_classes/shared_functions.py | 81 +++++++++++++++++++ .../utilities/data_classes/vpc_lattice.py | 12 +-- 12 files changed, 117 insertions(+), 61 deletions(-) create mode 100644 aws_lambda_powertools/utilities/data_classes/shared_functions.py diff --git a/aws_lambda_powertools/utilities/data_classes/active_mq_event.py b/aws_lambda_powertools/utilities/data_classes/active_mq_event.py index 94929a79572..94addef92fa 100644 --- a/aws_lambda_powertools/utilities/data_classes/active_mq_event.py +++ b/aws_lambda_powertools/utilities/data_classes/active_mq_event.py @@ -1,8 +1,7 @@ -import base64 -import json from typing import Any, Iterator, Optional from aws_lambda_powertools.utilities.data_classes.common import DictWrapper +from aws_lambda_powertools.utilities.data_classes.shared_functions import base64_decode class ActiveMQMessage(DictWrapper): @@ -22,13 +21,13 @@ def data(self) -> str: @property def decoded_data(self) -> str: """Decodes the data as a str""" - return base64.b64decode(self.data.encode()).decode() + return base64_decode(self.data) @property def json_data(self) -> Any: """Parses the data as json""" if self._json_data is None: - self._json_data = json.loads(self.decoded_data) + self._json_data = self._json_deserializer(self.decoded_data) return self._json_data @property @@ -125,7 +124,7 @@ def event_source_arn(self) -> str: @property def messages(self) -> Iterator[ActiveMQMessage]: for record in self["messages"]: - yield ActiveMQMessage(record) + yield ActiveMQMessage(record, json_deserializer=self._json_deserializer) @property def message(self) -> ActiveMQMessage: diff --git a/aws_lambda_powertools/utilities/data_classes/api_gateway_authorizer_event.py b/aws_lambda_powertools/utilities/data_classes/api_gateway_authorizer_event.py index 431d678e9b6..a8897cce5b0 100644 --- a/aws_lambda_powertools/utilities/data_classes/api_gateway_authorizer_event.py +++ b/aws_lambda_powertools/utilities/data_classes/api_gateway_authorizer_event.py @@ -6,6 +6,8 @@ BaseRequestContext, BaseRequestContextV2, DictWrapper, +) +from aws_lambda_powertools.utilities.data_classes.shared_functions import ( get_header_value, ) diff --git a/aws_lambda_powertools/utilities/data_classes/appsync_resolver_event.py b/aws_lambda_powertools/utilities/data_classes/appsync_resolver_event.py index 30cd497e514..fc54f334cab 100644 --- a/aws_lambda_powertools/utilities/data_classes/appsync_resolver_event.py +++ b/aws_lambda_powertools/utilities/data_classes/appsync_resolver_event.py @@ -1,7 +1,7 @@ from typing import Any, Dict, List, Optional, Union -from aws_lambda_powertools.utilities.data_classes.common import ( - DictWrapper, +from aws_lambda_powertools.utilities.data_classes.common import DictWrapper +from aws_lambda_powertools.utilities.data_classes.shared_functions import ( get_header_value, ) diff --git a/aws_lambda_powertools/utilities/data_classes/aws_config_rule_event.py b/aws_lambda_powertools/utilities/data_classes/aws_config_rule_event.py index 2bfa2df61c5..f8d4f991cc0 100644 --- a/aws_lambda_powertools/utilities/data_classes/aws_config_rule_event.py +++ b/aws_lambda_powertools/utilities/data_classes/aws_config_rule_event.py @@ -1,6 +1,5 @@ from __future__ import annotations -import json from typing import Any, Dict, List, Optional from aws_lambda_powertools.utilities.data_classes.common import DictWrapper @@ -303,9 +302,9 @@ def invoking_event( ) -> AWSConfigConfigurationChanged | AWSConfigScheduledNotification | AWSConfigOversizedConfiguration: """The invoking payload of the event.""" if self._invoking_event is None: - self._invoking_event = self["invokingEvent"] + self._invoking_event = self._json_deserializer(self["invokingEvent"]) - return get_invoke_event(json.loads(self._invoking_event)) + return get_invoke_event(self._invoking_event) @property def raw_invoking_event(self) -> str: @@ -316,9 +315,9 @@ def raw_invoking_event(self) -> str: def rule_parameters(self) -> Dict: """The parameters of the event.""" if self._rule_parameters is None: - self._rule_parameters = self["ruleParameters"] + self._rule_parameters = self._json_deserializer(self["ruleParameters"]) - return json.loads(self._rule_parameters) + return self._rule_parameters @property def result_token(self) -> str: diff --git a/aws_lambda_powertools/utilities/data_classes/cloud_watch_logs_event.py b/aws_lambda_powertools/utilities/data_classes/cloud_watch_logs_event.py index 978f6956fc2..b12e941a062 100644 --- a/aws_lambda_powertools/utilities/data_classes/cloud_watch_logs_event.py +++ b/aws_lambda_powertools/utilities/data_classes/cloud_watch_logs_event.py @@ -1,5 +1,4 @@ import base64 -import json import zlib from typing import Dict, List, Optional @@ -97,5 +96,6 @@ def decompress_logs_data(self) -> bytes: def parse_logs_data(self) -> CloudWatchLogsDecodedData: """Decode, decompress and parse json data as CloudWatchLogsDecodedData""" if self._json_logs_data is None: - self._json_logs_data = json.loads(self.decompress_logs_data.decode("UTF-8")) + self._json_logs_data = self._json_deserializer(self.decompress_logs_data.decode("UTF-8")) + return CloudWatchLogsDecodedData(self._json_logs_data) diff --git a/aws_lambda_powertools/utilities/data_classes/code_pipeline_job_event.py b/aws_lambda_powertools/utilities/data_classes/code_pipeline_job_event.py index 96c209e0eca..1813d6016b5 100644 --- a/aws_lambda_powertools/utilities/data_classes/code_pipeline_job_event.py +++ b/aws_lambda_powertools/utilities/data_classes/code_pipeline_job_event.py @@ -1,4 +1,3 @@ -import json import tempfile import zipfile from typing import Any, Dict, List, Optional @@ -22,7 +21,7 @@ def user_parameters(self) -> Optional[str]: def decoded_user_parameters(self) -> Optional[Dict[str, Any]]: """Json Decoded user parameters""" if self._json_data is None and self.user_parameters is not None: - self._json_data = json.loads(self.user_parameters) + self._json_data = self._json_deserializer(self.user_parameters) return self._json_data diff --git a/aws_lambda_powertools/utilities/data_classes/common.py b/aws_lambda_powertools/utilities/data_classes/common.py index 8e2cea7a5cf..1b64f21b56d 100644 --- a/aws_lambda_powertools/utilities/data_classes/common.py +++ b/aws_lambda_powertools/utilities/data_classes/common.py @@ -4,6 +4,10 @@ from typing import Any, Callable, Dict, Iterator, List, Optional from aws_lambda_powertools.shared.headers_serializer import BaseHeadersSerializer +from aws_lambda_powertools.utilities.data_classes.shared_functions import ( + get_header_value, + get_query_string_value, +) class DictWrapper(Mapping): @@ -90,26 +94,6 @@ def raw_event(self) -> Dict[str, Any]: return self._data -def get_header_value( - headers: Dict[str, str], name: str, default_value: Optional[str], case_sensitive: Optional[bool] -) -> Optional[str]: - """Get header value by name""" - # If headers is NoneType, return default value - if not headers: - return default_value - - if case_sensitive: - return headers.get(name, default_value) - name_lower = name.lower() - - return next( - # Iterate over the dict and do a case-insensitive key comparison - (value for key, value in headers.items() if key.lower() == name_lower), - # Default value is returned if no matches was found - default_value, - ) - - class BaseProxyEvent(DictWrapper): @property def headers(self) -> Dict[str, str]: @@ -166,8 +150,7 @@ def get_query_string_value(self, name: str, default_value: Optional[str] = None) str, optional Query string parameter value """ - params = self.query_string_parameters - return default_value if params is None else params.get(name, default_value) + return get_query_string_value(self.query_string_parameters, name, default_value) def get_header_value( self, name: str, default_value: Optional[str] = None, case_sensitive: Optional[bool] = False diff --git a/aws_lambda_powertools/utilities/data_classes/kafka_event.py b/aws_lambda_powertools/utilities/data_classes/kafka_event.py index 4773d9e50de..c2f852602e8 100644 --- a/aws_lambda_powertools/utilities/data_classes/kafka_event.py +++ b/aws_lambda_powertools/utilities/data_classes/kafka_event.py @@ -2,6 +2,9 @@ from typing import Any, Dict, Iterator, List, Optional from aws_lambda_powertools.utilities.data_classes.common import DictWrapper +from aws_lambda_powertools.utilities.data_classes.shared_functions import ( + get_header_value, +) class KafkaEventRecord(DictWrapper): @@ -69,18 +72,9 @@ def decoded_headers(self) -> Dict[str, bytes]: def get_header_value( self, name: str, default_value: Optional[Any] = None, case_sensitive: bool = True - ) -> Optional[bytes]: + ) -> Optional[str]: """Get a decoded header value by name.""" - if case_sensitive: - return self.decoded_headers.get(name, default_value) - name_lower = name.lower() - - return next( - # Iterate over the dict and do a case-insensitive key comparison - (value for key, value in self.decoded_headers.items() if key.lower() == name_lower), - # Default value is returned if no matches was found - default_value, - ) + return get_header_value(self.decoded_headers, name, default_value, case_sensitive) class KafkaEvent(DictWrapper): diff --git a/aws_lambda_powertools/utilities/data_classes/rabbit_mq_event.py b/aws_lambda_powertools/utilities/data_classes/rabbit_mq_event.py index 0822a58da18..ab792f3b893 100644 --- a/aws_lambda_powertools/utilities/data_classes/rabbit_mq_event.py +++ b/aws_lambda_powertools/utilities/data_classes/rabbit_mq_event.py @@ -1,8 +1,7 @@ -import base64 -import json from typing import Any, Dict, List from aws_lambda_powertools.utilities.data_classes.common import DictWrapper +from aws_lambda_powertools.utilities.data_classes.shared_functions import base64_decode class BasicProperties(DictWrapper): @@ -83,13 +82,13 @@ def data(self) -> str: @property def decoded_data(self) -> str: """Decodes the data as a str""" - return base64.b64decode(self.data.encode()).decode() + return base64_decode(self.data) @property def json_data(self) -> Any: """Parses the data as json""" if self._json_data is None: - self._json_data = json.loads(self.decoded_data) + self._json_data = self._json_deserializer(self.decoded_data) return self._json_data diff --git a/aws_lambda_powertools/utilities/data_classes/s3_object_event.py b/aws_lambda_powertools/utilities/data_classes/s3_object_event.py index 45985120698..8cff4cfd59a 100644 --- a/aws_lambda_powertools/utilities/data_classes/s3_object_event.py +++ b/aws_lambda_powertools/utilities/data_classes/s3_object_event.py @@ -1,7 +1,7 @@ from typing import Dict, Optional -from aws_lambda_powertools.utilities.data_classes.common import ( - DictWrapper, +from aws_lambda_powertools.utilities.data_classes.common import DictWrapper +from aws_lambda_powertools.utilities.data_classes.shared_functions import ( get_header_value, ) diff --git a/aws_lambda_powertools/utilities/data_classes/shared_functions.py b/aws_lambda_powertools/utilities/data_classes/shared_functions.py new file mode 100644 index 00000000000..fe5108b2b20 --- /dev/null +++ b/aws_lambda_powertools/utilities/data_classes/shared_functions.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +import base64 +from typing import Any + + +def base64_decode(value: str) -> str: + """ + Decodes a Base64-encoded string and returns the decoded value. + + Parameters + ---------- + value: str + The Base64-encoded string to decode. + + Returns + ------- + str + The decoded string value. + """ + return base64.b64decode(value).decode("UTF-8") + + +def get_header_value( + headers: dict[str, Any], name: str, default_value: str | None, case_sensitive: bool | None +) -> str | None: + """ + Get the value of a header by its name. + + Parameters + ---------- + headers: Dict[str, str] + The dictionary of headers. + name: str + The name of the header to retrieve. + default_value: str, optional + The default value to return if the header is not found. Default is None. + case_sensitive: bool, optional + Indicates whether the header name should be case-sensitive. Default is None. + + Returns + ------- + str, optional + The value of the header if found, otherwise the default value or None. + """ + # If headers is NoneType, return default value + if not headers: + return default_value + + if case_sensitive: + return headers.get(name, default_value) + name_lower = name.lower() + + return next( + # Iterate over the dict and do a case-insensitive key comparison + (value for key, value in headers.items() if key.lower() == name_lower), + # Default value is returned if no matches was found + default_value, + ) + + +def get_query_string_value( + query_string_parameters: dict[str, str] | None, name: str, default_value: str | None = None +) -> str | None: + """ + Retrieves the value of a query string parameter specified by the given name. + + Parameters + ---------- + name: str + The name of the query string parameter to retrieve. + default_value: str, optional + The default value to return if the parameter is not found. Defaults to None. + + Returns + ------- + str. optional + The value of the query string parameter if found, or the default value if not found. + """ + params = query_string_parameters + return default_value if params is None else params.get(name, default_value) diff --git a/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py b/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py index 4e503daf4ab..6a40cbe4392 100644 --- a/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py +++ b/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py @@ -1,9 +1,10 @@ -import base64 from typing import Any, Dict, Optional -from aws_lambda_powertools.utilities.data_classes.common import ( - DictWrapper, +from aws_lambda_powertools.utilities.data_classes.common import DictWrapper +from aws_lambda_powertools.utilities.data_classes.shared_functions import ( + base64_decode, get_header_value, + get_query_string_value, ) @@ -35,7 +36,7 @@ def decoded_body(self) -> str: """Dynamically base64 decode body as a str""" body: str = self["body"] if self.is_base64_encoded: - return base64.b64decode(body.encode()).decode() + return base64_decode(body) return body @property @@ -67,8 +68,7 @@ def get_query_string_value(self, name: str, default_value: Optional[str] = None) str, optional Query string parameter value """ - params = self.query_string_parameters - return default_value if params is None else params.get(name, default_value) + return get_query_string_value(self.query_string_parameters, name, default_value) def get_header_value( self, name: str, default_value: Optional[str] = None, case_sensitive: Optional[bool] = False From 66d530d57130f4d25c3471a0bfbb3cea12dedd27 Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Mon, 19 Jun 2023 20:44:48 +0100 Subject: [PATCH 2/2] addressing Heitor's feedback --- aws_lambda_powertools/utilities/data_classes/common.py | 8 ++++++-- .../utilities/data_classes/kafka_event.py | 4 +++- .../utilities/data_classes/vpc_lattice.py | 8 ++++++-- 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/aws_lambda_powertools/utilities/data_classes/common.py b/aws_lambda_powertools/utilities/data_classes/common.py index 1b64f21b56d..a862c7da454 100644 --- a/aws_lambda_powertools/utilities/data_classes/common.py +++ b/aws_lambda_powertools/utilities/data_classes/common.py @@ -150,7 +150,9 @@ def get_query_string_value(self, name: str, default_value: Optional[str] = None) str, optional Query string parameter value """ - return get_query_string_value(self.query_string_parameters, name, default_value) + return get_query_string_value( + query_string_parameters=self.query_string_parameters, name=name, default_value=default_value + ) def get_header_value( self, name: str, default_value: Optional[str] = None, case_sensitive: Optional[bool] = False @@ -170,7 +172,9 @@ def get_header_value( str, optional Header value """ - return get_header_value(self.headers, name, default_value, case_sensitive) + return get_header_value( + headers=self.headers, name=name, default_value=default_value, case_sensitive=case_sensitive + ) def header_serializer(self) -> BaseHeadersSerializer: raise NotImplementedError() diff --git a/aws_lambda_powertools/utilities/data_classes/kafka_event.py b/aws_lambda_powertools/utilities/data_classes/kafka_event.py index c2f852602e8..6172ededc05 100644 --- a/aws_lambda_powertools/utilities/data_classes/kafka_event.py +++ b/aws_lambda_powertools/utilities/data_classes/kafka_event.py @@ -74,7 +74,9 @@ def get_header_value( self, name: str, default_value: Optional[Any] = None, case_sensitive: bool = True ) -> Optional[str]: """Get a decoded header value by name.""" - return get_header_value(self.decoded_headers, name, default_value, case_sensitive) + return get_header_value( + headers=self.decoded_headers, name=name, default_value=default_value, case_sensitive=case_sensitive + ) class KafkaEvent(DictWrapper): diff --git a/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py b/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py index 6a40cbe4392..e5e126d2702 100644 --- a/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py +++ b/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py @@ -68,7 +68,9 @@ def get_query_string_value(self, name: str, default_value: Optional[str] = None) str, optional Query string parameter value """ - return get_query_string_value(self.query_string_parameters, name, default_value) + return get_query_string_value( + query_string_parameters=self.query_string_parameters, name=name, default_value=default_value + ) def get_header_value( self, name: str, default_value: Optional[str] = None, case_sensitive: Optional[bool] = False @@ -88,4 +90,6 @@ def get_header_value( str, optional Header value """ - return get_header_value(self.headers, name, default_value, case_sensitive) + return get_header_value( + headers=self.headers, name=name, default_value=default_value, case_sensitive=case_sensitive + )