From 1c303cb9b714532b4431136e0c0a041d0e3f8a81 Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Wed, 31 Jan 2024 01:02:07 +0000 Subject: [PATCH 01/13] Adding header - Initial commit --- .../middlewares/openapi_validation.py | 41 +++++++++++- .../event_handler/openapi/dependant.py | 27 ++++---- .../event_handler/openapi/params.py | 63 ++++++++++++++++++- .../utilities/data_classes/alb_event.py | 7 +++ .../data_classes/api_gateway_proxy_event.py | 15 +++++ .../data_classes/bedrock_agent_event.py | 6 +- .../utilities/data_classes/common.py | 11 ++++ .../utilities/data_classes/vpc_lattice.py | 4 ++ .../event_handler/test_openapi_params.py | 6 +- 9 files changed, 163 insertions(+), 17 deletions(-) diff --git a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py index e819947b147..63bdaa8358e 100644 --- a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py +++ b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py @@ -81,9 +81,22 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> query_string, ) + # Normalize query values before validate this + headers = _normalize_multi_header_values_with_param( + app.current_event.resolved_headers_field, + route.dependant.header_params, + ) + + # Process header values + header_values, header_errors = _request_params_to_args( + route.dependant.header_params, + headers, + ) + values.update(path_values) values.update(query_values) - errors += path_errors + query_errors + values.update(header_values) + errors += path_errors + query_errors + header_errors # Process the request body, if it exists if route.dependant.body_params: @@ -377,3 +390,29 @@ def _normalize_multi_query_string_with_param(query_string: Optional[Dict[str, st except KeyError: pass return query_string + + +def _normalize_multi_header_values_with_param(headers: Optional[Dict[str, str]], params: Sequence[ModelField]): + """ + Extract and normalize resolved_headers_field + + Parameters + ---------- + headers: Dict + A dictionary containing the initial header parameters. + params: Sequence[ModelField] + A sequence of ModelField objects representing parameters. + + Returns + ------- + A dictionary containing the processed headers. + """ + if headers: + for param in filter(is_scalar_field, params): + try: + # if the target parameter is a scalar, we keep the first value of the headers + # regardless if there are more in the payload + headers[param.name] = headers[param.name][0] + except KeyError: + pass + return headers diff --git a/aws_lambda_powertools/event_handler/openapi/dependant.py b/aws_lambda_powertools/event_handler/openapi/dependant.py index 418a86e083c..abcb91e90dd 100644 --- a/aws_lambda_powertools/event_handler/openapi/dependant.py +++ b/aws_lambda_powertools/event_handler/openapi/dependant.py @@ -14,12 +14,12 @@ from aws_lambda_powertools.event_handler.openapi.params import ( Body, Dependant, + Header, Param, ParamTypes, Query, _File, _Form, - _Header, analyze_param, create_response_field, get_flat_dependant, @@ -59,16 +59,21 @@ def add_param_to_fields( """ field_info = cast(Param, field.field_info) - if field_info.in_ == ParamTypes.path: - dependant.path_params.append(field) - elif field_info.in_ == ParamTypes.query: - dependant.query_params.append(field) - elif field_info.in_ == ParamTypes.header: - dependant.header_params.append(field) + + # Dictionary to map ParamTypes to their corresponding lists in dependant + param_type_map = { + ParamTypes.path: dependant.path_params, + ParamTypes.query: dependant.query_params, + ParamTypes.header: dependant.header_params, + ParamTypes.cookie: dependant.cookie_params, + } + + # Check if field_info.in_ is a valid key in param_type_map and append the field to the corresponding list + # or raise an exception if it's not a valid key. + if field_info.in_ in param_type_map: + param_type_map[field_info.in_].append(field) else: - if field_info.in_ != ParamTypes.cookie: - raise AssertionError(f"Unsupported param type: {field_info.in_}") - dependant.cookie_params.append(field) + raise AssertionError(f"Unsupported param type: {field_info.in_}") def get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any: @@ -265,7 +270,7 @@ def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool: return False elif is_scalar_field(field=param_field): return False - elif isinstance(param_field.field_info, (Query, _Header)) and is_scalar_sequence_field(param_field): + elif isinstance(param_field.field_info, (Query, Header)) and is_scalar_sequence_field(param_field): return False else: if not isinstance(param_field.field_info, Body): diff --git a/aws_lambda_powertools/event_handler/openapi/params.py b/aws_lambda_powertools/event_handler/openapi/params.py index 78426cbc7c9..ce7a1b8520e 100644 --- a/aws_lambda_powertools/event_handler/openapi/params.py +++ b/aws_lambda_powertools/event_handler/openapi/params.py @@ -486,7 +486,7 @@ def __init__( ) -class _Header(Param): +class Header(Param): """ A class used internally to represent a header parameter in a path operation. """ @@ -527,6 +527,67 @@ def __init__( json_schema_extra: Union[Dict[str, Any], None] = None, **extra: Any, ): + """ + Constructs a new Query param. + + Parameters + ---------- + default: Any + The default value of the parameter + default_factory: Callable[[], Any], optional + Callable that will be called when a default value is needed for this field + annotation: Any, optional + The type annotation of the parameter + alias: str, optional + The public name of the field + alias_priority: int, optional + Priority of the alias. This affects whether an alias generator is used + validation_alias: str | AliasPath | AliasChoices | None, optional + Alias to be used for validation only + serialization_alias: str | AliasPath | AliasChoices | None, optional + Alias to be used for serialization only + convert_underscores: bool + If true convert "_" to "-" + See RFC: https://www.rfc-editor.org/rfc/rfc9110.html#name-field-name-registry + title: str, optional + The title of the parameter + description: str, optional + The description of the parameter + gt: float, optional + Only applies to numbers, required the field to be "greater than" + ge: float, optional + Only applies to numbers, required the field to be "greater than or equal" + lt: float, optional + Only applies to numbers, required the field to be "less than" + le: float, optional + Only applies to numbers, required the field to be "less than or equal" + min_length: int, optional + Only applies to strings, required the field to have a minimum length + max_length: int, optional + Only applies to strings, required the field to have a maximum length + pattern: str, optional + Only applies to strings, requires the field match against a regular expression pattern string + discriminator: str, optional + Parameter field name for discriminating the type in a tagged union + strict: bool, optional + Enables Pydantic's strict mode for the field + multiple_of: float, optional + Only applies to numbers, requires the field to be a multiple of the given value + allow_inf_nan: bool, optional + Only applies to numbers, requires the field to allow infinity and NaN values + max_digits: int, optional + Only applies to Decimals, requires the field to have a maxmium number of digits within the decimal. + decimal_places: int, optional + Only applies to Decimals, requires the field to have at most a number of decimal places + examples: List[Any], optional + A list of examples for the parameter + deprecated: bool, optional + If `True`, the parameter will be marked as deprecated + include_in_schema: bool, optional + If `False`, the parameter will be excluded from the generated OpenAPI schema + json_schema_extra: Dict[str, Any], optional + Extra values to include in the generated OpenAPI schema + """ self.convert_underscores = convert_underscores super().__init__( default=default, diff --git a/aws_lambda_powertools/utilities/data_classes/alb_event.py b/aws_lambda_powertools/utilities/data_classes/alb_event.py index 688c9567efa..6e17146ec37 100644 --- a/aws_lambda_powertools/utilities/data_classes/alb_event.py +++ b/aws_lambda_powertools/utilities/data_classes/alb_event.py @@ -42,6 +42,13 @@ def resolved_query_string_parameters(self) -> Optional[Dict[str, Any]]: return self.query_string_parameters + @property + def resolved_headers_field(self) -> Optional[Dict[str, Any]]: + if self.multi_value_headers: + return self.multi_value_headers + + return self.headers + @property def multi_value_headers(self) -> Optional[Dict[str, List[str]]]: return self.get("multiValueHeaders") diff --git a/aws_lambda_powertools/utilities/data_classes/api_gateway_proxy_event.py b/aws_lambda_powertools/utilities/data_classes/api_gateway_proxy_event.py index 9e013eac038..26951d61201 100644 --- a/aws_lambda_powertools/utilities/data_classes/api_gateway_proxy_event.py +++ b/aws_lambda_powertools/utilities/data_classes/api_gateway_proxy_event.py @@ -125,6 +125,13 @@ def resolved_query_string_parameters(self) -> Optional[Dict[str, Any]]: return self.query_string_parameters + @property + def resolved_headers_field(self) -> Optional[Dict[str, Any]]: + if self.multi_value_headers: + return self.multi_value_headers + + return self.headers + @property def request_context(self) -> APIGatewayEventRequestContext: return APIGatewayEventRequestContext(self._data) @@ -316,3 +323,11 @@ def resolved_query_string_parameters(self) -> Optional[Dict[str, Any]]: return query_string return {} + + @property + def resolved_headers_field(self) -> Optional[Dict[str, Any]]: + if self.headers is not None: + headers = {key: value.split(",") if "," in value else value for key, value in self.headers.items()} + return headers + + return {} diff --git a/aws_lambda_powertools/utilities/data_classes/bedrock_agent_event.py b/aws_lambda_powertools/utilities/data_classes/bedrock_agent_event.py index d9b45242376..0fa97036a3e 100644 --- a/aws_lambda_powertools/utilities/data_classes/bedrock_agent_event.py +++ b/aws_lambda_powertools/utilities/data_classes/bedrock_agent_event.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional from aws_lambda_powertools.utilities.data_classes.common import BaseProxyEvent, DictWrapper @@ -112,3 +112,7 @@ def query_string_parameters(self) -> Optional[Dict[str, str]]: @property def resolved_query_string_parameters(self) -> Optional[Dict[str, str]]: return self.query_string_parameters + + @property + def resolved_headers_field(self) -> Optional[Dict[str, Any]]: + return {} diff --git a/aws_lambda_powertools/utilities/data_classes/common.py b/aws_lambda_powertools/utilities/data_classes/common.py index d2cf57d4af5..3f4f47c5a2a 100644 --- a/aws_lambda_powertools/utilities/data_classes/common.py +++ b/aws_lambda_powertools/utilities/data_classes/common.py @@ -114,6 +114,17 @@ def resolved_query_string_parameters(self) -> Optional[Dict[str, str]]: """ return self.query_string_parameters + @property + def resolved_headers_field(self) -> Optional[Dict[str, Any]]: + """ + This property determines the appropriate header to be used + as a trusted source for validating OpenAPI. + + This is necessary because different resolvers use different formats to encode + headers parameters. + """ + return self.headers + @property def is_base64_encoded(self) -> Optional[bool]: return self.get("isBase64Encoded") diff --git a/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py b/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py index 633ce068f6e..0e61f479948 100644 --- a/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py +++ b/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py @@ -30,6 +30,10 @@ def headers(self) -> Dict[str, str]: """The VPC Lattice event headers.""" return self["headers"] + @property + def resolved_headers_field(self) -> Optional[Dict[str, Any]]: + return self.headers + @property def decoded_body(self) -> str: """Dynamically base64 decode body as a str""" diff --git a/tests/functional/event_handler/test_openapi_params.py b/tests/functional/event_handler/test_openapi_params.py index 2f48f5aa534..38b0cbed307 100644 --- a/tests/functional/event_handler/test_openapi_params.py +++ b/tests/functional/event_handler/test_openapi_params.py @@ -13,11 +13,11 @@ ) from aws_lambda_powertools.event_handler.openapi.params import ( Body, + Header, Param, ParamTypes, Query, _create_model_field, - _Header, ) from aws_lambda_powertools.shared.types import Annotated @@ -431,7 +431,7 @@ def handler(): def test_create_header(): - header = _Header(convert_underscores=True) + header = Header(convert_underscores=True) assert header.convert_underscores is True @@ -456,7 +456,7 @@ def test_create_model_field_with_empty_in(): # Tests that when we try to create a model field with convert_underscore, we convert the field name def test_create_model_field_convert_underscore(): - field_info = _Header(alias=None, convert_underscores=True) + field_info = Header(alias=None, convert_underscores=True) result = _create_model_field(field_info, int, "user_id", False) assert result.alias == "user-id" From 2570c722d1bae57aec561372cc0f16d90956ac4b Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Wed, 31 Jan 2024 14:10:42 +0000 Subject: [PATCH 02/13] Adding header - Fix VPC Lattice Payload --- .../utilities/data_classes/vpc_lattice.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py b/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py index 0e61f479948..a17b8c96b1c 100644 --- a/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py +++ b/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py @@ -30,10 +30,6 @@ def headers(self) -> Dict[str, str]: """The VPC Lattice event headers.""" return self["headers"] - @property - def resolved_headers_field(self) -> Optional[Dict[str, Any]]: - return self.headers - @property def decoded_body(self) -> str: """Dynamically base64 decode body as a str""" @@ -149,6 +145,14 @@ def query_string_parameters(self) -> Dict[str, str]: def resolved_query_string_parameters(self) -> Optional[Dict[str, str]]: return self.query_string_parameters + @property + def resolved_headers_field(self) -> Optional[Dict[str, Any]]: + if self.headers is not None: + headers = {key: value.split(",") if "," in value else value for key, value in self.headers.items()} + return headers + + return {} + class vpcLatticeEventV2Identity(DictWrapper): @property @@ -263,3 +267,7 @@ def query_string_parameters(self) -> Optional[Dict[str, str]]: @property def resolved_query_string_parameters(self) -> Optional[Dict[str, str]]: return self.query_string_parameters + + @property + def resolved_headers_field(self) -> Optional[Dict[str, str]]: + return self.headers From c74a6a4d0a748981f80f2613ab58ac5f3d224031 Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Wed, 31 Jan 2024 15:46:23 +0000 Subject: [PATCH 03/13] Adding header - tests and final changes --- .../middlewares/openapi_validation.py | 9 +- .../events/albMultiValueQueryStringEvent.json | 7 + tests/events/apiGatewayProxyEvent.json | 2 +- .../lambdaFunctionUrlEventWithHeaders.json | 4 +- tests/events/vpcLatticeEvent.json | 4 +- .../events/vpcLatticeV2EventWithHeaders.json | 31 +- .../test_openapi_validation_middleware.py | 625 +++++++++++++----- 7 files changed, 492 insertions(+), 190 deletions(-) diff --git a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py index 63bdaa8358e..669e30b06db 100644 --- a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py +++ b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py @@ -81,7 +81,7 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> query_string, ) - # Normalize query values before validate this + # Normalize header values before validate this headers = _normalize_multi_header_values_with_param( app.current_event.resolved_headers_field, route.dependant.header_params, @@ -410,9 +410,10 @@ def _normalize_multi_header_values_with_param(headers: Optional[Dict[str, str]], if headers: for param in filter(is_scalar_field, params): try: - # if the target parameter is a scalar, we keep the first value of the headers - # regardless if there are more in the payload - headers[param.name] = headers[param.name][0] + if len(headers[param.name]) == 1: + # if the target parameter is a scalar and the list contains only 1 element + # we keep the first value of the headers regardless if there are more in the payload + headers[param.name] = headers[param.name][0] except KeyError: pass return headers diff --git a/tests/events/albMultiValueQueryStringEvent.json b/tests/events/albMultiValueQueryStringEvent.json index 4584ba7c477..7fb537794b1 100644 --- a/tests/events/albMultiValueQueryStringEvent.json +++ b/tests/events/albMultiValueQueryStringEvent.json @@ -14,6 +14,13 @@ "accept": [ "*/*" ], + "Header2": [ + "value1", + "value2" + ], + "Header1": [ + "value1" + ], "host": [ "alb-c-LoadB-14POFKYCLBNSF-1815800096.eu-central-1.elb.amazonaws.com" ], diff --git a/tests/events/apiGatewayProxyEvent.json b/tests/events/apiGatewayProxyEvent.json index 3f095e28e45..da814c91100 100644 --- a/tests/events/apiGatewayProxyEvent.json +++ b/tests/events/apiGatewayProxyEvent.json @@ -78,4 +78,4 @@ "stageVariables": null, "body": "Hello from Lambda!", "isBase64Encoded": false -} \ No newline at end of file +} diff --git a/tests/events/lambdaFunctionUrlEventWithHeaders.json b/tests/events/lambdaFunctionUrlEventWithHeaders.json index e453690d9b3..a2b3d9ef147 100644 --- a/tests/events/lambdaFunctionUrlEventWithHeaders.json +++ b/tests/events/lambdaFunctionUrlEventWithHeaders.json @@ -23,7 +23,9 @@ "cache-control":"max-age=0", "accept-encoding":"gzip, deflate, br", "sec-fetch-dest":"document", - "user-agent":"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/102.0.0.0 Safari/537.36" + "user-agent":"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/102.0.0.0 Safari/537.36", + "Header1": "value1", + "Header2": "value1,value2" }, "queryStringParameters": { "parameter1": "value1,value2", diff --git a/tests/events/vpcLatticeEvent.json b/tests/events/vpcLatticeEvent.json index 936bfb22d1b..b00b9d3d7f3 100644 --- a/tests/events/vpcLatticeEvent.json +++ b/tests/events/vpcLatticeEvent.json @@ -5,7 +5,9 @@ "user_agent": "curl/7.64.1", "x-forwarded-for": "10.213.229.10", "host": "test-lambda-service-3908sdf9u3u.dkfjd93.vpc-lattice-svcs.us-east-2.on.aws", - "accept": "*/*" + "accept": "*/*", + "Header1": "value1", + "Header2": "value1,value2" }, "query_string_parameters": { "order-id": "1" diff --git a/tests/events/vpcLatticeV2EventWithHeaders.json b/tests/events/vpcLatticeV2EventWithHeaders.json index 11b36ef118b..fdaf7dc7891 100644 --- a/tests/events/vpcLatticeV2EventWithHeaders.json +++ b/tests/events/vpcLatticeV2EventWithHeaders.json @@ -2,12 +2,31 @@ "version": "2.0", "path": "/newpath", "method": "GET", - "headers": { - "user_agent": "curl/7.64.1", - "x-forwarded-for": "10.213.229.10", - "host": "test-lambda-service-3908sdf9u3u.dkfjd93.vpc-lattice-svcs.us-east-2.on.aws", - "accept": "*/*" - }, + "headers":{ + "user-agent":[ + "curl/8.3.0" + ], + "accept":[ + "*/*" + ], + "powertools":[ + "a", + "b" + ], + "x-forwarded-for":[ + "172.31.40.143" + ], + "host":[ + "lattice-svc-027b423199122da5f.7d67968.vpc-lattice-svcs.us-east-1.on.aws" + ], + "Header1": [ + "value1" + ], + "Header2": [ + "value1", + "value2" + ] + }, "queryStringParameters": { "parameter1": [ "value1", diff --git a/tests/functional/event_handler/test_openapi_validation_middleware.py b/tests/functional/event_handler/test_openapi_validation_middleware.py index ea4305257d4..9c66a24eb34 100644 --- a/tests/functional/event_handler/test_openapi_validation_middleware.py +++ b/tests/functional/event_handler/test_openapi_validation_middleware.py @@ -4,6 +4,7 @@ from pathlib import PurePath from typing import List, Tuple +import pytest from pydantic import BaseModel from aws_lambda_powertools.event_handler import ( @@ -14,7 +15,8 @@ Response, VPCLatticeV2Resolver, ) -from aws_lambda_powertools.event_handler.openapi.params import Body, Query +from aws_lambda_powertools.event_handler.openapi.params import Body, Header, Query +from aws_lambda_powertools.event_handler.vpc_lattice import VPCLatticeResolver from aws_lambda_powertools.shared.types import Annotated from tests.functional.utils import load_event @@ -23,6 +25,7 @@ LOAD_GW_EVENT_ALB = load_event("albMultiValueQueryStringEvent.json") LOAD_GW_EVENT_LAMBDA_URL = load_event("lambdaFunctionUrlEventWithHeaders.json") LOAD_GW_EVENT_VPC_LATTICE = load_event("vpcLatticeV2EventWithHeaders.json") +LOAD_GW_EVENT_VPC_LATTICE_V1 = load_event("vpcLatticeEvent.json") def test_validate_scalars(): @@ -391,267 +394,535 @@ def handler(user: Model) -> Response[Model]: assert "missing" in result["body"] -def test_validate_rest_api_resolver_with_multi_query_params(): - # GIVEN an APIGatewayRestResolver with validation enabled +########### TEST WITH QUERY PARAMS +@pytest.mark.parametrize( + "handler_func, expected_status_code, expected_error_text", + [ + ("handler1_with_correct_params", 200, None), + ("handler2_with_wrong_params", 422, "['type_error.integer', 'int_parsing']"), + ("handler3_without_query_params", 200, None), + ], +) +def test_validation_query_string_with_api_rest_resolver(handler_func, expected_status_code, expected_error_text): + # GIVEN a APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) - # WHEN a handler is defined with a default scalar parameter and a list - @app.get("/users") - def handler(parameter1: Annotated[List[str], Query()], parameter2: str): - print(parameter2) - LOAD_GW_EVENT["httpMethod"] = "GET" LOAD_GW_EVENT["path"] = "/users" + # WHEN a handler is defined with various parameters and routes - # THEN the handler should be invoked and return 200 - result = app(LOAD_GW_EVENT, {}) - assert result["statusCode"] == 200 + # Define handler1 with correct params + if handler_func == "handler1_with_correct_params": + @app.get("/users") + def handler1(parameter1: Annotated[List[str], Query()], parameter2: str): + print(parameter2) -def test_validate_rest_api_resolver_with_multi_query_params_fail(): - # GIVEN an APIGatewayRestResolver with validation enabled - app = APIGatewayRestResolver(enable_validation=True) + # Define handler2 with wrong params + if handler_func == "handler2_with_wrong_params": - # WHEN a handler is defined with a default scalar parameter and a list with wrong type - @app.get("/users") - def handler(parameter1: Annotated[List[int], Query()], parameter2: str): - print(parameter2) + @app.get("/users") + def handler2(parameter1: Annotated[List[int], Query()], parameter2: str): + print(parameter2) - LOAD_GW_EVENT["httpMethod"] = "GET" - LOAD_GW_EVENT["path"] = "/users" + # Define handler3 without params + if handler_func == "handler3_without_query_params": + LOAD_GW_EVENT["queryStringParameters"] = None + LOAD_GW_EVENT["multiValueQueryStringParameters"] = None - # THEN the handler should be invoked and return 422 + @app.get("/users") + def handler3(): + return None + + # THEN the handler should be invoked with the expected result + # AND the status code should match the expected_status_code result = app(LOAD_GW_EVENT, {}) - assert result["statusCode"] == 422 - assert any(text in result["body"] for text in ["type_error.integer", "int_parsing"]) + assert result["statusCode"] == expected_status_code + # IF expected_error_text is provided, THEN check for its presence in the response body + if expected_error_text: + assert any(text in result["body"] for text in expected_error_text) -def test_validate_rest_api_resolver_without_query_params(): - # GIVEN an APIGatewayRestResolver with validation enabled - app = APIGatewayRestResolver(enable_validation=True) - # WHEN a handler is defined with a default scalar parameter and a list with wrong type - @app.get("/users") - def handler(): - return None +@pytest.mark.parametrize( + "handler_func, expected_status_code, expected_error_text", + [ + ("handler1_with_correct_params", 200, None), + ("handler2_with_wrong_params", 422, "['type_error.integer', 'int_parsing']"), + ("handler3_without_query_params", 200, None), + ], +) +def test_validation_query_string_with_api_http_resolver(handler_func, expected_status_code, expected_error_text): + # GIVEN a APIGatewayHttpResolver with validation enabled + app = APIGatewayHttpResolver(enable_validation=True) - LOAD_GW_EVENT["httpMethod"] = "GET" - LOAD_GW_EVENT["path"] = "/users" - LOAD_GW_EVENT["queryStringParameters"] = None - LOAD_GW_EVENT["multiValueQueryStringParameters"] = None + LOAD_GW_EVENT_HTTP["rawPath"] = "/users" + LOAD_GW_EVENT_HTTP["requestContext"]["http"]["method"] = "GET" + LOAD_GW_EVENT_HTTP["requestContext"]["http"]["path"] = "/users" + # WHEN a handler is defined with various parameters and routes - # THEN the handler should be invoked and return 422 - result = app(LOAD_GW_EVENT, {}) - assert result["statusCode"] == 200 + # Define handler1 with correct params + if handler_func == "handler1_with_correct_params": + @app.get("/users") + def handler1(parameter1: Annotated[List[str], Query()], parameter2: str): + print(parameter2) -def test_validate_http_resolver_with_multi_query_params(): - # GIVEN an APIGatewayHttpResolver with validation enabled - app = APIGatewayHttpResolver(enable_validation=True) + # Define handler2 with wrong params + if handler_func == "handler2_with_wrong_params": - # WHEN a handler is defined with a default scalar parameter and a list - @app.get("/users") - def handler(parameter1: Annotated[List[str], Query()], parameter2: str): - print(parameter2) + @app.get("/users") + def handler2(parameter1: Annotated[List[int], Query()], parameter2: str): + print(parameter2) - LOAD_GW_EVENT_HTTP["rawPath"] = "/users" - LOAD_GW_EVENT_HTTP["requestContext"]["http"]["method"] = "GET" - LOAD_GW_EVENT_HTTP["requestContext"]["http"]["path"] = "/users" + # Define handler3 without params + if handler_func == "handler3_without_query_params": + LOAD_GW_EVENT_HTTP["queryStringParameters"] = None - # THEN the handler should be invoked and return 200 + @app.get("/users") + def handler3(): + return None + + # THEN the handler should be invoked with the expected result + # AND the status code should match the expected_status_code result = app(LOAD_GW_EVENT_HTTP, {}) - assert result["statusCode"] == 200 + assert result["statusCode"] == expected_status_code + # IF expected_error_text is provided, THEN check for its presence in the response body + if expected_error_text: + assert any(text in result["body"] for text in expected_error_text) -def test_validate_http_resolver_with_multi_query_values_fail(): - # GIVEN an APIGatewayHttpResolver with validation enabled - app = APIGatewayHttpResolver(enable_validation=True) - # WHEN a handler is defined with a default scalar parameter and a list with wrong type - @app.get("/users") - def handler(parameter1: Annotated[List[int], Query()], parameter2: str): - print(parameter2) +@pytest.mark.parametrize( + "handler_func, expected_status_code, expected_error_text", + [ + ("handler1_with_correct_params", 200, None), + ("handler2_with_wrong_params", 422, "['type_error.integer', 'int_parsing']"), + ("handler3_without_query_params", 200, None), + ], +) +def test_validation_query_string_with_alb_resolver(handler_func, expected_status_code, expected_error_text): + # GIVEN a ALBResolver with validation enabled + app = ALBResolver(enable_validation=True) - LOAD_GW_EVENT_HTTP["rawPath"] = "/users" - LOAD_GW_EVENT_HTTP["requestContext"]["http"]["method"] = "GET" - LOAD_GW_EVENT_HTTP["requestContext"]["http"]["path"] = "/users" + LOAD_GW_EVENT_ALB["path"] = "/users" + # WHEN a handler is defined with various parameters and routes - # THEN the handler should be invoked and return 422 - result = app(LOAD_GW_EVENT_HTTP, {}) - assert result["statusCode"] == 422 - assert any(text in result["body"] for text in ["type_error.integer", "int_parsing"]) + # Define handler1 with correct params + if handler_func == "handler1_with_correct_params": + @app.get("/users") + def handler1(parameter1: Annotated[List[str], Query()], parameter2: str): + print(parameter2) -def test_validate_http_resolver_without_query_params(): - # GIVEN an APIGatewayHttpResolver with validation enabled - app = APIGatewayHttpResolver(enable_validation=True) + # Define handler2 with wrong params + if handler_func == "handler2_with_wrong_params": + + @app.get("/users") + def handler2(parameter1: Annotated[List[int], Query()], parameter2: str): + print(parameter2) + + # Define handler3 without params + if handler_func == "handler3_without_query_params": + LOAD_GW_EVENT_HTTP["multiValueQueryStringParameters"] = None + + @app.get("/users") + def handler3(): + return None + + # THEN the handler should be invoked with the expected result + # AND the status code should match the expected_status_code + result = app(LOAD_GW_EVENT_ALB, {}) + assert result["statusCode"] == expected_status_code + + # IF expected_error_text is provided, THEN check for its presence in the response body + if expected_error_text: + assert any(text in result["body"] for text in expected_error_text) + + +@pytest.mark.parametrize( + "handler_func, expected_status_code, expected_error_text", + [ + ("handler1_with_correct_params", 200, None), + ("handler2_with_wrong_params", 422, "['type_error.integer', 'int_parsing']"), + ("handler3_without_query_params", 200, None), + ], +) +def test_validation_query_string_with_lambda_url_resolver(handler_func, expected_status_code, expected_error_text): + # GIVEN a LambdaFunctionUrlResolver with validation enabled + app = LambdaFunctionUrlResolver(enable_validation=True) + + LOAD_GW_EVENT_LAMBDA_URL["rawPath"] = "/users" + LOAD_GW_EVENT_LAMBDA_URL["requestContext"]["http"]["method"] = "GET" + LOAD_GW_EVENT_LAMBDA_URL["requestContext"]["http"]["path"] = "/users" + # WHEN a handler is defined with various parameters and routes + + # Define handler1 with correct params + if handler_func == "handler1_with_correct_params": + + @app.get("/users") + def handler1(parameter1: Annotated[List[str], Query()], parameter2: str): + print(parameter2) + + # Define handler2 with wrong params + if handler_func == "handler2_with_wrong_params": + + @app.get("/users") + def handler2(parameter1: Annotated[List[int], Query()], parameter2: str): + print(parameter2) + + # Define handler3 without params + if handler_func == "handler3_without_query_params": + LOAD_GW_EVENT_LAMBDA_URL["queryStringParameters"] = None + + @app.get("/users") + def handler3(): + return None + + # THEN the handler should be invoked with the expected result + # AND the status code should match the expected_status_code + result = app(LOAD_GW_EVENT_LAMBDA_URL, {}) + assert result["statusCode"] == expected_status_code + + # IF expected_error_text is provided, THEN check for its presence in the response body + if expected_error_text: + assert any(text in result["body"] for text in expected_error_text) + + +@pytest.mark.parametrize( + "handler_func, expected_status_code, expected_error_text", + [ + ("handler1_with_correct_params", 200, None), + ("handler2_with_wrong_params", 422, "['type_error.integer', 'int_parsing']"), + ("handler3_without_query_params", 200, None), + ], +) +def test_validation_query_string_with_vpc_lattice_resolver(handler_func, expected_status_code, expected_error_text): + # GIVEN a VPCLatticeV2Resolver with validation enabled + app = VPCLatticeV2Resolver(enable_validation=True) + + LOAD_GW_EVENT_VPC_LATTICE["path"] = "/users" + + # WHEN a handler is defined with various parameters and routes + + # Define handler1 with correct params + if handler_func == "handler1_with_correct_params": + + @app.get("/users") + def handler1(parameter1: Annotated[List[str], Query()], parameter2: str): + print(parameter2) + + # Define handler2 with wrong params + if handler_func == "handler2_with_wrong_params": + + @app.get("/users") + def handler2(parameter1: Annotated[List[int], Query()], parameter2: str): + print(parameter2) + + # Define handler3 without params + if handler_func == "handler3_without_query_params": + LOAD_GW_EVENT_VPC_LATTICE["queryStringParameters"] = None + + @app.get("/users") + def handler3(): + return None + + # THEN the handler should be invoked with the expected result + # AND the status code should match the expected_status_code + result = app(LOAD_GW_EVENT_VPC_LATTICE, {}) + assert result["statusCode"] == expected_status_code - # WHEN a handler is defined without any query params - @app.get("/users") - def handler(): - return None + # IF expected_error_text is provided, THEN check for its presence in the response body + if expected_error_text: + assert any(text in result["body"] for text in expected_error_text) + + +########### TEST WITH HEADER PARAMS +@pytest.mark.parametrize( + "handler_func, expected_status_code, expected_error_text", + [ + ("handler1_with_correct_params", 200, None), + ("handler2_with_wrong_params", 422, "['type_error.integer', 'int_parsing']"), + ("handler3_without_header_params", 200, None), + ], +) +def test_validation_header_with_api_rest_resolver(handler_func, expected_status_code, expected_error_text): + # GIVEN a APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + LOAD_GW_EVENT["httpMethod"] = "GET" + LOAD_GW_EVENT["path"] = "/users" + # WHEN a handler is defined with various parameters and routes + + # Define handler1 with correct params + if handler_func == "handler1_with_correct_params": + + @app.get("/users") + def handler1(Header2: Annotated[List[str], Header()], Header1: Annotated[str, Header()]): + print(Header2) + + # Define handler2 with wrong params + if handler_func == "handler2_with_wrong_params": + + @app.get("/users") + def handler1(Header2: Annotated[List[int], Header()], Header1: Annotated[str, Header()]): + print(Header2) + + # Define handler3 without params + if handler_func == "handler3_without_header_params": + LOAD_GW_EVENT["headers"] = None + LOAD_GW_EVENT["multiValueHeaders"] = None + + @app.get("/users") + def handler3(): + return None + + # THEN the handler should be invoked with the expected result + # AND the status code should match the expected_status_code + result = app(LOAD_GW_EVENT, {}) + assert result["statusCode"] == expected_status_code + + # IF expected_error_text is provided, THEN check for its presence in the response body + if expected_error_text: + assert any(text in result["body"] for text in expected_error_text) + + +@pytest.mark.parametrize( + "handler_func, expected_status_code, expected_error_text", + [ + ("handler1_with_correct_params", 200, None), + ("handler2_with_wrong_params", 422, "['type_error.integer', 'int_parsing']"), + ("handler3_without_header_params", 200, None), + ], +) +def test_validation_header_with_http_rest_resolver(handler_func, expected_status_code, expected_error_text): + # GIVEN a APIGatewayHttpResolver with validation enabled + app = APIGatewayHttpResolver(enable_validation=True) LOAD_GW_EVENT_HTTP["rawPath"] = "/users" LOAD_GW_EVENT_HTTP["requestContext"]["http"]["method"] = "GET" LOAD_GW_EVENT_HTTP["requestContext"]["http"]["path"] = "/users" - LOAD_GW_EVENT_HTTP["queryStringParameters"] = None + # WHEN a handler is defined with various parameters and routes - # THEN the handler should be invoked and return 200 - result = app(LOAD_GW_EVENT_HTTP, {}) - assert result["statusCode"] == 200 + # Define handler1 with correct params + if handler_func == "handler1_with_correct_params": + @app.get("/users") + def handler1(Header2: Annotated[List[str], Header()], Header1: Annotated[str, Header()]): + print(Header2) -def test_validate_alb_resolver_with_multi_query_values(): - # GIVEN an ALBResolver with validation enabled - app = ALBResolver(enable_validation=True) + # Define handler2 with wrong params + if handler_func == "handler2_with_wrong_params": - # WHEN a handler is defined with a default scalar parameter and a list - @app.get("/users") - def handler(parameter1: Annotated[List[str], Query()], parameter2: str): - print(parameter2) + @app.get("/users") + def handler1(Header2: Annotated[List[int], Header()], Header1: Annotated[str, Header()]): + print(Header2) - LOAD_GW_EVENT_ALB["path"] = "/users" + # Define handler3 without params + if handler_func == "handler3_without_header_params": + LOAD_GW_EVENT_HTTP["headers"] = None - # THEN the handler should be invoked and return 200 - result = app(LOAD_GW_EVENT_ALB, {}) - assert result["statusCode"] == 200 + @app.get("/users") + def handler3(): + return None + # THEN the handler should be invoked with the expected result + # AND the status code should match the expected_status_code + result = app(LOAD_GW_EVENT_HTTP, {}) + assert result["statusCode"] == expected_status_code -def test_validate_alb_resolver_with_multi_query_values_fail(): - # GIVEN an ALBResolver with validation enabled - app = ALBResolver(enable_validation=True) + # IF expected_error_text is provided, THEN check for its presence in the response body + if expected_error_text: + assert any(text in result["body"] for text in expected_error_text) - # WHEN a handler is defined with a default scalar parameter and a list with wrong type - @app.get("/users") - def handler(parameter1: Annotated[List[int], Query()], parameter2: str): - print(parameter2) + +@pytest.mark.parametrize( + "handler_func, expected_status_code, expected_error_text", + [ + ("handler1_with_correct_params", 200, None), + ("handler2_with_wrong_params", 422, "['type_error.integer', 'int_parsing']"), + ("handler3_without_header_params", 200, None), + ], +) +def test_validation_header_with_alb_resolver(handler_func, expected_status_code, expected_error_text): + # GIVEN a ALBResolver with validation enabled + app = ALBResolver(enable_validation=True) LOAD_GW_EVENT_ALB["path"] = "/users" + # WHEN a handler is defined with various parameters and routes - # THEN the handler should be invoked and return 422 - result = app(LOAD_GW_EVENT_ALB, {}) - assert result["statusCode"] == 422 - assert any(text in result["body"] for text in ["type_error.integer", "int_parsing"]) + # Define handler1 with correct params + if handler_func == "handler1_with_correct_params": + @app.get("/users") + def handler1(Header2: Annotated[List[str], Header()], Header1: Annotated[str, Header()]): + print(Header2) -def test_validate_alb_resolver_without_query_params(): - # GIVEN an ALBResolver with validation enabled - app = ALBResolver(enable_validation=True) + # Define handler2 with wrong params + if handler_func == "handler2_with_wrong_params": - # WHEN a handler is defined without any query params - @app.get("/users") - def handler(parameter1: Annotated[List[str], Query()], parameter2: str): - print(parameter2) + @app.get("/users") + def handler1(Header2: Annotated[List[int], Header()], Header1: Annotated[str, Header()]): + print(Header2) - LOAD_GW_EVENT_ALB["path"] = "/users" - LOAD_GW_EVENT_HTTP["multiValueQueryStringParameters"] = None + # Define handler3 without params + if handler_func == "handler3_without_header_params": + LOAD_GW_EVENT_ALB["multiValueHeaders"] = None - # THEN the handler should be invoked and return 200 + @app.get("/users") + def handler3(): + return None + + # THEN the handler should be invoked with the expected result + # AND the status code should match the expected_status_code result = app(LOAD_GW_EVENT_ALB, {}) - assert result["statusCode"] == 200 + assert result["statusCode"] == expected_status_code + # IF expected_error_text is provided, THEN check for its presence in the response body + if expected_error_text: + assert any(text in result["body"] for text in expected_error_text) -def test_validate_lambda_url_resolver_with_multi_query_params(): - # GIVEN an LambdaFunctionUrlResolver with validation enabled - app = LambdaFunctionUrlResolver(enable_validation=True) - # WHEN a handler is defined with a default scalar parameter and a list - @app.get("/users") - def handler(parameter1: Annotated[List[str], Query()], parameter2: str): - print(parameter2) +@pytest.mark.parametrize( + "handler_func, expected_status_code, expected_error_text", + [ + ("handler1_with_correct_params", 200, None), + ("handler2_with_wrong_params", 422, "['type_error.integer', 'int_parsing']"), + ("handler3_without_header_params", 200, None), + ], +) +def test_validation_header_with_lambda_url_resolver(handler_func, expected_status_code, expected_error_text): + # GIVEN a LambdaFunctionUrlResolver with validation enabled + app = LambdaFunctionUrlResolver(enable_validation=True) LOAD_GW_EVENT_LAMBDA_URL["rawPath"] = "/users" LOAD_GW_EVENT_LAMBDA_URL["requestContext"]["http"]["method"] = "GET" LOAD_GW_EVENT_LAMBDA_URL["requestContext"]["http"]["path"] = "/users" + # WHEN a handler is defined with various parameters and routes - # THEN the handler should be invoked and return 200 - result = app(LOAD_GW_EVENT_LAMBDA_URL, {}) - assert result["statusCode"] == 200 + # Define handler1 with correct params + if handler_func == "handler1_with_correct_params": + @app.get("/users") + def handler1(Header2: Annotated[List[str], Header()], Header1: Annotated[str, Header()]): + print(Header2) -def test_validate_lambda_url_resolver_with_multi_query_params_fail(): - # GIVEN an LambdaFunctionUrlResolver with validation enabled - app = LambdaFunctionUrlResolver(enable_validation=True) + # Define handler2 with wrong params + if handler_func == "handler2_with_wrong_params": - # WHEN a handler is defined with a default scalar parameter and a list with wrong type - @app.get("/users") - def handler(parameter1: Annotated[List[int], Query()], parameter2: str): - print(parameter2) + @app.get("/users") + def handler1(Header2: Annotated[List[int], Header()], Header1: Annotated[str, Header()]): + print(Header2) - LOAD_GW_EVENT_LAMBDA_URL["rawPath"] = "/users" - LOAD_GW_EVENT_LAMBDA_URL["requestContext"]["http"]["method"] = "GET" - LOAD_GW_EVENT_LAMBDA_URL["requestContext"]["http"]["path"] = "/users" + # Define handler3 without params + if handler_func == "handler3_without_header_params": + LOAD_GW_EVENT_LAMBDA_URL["headers"] = None - # THEN the handler should be invoked and return 422 + @app.get("/users") + def handler3(): + return None + + # THEN the handler should be invoked with the expected result + # AND the status code should match the expected_status_code result = app(LOAD_GW_EVENT_LAMBDA_URL, {}) - assert result["statusCode"] == 422 - assert any(text in result["body"] for text in ["type_error.integer", "int_parsing"]) + assert result["statusCode"] == expected_status_code + # IF expected_error_text is provided, THEN check for its presence in the response body + if expected_error_text: + assert any(text in result["body"] for text in expected_error_text) -def test_validate_lambda_url_resolver_without_query_params(): - # GIVEN an LambdaFunctionUrlResolver with validation enabled - app = LambdaFunctionUrlResolver(enable_validation=True) - # WHEN a handler is defined without any query params - @app.get("/users") - def handler(): - return None +@pytest.mark.parametrize( + "handler_func, expected_status_code, expected_error_text", + [ + ("handler1_with_correct_params", 200, None), + ("handler2_with_wrong_params", 422, "['type_error.integer', 'int_parsing']"), + ("handler3_without_header_params", 200, None), + ], +) +def test_validation_header_with_vpc_lattice_v1_resolver(handler_func, expected_status_code, expected_error_text): + # GIVEN a VPCLatticeResolver with validation enabled + app = VPCLatticeResolver(enable_validation=True) - LOAD_GW_EVENT_LAMBDA_URL["rawPath"] = "/users" - LOAD_GW_EVENT_LAMBDA_URL["requestContext"]["http"]["method"] = "GET" - LOAD_GW_EVENT_LAMBDA_URL["requestContext"]["http"]["path"] = "/users" - LOAD_GW_EVENT_LAMBDA_URL["queryStringParameters"] = None + LOAD_GW_EVENT_VPC_LATTICE_V1["raw_path"] = "/users" + LOAD_GW_EVENT_VPC_LATTICE_V1["method"] = "GET" + # WHEN a handler is defined with various parameters and routes - # THEN the handler should be invoked and return 200 - result = app(LOAD_GW_EVENT_LAMBDA_URL, {}) - assert result["statusCode"] == 200 + # Define handler1 with correct params + if handler_func == "handler1_with_correct_params": + @app.get("/users") + def handler1(Header2: Annotated[List[str], Header()], Header1: Annotated[str, Header()]): + print(Header2) -def test_validate_vpc_lattice_resolver_with_multi_params_values(): - # GIVEN an VPCLatticeV2Resolver with validation enabled - app = VPCLatticeV2Resolver(enable_validation=True) + # Define handler2 with wrong params + if handler_func == "handler2_with_wrong_params": - # WHEN a handler is defined with a default scalar parameter and a list - @app.get("/users") - def handler(parameter1: Annotated[List[str], Query()], parameter2: str): - print(parameter2) + @app.get("/users") + def handler1(Header2: Annotated[List[int], Header()], Header1: Annotated[str, Header()]): + print(Header2) - LOAD_GW_EVENT_VPC_LATTICE["path"] = "/users" + # Define handler3 without params + if handler_func == "handler3_without_header_params": + LOAD_GW_EVENT_VPC_LATTICE_V1["headers"] = None - # THEN the handler should be invoked and return 200 - result = app(LOAD_GW_EVENT_VPC_LATTICE, {}) - assert result["statusCode"] == 200 + @app.get("/users") + def handler3(): + return None + # THEN the handler should be invoked with the expected result + # AND the status code should match the expected_status_code + result = app(LOAD_GW_EVENT_VPC_LATTICE_V1, {}) + assert result["statusCode"] == expected_status_code + + # IF expected_error_text is provided, THEN check for its presence in the response body + if expected_error_text: + assert any(text in result["body"] for text in expected_error_text) -def test_validate_vpc_lattice_resolver_with_multi_query_params_fail(): - # GIVEN an VPCLatticeV2Resolver with validation enabled - app = VPCLatticeV2Resolver(enable_validation=True) - # WHEN a handler is defined with a default scalar parameter and a list with wrong type - @app.get("/users") - def handler(parameter1: Annotated[List[int], Query()], parameter2: str): - print(parameter2) +@pytest.mark.parametrize( + "handler_func, expected_status_code, expected_error_text", + [ + ("handler1_with_correct_params", 200, None), + ("handler2_with_wrong_params", 422, "['type_error.integer', 'int_parsing']"), + ("handler3_without_header_params", 200, None), + ], +) +def test_validation_header_with_vpc_lattice_v2_resolver(handler_func, expected_status_code, expected_error_text): + # GIVEN a VPCLatticeV2Resolver with validation enabled + app = VPCLatticeV2Resolver(enable_validation=True) LOAD_GW_EVENT_VPC_LATTICE["path"] = "/users" + LOAD_GW_EVENT_VPC_LATTICE["method"] = "GET" + # WHEN a handler is defined with various parameters and routes - # THEN the handler should be invoked and return 422 - result = app(LOAD_GW_EVENT_VPC_LATTICE, {}) - assert result["statusCode"] == 422 - assert any(text in result["body"] for text in ["type_error.integer", "int_parsing"]) + # Define handler1 with correct params + if handler_func == "handler1_with_correct_params": + @app.get("/users") + def handler1(Header2: Annotated[List[str], Header()], Header1: Annotated[str, Header()]): + print(Header2) -def test_validate_vpc_lattice_resolver_without_query_params(): - # GIVEN an VPCLatticeV2Resolver with validation enabled - app = VPCLatticeV2Resolver(enable_validation=True) + # Define handler2 with wrong params + if handler_func == "handler2_with_wrong_params": - # WHEN a handler is defined without any query params - @app.get("/users") - def handler(): - return None + @app.get("/users") + def handler1(Header2: Annotated[List[int], Header()], Header1: Annotated[str, Header()]): + print(Header2) - LOAD_GW_EVENT_VPC_LATTICE["path"] = "/users" - LOAD_GW_EVENT_VPC_LATTICE["queryStringParameters"] = None + # Define handler3 without params + if handler_func == "handler3_without_header_params": + LOAD_GW_EVENT_VPC_LATTICE["headers"] = None - # THEN the handler should be invoked and return 200 + @app.get("/users") + def handler3(): + return None + + # THEN the handler should be invoked with the expected result + # AND the status code should match the expected_status_code result = app(LOAD_GW_EVENT_VPC_LATTICE, {}) - assert result["statusCode"] == 200 + assert result["statusCode"] == expected_status_code + + # IF expected_error_text is provided, THEN check for its presence in the response body + if expected_error_text: + assert any(text in result["body"] for text in expected_error_text) From 3c8f0897e5311b8456f1ffce518497a3dd90cc58 Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Wed, 31 Jan 2024 15:59:12 +0000 Subject: [PATCH 04/13] Making sonarqube happy --- .../events/albMultiValueQueryStringEvent.json | 4 +- tests/events/apiGatewayProxyEvent.json | 8 +-- tests/events/apiGatewayProxyV2Event.json | 4 +- .../lambdaFunctionUrlEventWithHeaders.json | 4 +- tests/events/vpcLatticeEvent.json | 4 +- .../events/vpcLatticeV2EventWithHeaders.json | 4 +- .../test_openapi_validation_middleware.py | 50 +++++++++---------- 7 files changed, 39 insertions(+), 39 deletions(-) diff --git a/tests/events/albMultiValueQueryStringEvent.json b/tests/events/albMultiValueQueryStringEvent.json index 7fb537794b1..d5cdf18f023 100644 --- a/tests/events/albMultiValueQueryStringEvent.json +++ b/tests/events/albMultiValueQueryStringEvent.json @@ -14,11 +14,11 @@ "accept": [ "*/*" ], - "Header2": [ + "header2": [ "value1", "value2" ], - "Header1": [ + "header1": [ "value1" ], "host": [ diff --git a/tests/events/apiGatewayProxyEvent.json b/tests/events/apiGatewayProxyEvent.json index da814c91100..435b20e1ab1 100644 --- a/tests/events/apiGatewayProxyEvent.json +++ b/tests/events/apiGatewayProxyEvent.json @@ -4,15 +4,15 @@ "path": "/my/path", "httpMethod": "GET", "headers": { - "Header1": "value1", - "Header2": "value2", + "header1": "value1", + "header2": "value2", "Origin": "https://aws.amazon.com" }, "multiValueHeaders": { - "Header1": [ + "header1": [ "value1" ], - "Header2": [ + "header2": [ "value1", "value2" ] diff --git a/tests/events/apiGatewayProxyV2Event.json b/tests/events/apiGatewayProxyV2Event.json index 9de632b8e3d..2c46b6de564 100644 --- a/tests/events/apiGatewayProxyV2Event.json +++ b/tests/events/apiGatewayProxyV2Event.json @@ -8,8 +8,8 @@ "cookie2" ], "headers": { - "Header1": "value1", - "Header2": "value1,value2" + "header1": "value1", + "header2": "value1,value2" }, "queryStringParameters": { "parameter1": "value1,value2", diff --git a/tests/events/lambdaFunctionUrlEventWithHeaders.json b/tests/events/lambdaFunctionUrlEventWithHeaders.json index a2b3d9ef147..d1cc50630a8 100644 --- a/tests/events/lambdaFunctionUrlEventWithHeaders.json +++ b/tests/events/lambdaFunctionUrlEventWithHeaders.json @@ -24,8 +24,8 @@ "accept-encoding":"gzip, deflate, br", "sec-fetch-dest":"document", "user-agent":"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/102.0.0.0 Safari/537.36", - "Header1": "value1", - "Header2": "value1,value2" + "header1": "value1", + "header2": "value1,value2" }, "queryStringParameters": { "parameter1": "value1,value2", diff --git a/tests/events/vpcLatticeEvent.json b/tests/events/vpcLatticeEvent.json index b00b9d3d7f3..fa9031f7dc4 100644 --- a/tests/events/vpcLatticeEvent.json +++ b/tests/events/vpcLatticeEvent.json @@ -6,8 +6,8 @@ "x-forwarded-for": "10.213.229.10", "host": "test-lambda-service-3908sdf9u3u.dkfjd93.vpc-lattice-svcs.us-east-2.on.aws", "accept": "*/*", - "Header1": "value1", - "Header2": "value1,value2" + "header1": "value1", + "header2": "value1,value2" }, "query_string_parameters": { "order-id": "1" diff --git a/tests/events/vpcLatticeV2EventWithHeaders.json b/tests/events/vpcLatticeV2EventWithHeaders.json index fdaf7dc7891..40e5f17cfe6 100644 --- a/tests/events/vpcLatticeV2EventWithHeaders.json +++ b/tests/events/vpcLatticeV2EventWithHeaders.json @@ -19,10 +19,10 @@ "host":[ "lattice-svc-027b423199122da5f.7d67968.vpc-lattice-svcs.us-east-1.on.aws" ], - "Header1": [ + "header1": [ "value1" ], - "Header2": [ + "header2": [ "value1", "value2" ] diff --git a/tests/functional/event_handler/test_openapi_validation_middleware.py b/tests/functional/event_handler/test_openapi_validation_middleware.py index 9c66a24eb34..7e3e27ae1d1 100644 --- a/tests/functional/event_handler/test_openapi_validation_middleware.py +++ b/tests/functional/event_handler/test_openapi_validation_middleware.py @@ -13,10 +13,10 @@ APIGatewayRestResolver, LambdaFunctionUrlResolver, Response, + VPCLatticeResolver, VPCLatticeV2Resolver, ) from aws_lambda_powertools.event_handler.openapi.params import Body, Header, Query -from aws_lambda_powertools.event_handler.vpc_lattice import VPCLatticeResolver from aws_lambda_powertools.shared.types import Annotated from tests.functional.utils import load_event @@ -658,15 +658,15 @@ def test_validation_header_with_api_rest_resolver(handler_func, expected_status_ if handler_func == "handler1_with_correct_params": @app.get("/users") - def handler1(Header2: Annotated[List[str], Header()], Header1: Annotated[str, Header()]): - print(Header2) + def handler1(header2: Annotated[List[str], Header()], header1: Annotated[str, Header()]): + print(header2) # Define handler2 with wrong params if handler_func == "handler2_with_wrong_params": @app.get("/users") - def handler1(Header2: Annotated[List[int], Header()], Header1: Annotated[str, Header()]): - print(Header2) + def handler1(header2: Annotated[List[int], Header()], header1: Annotated[str, Header()]): + print(header2) # Define handler3 without params if handler_func == "handler3_without_header_params": @@ -708,15 +708,15 @@ def test_validation_header_with_http_rest_resolver(handler_func, expected_status if handler_func == "handler1_with_correct_params": @app.get("/users") - def handler1(Header2: Annotated[List[str], Header()], Header1: Annotated[str, Header()]): - print(Header2) + def handler1(header2: Annotated[List[str], Header()], header1: Annotated[str, Header()]): + print(header2) # Define handler2 with wrong params if handler_func == "handler2_with_wrong_params": @app.get("/users") - def handler1(Header2: Annotated[List[int], Header()], Header1: Annotated[str, Header()]): - print(Header2) + def handler1(header2: Annotated[List[int], Header()], header1: Annotated[str, Header()]): + print(header2) # Define handler3 without params if handler_func == "handler3_without_header_params": @@ -755,15 +755,15 @@ def test_validation_header_with_alb_resolver(handler_func, expected_status_code, if handler_func == "handler1_with_correct_params": @app.get("/users") - def handler1(Header2: Annotated[List[str], Header()], Header1: Annotated[str, Header()]): - print(Header2) + def handler1(header2: Annotated[List[str], Header()], header1: Annotated[str, Header()]): + print(header2) # Define handler2 with wrong params if handler_func == "handler2_with_wrong_params": @app.get("/users") - def handler1(Header2: Annotated[List[int], Header()], Header1: Annotated[str, Header()]): - print(Header2) + def handler1(header2: Annotated[List[int], Header()], header1: Annotated[str, Header()]): + print(header2) # Define handler3 without params if handler_func == "handler3_without_header_params": @@ -804,15 +804,15 @@ def test_validation_header_with_lambda_url_resolver(handler_func, expected_statu if handler_func == "handler1_with_correct_params": @app.get("/users") - def handler1(Header2: Annotated[List[str], Header()], Header1: Annotated[str, Header()]): - print(Header2) + def handler1(header2: Annotated[List[str], Header()], header1: Annotated[str, Header()]): + print(header2) # Define handler2 with wrong params if handler_func == "handler2_with_wrong_params": @app.get("/users") - def handler1(Header2: Annotated[List[int], Header()], Header1: Annotated[str, Header()]): - print(Header2) + def handler1(header2: Annotated[List[int], Header()], header1: Annotated[str, Header()]): + print(header2) # Define handler3 without params if handler_func == "handler3_without_header_params": @@ -852,15 +852,15 @@ def test_validation_header_with_vpc_lattice_v1_resolver(handler_func, expected_s if handler_func == "handler1_with_correct_params": @app.get("/users") - def handler1(Header2: Annotated[List[str], Header()], Header1: Annotated[str, Header()]): - print(Header2) + def handler1(header2: Annotated[List[str], Header()], header1: Annotated[str, Header()]): + print(header2) # Define handler2 with wrong params if handler_func == "handler2_with_wrong_params": @app.get("/users") - def handler1(Header2: Annotated[List[int], Header()], Header1: Annotated[str, Header()]): - print(Header2) + def handler1(header2: Annotated[List[int], Header()], header1: Annotated[str, Header()]): + print(header2) # Define handler3 without params if handler_func == "handler3_without_header_params": @@ -900,15 +900,15 @@ def test_validation_header_with_vpc_lattice_v2_resolver(handler_func, expected_s if handler_func == "handler1_with_correct_params": @app.get("/users") - def handler1(Header2: Annotated[List[str], Header()], Header1: Annotated[str, Header()]): - print(Header2) + def handler1(header2: Annotated[List[str], Header()], header1: Annotated[str, Header()]): + print(header2) # Define handler2 with wrong params if handler_func == "handler2_with_wrong_params": @app.get("/users") - def handler1(Header2: Annotated[List[int], Header()], Header1: Annotated[str, Header()]): - print(Header2) + def handler1(header2: Annotated[List[int], Header()], header1: Annotated[str, Header()]): + print(header2) # Define handler3 without params if handler_func == "handler3_without_header_params": From 921a99d88289916141e0f591c76a179fd33e18a7 Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Wed, 31 Jan 2024 18:10:08 +0000 Subject: [PATCH 05/13] Adding documentation --- docs/core/event_handler/api_gateway.md | 32 +++++++++++++++++ .../src/validating_headers.py | 34 +++++++++++++++++++ .../src/working_with_headers_multi_value.py | 34 +++++++++++++++++++ 3 files changed, 100 insertions(+) create mode 100644 examples/event_handler_rest/src/validating_headers.py create mode 100644 examples/event_handler_rest/src/working_with_headers_multi_value.py diff --git a/docs/core/event_handler/api_gateway.md b/docs/core/event_handler/api_gateway.md index 86b97c87e4b..0f30369097c 100644 --- a/docs/core/event_handler/api_gateway.md +++ b/docs/core/event_handler/api_gateway.md @@ -424,6 +424,38 @@ For example, we could validate that `` dynamic path should be no greate 1. `Path` is a special OpenAPI type that allows us to constrain todo_id to be less than 999. +#### Validating headers + +We use the `Annotated` type to tell Event Handler that a particular parameter is a header that needs to be validated. + +In the following example, we use a new `Header` OpenAPI type to add [one out of many possible constraints](#customizing-openapi-parameters), which should read as: + +* `correlation_id` is a header that must be present in the request +* `correlation_id`, when set, should have 16 characters +* Doesn't match? Event Handler will return a validation error response + + + +=== "validating_headers.py" + + ```python hl_lines="8 10 29" + --8<-- "examples/event_handler_rest/src/validating_headers.py" + ``` + + 1. If you're not using Python 3.9 or higher, you can install and use [`typing_extensions`](https://pypi.org/project/typing-extensions/){target="_blank" rel="nofollow"} to the same effect + 2. `Header` is a special OpenAPI type that can add constraints to a header well as document them + 3. **First time seeing the `Annotated`?**

