Skip to content

Commit c9a9157

Browse files
leandrodamascenarafaelgsr
authored andcommitted
fix(event_source): centralizing helper functions for query, header and base64 (aws-powertools#2496)
1 parent e5b59f5 commit c9a9157

12 files changed

+128
-62
lines changed

aws_lambda_powertools/utilities/data_classes/active_mq_event.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
import base64
2-
import json
31
from typing import Any, Iterator, Optional
42

53
from aws_lambda_powertools.utilities.data_classes.common import DictWrapper
4+
from aws_lambda_powertools.utilities.data_classes.shared_functions import base64_decode
65

76

87
class ActiveMQMessage(DictWrapper):
@@ -22,13 +21,13 @@ def data(self) -> str:
2221
@property
2322
def decoded_data(self) -> str:
2423
"""Decodes the data as a str"""
25-
return base64.b64decode(self.data.encode()).decode()
24+
return base64_decode(self.data)
2625

2726
@property
2827
def json_data(self) -> Any:
2928
"""Parses the data as json"""
3029
if self._json_data is None:
31-
self._json_data = json.loads(self.decoded_data)
30+
self._json_data = self._json_deserializer(self.decoded_data)
3231
return self._json_data
3332

3433
@property
@@ -125,7 +124,7 @@ def event_source_arn(self) -> str:
125124
@property
126125
def messages(self) -> Iterator[ActiveMQMessage]:
127126
for record in self["messages"]:
128-
yield ActiveMQMessage(record)
127+
yield ActiveMQMessage(record, json_deserializer=self._json_deserializer)
129128

130129
@property
131130
def message(self) -> ActiveMQMessage:

aws_lambda_powertools/utilities/data_classes/api_gateway_authorizer_event.py

+2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
BaseRequestContext,
77
BaseRequestContextV2,
88
DictWrapper,
9+
)
10+
from aws_lambda_powertools.utilities.data_classes.shared_functions import (
911
get_header_value,
1012
)
1113

aws_lambda_powertools/utilities/data_classes/appsync_resolver_event.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Any, Dict, List, Optional, Union
22

