Skip to content

fix(event-handler): multi-value query string and validation of scalar parameters #3795

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 19, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,10 @@ def _get_embed_body(
return received_body, field_alias_omitted


def _normalize_multi_query_string_with_param(query_string: Optional[Dict[str, str]], params: Sequence[ModelField]):
def _normalize_multi_query_string_with_param(
query_string: Dict[str, List[str]],
params: Sequence[ModelField],
) -> Dict[str, Any]:
"""
Extract and normalize resolved_query_string_parameters

Expand All @@ -383,15 +386,15 @@ def _normalize_multi_query_string_with_param(query_string: Optional[Dict[str, st
-------
A dictionary containing the processed multi_query_string_parameters.
"""
if query_string:
for param in filter(is_scalar_field, params):
try:
# if the target parameter is a scalar, we keep the first value of the query string
# regardless if there are more in the payload
query_string[param.alias] = query_string[param.alias][0]
except KeyError:
pass
return query_string
resolved_query_string: Dict[str, Any] = query_string
for param in filter(is_scalar_field, params):
try:
# if the target parameter is a scalar, we keep the first value of the query string
# regardless if there are more in the payload
resolved_query_string[param.alias] = query_string[param.alias][0]
except KeyError:
pass
return resolved_query_string


def _normalize_multi_header_values_with_param(headers: Optional[Dict[str, str]], params: Sequence[ModelField]):
Expand Down
4 changes: 2 additions & 2 deletions aws_lambda_powertools/utilities/data_classes/alb_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ def multi_value_query_string_parameters(self) -> Optional[Dict[str, List[str]]]:
return self.get("multiValueQueryStringParameters")

@property
def resolved_query_string_parameters(self) -> Optional[Dict[str, Any]]:
def resolved_query_string_parameters(self) -> Dict[str, List[str]]:
if self.multi_value_query_string_parameters:
return self.multi_value_query_string_parameters

return self.query_string_parameters
return super().resolved_query_string_parameters

@property
def resolved_headers_field(self) -> Optional[Dict[str, Any]]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,11 @@ def multi_value_query_string_parameters(self) -> Optional[Dict[str, List[str]]]:
return self.get("multiValueQueryStringParameters")

@property
def resolved_query_string_parameters(self) -> Optional[Dict[str, Any]]:
def resolved_query_string_parameters(self) -> Dict[str, List[str]]:
if self.multi_value_query_string_parameters:
return self.multi_value_query_string_parameters

return self.query_string_parameters
return super().resolved_query_string_parameters

@property
def resolved_headers_field(self) -> Optional[Dict[str, Any]]:
Expand Down Expand Up @@ -318,16 +318,6 @@ def http_method(self) -> str:
def header_serializer(self):
return HttpApiHeadersSerializer()

@property
def resolved_query_string_parameters(self) -> Optional[Dict[str, Any]]:
if self.query_string_parameters is not None:
query_string = {
key: value.split(",") if "," in value else value for key, value in self.query_string_parameters.items()
}
return query_string

return {}

@property
def resolved_headers_field(self) -> Optional[Dict[str, Any]]:
if self.headers is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,6 @@ def query_string_parameters(self) -> Optional[Dict[str, str]]:
# together with the other parameters. So we just return all parameters here.
return {x["name"]: x["value"] for x in self["parameters"]} if self.get("parameters") else None

@property
def resolved_query_string_parameters(self) -> Optional[Dict[str, str]]:
return self.query_string_parameters

@property
def resolved_headers_field(self) -> Optional[Dict[str, Any]]:
return {}
14 changes: 8 additions & 6 deletions aws_lambda_powertools/utilities/data_classes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,15 +104,19 @@ def query_string_parameters(self) -> Optional[Dict[str, str]]:
return self.get("queryStringParameters")

@property
def resolved_query_string_parameters(self) -> Optional[Dict[str, str]]:
def resolved_query_string_parameters(self) -> Dict[str, List[str]]:
"""
This property determines the appropriate query string parameter to be used
as a trusted source for validating OpenAPI.

This is necessary because different resolvers use different formats to encode
multi query string parameters.
"""
return self.query_string_parameters
if self.query_string_parameters is not None:
query_string = {key: value.split(",") for key, value in self.query_string_parameters.items()}
return query_string

return {}

@property
def resolved_headers_field(self) -> Optional[Dict[str, Any]]:
Expand Down Expand Up @@ -186,17 +190,15 @@ def get_header_value(
name: str,
default_value: str,
case_sensitive: Optional[bool] = False,
) -> str:
...
) -> str: ...

@overload
def get_header_value(
self,
name: str,
default_value: Optional[str] = None,
case_sensitive: Optional[bool] = False,
) -> Optional[str]:
...
) -> Optional[str]: ...

def get_header_value(
self,
Expand Down
26 changes: 12 additions & 14 deletions aws_lambda_powertools/utilities/data_classes/vpc_lattice.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,15 @@ def get_header_value(
name: str,
default_value: str,
case_sensitive: Optional[bool] = False,
) -> str:
...
) -> str: ...

@overload
def get_header_value(
self,
name: str,
default_value: Optional[str] = None,
case_sensitive: Optional[bool] = False,
) -> Optional[str]:
...
) -> Optional[str]: ...

def get_header_value(
self,
Expand Down Expand Up @@ -140,10 +138,6 @@ def query_string_parameters(self) -> Dict[str, str]:
"""The request query string parameters."""
return self["query_string_parameters"]

@property
def resolved_query_string_parameters(self) -> Optional[Dict[str, str]]:
return self.query_string_parameters

@property
def resolved_headers_field(self) -> Optional[Dict[str, Any]]:
if self.headers is not None:
Expand Down Expand Up @@ -255,17 +249,21 @@ def path(self) -> str:

@property
def request_context(self) -> vpcLatticeEventV2RequestContext:
"""he VPC Lattice v2 Event request context."""
"""The VPC Lattice v2 Event request context."""
return vpcLatticeEventV2RequestContext(self["requestContext"])

@property
def query_string_parameters(self) -> Optional[Dict[str, str]]:
"""The request query string parameters."""
return self.get("queryStringParameters")
"""The request query string parameters.

@property
def resolved_query_string_parameters(self) -> Optional[Dict[str, str]]:
return self.query_string_parameters
For VPC Lattice V2, the queryStringParameters will contain a Dict[str, List[str]]
so to keep compatibility with existing utilities, we merge all the values with a comma.
"""
params = self.get("queryStringParameters")
if params:
return {key: ",".join(value) for key, value in params.items()}
else:
return None

@property
def resolved_headers_field(self) -> Optional[Dict[str, str]]:
Expand Down
Loading