This special type uses the first argument as the actual type, and subsequent arguments are metadata.

At runtime, static checkers will also see the first argument, but anyone receiving them could inspect them to fetch their metadata. + +=== "working_with_headers_multi_value.py" + + If you need to handle multi-value for specific headers, you can create a list of the desired type. + + ```python hl_lines="23" + --8<-- "examples/event_handler_rest/src/working_with_headers_multi_value.py" + ``` + + 1. `cloudfront_viewer_country` is a list that must contain values from the `CountriesAllowed` enumeration. + ### Accessing request details Event Handler integrates with [Event Source Data Classes utilities](../../utilities/data_classes.md){target="_blank"}, and it exposes their respective resolver request details and convenient methods under `app.current_event`. diff --git a/examples/event_handler_rest/src/validating_headers.py b/examples/event_handler_rest/src/validating_headers.py new file mode 100644 index 00000000000..956fd58b14d --- /dev/null +++ b/examples/event_handler_rest/src/validating_headers.py @@ -0,0 +1,34 @@ +from enum import Enum +from typing import List + +from aws_lambda_powertools.event_handler import APIGatewayRestResolver +from aws_lambda_powertools.event_handler.openapi.params import Header +from aws_lambda_powertools.shared.types import Annotated +from aws_lambda_powertools.utilities.typing import LambdaContext + +app = APIGatewayRestResolver(enable_validation=True) + + +class CountriesAllowed(Enum): + """Example of an Enum class.""" + + US = "US" + PT = "PT" + BR = "BR" + + +@app.get("/hello") +def get( + cloudfront_viewer_country: Annotated[ + List[CountriesAllowed], # (1)! + Header( + description="This is multi value header parameter.", + ), + ], +): + """Return validated multi-value header values.""" + return cloudfront_viewer_country + + +def lambda_handler(event: dict, context: LambdaContext) -> dict: + return app.resolve(event, context) diff --git a/examples/event_handler_rest/src/working_with_headers_multi_value.py b/examples/event_handler_rest/src/working_with_headers_multi_value.py new file mode 100644 index 00000000000..69f31b60762 --- /dev/null +++ b/examples/event_handler_rest/src/working_with_headers_multi_value.py @@ -0,0 +1,34 @@ +from enum import Enum +from typing import List + +from aws_lambda_powertools.event_handler import APIGatewayRestResolver +from aws_lambda_powertools.event_handler.openapi.params import Header +from aws_lambda_powertools.shared.types import Annotated +from aws_lambda_powertools.utilities.typing import LambdaContext + +app = APIGatewayRestResolver(enable_validation=True) + + +class CountriesAllowed(Enum): + """Example of an Enum class.""" + + US = "US" + PT = "PT" + BR = "BR" + + +@app.get("/todos") +def get( + example_headers_multi_value: Annotated[ + List[CountriesAllowed], # (1)! + Header( + description="This is multi value header parameter.", + ), + ], +): + """Return validated multi-value header values.""" + return example_headers_multi_value + + +def lambda_handler(event: dict, context: LambdaContext) -> dict: + return app.resolve(event, context) From 2f1099061bd4e0e1b68bfe9070fdd342bf3fa396 Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Thu, 1 Feb 2024 02:44:25 +0000 Subject: [PATCH 06/13] Rafactoring to be complaint with RFC --- .../middlewares/openapi_validation.py | 18 ++++++++++++++---- .../utilities/data_classes/alb_event.py | 8 ++++++-- .../data_classes/api_gateway_proxy_event.py | 10 +++++++--- .../utilities/data_classes/common.py | 4 ++++ .../utilities/data_classes/vpc_lattice.py | 7 +++++-- 5 files changed, 36 insertions(+), 11 deletions(-) diff --git a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py index 4d8aedfe58e..0a76230647b 100644 --- a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py +++ b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py @@ -19,7 +19,7 @@ from aws_lambda_powertools.event_handler.openapi.dependant import is_scalar_field from aws_lambda_powertools.event_handler.openapi.encoders import jsonable_encoder from aws_lambda_powertools.event_handler.openapi.exceptions import RequestValidationError -from aws_lambda_powertools.event_handler.openapi.params import Param +from aws_lambda_powertools.event_handler.openapi.params import Param, ParamTypes from aws_lambda_powertools.event_handler.openapi.types import IncEx from aws_lambda_powertools.event_handler.types import EventHandlerInstance @@ -256,13 +256,23 @@ def _request_params_to_args( errors = [] for field in required_params: - value = received_params.get(field.alias) - field_info = field.field_info + + # To ensure early failure, we check if it's not an instance of Param. if not isinstance(field_info, Param): raise AssertionError(f"Expected Param field_info, got {field_info}") - loc = (field_info.in_.value, field.alias) + field_alias = field.alias + + if field_info.in_ == ParamTypes.header and field_alias: + # Headers are case-insensitive according to RFC 7540 (HTTP/2), so we lower the parameter name + # This ensures that customers can access headers with any casing, as per the RFC guidelines. + # Reference: https://www.rfc-editor.org/rfc/rfc7540#section-8.1.2 + field_alias = field.alias.lower() + + value = received_params.get(field_alias) + + loc = (field_info.in_.value, field_alias) # If we don't have a value, see if it's required or has a default if value is None: diff --git a/aws_lambda_powertools/utilities/data_classes/alb_event.py b/aws_lambda_powertools/utilities/data_classes/alb_event.py index 6e17146ec37..98f37b4f415 100644 --- a/aws_lambda_powertools/utilities/data_classes/alb_event.py +++ b/aws_lambda_powertools/utilities/data_classes/alb_event.py @@ -44,10 +44,14 @@ def resolved_query_string_parameters(self) -> Optional[Dict[str, Any]]: @property def resolved_headers_field(self) -> Optional[Dict[str, Any]]: + headers: Dict[str, Any] = {} + if self.multi_value_headers: - return self.multi_value_headers + headers = self.multi_value_headers + else: + headers = self.headers - return self.headers + return {key.lower(): value for key, value in headers.items()} @property def multi_value_headers(self) -> Optional[Dict[str, List[str]]]: diff --git a/aws_lambda_powertools/utilities/data_classes/api_gateway_proxy_event.py b/aws_lambda_powertools/utilities/data_classes/api_gateway_proxy_event.py index 26951d61201..c37bd22ca53 100644 --- a/aws_lambda_powertools/utilities/data_classes/api_gateway_proxy_event.py +++ b/aws_lambda_powertools/utilities/data_classes/api_gateway_proxy_event.py @@ -127,10 +127,14 @@ def resolved_query_string_parameters(self) -> Optional[Dict[str, Any]]: @property def resolved_headers_field(self) -> Optional[Dict[str, Any]]: + headers: Dict[str, Any] = {} + if self.multi_value_headers: - return self.multi_value_headers + headers = self.multi_value_headers + else: + headers = self.headers - return self.headers + return {key.lower(): value for key, value in headers.items()} @property def request_context(self) -> APIGatewayEventRequestContext: @@ -327,7 +331,7 @@ def resolved_query_string_parameters(self) -> Optional[Dict[str, Any]]: @property def resolved_headers_field(self) -> Optional[Dict[str, Any]]: if self.headers is not None: - headers = {key: value.split(",") if "," in value else value for key, value in self.headers.items()} + headers = {key.lower(): value.split(",") if "," in value else value for key, value in self.headers.items()} return headers return {} diff --git a/aws_lambda_powertools/utilities/data_classes/common.py b/aws_lambda_powertools/utilities/data_classes/common.py index 3f4f47c5a2a..0560159ecc5 100644 --- a/aws_lambda_powertools/utilities/data_classes/common.py +++ b/aws_lambda_powertools/utilities/data_classes/common.py @@ -122,6 +122,10 @@ def resolved_headers_field(self) -> Optional[Dict[str, Any]]: This is necessary because different resolvers use different formats to encode headers parameters. + + Headers are case-insensitive according to RFC 7540 (HTTP/2), so we lower the header name + This ensures that customers can access headers with any casing, as per the RFC guidelines. + Reference: https://www.rfc-editor.org/rfc/rfc7540#section-8.1.2 """ return self.headers diff --git a/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py b/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py index a17b8c96b1c..f12c53d841a 100644 --- a/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py +++ b/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py @@ -148,7 +148,7 @@ def resolved_query_string_parameters(self) -> Optional[Dict[str, str]]: @property def resolved_headers_field(self) -> Optional[Dict[str, Any]]: if self.headers is not None: - headers = {key: value.split(",") if "," in value else value for key, value in self.headers.items()} + headers = {key.lower(): value.split(",") if "," in value else value for key, value in self.headers.items()} return headers return {} @@ -270,4 +270,7 @@ def resolved_query_string_parameters(self) -> Optional[Dict[str, str]]: @property def resolved_headers_field(self) -> Optional[Dict[str, str]]: - return self.headers + if self.headers is not None: + return {key.lower(): value for key, value in self.headers.items()} + + return {} From 5b2bf85557c4b2748877913e56e16ac49b5b49e3 Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Thu, 1 Feb 2024 08:54:40 +0000 Subject: [PATCH 07/13] Adding tests --- tests/events/apiGatewayProxyEvent.json | 8 ++++---- .../test_openapi_validation_middleware.py | 16 ++++++++++++---- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/tests/events/apiGatewayProxyEvent.json b/tests/events/apiGatewayProxyEvent.json index 435b20e1ab1..da814c91100 100644 --- a/tests/events/apiGatewayProxyEvent.json +++ b/tests/events/apiGatewayProxyEvent.json @@ -4,15 +4,15 @@ "path": "/my/path", "httpMethod": "GET", "headers": { - "header1": "value1", - "header2": "value2", + "Header1": "value1", + "Header2": "value2", "Origin": "https://aws.amazon.com" }, "multiValueHeaders": { - "header1": [ + "Header1": [ "value1" ], - "header2": [ + "Header2": [ "value1", "value2" ] diff --git a/tests/functional/event_handler/test_openapi_validation_middleware.py b/tests/functional/event_handler/test_openapi_validation_middleware.py index 6b19fd10bd0..39158de9b90 100644 --- a/tests/functional/event_handler/test_openapi_validation_middleware.py +++ b/tests/functional/event_handler/test_openapi_validation_middleware.py @@ -669,7 +669,8 @@ def handler3(): [ ("handler1_with_correct_params", 200, None), ("handler2_with_wrong_params", 422, "['type_error.integer', 'int_parsing']"), - ("handler3_without_header_params", 200, None), + ("handler3_with_uppercase_params", 200, None), + ("handler4_without_header_params", 200, None), ], ) def test_validation_header_with_api_rest_resolver(handler_func, expected_status_code, expected_error_text): @@ -694,13 +695,20 @@ def handler1(header2: Annotated[List[str], Header()], header1: Annotated[str, He def handler1(header2: Annotated[List[int], Header()], header1: Annotated[str, Header()]): print(header2) - # Define handler3 without params - if handler_func == "handler3_without_header_params": + # Define handler3 with uppercase parameters + if handler_func == "handler3_with_uppercase_params": + + @app.get("/users") + def handler3(Header2: Annotated[List[str], Header()], Header1: Annotated[str, Header()]): + print(Header2) + + # Define handler4 without params + if handler_func == "handler4_without_header_params": LOAD_GW_EVENT["headers"] = None LOAD_GW_EVENT["multiValueHeaders"] = None @app.get("/users") - def handler3(): + def handler4(): return None # THEN the handler should be invoked with the expected result From a5a0c1d3f8072ef92579a3a8da70ffc49ed49117 Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Thu, 1 Feb 2024 12:52:06 +0000 Subject: [PATCH 08/13] Adding test with Uppercase variables --- .../middlewares/openapi_validation.py | 5 +- .../test_openapi_validation_middleware.py | 88 ++++++++++++++----- 2 files changed, 67 insertions(+), 26 deletions(-) diff --git a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py index 0a76230647b..b3cada8426e 100644 --- a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py +++ b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py @@ -419,11 +419,12 @@ def _normalize_multi_header_values_with_param(headers: Optional[Dict[str, str]], """ if headers: for param in filter(is_scalar_field, params): + param_name = param.name.lower() try: - if len(headers[param.name]) == 1: + if len(headers[param_name]) == 1: # if the target parameter is a scalar and the list contains only 1 element # we keep the first value of the headers regardless if there are more in the payload - headers[param.name] = headers[param.name][0] + headers[param_name] = headers[param_name][0] except KeyError: pass return headers diff --git a/tests/functional/event_handler/test_openapi_validation_middleware.py b/tests/functional/event_handler/test_openapi_validation_middleware.py index 39158de9b90..b22137ac0d4 100644 --- a/tests/functional/event_handler/test_openapi_validation_middleware.py +++ b/tests/functional/event_handler/test_openapi_validation_middleware.py @@ -692,7 +692,7 @@ def handler1(header2: Annotated[List[str], Header()], header1: Annotated[str, He if handler_func == "handler2_with_wrong_params": @app.get("/users") - def handler1(header2: Annotated[List[int], Header()], header1: Annotated[str, Header()]): + def handler2(header2: Annotated[List[int], Header()], header1: Annotated[str, Header()]): print(header2) # Define handler3 with uppercase parameters @@ -726,7 +726,8 @@ def handler4(): [ ("handler1_with_correct_params", 200, None), ("handler2_with_wrong_params", 422, "['type_error.integer', 'int_parsing']"), - ("handler3_without_header_params", 200, None), + ("handler3_with_uppercase_params", 200, None), + ("handler4_without_header_params", 200, None), ], ) def test_validation_header_with_http_rest_resolver(handler_func, expected_status_code, expected_error_text): @@ -749,15 +750,22 @@ def handler1(header2: Annotated[List[str], Header()], header1: Annotated[str, He if handler_func == "handler2_with_wrong_params": @app.get("/users") - def handler1(header2: Annotated[List[int], Header()], header1: Annotated[str, Header()]): + def handler2(header2: Annotated[List[int], Header()], header1: Annotated[str, Header()]): print(header2) - # Define handler3 without params - if handler_func == "handler3_without_header_params": + # Define handler3 with uppercase parameters + if handler_func == "handler3_with_uppercase_params": + + @app.get("/users") + def handler3(Header2: Annotated[List[str], Header()], Header1: Annotated[str, Header()]): + print(Header2) + + # Define handler4 without params + if handler_func == "handler4_without_header_params": LOAD_GW_EVENT_HTTP["headers"] = None @app.get("/users") - def handler3(): + def handler4(): return None # THEN the handler should be invoked with the expected result @@ -775,7 +783,8 @@ def handler3(): [ ("handler1_with_correct_params", 200, None), ("handler2_with_wrong_params", 422, "['type_error.integer', 'int_parsing']"), - ("handler3_without_header_params", 200, None), + ("handler3_with_uppercase_params", 200, None), + ("handler4_without_header_params", 200, None), ], ) def test_validation_header_with_alb_resolver(handler_func, expected_status_code, expected_error_text): @@ -796,15 +805,22 @@ def handler1(header2: Annotated[List[str], Header()], header1: Annotated[str, He if handler_func == "handler2_with_wrong_params": @app.get("/users") - def handler1(header2: Annotated[List[int], Header()], header1: Annotated[str, Header()]): + def handler2(header2: Annotated[List[int], Header()], header1: Annotated[str, Header()]): print(header2) - # Define handler3 without params - if handler_func == "handler3_without_header_params": + # Define handler3 with uppercase parameters + if handler_func == "handler3_with_uppercase_params": + + @app.get("/users") + def handler3(Header2: Annotated[List[str], Header()], Header1: Annotated[str, Header()]): + print(Header2) + + # Define handler4 without params + if handler_func == "handler4_without_header_params": LOAD_GW_EVENT_ALB["multiValueHeaders"] = None @app.get("/users") - def handler3(): + def handler4(): return None # THEN the handler should be invoked with the expected result @@ -822,7 +838,8 @@ def handler3(): [ ("handler1_with_correct_params", 200, None), ("handler2_with_wrong_params", 422, "['type_error.integer', 'int_parsing']"), - ("handler3_without_header_params", 200, None), + ("handler3_with_uppercase_params", 200, None), + ("handler4_without_header_params", 200, None), ], ) def test_validation_header_with_lambda_url_resolver(handler_func, expected_status_code, expected_error_text): @@ -845,15 +862,22 @@ def handler1(header2: Annotated[List[str], Header()], header1: Annotated[str, He if handler_func == "handler2_with_wrong_params": @app.get("/users") - def handler1(header2: Annotated[List[int], Header()], header1: Annotated[str, Header()]): + def handler2(header2: Annotated[List[int], Header()], header1: Annotated[str, Header()]): print(header2) - # Define handler3 without params - if handler_func == "handler3_without_header_params": + # Define handler3 with uppercase parameters + if handler_func == "handler3_with_uppercase_params": + + @app.get("/users") + def handler3(Header2: Annotated[List[str], Header()], Header1: Annotated[str, Header()]): + print(Header2) + + # Define handler4 without params + if handler_func == "handler4_without_header_params": LOAD_GW_EVENT_LAMBDA_URL["headers"] = None @app.get("/users") - def handler3(): + def handler4(): return None # THEN the handler should be invoked with the expected result @@ -871,7 +895,8 @@ def handler3(): [ ("handler1_with_correct_params", 200, None), ("handler2_with_wrong_params", 422, "['type_error.integer', 'int_parsing']"), - ("handler3_without_header_params", 200, None), + ("handler3_with_uppercase_params", 200, None), + ("handler4_without_header_params", 200, None), ], ) def test_validation_header_with_vpc_lattice_v1_resolver(handler_func, expected_status_code, expected_error_text): @@ -893,15 +918,22 @@ def handler1(header2: Annotated[List[str], Header()], header1: Annotated[str, He if handler_func == "handler2_with_wrong_params": @app.get("/users") - def handler1(header2: Annotated[List[int], Header()], header1: Annotated[str, Header()]): + def handler2(header2: Annotated[List[int], Header()], header1: Annotated[str, Header()]): print(header2) - # Define handler3 without params - if handler_func == "handler3_without_header_params": + # Define handler3 with uppercase parameters + if handler_func == "handler3_with_uppercase_params": + + @app.get("/users") + def handler3(Header2: Annotated[List[str], Header()], Header1: Annotated[str, Header()]): + print(Header2) + + # Define handler4 without params + if handler_func == "handler4_without_header_params": LOAD_GW_EVENT_VPC_LATTICE_V1["headers"] = None @app.get("/users") - def handler3(): + def handler4(): return None # THEN the handler should be invoked with the expected result @@ -919,7 +951,8 @@ def handler3(): [ ("handler1_with_correct_params", 200, None), ("handler2_with_wrong_params", 422, "['type_error.integer', 'int_parsing']"), - ("handler3_without_header_params", 200, None), + ("handler3_with_uppercase_params", 200, None), + ("handler4_without_header_params", 200, None), ], ) def test_validation_header_with_vpc_lattice_v2_resolver(handler_func, expected_status_code, expected_error_text): @@ -944,8 +977,15 @@ def handler1(header2: Annotated[List[str], Header()], header1: Annotated[str, He def handler1(header2: Annotated[List[int], Header()], header1: Annotated[str, Header()]): print(header2) - # Define handler3 without params - if handler_func == "handler3_without_header_params": + # Define handler3 with uppercase parameters + if handler_func == "handler3_with_uppercase_params": + + @app.get("/users") + def handler3(Header2: Annotated[List[str], Header()], Header1: Annotated[str, Header()]): + print(Header2) + + # Define handler4 without params + if handler_func == "handler4_without_header_params": LOAD_GW_EVENT_VPC_LATTICE["headers"] = None @app.get("/users") From 0ba69f4d07554383d5da325551dd24b1dbb061cd Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Thu, 1 Feb 2024 12:56:05 +0000 Subject: [PATCH 09/13] Revert event changes --- tests/events/apiGatewayProxyV2Event.json | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/events/apiGatewayProxyV2Event.json b/tests/events/apiGatewayProxyV2Event.json index 2c46b6de564..9de632b8e3d 100644 --- a/tests/events/apiGatewayProxyV2Event.json +++ b/tests/events/apiGatewayProxyV2Event.json @@ -8,8 +8,8 @@ "cookie2" ], "headers": { - "header1": "value1", - "header2": "value1,value2" + "Header1": "value1", + "Header2": "value1,value2" }, "queryStringParameters": { "parameter1": "value1,value2", From 5917ddfdd6ebc9c0977d8a7d4d50921f98f4ce83 Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Thu, 1 Feb 2024 13:07:23 +0000 Subject: [PATCH 10/13] Adding HTTP RFC --- docs/core/event_handler/api_gateway.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/core/event_handler/api_gateway.md b/docs/core/event_handler/api_gateway.md index 0f30369097c..08a73b5c20a 100644 --- a/docs/core/event_handler/api_gateway.md +++ b/docs/core/event_handler/api_gateway.md @@ -428,6 +428,8 @@ For example, we could validate that `` dynamic path should be no greate We use the `Annotated` type to tell Event Handler that a particular parameter is a header that needs to be validated. +!!! info "We adhere to HTTP RFC standards, which means we treat HTTP headers as case-insensitive." + In the following example, we use a new `Header` OpenAPI type to add [one out of many possible constraints](#customizing-openapi-parameters), which should read as: * `correlation_id` is a header that must be present in the request From 74dfe4baa94e82dbaae9f446ee1c428888686271 Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Thu, 1 Feb 2024 14:49:48 +0000 Subject: [PATCH 11/13] Adding getter/setter to clean the code --- .../middlewares/openapi_validation.py | 19 +++------ .../event_handler/openapi/params.py | 18 +++++++- .../events/vpcLatticeV2EventWithHeaders.json | 4 +- .../test_openapi_validation_middleware.py | 42 +++++++++++++------ 4 files changed, 54 insertions(+), 29 deletions(-) diff --git a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py index b3cada8426e..54c48189282 100644 --- a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py +++ b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py @@ -19,7 +19,7 @@ from aws_lambda_powertools.event_handler.openapi.dependant import is_scalar_field from aws_lambda_powertools.event_handler.openapi.encoders import jsonable_encoder from aws_lambda_powertools.event_handler.openapi.exceptions import RequestValidationError -from aws_lambda_powertools.event_handler.openapi.params import Param, ParamTypes +from aws_lambda_powertools.event_handler.openapi.params import Param from aws_lambda_powertools.event_handler.openapi.types import IncEx from aws_lambda_powertools.event_handler.types import EventHandlerInstance @@ -262,17 +262,9 @@ def _request_params_to_args( if not isinstance(field_info, Param): raise AssertionError(f"Expected Param field_info, got {field_info}") - field_alias = field.alias + value = received_params.get(field.alias) - if field_info.in_ == ParamTypes.header and field_alias: - # Headers are case-insensitive according to RFC 7540 (HTTP/2), so we lower the parameter name - # This ensures that customers can access headers with any casing, as per the RFC guidelines. - # Reference: https://www.rfc-editor.org/rfc/rfc7540#section-8.1.2 - field_alias = field.alias.lower() - - value = received_params.get(field_alias) - - loc = (field_info.in_.value, field_alias) + loc = (field_info.in_.value, field.alias) # If we don't have a value, see if it's required or has a default if value is None: @@ -419,12 +411,11 @@ def _normalize_multi_header_values_with_param(headers: Optional[Dict[str, str]], """ if headers: for param in filter(is_scalar_field, params): - param_name = param.name.lower() try: - if len(headers[param_name]) == 1: + if len(headers[param.alias]) == 1: # if the target parameter is a scalar and the list contains only 1 element # we keep the first value of the headers regardless if there are more in the payload - headers[param_name] = headers[param_name][0] + headers[param.alias] = headers[param.alias][0] except KeyError: pass return headers diff --git a/aws_lambda_powertools/event_handler/openapi/params.py b/aws_lambda_powertools/event_handler/openapi/params.py index ce7a1b8520e..7520c2b39a7 100644 --- a/aws_lambda_powertools/event_handler/openapi/params.py +++ b/aws_lambda_powertools/event_handler/openapi/params.py @@ -589,11 +589,13 @@ def __init__( Extra values to include in the generated OpenAPI schema """ self.convert_underscores = convert_underscores + self._alias = alias + super().__init__( default=default, default_factory=default_factory, annotation=annotation, - alias=alias, + alias=self._alias, alias_priority=alias_priority, validation_alias=validation_alias, serialization_alias=serialization_alias, @@ -619,6 +621,20 @@ def __init__( **extra, ) + @property + def alias(self): + return self._alias + + @alias.setter + def alias(self, value: Optional[str] = None): + if value is not None: + # Headers are case-insensitive according to RFC 7540 (HTTP/2), so we lower the parameter name + # This ensures that customers can access headers with any casing, as per the RFC guidelines. + # Reference: https://www.rfc-editor.org/rfc/rfc7540#section-8.1.2 + self._alias = value.lower() + else: + self._alias = None + class Body(FieldInfo): """ diff --git a/tests/events/vpcLatticeV2EventWithHeaders.json b/tests/events/vpcLatticeV2EventWithHeaders.json index 40e5f17cfe6..fdaf7dc7891 100644 --- a/tests/events/vpcLatticeV2EventWithHeaders.json +++ b/tests/events/vpcLatticeV2EventWithHeaders.json @@ -19,10 +19,10 @@ "host":[ "lattice-svc-027b423199122da5f.7d67968.vpc-lattice-svcs.us-east-1.on.aws" ], - "header1": [ + "Header1": [ "value1" ], - "header2": [ + "Header2": [ "value1", "value2" ] diff --git a/tests/functional/event_handler/test_openapi_validation_middleware.py b/tests/functional/event_handler/test_openapi_validation_middleware.py index b22137ac0d4..07e2a34ac42 100644 --- a/tests/functional/event_handler/test_openapi_validation_middleware.py +++ b/tests/functional/event_handler/test_openapi_validation_middleware.py @@ -699,8 +699,11 @@ def handler2(header2: Annotated[List[int], Header()], header1: Annotated[str, He if handler_func == "handler3_with_uppercase_params": @app.get("/users") - def handler3(Header2: Annotated[List[str], Header()], Header1: Annotated[str, Header()]): - print(Header2) + def handler3( + header2: Annotated[List[str], Header(name="Header2")], + header1: Annotated[str, Header(name="Header1")], + ): + print(header2) # Define handler4 without params if handler_func == "handler4_without_header_params": @@ -757,8 +760,11 @@ def handler2(header2: Annotated[List[int], Header()], header1: Annotated[str, He if handler_func == "handler3_with_uppercase_params": @app.get("/users") - def handler3(Header2: Annotated[List[str], Header()], Header1: Annotated[str, Header()]): - print(Header2) + def handler3( + header2: Annotated[List[str], Header(name="Header2")], + header1: Annotated[str, Header(name="Header1")], + ): + print(header2) # Define handler4 without params if handler_func == "handler4_without_header_params": @@ -812,8 +818,11 @@ def handler2(header2: Annotated[List[int], Header()], header1: Annotated[str, He if handler_func == "handler3_with_uppercase_params": @app.get("/users") - def handler3(Header2: Annotated[List[str], Header()], Header1: Annotated[str, Header()]): - print(Header2) + def handler3( + header2: Annotated[List[str], Header(name="Header2")], + header1: Annotated[str, Header(name="Header1")], + ): + print(header2) # Define handler4 without params if handler_func == "handler4_without_header_params": @@ -869,8 +878,11 @@ def handler2(header2: Annotated[List[int], Header()], header1: Annotated[str, He if handler_func == "handler3_with_uppercase_params": @app.get("/users") - def handler3(Header2: Annotated[List[str], Header()], Header1: Annotated[str, Header()]): - print(Header2) + def handler3( + header2: Annotated[List[str], Header(name="Header2")], + header1: Annotated[str, Header(name="Header1")], + ): + print(header2) # Define handler4 without params if handler_func == "handler4_without_header_params": @@ -925,8 +937,11 @@ def handler2(header2: Annotated[List[int], Header()], header1: Annotated[str, He if handler_func == "handler3_with_uppercase_params": @app.get("/users") - def handler3(Header2: Annotated[List[str], Header()], Header1: Annotated[str, Header()]): - print(Header2) + def handler3( + header2: Annotated[List[str], Header(name="Header2")], + header1: Annotated[str, Header(name="Header1")], + ): + print(header2) # Define handler4 without params if handler_func == "handler4_without_header_params": @@ -981,8 +996,11 @@ def handler1(header2: Annotated[List[int], Header()], header1: Annotated[str, He if handler_func == "handler3_with_uppercase_params": @app.get("/users") - def handler3(Header2: Annotated[List[str], Header()], Header1: Annotated[str, Header()]): - print(Header2) + def handler3( + header2: Annotated[List[str], Header(name="Header2")], + header1: Annotated[str, Header(name="Header1")], + ): + print(header2) # Define handler4 without params if handler_func == "handler4_without_header_params": From 0f60963ea6923685123c5b56b0c9def5de1180eb Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Thu, 1 Feb 2024 14:57:24 +0000 Subject: [PATCH 12/13] Adding getter/setter to clean the code --- aws_lambda_powertools/event_handler/openapi/params.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/aws_lambda_powertools/event_handler/openapi/params.py b/aws_lambda_powertools/event_handler/openapi/params.py index 7520c2b39a7..d5665a48d30 100644 --- a/aws_lambda_powertools/event_handler/openapi/params.py +++ b/aws_lambda_powertools/event_handler/openapi/params.py @@ -632,8 +632,6 @@ def alias(self, value: Optional[str] = None): # This ensures that customers can access headers with any casing, as per the RFC guidelines. # Reference: https://www.rfc-editor.org/rfc/rfc7540#section-8.1.2 self._alias = value.lower() - else: - self._alias = None class Body(FieldInfo): From 25a183de2bf31387b2acd0bb5b238aafcb75f0e0 Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Thu, 1 Feb 2024 16:30:09 +0000 Subject: [PATCH 13/13] Addressing Ruben's feedback --- docs/core/event_handler/api_gateway.md | 22 ++++----- .../src/validating_headers.py | 45 ++++++++++--------- .../src/working_with_headers_multi_value.py | 6 +-- 3 files changed, 39 insertions(+), 34 deletions(-) diff --git a/docs/core/event_handler/api_gateway.md b/docs/core/event_handler/api_gateway.md index 08a73b5c20a..32631ac867e 100644 --- a/docs/core/event_handler/api_gateway.md +++ b/docs/core/event_handler/api_gateway.md @@ -368,13 +368,13 @@ We use the `Annotated` and OpenAPI `Body` type to instruct Event Handler that ou !!! info "We will automatically validate and inject incoming query strings via type annotation." -We use the `Annotated` type to tell Event Handler that a particular parameter is not only an optional string, but also a query string with constraints. +We use the `Annotated` type to tell the Event Handler that a particular parameter is not only an optional string, but also a query string with constraints. In the following example, we use a new `Query` OpenAPI type to add [one out of many possible constraints](#customizing-openapi-parameters), which should read as: * `completed` is a query string with a `None` as its default value * `completed`, when set, should have at minimum 4 characters -* Doesn't match? Event Handler will return a validation error response +* No match? Event Handler will return a validation error response @@ -386,7 +386,7 @@ In the following example, we use a new `Query` OpenAPI type to add [one out of m 1. If you're not using Python 3.9 or higher, you can install and use [`typing_extensions`](https://pypi.org/project/typing-extensions/){target="_blank" rel="nofollow"} to the same effect 2. `Query` is a special OpenAPI type that can add constraints to a query string as well as document them - 3. **First time seeing the `Annotated`?**

