Skip to content

Commit 36905b5

Browse files
authored
fix(event-handler): multi-value query string and validation of scalar parameters (#3795)
1 parent 770f023 commit 36905b5

File tree

8 files changed

+287
-198
lines changed

8 files changed

+287
-198
lines changed

aws_lambda_powertools/event_handler/middlewares/openapi_validation.py

+13-10
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,10 @@ def _get_embed_body(
368368
return received_body, field_alias_omitted
369369

370370

371-
def _normalize_multi_query_string_with_param(query_string: Optional[Dict[str, str]], params: Sequence[ModelField]):
371+
def _normalize_multi_query_string_with_param(
372+
query_string: Dict[str, List[str]],
373+
params: Sequence[ModelField],
374+
) -> Dict[str, Any]:
372375
"""
373376
Extract and normalize resolved_query_string_parameters
374377
@@ -383,15 +386,15 @@ def _normalize_multi_query_string_with_param(query_string: Optional[Dict[str, st
383386
-------
384387
A dictionary containing the processed multi_query_string_parameters.
385388
"""
386-
if query_string:
387-
for param in filter(is_scalar_field, params):
388-
try:
389-
# if the target parameter is a scalar, we keep the first value of the query string
390-
# regardless if there are more in the payload
391-
query_string[param.alias] = query_string[param.alias][0]
392-
except KeyError:
393-
pass
394-
return query_string
389+
resolved_query_string: Dict[str, Any] = query_string
390+
for param in filter(is_scalar_field, params):
391+
try:
392+
# if the target parameter is a scalar, we keep the first value of the query string
393+
# regardless if there are more in the payload
394+
resolved_query_string[param.alias] = query_string[param.alias][0]
395+
except KeyError:
396+
pass
397+
return resolved_query_string
395398

396399

397400
def _normalize_multi_header_values_with_param(headers: Optional[Dict[str, str]], params: Sequence[ModelField]):

aws_lambda_powertools/utilities/data_classes/alb_event.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,11 @@ def multi_value_query_string_parameters(self) -> Optional[Dict[str, List[str]]]:
3636
return self.get("multiValueQueryStringParameters")
3737

3838
@property
39-
def resolved_query_string_parameters(self) -> Optional[Dict[str, Any]]:
39+
def resolved_query_string_parameters(self) -> Dict[str, List[str]]:
4040
if self.multi_value_query_string_parameters:
4141
return self.multi_value_query_string_parameters
4242

43-
return self.query_string_parameters
43+
return super().resolved_query_string_parameters
4444

4545
@property
4646
def resolved_headers_field(self) -> Optional[Dict[str, Any]]:

aws_lambda_powertools/utilities/data_classes/api_gateway_proxy_event.py

+2-12
Original file line numberDiff line numberDiff line change
@@ -119,11 +119,11 @@ def multi_value_query_string_parameters(self) -> Optional[Dict[str, List[str]]]:
119119
return self.get("multiValueQueryStringParameters")
120120

121121
@property
122-
def resolved_query_string_parameters(self) -> Optional[Dict[str, Any]]:
122+
def resolved_query_string_parameters(self) -> Dict[str, List[str]]:
123123
if self.multi_value_query_string_parameters:
124124
return self.multi_value_query_string_parameters
125125

126-
return self.query_string_parameters
126+
return super().resolved_query_string_parameters
127127

128128
@property
129129
def resolved_headers_field(self) -> Optional[Dict[str, Any]]:
@@ -318,16 +318,6 @@ def http_method(self) -> str:
318318
def header_serializer(self):
319319
return HttpApiHeadersSerializer()
320320

321-
@property
322-
def resolved_query_string_parameters(self) -> Optional[Dict[str, Any]]:
323-
if self.query_string_parameters is not None:
324-
query_string = {
325-
key: value.split(",") if "," in value else value for key, value in self.query_string_parameters.items()
326-
}
327-
return query_string
328-
329-
return {}
330-
331321
@property
332322
def resolved_headers_field(self) -> Optional[Dict[str, Any]]:
333323
if self.headers is not None:

aws_lambda_powertools/utilities/data_classes/bedrock_agent_event.py

-4
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,6 @@ def query_string_parameters(self) -> Optional[Dict[str, str]]:
109109
# together with the other parameters. So we just return all parameters here.
110110
return {x["name"]: x["value"] for x in self["parameters"]} if self.get("parameters") else None
111111

112-
@property
113-
def resolved_query_string_parameters(self) -> Optional[Dict[str, str]]:
114-
return self.query_string_parameters
115-
116112
@property
117113
def resolved_headers_field(self) -> Optional[Dict[str, Any]]:
118114
return {}

aws_lambda_powertools/utilities/data_classes/common.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -104,15 +104,19 @@ def query_string_parameters(self) -> Optional[Dict[str, str]]:
104104
return self.get("queryStringParameters")
105105

106106
@property
107-
def resolved_query_string_parameters(self) -> Optional[Dict[str, str]]:
107+
def resolved_query_string_parameters(self) -> Dict[str, List[str]]:
108108
"""
109109
This property determines the appropriate query string parameter to be used
110110
as a trusted source for validating OpenAPI.
111111
112112
This is necessary because different resolvers use different formats to encode
113113
multi query string parameters.
114114
"""
115-
return self.query_string_parameters
115+
if self.query_string_parameters is not None:
116+
query_string = {key: value.split(",") for key, value in self.query_string_parameters.items()}
117+
return query_string
118+
119+
return {}
116120

117121
@property
118122
def resolved_headers_field(self) -> Optional[Dict[str, Any]]:
@@ -186,17 +190,15 @@ def get_header_value(
186190
name: str,
187191
default_value: str,
188192
case_sensitive: Optional[bool] = False,
189-
) -> str:
190-
...
193+
) -> str: ...
191194

192195
@overload
193196
def get_header_value(
194197
self,
195198
name: str,
196199
default_value: Optional[str] = None,
197200
case_sensitive: Optional[bool] = False,
198-
) -> Optional[str]:
199-
...
201+
) -> Optional[str]: ...
200202

201203
def get_header_value(
202204
self,

aws_lambda_powertools/utilities/data_classes/vpc_lattice.py

+12-14
Original file line numberDiff line numberDiff line change
@@ -73,17 +73,15 @@ def get_header_value(
7373
name: str,
7474
default_value: str,
7575
case_sensitive: Optional[bool] = False,
76-
) -> str:
77-
...
76+
) -> str: ...
7877

7978
@overload
8079
def get_header_value(
8180
self,
8281
name: str,
8382
default_value: Optional[str] = None,
8483
case_sensitive: Optional[bool] = False,
85-
) -> Optional[str]:
86-
...
84+
) -> Optional[str]: ...
8785

8886
def get_header_value(
8987
self,
@@ -140,10 +138,6 @@ def query_string_parameters(self) -> Dict[str, str]:
140138
"""The request query string parameters."""
141139
return self["query_string_parameters"]
142140

143-
@property
144-
def resolved_query_string_parameters(self) -> Optional[Dict[str, str]]:
145-
return self.query_string_parameters
146-
147141
@property
148142
def resolved_headers_field(self) -> Optional[Dict[str, Any]]:
149143
if self.headers is not None:
@@ -255,17 +249,21 @@ def path(self) -> str:
255249

256250
@property
257251
def request_context(self) -> vpcLatticeEventV2RequestContext:
258-
"""he VPC Lattice v2 Event request context."""
252+
"""The VPC Lattice v2 Event request context."""
259253
return vpcLatticeEventV2RequestContext(self["requestContext"])
260254

261255
@property
262256
def query_string_parameters(self) -> Optional[Dict[str, str]]:
263-
"""The request query string parameters."""
264-
return self.get("queryStringParameters")
257+
"""The request query string parameters.
265258
266-
@property
267-
def resolved_query_string_parameters(self) -> Optional[Dict[str, str]]:
268-
return self.query_string_parameters
259+
For VPC Lattice V2, the queryStringParameters will contain a Dict[str, List[str]]
260+
so to keep compatibility with existing utilities, we merge all the values with a comma.
261+
"""
262+
params = self.get("queryStringParameters")
263+
if params:
264+
return {key: ",".join(value) for key, value in params.items()}
265+
else:
266+
return None
269267

270268
@property
271269
def resolved_headers_field(self) -> Optional[Dict[str, str]]:

tests/functional/event_handler/conftest.py

+32
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
import pytest
44

5+
from tests.functional.utils import load_event
6+
57

68
@pytest.fixture
79
def json_dump():
@@ -39,3 +41,33 @@ def validation_schema():
3941
@pytest.fixture
4042
def raw_event():
4143
return {"message": "hello hello", "username": "blah blah"}
44+
45+
46+
@pytest.fixture
47+
def gw_event():
48+
return load_event("apiGatewayProxyEvent.json")
49+
50+
51+
@pytest.fixture
52+
def gw_event_http():
53+
return load_event("apiGatewayProxyV2Event.json")
54+
55+
56+
@pytest.fixture
57+
def gw_event_alb():
58+
return load_event("albMultiValueQueryStringEvent.json")
59+
60+
61+
@pytest.fixture
62+
def gw_event_lambda_url():
63+
return load_event("lambdaFunctionUrlEventWithHeaders.json")
64+
65+
66+
@pytest.fixture
67+
def gw_event_vpc_lattice():
68+
return load_event("vpcLatticeV2EventWithHeaders.json")
69+
70+
71+
@pytest.fixture
72+
def gw_event_vpc_lattice_v1():
73+
return load_event("vpcLatticeEvent.json")

0 commit comments

Comments
 (0)