3-
from aws_lambda_powertools.utilities.data_classes.common import (
4-
DictWrapper,
3+
from aws_lambda_powertools.utilities.data_classes.common import DictWrapper
4+
from aws_lambda_powertools.utilities.data_classes.shared_functions import (
55
get_header_value,
66
)
77

aws_lambda_powertools/utilities/data_classes/aws_config_rule_event.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
import json
43
from typing import Any, Dict, List, Optional
54

65
from aws_lambda_powertools.utilities.data_classes.common import DictWrapper
@@ -303,9 +302,9 @@ def invoking_event(
303302
) -> AWSConfigConfigurationChanged | AWSConfigScheduledNotification | AWSConfigOversizedConfiguration:
304303
"""The invoking payload of the event."""
305304
if self._invoking_event is None:
306-
self._invoking_event = self["invokingEvent"]
305+
self._invoking_event = self._json_deserializer(self["invokingEvent"])
307306

308-
return get_invoke_event(json.loads(self._invoking_event))
307+
return get_invoke_event(self._invoking_event)
309308

310309
@property
311310
def raw_invoking_event(self) -> str:
@@ -316,9 +315,9 @@ def raw_invoking_event(self) -> str:
316315
def rule_parameters(self) -> Dict:
317316
"""The parameters of the event."""
318317
if self._rule_parameters is None:
319-
self._rule_parameters = self["ruleParameters"]
318+
self._rule_parameters = self._json_deserializer(self["ruleParameters"])
320319

321-
return json.loads(self._rule_parameters)
320+
return self._rule_parameters
322321

323322
@property
324323
def result_token(self) -> str:

aws_lambda_powertools/utilities/data_classes/cloud_watch_logs_event.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import base64
2-
import json
32
import zlib
43
from typing import Dict, List, Optional
54

@@ -97,5 +96,6 @@ def decompress_logs_data(self) -> bytes:
9796
def parse_logs_data(self) -> CloudWatchLogsDecodedData:
9897
"""Decode, decompress and parse json data as CloudWatchLogsDecodedData"""
9998
if self._json_logs_data is None:
100-
self._json_logs_data = json.loads(self.decompress_logs_data.decode("UTF-8"))
99+
self._json_logs_data = self._json_deserializer(self.decompress_logs_data.decode("UTF-8"))
100+
101101
return CloudWatchLogsDecodedData(self._json_logs_data)

aws_lambda_powertools/utilities/data_classes/code_pipeline_job_event.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import json
21
import tempfile
32
import zipfile
43
from typing import Any, Dict, List, Optional
@@ -22,7 +21,7 @@ def user_parameters(self) -> Optional[str]:
2221
def decoded_user_parameters(self) -> Optional[Dict[str, Any]]:
2322
"""Json Decoded user parameters"""
2423
if self._json_data is None and self.user_parameters is not None:
25-
self._json_data = json.loads(self.user_parameters)
24+
self._json_data = self._json_deserializer(self.user_parameters)
2625
return self._json_data
2726

2827

aws_lambda_powertools/utilities/data_classes/common.py

+10-23
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
from typing import Any, Callable, Dict, Iterator, List, Optional
55

66
from aws_lambda_powertools.shared.headers_serializer import BaseHeadersSerializer
7+
from aws_lambda_powertools.utilities.data_classes.shared_functions import (
8+
get_header_value,
9+
get_query_string_value,
10+
)
711

812

913
class DictWrapper(Mapping):
@@ -90,26 +94,6 @@ def raw_event(self) -> Dict[str, Any]:
9094
return self._data
9195

9296

93-
def get_header_value(
94-
headers: Dict[str, str], name: str, default_value: Optional[str], case_sensitive: Optional[bool]
95-
) -> Optional[str]:
96-
"""Get header value by name"""
97-
# If headers is NoneType, return default value
98-
if not headers:
99-
return default_value
100-
101-
if case_sensitive:
102-
return headers.get(name, default_value)
103-
name_lower = name.lower()
104-
105-
return next(
106-
# Iterate over the dict and do a case-insensitive key comparison
107-
(value for key, value in headers.items() if key.lower() == name_lower),
108-
# Default value is returned if no matches was found
109-
default_value,
110-
)
111-
112-
11397
class BaseProxyEvent(DictWrapper):
11498
@property
11599
def headers(self) -> Dict[str, str]:
@@ -166,8 +150,9 @@ def get_query_string_value(self, name: str, default_value: Optional[str] = None)
166150
str, optional
167151
Query string parameter value
168152
"""
169-
params = self.query_string_parameters
170-
return default_value if params is None else params.get(name, default_value)
153+
return get_query_string_value(
154+
query_string_parameters=self.query_string_parameters, name=name, default_value=default_value
155+
)
171156

172157
def get_header_value(
173158
self, name: str, default_value: Optional[str] = None, case_sensitive: Optional[bool] = False
@@ -187,7 +172,9 @@ def get_header_value(
187172
str, optional
188173
Header value
189174
"""
190-
return get_header_value(self.headers, name, default_value, case_sensitive)
175+
return get_header_value(
176+
headers=self.headers, name=name, default_value=default_value, case_sensitive=case_sensitive
177+
)
191178

192179
def header_serializer(self) -> BaseHeadersSerializer:
193180
raise NotImplementedError()

aws_lambda_powertools/utilities/data_classes/kafka_event.py

+6-10
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
from typing import Any, Dict, Iterator, List, Optional
33

44
from aws_lambda_powertools.utilities.data_classes.common import DictWrapper
5+
from aws_lambda_powertools.utilities.data_classes.shared_functions import (
6+
get_header_value,
7+
)
58

69

710
class KafkaEventRecord(DictWrapper):
@@ -69,17 +72,10 @@ def decoded_headers(self) -> Dict[str, bytes]:
6972

7073
def get_header_value(
7174
self, name: str, default_value: Optional[Any] = None, case_sensitive: bool = True
72-
) -> Optional[bytes]:
75+
) -> Optional[str]:
7376
"""Get a decoded header value by name."""
74-
if case_sensitive:
75-
return self.decoded_headers.get(name, default_value)
76-
name_lower = name.lower()
77-
78-
return next(
79-
# Iterate over the dict and do a case-insensitive key comparison
80-
(value for key, value in self.decoded_headers.items() if key.lower() == name_lower),
81-
# Default value is returned if no matches was found
82-
default_value,
77+
return get_header_value(
78+
headers=self.decoded_headers, name=name, default_value=default_value, case_sensitive=case_sensitive
8379
)
8480

8581

aws_lambda_powertools/utilities/data_classes/rabbit_mq_event.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
import base64
2-
import json
31
from typing import Any, Dict, List
42

53
from aws_lambda_powertools.utilities.data_classes.common import DictWrapper
4+
from aws_lambda_powertools.utilities.data_classes.shared_functions import base64_decode
65

76

87
class BasicProperties(DictWrapper):
@@ -83,13 +82,13 @@ def data(self) -> str:
8382
@property
8483
def decoded_data(self) -> str:
8584
"""Decodes the data as a str"""
86-
return base64.b64decode(self.data.encode()).decode()
85+
return base64_decode(self.data)
8786

8887
@property
8988
def json_data(self) -> Any:
9089
"""Parses the data as json"""
9190
if self._json_data is None:
92-
self._json_data = json.loads(self.decoded_data)
91+
self._json_data = self._json_deserializer(self.decoded_data)
9392
return self._json_data
9493

9594

aws_lambda_powertools/utilities/data_classes/s3_object_event.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Dict, Optional
22

3-
from aws_lambda_powertools.utilities.data_classes.common import (
4-
DictWrapper,
3+
from aws_lambda_powertools.utilities.data_classes.common import DictWrapper
4+
from aws_lambda_powertools.utilities.data_classes.shared_functions import (
55
get_header_value,
66
)
77

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
from __future__ import annotations
2+
3+
import base64
4+
from typing import Any
5+
6+
7+
def base64_decode(value: str) -> str:
8+
"""
9+
Decodes a Base64-encoded string and returns the decoded value.
10+
11+
Parameters
12+
----------
13+
value: str
14+
The Base64-encoded string to decode.
15+
16+
Returns
17+
-------
18+
str
19+
The decoded string value.
20+
"""
21+
return base64.b64decode(value).decode("UTF-8")
22+
23+
24+
def get_header_value(
25+
headers: dict[str, Any], name: str, default_value: str | None, case_sensitive: bool | None
26+
) -> str | None:
27+
"""
28+
Get the value of a header by its name.
29+
30+
Parameters
31+
----------
32+
headers: Dict[str, str]
33+
The dictionary of headers.
34+
name: str
35+
The name of the header to retrieve.
36+
default_value: str, optional
37+
The default value to return if the header is not found. Default is None.
38+
case_sensitive: bool, optional
39+
Indicates whether the header name should be case-sensitive. Default is None.
40+
41+
Returns
42+
-------
43+
str, optional
44+
The value of the header if found, otherwise the default value or None.
45+
"""
46+
# If headers is NoneType, return default value
47+
if not headers:
48+
return default_value
49+
50+
if case_sensitive:
51+
return headers.get(name, default_value)
52+
name_lower = name.lower()
53+
54+
return next(
55+
# Iterate over the dict and do a case-insensitive key comparison
56+
(value for key, value in headers.items() if key.lower() == name_lower),
57+
# Default value is returned if no matches was found
58+
default_value,
59+
)
60+
61+
62+
def get_query_string_value(
63+
query_string_parameters: dict[str, str] | None, name: str, default_value: str | None = None
64+
) -> str | None:
65+
"""
66+
Retrieves the value of a query string parameter specified by the given name.
67+
68+
Parameters
69+
----------
70+
name: str
71+
The name of the query string parameter to retrieve.
72+
default_value: str, optional
73+
The default value to return if the parameter is not found. Defaults to None.
74+
75+
Returns
76+
-------
77+
str. optional
78+
The value of the query string parameter if found, or the default value if not found.
79+
"""
80+
params = query_string_parameters
81+
return default_value if params is None else params.get(name, default_value)

aws_lambda_powertools/utilities/data_classes/vpc_lattice.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
import base64
21
from typing import Any, Dict, Optional
32

4-
from aws_lambda_powertools.utilities.data_classes.common import (
5-
DictWrapper,
3+
from aws_lambda_powertools.utilities.data_classes.common import DictWrapper
4+
from aws_lambda_powertools.utilities.data_classes.shared_functions import (
5+
base64_decode,
66
get_header_value,
7+
get_query_string_value,
78
)
89

910

@@ -35,7 +36,7 @@ def decoded_body(self) -> str:
3536
"""Dynamically base64 decode body as a str"""
3637
body: str = self["body"]
3738
if self.is_base64_encoded:
38-
return base64.b64decode(body.encode()).decode()
39+
return base64_decode(body)
3940
return body
4041

4142
@property
@@ -67,8 +68,9 @@ def get_query_string_value(self, name: str, default_value: Optional[str] = None)
6768
str, optional
6869
Query string parameter value
6970
"""
70-
params = self.query_string_parameters
71-
return default_value if params is None else params.get(name, default_value)
71+
return get_query_string_value(
72+
query_string_parameters=self.query_string_parameters, name=name, default_value=default_value
73+
)
7274

7375
def get_header_value(
7476
self, name: str, default_value: Optional[str] = None, case_sensitive: Optional[bool] = False
@@ -88,4 +90,6 @@ def get_header_value(
8890
str, optional
8991
Header value
9092
"""
91-
return get_header_value(self.headers, name, default_value, case_sensitive)
93+
return get_header_value(
94+
headers=self.headers, name=name, default_value=default_value, case_sensitive=case_sensitive
95+
)

0 commit comments

Comments
 (0)