This special type uses the first argument as the actual type, and subsequent arguments are metadata.

At runtime, static checkers will also see the first argument, but anyone receiving them could inspect them to fetch their metadata. + 3. **First time seeing `Annotated`?**

This special type uses the first argument as the actual type, and subsequent arguments as metadata.

At runtime, static checkers will also see the first argument, but any receiver can inspect it to get the metadata. === "skip_validating_query_strings.py" @@ -426,31 +426,31 @@ For example, we could validate that `` dynamic path should be no greate #### Validating headers -We use the `Annotated` type to tell Event Handler that a particular parameter is a header that needs to be validated. +We use the `Annotated` type to tell the Event Handler that a particular parameter is a header that needs to be validated. -!!! info "We adhere to HTTP RFC standards, which means we treat HTTP headers as case-insensitive." +!!! info "We adhere to [HTTP RFC standards](https://www.rfc-editor.org/rfc/rfc7540#section-8.1.2){target="_blank" rel="nofollow"}, which means we treat HTTP headers as case-insensitive." In the following example, we use a new `Header` OpenAPI type to add [one out of many possible constraints](#customizing-openapi-parameters), which should read as: * `correlation_id` is a header that must be present in the request -* `correlation_id`, when set, should have 16 characters -* Doesn't match? Event Handler will return a validation error response +* `correlation_id` should have 16 characters +* No match? Event Handler will return a validation error response === "validating_headers.py" - ```python hl_lines="8 10 29" + ```python hl_lines="8 10 27" --8<-- "examples/event_handler_rest/src/validating_headers.py" ``` 1. If you're not using Python 3.9 or higher, you can install and use [`typing_extensions`](https://pypi.org/project/typing-extensions/){target="_blank" rel="nofollow"} to the same effect - 2. `Header` is a special OpenAPI type that can add constraints to a header well as document them - 3. **First time seeing the `Annotated`?**

