Skip to content

refactor(event_source): centralizing helper functions for query, header and base64 #2496

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import base64
import json
from typing import Any, Iterator, Optional

from aws_lambda_powertools.utilities.data_classes.common import DictWrapper
from aws_lambda_powertools.utilities.data_classes.shared_functions import base64_decode


class ActiveMQMessage(DictWrapper):
Expand All @@ -22,13 +21,13 @@ def data(self) -> str:
@property
def decoded_data(self) -> str:
"""Decodes the data as a str"""
return base64.b64decode(self.data.encode()).decode()
return base64_decode(self.data)

@property
def json_data(self) -> Any:
"""Parses the data as json"""
if self._json_data is None:
self._json_data = json.loads(self.decoded_data)
self._json_data = self._json_deserializer(self.decoded_data)
return self._json_data

@property
Expand Down Expand Up @@ -125,7 +124,7 @@ def event_source_arn(self) -> str:
@property
def messages(self) -> Iterator[ActiveMQMessage]:
for record in self["messages"]:
yield ActiveMQMessage(record)
yield ActiveMQMessage(record, json_deserializer=self._json_deserializer)

@property
def message(self) -> ActiveMQMessage:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
BaseRequestContext,
BaseRequestContextV2,
DictWrapper,
)
from aws_lambda_powertools.utilities.data_classes.shared_functions import (
get_header_value,
)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, Dict, List, Optional, Union

from aws_lambda_powertools.utilities.data_classes.common import (
DictWrapper,
from aws_lambda_powertools.utilities.data_classes.common import DictWrapper
from aws_lambda_powertools.utilities.data_classes.shared_functions import (
get_header_value,
)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import json
from typing import Any, Dict, List, Optional

from aws_lambda_powertools.utilities.data_classes.common import DictWrapper
Expand Down Expand Up @@ -303,9 +302,9 @@ def invoking_event(
) -> AWSConfigConfigurationChanged | AWSConfigScheduledNotification | AWSConfigOversizedConfiguration:
"""The invoking payload of the event."""
if self._invoking_event is None:
self._invoking_event = self["invokingEvent"]
self._invoking_event = self._json_deserializer(self["invokingEvent"])

return get_invoke_event(json.loads(self._invoking_event))
return get_invoke_event(self._invoking_event)

@property
def raw_invoking_event(self) -> str:
Expand All @@ -316,9 +315,9 @@ def raw_invoking_event(self) -> str:
def rule_parameters(self) -> Dict:
"""The parameters of the event."""
if self._rule_parameters is None:
self._rule_parameters = self["ruleParameters"]
self._rule_parameters = self._json_deserializer(self["ruleParameters"])

return json.loads(self._rule_parameters)
return self._rule_parameters

@property
def result_token(self) -> str:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import base64
import json
import zlib
from typing import Dict, List, Optional

Expand Down Expand Up @@ -97,5 +96,6 @@ def decompress_logs_data(self) -> bytes:
def parse_logs_data(self) -> CloudWatchLogsDecodedData:
"""Decode, decompress and parse json data as CloudWatchLogsDecodedData"""
if self._json_logs_data is None:
self._json_logs_data = json.loads(self.decompress_logs_data.decode("UTF-8"))
self._json_logs_data = self._json_deserializer(self.decompress_logs_data.decode("UTF-8"))

return CloudWatchLogsDecodedData(self._json_logs_data)
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
import tempfile
import zipfile
from typing import Any, Dict, List, Optional
Expand All @@ -22,7 +21,7 @@ def user_parameters(self) -> Optional[str]:
def decoded_user_parameters(self) -> Optional[Dict[str, Any]]:
"""Json Decoded user parameters"""
if self._json_data is None and self.user_parameters is not None:
self._json_data = json.loads(self.user_parameters)
self._json_data = self._json_deserializer(self.user_parameters)
return self._json_data


Expand Down
33 changes: 10 additions & 23 deletions aws_lambda_powertools/utilities/data_classes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
from typing import Any, Callable, Dict, Iterator, List, Optional

from aws_lambda_powertools.shared.headers_serializer import BaseHeadersSerializer
from aws_lambda_powertools.utilities.data_classes.shared_functions import (
get_header_value,
get_query_string_value,
)


class DictWrapper(Mapping):
Expand Down Expand Up @@ -90,26 +94,6 @@ def raw_event(self) -> Dict[str, Any]:
return self._data


def get_header_value(
headers: Dict[str, str], name: str, default_value: Optional[str], case_sensitive: Optional[bool]
) -> Optional[str]:
"""Get header value by name"""
# If headers is NoneType, return default value
if not headers:
return default_value

if case_sensitive:
return headers.get(name, default_value)
name_lower = name.lower()

return next(
# Iterate over the dict and do a case-insensitive key comparison
(value for key, value in headers.items() if key.lower() == name_lower),
# Default value is returned if no matches was found
default_value,
)


class BaseProxyEvent(DictWrapper):
@property
def headers(self) -> Dict[str, str]:
Expand Down Expand Up @@ -166,8 +150,9 @@ def get_query_string_value(self, name: str, default_value: Optional[str] = None)
str, optional
Query string parameter value
"""
params = self.query_string_parameters
return default_value if params is None else params.get(name, default_value)
return get_query_string_value(
query_string_parameters=self.query_string_parameters, name=name, default_value=default_value
)

def get_header_value(
self, name: str, default_value: Optional[str] = None, case_sensitive: Optional[bool] = False
Expand All @@ -187,7 +172,9 @@ def get_header_value(
str, optional
Header value
"""
return get_header_value(self.headers, name, default_value, case_sensitive)
return get_header_value(
headers=self.headers, name=name, default_value=default_value, case_sensitive=case_sensitive
)

def header_serializer(self) -> BaseHeadersSerializer:
raise NotImplementedError()
Expand Down
16 changes: 6 additions & 10 deletions aws_lambda_powertools/utilities/data_classes/kafka_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
from typing import Any, Dict, Iterator, List, Optional

from aws_lambda_powertools.utilities.data_classes.common import DictWrapper
from aws_lambda_powertools.utilities.data_classes.shared_functions import (
get_header_value,
)


class KafkaEventRecord(DictWrapper):
Expand Down Expand Up @@ -69,17 +72,10 @@ def decoded_headers(self) -> Dict[str, bytes]:

def get_header_value(
self, name: str, default_value: Optional[Any] = None, case_sensitive: bool = True
) -> Optional[bytes]:
) -> Optional[str]:
"""Get a decoded header value by name."""
if case_sensitive:
return self.decoded_headers.get(name, default_value)
name_lower = name.lower()

return next(
# Iterate over the dict and do a case-insensitive key comparison
(value for key, value in self.decoded_headers.items() if key.lower() == name_lower),
# Default value is returned if no matches was found
default_value,
return get_header_value(
headers=self.decoded_headers, name=name, default_value=default_value, case_sensitive=case_sensitive
)


Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import base64
import json
from typing import Any, Dict, List

from aws_lambda_powertools.utilities.data_classes.common import DictWrapper
from aws_lambda_powertools.utilities.data_classes.shared_functions import base64_decode


class BasicProperties(DictWrapper):
Expand Down Expand Up @@ -83,13 +82,13 @@ def data(self) -> str:
@property
def decoded_data(self) -> str:
"""Decodes the data as a str"""
return base64.b64decode(self.data.encode()).decode()
return base64_decode(self.data)

@property
def json_data(self) -> Any:
"""Parses the data as json"""
if self._json_data is None:
self._json_data = json.loads(self.decoded_data)
self._json_data = self._json_deserializer(self.decoded_data)
return self._json_data


Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Dict, Optional

from aws_lambda_powertools.utilities.data_classes.common import (
DictWrapper,
from aws_lambda_powertools.utilities.data_classes.common import DictWrapper
from aws_lambda_powertools.utilities.data_classes.shared_functions import (
get_header_value,
)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from __future__ import annotations

import base64
from typing import Any


def base64_decode(value: str) -> str:
"""
Decodes a Base64-encoded string and returns the decoded value.
Parameters
----------
value: str
The Base64-encoded string to decode.
Returns
-------
str
The decoded string value.
"""
return base64.b64decode(value).decode("UTF-8")


def get_header_value(
headers: dict[str, Any], name: str, default_value: str | None, case_sensitive: bool | None
) -> str | None:
"""
Get the value of a header by its name.
Parameters
----------
headers: Dict[str, str]
The dictionary of headers.
name: str
The name of the header to retrieve.
default_value: str, optional
The default value to return if the header is not found. Default is None.
case_sensitive: bool, optional
Indicates whether the header name should be case-sensitive. Default is None.
Returns
-------
str, optional
The value of the header if found, otherwise the default value or None.
"""
# If headers is NoneType, return default value
if not headers:
return default_value

if case_sensitive:
return headers.get(name, default_value)
name_lower = name.lower()

return next(
# Iterate over the dict and do a case-insensitive key comparison
(value for key, value in headers.items() if key.lower() == name_lower),
# Default value is returned if no matches was found
default_value,
)


def get_query_string_value(
query_string_parameters: dict[str, str] | None, name: str, default_value: str | None = None
) -> str | None:
"""
Retrieves the value of a query string parameter specified by the given name.
Parameters
----------
name: str
The name of the query string parameter to retrieve.
default_value: str, optional
The default value to return if the parameter is not found. Defaults to None.
Returns
-------
str. optional
The value of the query string parameter if found, or the default value if not found.
"""
params = query_string_parameters
return default_value if params is None else params.get(name, default_value)
18 changes: 11 additions & 7 deletions aws_lambda_powertools/utilities/data_classes/vpc_lattice.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import base64
from typing import Any, Dict, Optional

from aws_lambda_powertools.utilities.data_classes.common import (
DictWrapper,
from aws_lambda_powertools.utilities.data_classes.common import DictWrapper
from aws_lambda_powertools.utilities.data_classes.shared_functions import (
base64_decode,
get_header_value,
get_query_string_value,
)


Expand Down Expand Up @@ -35,7 +36,7 @@ def decoded_body(self) -> str:
"""Dynamically base64 decode body as a str"""
body: str = self["body"]
if self.is_base64_encoded:
return base64.b64decode(body.encode()).decode()
return base64_decode(body)
return body

@property
Expand Down Expand Up @@ -67,8 +68,9 @@ def get_query_string_value(self, name: str, default_value: Optional[str] = None)
str, optional
Query string parameter value
"""
params = self.query_string_parameters
return default_value if params is None else params.get(name, default_value)
return get_query_string_value(
query_string_parameters=self.query_string_parameters, name=name, default_value=default_value
)

def get_header_value(
self, name: str, default_value: Optional[str] = None, case_sensitive: Optional[bool] = False
Expand All @@ -88,4 +90,6 @@ def get_header_value(
str, optional
Header value
"""
return get_header_value(self.headers, name, default_value, case_sensitive)
return get_header_value(
headers=self.headers, name=name, default_value=default_value, case_sensitive=case_sensitive
)