Skip to content

Commit 1c303cb

Browse files
Adding header - Initial commit
1 parent 0519fa3 commit 1c303cb

File tree

9 files changed

+163
-17
lines changed

9 files changed

+163
-17
lines changed

aws_lambda_powertools/event_handler/middlewares/openapi_validation.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,22 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) ->
8181
query_string,
8282
)
8383

84+
# Normalize query values before validate this
85+
headers = _normalize_multi_header_values_with_param(
86+
app.current_event.resolved_headers_field,
87+
route.dependant.header_params,
88+
)
89+
90+
# Process header values
91+
header_values, header_errors = _request_params_to_args(
92+
route.dependant.header_params,
93+
headers,
94+
)
95+
8496
values.update(path_values)
8597
values.update(query_values)
86-
errors += path_errors + query_errors
98+
values.update(header_values)
99+
errors += path_errors + query_errors + header_errors
87100

88101
# Process the request body, if it exists
89102
if route.dependant.body_params:
@@ -377,3 +390,29 @@ def _normalize_multi_query_string_with_param(query_string: Optional[Dict[str, st
377390
except KeyError:
378391
pass
379392
return query_string
393+
394+
395+
def _normalize_multi_header_values_with_param(headers: Optional[Dict[str, str]], params: Sequence[ModelField]):
396+
"""
397+
Extract and normalize resolved_headers_field
398+
399+
Parameters
400+
----------
401+
headers: Dict
402+
A dictionary containing the initial header parameters.
403+
params: Sequence[ModelField]
404+
A sequence of ModelField objects representing parameters.
405+
406+
Returns
407+
-------
408+
A dictionary containing the processed headers.
409+
"""
410+
if headers:
411+
for param in filter(is_scalar_field, params):
412+
try:
413+
# if the target parameter is a scalar, we keep the first value of the headers
414+
# regardless if there are more in the payload
415+
headers[param.name] = headers[param.name][0]
416+
except KeyError:
417+
pass
418+
return headers

aws_lambda_powertools/event_handler/openapi/dependant.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414
from aws_lambda_powertools.event_handler.openapi.params import (
1515
Body,
1616
Dependant,
17+
Header,
1718
Param,
1819
ParamTypes,
1920
Query,
2021
_File,
2122
_Form,
22-
_Header,
2323
analyze_param,
2424
create_response_field,
2525
get_flat_dependant,
@@ -59,16 +59,21 @@ def add_param_to_fields(
5959
6060
"""
6161
field_info = cast(Param, field.field_info)
62-
if field_info.in_ == ParamTypes.path:
63-
dependant.path_params.append(field)
64-
elif field_info.in_ == ParamTypes.query:
65-
dependant.query_params.append(field)
66-
elif field_info.in_ == ParamTypes.header:
67-
dependant.header_params.append(field)
62+
63+
# Dictionary to map ParamTypes to their corresponding lists in dependant
64+
param_type_map = {
65+
ParamTypes.path: dependant.path_params,
66+
ParamTypes.query: dependant.query_params,
67+
ParamTypes.header: dependant.header_params,
68+
ParamTypes.cookie: dependant.cookie_params,
69+
}
70+
71+
# Check if field_info.in_ is a valid key in param_type_map and append the field to the corresponding list
72+
# or raise an exception if it's not a valid key.
73+
if field_info.in_ in param_type_map:
74+
param_type_map[field_info.in_].append(field)
6875
else:
69-
if field_info.in_ != ParamTypes.cookie:
70-
raise AssertionError(f"Unsupported param type: {field_info.in_}")
71-
dependant.cookie_params.append(field)
76+
raise AssertionError(f"Unsupported param type: {field_info.in_}")
7277

7378

7479
def get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any:
@@ -265,7 +270,7 @@ def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool:
265270
return False
266271
elif is_scalar_field(field=param_field):
267272
return False
268-
elif isinstance(param_field.field_info, (Query, _Header)) and is_scalar_sequence_field(param_field):
273+
elif isinstance(param_field.field_info, (Query, Header)) and is_scalar_sequence_field(param_field):
269274
return False
270275
else:
271276
if not isinstance(param_field.field_info, Body):

aws_lambda_powertools/event_handler/openapi/params.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,7 @@ def __init__(
486486
)
487487

488488

489-
class _Header(Param):
489+
class Header(Param):
490490
"""
491491
A class used internally to represent a header parameter in a path operation.
492492
"""
@@ -527,6 +527,67 @@ def __init__(
527527
json_schema_extra: Union[Dict[str, Any], None] = None,
528528
**extra: Any,
529529
):
530+
"""
531+
Constructs a new Query param.
532+
533+
Parameters
534+
----------
535+
default: Any
536+
The default value of the parameter
537+
default_factory: Callable[[], Any], optional
538+
Callable that will be called when a default value is needed for this field
539+
annotation: Any, optional
540+
The type annotation of the parameter
541+
alias: str, optional
542+
The public name of the field
543+
alias_priority: int, optional
544+
Priority of the alias. This affects whether an alias generator is used
545+
validation_alias: str | AliasPath | AliasChoices | None, optional
546+
Alias to be used for validation only
547+
serialization_alias: str | AliasPath | AliasChoices | None, optional
548+
Alias to be used for serialization only
549+
convert_underscores: bool
550+
If true convert "_" to "-"
551+
See RFC: https://www.rfc-editor.org/rfc/rfc9110.html#name-field-name-registry
552+
title: str, optional
553+
The title of the parameter
554+
description: str, optional
555+
The description of the parameter
556+
gt: float, optional
557+
Only applies to numbers, required the field to be "greater than"
558+
ge: float, optional
559+
Only applies to numbers, required the field to be "greater than or equal"
560+
lt: float, optional
561+
Only applies to numbers, required the field to be "less than"
562+
le: float, optional
563+
Only applies to numbers, required the field to be "less than or equal"
564+
min_length: int, optional
565+
Only applies to strings, required the field to have a minimum length
566+
max_length: int, optional
567+
Only applies to strings, required the field to have a maximum length
568+
pattern: str, optional
569+
Only applies to strings, requires the field match against a regular expression pattern string
570+
discriminator: str, optional
571+
Parameter field name for discriminating the type in a tagged union
572+
strict: bool, optional
573+
Enables Pydantic's strict mode for the field
574+
multiple_of: float, optional
575+
Only applies to numbers, requires the field to be a multiple of the given value
576+
allow_inf_nan: bool, optional
577+
Only applies to numbers, requires the field to allow infinity and NaN values
578+
max_digits: int, optional
579+
Only applies to Decimals, requires the field to have a maxmium number of digits within the decimal.
580+
decimal_places: int, optional
581+
Only applies to Decimals, requires the field to have at most a number of decimal places
582+
examples: List[Any], optional
583+
A list of examples for the parameter
584+
deprecated: bool, optional
585+
If `True`, the parameter will be marked as deprecated
586+
include_in_schema: bool, optional
587+
If `False`, the parameter will be excluded from the generated OpenAPI schema
588+
json_schema_extra: Dict[str, Any], optional
589+
Extra values to include in the generated OpenAPI schema
590+
"""
530591
self.convert_underscores = convert_underscores
531592
super().__init__(
532593
default=default,

aws_lambda_powertools/utilities/data_classes/alb_event.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,13 @@ def resolved_query_string_parameters(self) -> Optional[Dict[str, Any]]:
4242

4343
return self.query_string_parameters
4444

45+
@property
46+
def resolved_headers_field(self) -> Optional[Dict[str, Any]]:
47+
if self.multi_value_headers:
48+
return self.multi_value_headers
49+
50+
return self.headers
51+
4552
@property
4653
def multi_value_headers(self) -> Optional[Dict[str, List[str]]]:
4754
return self.get("multiValueHeaders")

aws_lambda_powertools/utilities/data_classes/api_gateway_proxy_event.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,13 @@ def resolved_query_string_parameters(self) -> Optional[Dict[str, Any]]:
125125

126126
return self.query_string_parameters
127127

128+
@property
129+
def resolved_headers_field(self) -> Optional[Dict[str, Any]]:
130+
if self.multi_value_headers:
131+
return self.multi_value_headers
132+
133+
return self.headers
134+
128135
@property
129136
def request_context(self) -> APIGatewayEventRequestContext:
130137
return APIGatewayEventRequestContext(self._data)
@@ -316,3 +323,11 @@ def resolved_query_string_parameters(self) -> Optional[Dict[str, Any]]:
316323
return query_string
317324

318325
return {}
326+
327+
@property
328+
def resolved_headers_field(self) -> Optional[Dict[str, Any]]:
329+
if self.headers is not None:
330+
headers = {key: value.split(",") if "," in value else value for key, value in self.headers.items()}
331+
return headers
332+
333+
return {}

aws_lambda_powertools/utilities/data_classes/bedrock_agent_event.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, List, Optional
1+
from typing import Any, Dict, List, Optional
22

33
from aws_lambda_powertools.utilities.data_classes.common import BaseProxyEvent, DictWrapper
44

@@ -112,3 +112,7 @@ def query_string_parameters(self) -> Optional[Dict[str, str]]:
112112
@property
113113
def resolved_query_string_parameters(self) -> Optional[Dict[str, str]]:
114114
return self.query_string_parameters
115+
116+
@property
117+
def resolved_headers_field(self) -> Optional[Dict[str, Any]]:
118+
return {}

aws_lambda_powertools/utilities/data_classes/common.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,17 @@ def resolved_query_string_parameters(self) -> Optional[Dict[str, str]]:
114114
"""
115115
return self.query_string_parameters
116116

117+
@property
118+
def resolved_headers_field(self) -> Optional[Dict[str, Any]]:
119+
"""
120+
This property determines the appropriate header to be used
121+
as a trusted source for validating OpenAPI.
122+
123+
This is necessary because different resolvers use different formats to encode
124+
headers parameters.
125+
"""
126+
return self.headers
127+
117128
@property
118129
def is_base64_encoded(self) -> Optional[bool]:
119130
return self.get("isBase64Encoded")

aws_lambda_powertools/utilities/data_classes/vpc_lattice.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ def headers(self) -> Dict[str, str]:
3030
"""The VPC Lattice event headers."""
3131
return self["headers"]
3232

33+
@property
34+
def resolved_headers_field(self) -> Optional[Dict[str, Any]]:
35+
return self.headers
36+
3337
@property
3438
def decoded_body(self) -> str:
3539
"""Dynamically base64 decode body as a str"""

tests/functional/event_handler/test_openapi_params.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@
1313
)
1414
from aws_lambda_powertools.event_handler.openapi.params import (
1515
Body,
16+
Header,
1617
Param,
1718
ParamTypes,
1819
Query,
1920
_create_model_field,
20-
_Header,
2121
)
2222
from aws_lambda_powertools.shared.types import Annotated
2323

@@ -431,7 +431,7 @@ def handler():
431431

432432

433433
def test_create_header():
434-
header = _Header(convert_underscores=True)
434+
header = Header(convert_underscores=True)
435435
assert header.convert_underscores is True
436436

437437

@@ -456,7 +456,7 @@ def test_create_model_field_with_empty_in():
456456

457457
# Tests that when we try to create a model field with convert_underscore, we convert the field name
458458
def test_create_model_field_convert_underscore():
459-
field_info = _Header(alias=None, convert_underscores=True)
459+
field_info = Header(alias=None, convert_underscores=True)
460460

461461
result = _create_model_field(field_info, int, "user_id", False)
462462
assert result.alias == "user-id"

0 commit comments

Comments
 (0)