This special type uses the first argument as the actual type, and subsequent arguments are metadata.

At runtime, static checkers will also see the first argument, but anyone receiving them could inspect them to fetch their metadata. + 2. `Header` is a special OpenAPI type that can add constraints and documentation to a header + 3. **First time seeing `Annotated`?**

This special type uses the first argument as the actual type, and subsequent arguments as metadata.

At runtime, static checkers will also see the first argument, but any receiver can inspect it to get the metadata. === "working_with_headers_multi_value.py" - If you need to handle multi-value for specific headers, you can create a list of the desired type. + You can handle multi-value headers by declaring it as a list of the desired type. ```python hl_lines="23" --8<-- "examples/event_handler_rest/src/working_with_headers_multi_value.py" diff --git a/examples/event_handler_rest/src/validating_headers.py b/examples/event_handler_rest/src/validating_headers.py index 956fd58b14d..e830a49c38c 100644 --- a/examples/event_handler_rest/src/validating_headers.py +++ b/examples/event_handler_rest/src/validating_headers.py @@ -1,34 +1,39 @@ -from enum import Enum -from typing import List +from typing import List, Optional +import requests +from pydantic import BaseModel, Field + +from aws_lambda_powertools import Logger, Tracer from aws_lambda_powertools.event_handler import APIGatewayRestResolver -from aws_lambda_powertools.event_handler.openapi.params import Header -from aws_lambda_powertools.shared.types import Annotated +from aws_lambda_powertools.event_handler.openapi.params import Header # (2)! +from aws_lambda_powertools.logging import correlation_paths +from aws_lambda_powertools.shared.types import Annotated # (1)! from aws_lambda_powertools.utilities.typing import LambdaContext +tracer = Tracer() +logger = Logger() app = APIGatewayRestResolver(enable_validation=True) -class CountriesAllowed(Enum): - """Example of an Enum class.""" +class Todo(BaseModel): + userId: int + id_: Optional[int] = Field(alias="id", default=None) + title: str + completed: bool + - US = "US" - PT = "PT" - BR = "BR" +@app.get("/todos") +@tracer.capture_method +def get_todos(correlation_id: Annotated[str, Header(min_length=16, max_length=16)]) -> List[Todo]: # (3)! + url = "https://jsonplaceholder.typicode.com/todos" + todo = requests.get(url, headers={"correlation_id": correlation_id}) + todo.raise_for_status() -@app.get("/hello") -def get( - cloudfront_viewer_country: Annotated[ - List[CountriesAllowed], # (1)! - Header( - description="This is multi value header parameter.", - ), - ], -): - """Return validated multi-value header values.""" - return cloudfront_viewer_country + return todo.json() +@logger.inject_lambda_context(correlation_id_path=correlation_paths.API_GATEWAY_HTTP) +@tracer.capture_lambda_handler def lambda_handler(event: dict, context: LambdaContext) -> dict: return app.resolve(event, context) diff --git a/examples/event_handler_rest/src/working_with_headers_multi_value.py b/examples/event_handler_rest/src/working_with_headers_multi_value.py index 69f31b60762..956fd58b14d 100644 --- a/examples/event_handler_rest/src/working_with_headers_multi_value.py +++ b/examples/event_handler_rest/src/working_with_headers_multi_value.py @@ -17,9 +17,9 @@ class CountriesAllowed(Enum): BR = "BR" -@app.get("/todos") +@app.get("/hello") def get( - example_headers_multi_value: Annotated[ + cloudfront_viewer_country: Annotated[ List[CountriesAllowed], # (1)! Header( description="This is multi value header parameter.", @@ -27,7 +27,7 @@ def get( ], ): """Return validated multi-value header values.""" - return example_headers_multi_value + return cloudfront_viewer_country def lambda_handler(event: dict, context: LambdaContext) -> dict: