From a6f9e44cdba581be5ec5588fb57e8be5a42d9f70 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Mon, 19 Feb 2024 12:45:05 +0100 Subject: [PATCH 1/5] fix(event-handler): handle single parameters correctly --- .../middlewares/openapi_validation.py | 14 +- .../utilities/data_classes/alb_event.py | 8 +- .../data_classes/api_gateway_proxy_event.py | 16 +- .../data_classes/bedrock_agent_event.py | 8 +- .../utilities/data_classes/common.py | 14 +- .../utilities/data_classes/vpc_lattice.py | 25 +- .../test_openapi_validation_middleware.py | 258 +++++++++++------- 7 files changed, 209 insertions(+), 134 deletions(-) diff --git a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py index 25ac97ddf89..14b80c83c87 100644 --- a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py +++ b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py @@ -368,7 +368,10 @@ def _get_embed_body( return received_body, field_alias_omitted -def _normalize_multi_query_string_with_param(query_string: Optional[Dict[str, str]], params: Sequence[ModelField]): +def _normalize_multi_query_string_with_param( + query_string: Optional[Dict[str, List[str]]], + params: Sequence[ModelField], +): """ Extract and normalize resolved_query_string_parameters @@ -383,15 +386,18 @@ def _normalize_multi_query_string_with_param(query_string: Optional[Dict[str, st ------- A dictionary containing the processed multi_query_string_parameters. """ - if query_string: + if not query_string: + return None + else: + resolved_query_string: Dict[str, Any] = query_string for param in filter(is_scalar_field, params): try: # if the target parameter is a scalar, we keep the first value of the query string # regardless if there are more in the payload - query_string[param.alias] = query_string[param.alias][0] + resolved_query_string[param.alias] = query_string[param.alias][0] except KeyError: pass - return query_string + return resolved_query_string def _normalize_multi_header_values_with_param(headers: Optional[Dict[str, str]], params: Sequence[ModelField]): diff --git a/aws_lambda_powertools/utilities/data_classes/alb_event.py b/aws_lambda_powertools/utilities/data_classes/alb_event.py index 98f37b4f415..7c5ae092d97 100644 --- a/aws_lambda_powertools/utilities/data_classes/alb_event.py +++ b/aws_lambda_powertools/utilities/data_classes/alb_event.py @@ -36,11 +36,15 @@ def multi_value_query_string_parameters(self) -> Optional[Dict[str, List[str]]]: return self.get("multiValueQueryStringParameters") @property - def resolved_query_string_parameters(self) -> Optional[Dict[str, Any]]: + def resolved_query_string_parameters(self) -> Optional[Dict[str, List[str]]]: if self.multi_value_query_string_parameters: return self.multi_value_query_string_parameters - return self.query_string_parameters + if self.query_string_parameters: + query_string = {key: value.split(",") for key, value in self.query_string_parameters.items()} + return query_string + + return None @property def resolved_headers_field(self) -> Optional[Dict[str, Any]]: diff --git a/aws_lambda_powertools/utilities/data_classes/api_gateway_proxy_event.py b/aws_lambda_powertools/utilities/data_classes/api_gateway_proxy_event.py index c37bd22ca53..2837ff1f5fd 100644 --- a/aws_lambda_powertools/utilities/data_classes/api_gateway_proxy_event.py +++ b/aws_lambda_powertools/utilities/data_classes/api_gateway_proxy_event.py @@ -119,11 +119,15 @@ def multi_value_query_string_parameters(self) -> Optional[Dict[str, List[str]]]: return self.get("multiValueQueryStringParameters") @property - def resolved_query_string_parameters(self) -> Optional[Dict[str, Any]]: + def resolved_query_string_parameters(self) -> Optional[Dict[str, List[str]]]: if self.multi_value_query_string_parameters: return self.multi_value_query_string_parameters - return self.query_string_parameters + if self.query_string_parameters: + query_string = {key: value.split(",") for key, value in self.query_string_parameters.items()} + return query_string + + return None @property def resolved_headers_field(self) -> Optional[Dict[str, Any]]: @@ -319,14 +323,12 @@ def header_serializer(self): return HttpApiHeadersSerializer() @property - def resolved_query_string_parameters(self) -> Optional[Dict[str, Any]]: + def resolved_query_string_parameters(self) -> Optional[Dict[str, List[str]]]: if self.query_string_parameters is not None: - query_string = { - key: value.split(",") if "," in value else value for key, value in self.query_string_parameters.items() - } + query_string = {key: value.split(",") for key, value in self.query_string_parameters.items()} return query_string - return {} + return None @property def resolved_headers_field(self) -> Optional[Dict[str, Any]]: diff --git a/aws_lambda_powertools/utilities/data_classes/bedrock_agent_event.py b/aws_lambda_powertools/utilities/data_classes/bedrock_agent_event.py index 0fa97036a3e..91c4f4aee76 100644 --- a/aws_lambda_powertools/utilities/data_classes/bedrock_agent_event.py +++ b/aws_lambda_powertools/utilities/data_classes/bedrock_agent_event.py @@ -110,8 +110,12 @@ def query_string_parameters(self) -> Optional[Dict[str, str]]: return {x["name"]: x["value"] for x in self["parameters"]} if self.get("parameters") else None @property - def resolved_query_string_parameters(self) -> Optional[Dict[str, str]]: - return self.query_string_parameters + def resolved_query_string_parameters(self) -> Optional[Dict[str, List[str]]]: + if self.query_string_parameters is not None: + query_string = {key: value.split(",") for key, value in self.query_string_parameters.items()} + return query_string + + return None @property def resolved_headers_field(self) -> Optional[Dict[str, Any]]: diff --git a/aws_lambda_powertools/utilities/data_classes/common.py b/aws_lambda_powertools/utilities/data_classes/common.py index 25fb5a4c170..f2c05bc65c8 100644 --- a/aws_lambda_powertools/utilities/data_classes/common.py +++ b/aws_lambda_powertools/utilities/data_classes/common.py @@ -104,7 +104,7 @@ def query_string_parameters(self) -> Optional[Dict[str, str]]: return self.get("queryStringParameters") @property - def resolved_query_string_parameters(self) -> Optional[Dict[str, str]]: + def resolved_query_string_parameters(self) -> Optional[Dict[str, List[str]]]: """ This property determines the appropriate query string parameter to be used as a trusted source for validating OpenAPI. @@ -112,7 +112,11 @@ def resolved_query_string_parameters(self) -> Optional[Dict[str, str]]: This is necessary because different resolvers use different formats to encode multi query string parameters. """ - return self.query_string_parameters + if self.query_string_parameters is not None: + query_string = {key: value.split(",") for key, value in self.query_string_parameters.items()} + return query_string + + return None @property def resolved_headers_field(self) -> Optional[Dict[str, Any]]: @@ -186,8 +190,7 @@ def get_header_value( name: str, default_value: str, case_sensitive: Optional[bool] = False, - ) -> str: - ... + ) -> str: ... @overload def get_header_value( @@ -195,8 +198,7 @@ def get_header_value( name: str, default_value: Optional[str] = None, case_sensitive: Optional[bool] = False, - ) -> Optional[str]: - ... + ) -> Optional[str]: ... def get_header_value( self, diff --git a/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py b/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py index 15144e41d7d..b8cf9690123 100644 --- a/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py +++ b/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py @@ -1,5 +1,5 @@ from functools import cached_property -from typing import Any, Dict, Optional, overload +from typing import Any, Dict, List, Optional, overload from aws_lambda_powertools.shared.headers_serializer import ( BaseHeadersSerializer, @@ -73,8 +73,7 @@ def get_header_value( name: str, default_value: str, case_sensitive: Optional[bool] = False, - ) -> str: - ... + ) -> str: ... @overload def get_header_value( @@ -82,8 +81,7 @@ def get_header_value( name: str, default_value: Optional[str] = None, case_sensitive: Optional[bool] = False, - ) -> Optional[str]: - ... + ) -> Optional[str]: ... def get_header_value( self, @@ -141,8 +139,11 @@ def query_string_parameters(self) -> Dict[str, str]: return self["query_string_parameters"] @property - def resolved_query_string_parameters(self) -> Optional[Dict[str, str]]: - return self.query_string_parameters + def resolved_query_string_parameters(self) -> Optional[Dict[str, List[str]]]: + if self.query_string_parameters is not None: + query_string = {key: value.split(",") for key, value in self.query_string_parameters.items()} + return query_string + return None @property def resolved_headers_field(self) -> Optional[Dict[str, Any]]: @@ -264,8 +265,14 @@ def query_string_parameters(self) -> Optional[Dict[str, str]]: return self.get("queryStringParameters") @property - def resolved_query_string_parameters(self) -> Optional[Dict[str, str]]: - return self.query_string_parameters + def resolved_query_string_parameters(self) -> Optional[Dict[str, List[str]]]: + if self.query_string_parameters is not None: + query_string = { + key: value.split(",") if not isinstance(value, list) else value + for key, value in self.query_string_parameters.items() + } + return query_string + return None @property def resolved_headers_field(self) -> Optional[Dict[str, str]]: diff --git a/tests/functional/event_handler/test_openapi_validation_middleware.py b/tests/functional/event_handler/test_openapi_validation_middleware.py index be3a13dd656..6bbc1103988 100644 --- a/tests/functional/event_handler/test_openapi_validation_middleware.py +++ b/tests/functional/event_handler/test_openapi_validation_middleware.py @@ -1,4 +1,5 @@ import json +from copy import deepcopy from dataclasses import dataclass from enum import Enum from pathlib import PurePath @@ -40,17 +41,18 @@ def handler(user_id: int): print(user_id) # sending a number - LOAD_GW_EVENT["path"] = "/users/123" + event = deepcopy(LOAD_GW_EVENT) + event["path"] = "/users/123" # THEN the handler should be invoked and return 200 - result = app(LOAD_GW_EVENT, {}) + result = app(event, {}) assert result["statusCode"] == 200 # sending a string - LOAD_GW_EVENT["path"] = "/users/abc" + event["path"] = "/users/abc" # THEN the handler should be invoked and return 422 - result = app(LOAD_GW_EVENT, {}) + result = app(event, {}) assert result["statusCode"] == 422 assert any(text in result["body"] for text in ["type_error.integer", "int_parsing"]) @@ -65,17 +67,18 @@ def handler(user_id: int = 123): print(user_id) # sending a number - LOAD_GW_EVENT["path"] = "/users/123" + event = deepcopy(LOAD_GW_EVENT) + event["path"] = "/users/123" # THEN the handler should be invoked and return 200 - result = app(LOAD_GW_EVENT, {}) + result = app(event, {}) assert result["statusCode"] == 200 # sending a string - LOAD_GW_EVENT["path"] = "/users/abc" + event["path"] = "/users/abc" # THEN the handler should be invoked and return 422 - result = app(LOAD_GW_EVENT, {}) + result = app(event, {}) assert result["statusCode"] == 422 assert any(text in result["body"] for text in ["type_error.integer", "int_parsing"]) @@ -90,17 +93,18 @@ def handler(user_id: int = 123, include_extra: bool = False): print(user_id) # sending a number - LOAD_GW_EVENT["path"] = "/users/123" + event = deepcopy(LOAD_GW_EVENT) + event["path"] = "/users/123" # THEN the handler should be invoked and return 200 - result = app(LOAD_GW_EVENT, {}) + result = app(event, {}) assert result["statusCode"] == 200 # sending a string - LOAD_GW_EVENT["path"] = "/users/abc" + event["path"] = "/users/abc" # THEN the handler should be invoked and return 422 - result = app(LOAD_GW_EVENT, {}) + result = app(event, {}) assert result["statusCode"] == 422 assert any(text in result["body"] for text in ["type_error.integer", "int_parsing"]) @@ -114,11 +118,12 @@ def test_validate_return_type(): def handler() -> int: return 123 - LOAD_GW_EVENT["path"] = "/" + event = deepcopy(LOAD_GW_EVENT) + event["path"] = "/" # THEN the handler should be invoked and return 200 # THEN the body must be 123 - result = app(LOAD_GW_EVENT, {}) + result = app(event, {}) assert result["statusCode"] == 200 assert result["body"] == "123" @@ -132,11 +137,12 @@ def test_validate_return_list(): def handler() -> List[int]: return [123, 234] - LOAD_GW_EVENT["path"] = "/" + event = deepcopy(LOAD_GW_EVENT) + event["path"] = "/" # THEN the handler should be invoked and return 200 # THEN the body must be [123, 234] - result = app(LOAD_GW_EVENT, {}) + result = app(event, {}) assert result["statusCode"] == 200 assert json.loads(result["body"]) == [123, 234] @@ -152,11 +158,12 @@ def test_validate_return_tuple(): def handler() -> Tuple: return sample_tuple - LOAD_GW_EVENT["path"] = "/" + event = deepcopy(LOAD_GW_EVENT) + event["path"] = "/" # THEN the handler should be invoked and return 200 # THEN the body must be a tuple - result = app(LOAD_GW_EVENT, {}) + result = app(event, {}) assert result["statusCode"] == 200 assert json.loads(result["body"]) == [1, 2, 3] @@ -173,11 +180,12 @@ def test_validate_return_purepath(): def handler() -> str: return sample_path.as_posix() - LOAD_GW_EVENT["path"] = "/" + event = deepcopy(LOAD_GW_EVENT) + event["path"] = "/" # THEN the handler should be invoked and return 200 # THEN the body must be a string - result = app(LOAD_GW_EVENT, {}) + result = app(event, {}) assert result["statusCode"] == 200 assert result["body"] == sample_path.as_posix() @@ -194,11 +202,12 @@ class Model(Enum): def handler() -> Model: return Model.name.value - LOAD_GW_EVENT["path"] = "/" + event = deepcopy(LOAD_GW_EVENT) + event["path"] = "/" # THEN the handler should be invoked and return 200 # THEN the body must be a string - result = app(LOAD_GW_EVENT, {}) + result = app(event, {}) assert result["statusCode"] == 200 assert result["body"] == "powertools" @@ -217,11 +226,12 @@ class Model: def handler() -> Model: return Model(name="John", age=30) - LOAD_GW_EVENT["path"] = "/" + event = deepcopy(LOAD_GW_EVENT) + event["path"] = "/" # THEN the handler should be invoked and return 200 # THEN the body must be a JSON object - result = app(LOAD_GW_EVENT, {}) + result = app(event, {}) assert result["statusCode"] == 200 assert json.loads(result["body"]) == {"name": "John", "age": 30} @@ -239,11 +249,12 @@ class Model(BaseModel): def handler() -> Model: return Model(name="John", age=30) - LOAD_GW_EVENT["path"] = "/" + event = deepcopy(LOAD_GW_EVENT) + event["path"] = "/" # THEN the handler should be invoked and return 200 # THEN the body must be a JSON object - result = app(LOAD_GW_EVENT, {}) + result = app(event, {}) assert result["statusCode"] == 200 assert json.loads(result["body"]) == {"name": "John", "age": 30} @@ -261,11 +272,12 @@ class Model(BaseModel): def handler() -> Model: return {"name": "John"} # type: ignore - LOAD_GW_EVENT["path"] = "/" + event = deepcopy(LOAD_GW_EVENT) + event["path"] = "/" # THEN the handler should be invoked and return 422 # THEN the body must be a dict - result = app(LOAD_GW_EVENT, {}) + result = app(event, {}) assert result["statusCode"] == 422 assert "missing" in result["body"] @@ -283,13 +295,14 @@ class Model(BaseModel): def handler(user: Model) -> Model: return user - LOAD_GW_EVENT["httpMethod"] = "POST" - LOAD_GW_EVENT["path"] = "/" - LOAD_GW_EVENT["body"] = json.dumps({"name": "John", "age": 30}) + event = deepcopy(LOAD_GW_EVENT) + event["httpMethod"] = "POST" + event["path"] = "/" + event["body"] = json.dumps({"name": "John", "age": 30}) # THEN the handler should be invoked and return 200 # THEN the body must be a JSON object - result = app(LOAD_GW_EVENT, {}) + result = app(event, {}) assert result["statusCode"] == 200 assert json.loads(result["body"]) == {"name": "John", "age": 30} @@ -308,14 +321,15 @@ class Model(BaseModel): def handler(user: Model) -> Model: return user - LOAD_GW_EVENT["httpMethod"] = "POST" - LOAD_GW_EVENT["headers"] = {"Content-type": " application/json "} - LOAD_GW_EVENT["path"] = "/" - LOAD_GW_EVENT["body"] = json.dumps({"name": "John", "age": 30}) + event = deepcopy(LOAD_GW_EVENT) + event["httpMethod"] = "POST" + event["headers"] = {"Content-type": " application/json "} + event["path"] = "/" + event["body"] = json.dumps({"name": "John", "age": 30}) # THEN the handler should be invoked and return 200 # THEN the body must be a JSON object - result = app(LOAD_GW_EVENT, {}) + result = app(event, {}) assert result["statusCode"] == 200 assert json.loads(result["body"]) == {"name": "John", "age": 30} @@ -333,13 +347,14 @@ class Model(BaseModel): def handler(user: Model) -> Model: return user - LOAD_GW_EVENT["httpMethod"] = "POST" - LOAD_GW_EVENT["path"] = "/" - LOAD_GW_EVENT["body"] = "{" # invalid JSON + event = deepcopy(LOAD_GW_EVENT) + event["httpMethod"] = "POST" + event["path"] = "/" + event["body"] = "{" # invalid JSON # THEN the handler should be invoked and return 422 # THEN the body must have the "json_invalid" error message - result = app(LOAD_GW_EVENT, {}) + result = app(event, {}) assert result["statusCode"] == 422 assert "json_invalid" in result["body"] @@ -357,20 +372,21 @@ class Model(BaseModel): def handler(user: Annotated[Model, Body(embed=True)]) -> Model: return user - LOAD_GW_EVENT["httpMethod"] = "POST" - LOAD_GW_EVENT["path"] = "/" - LOAD_GW_EVENT["body"] = json.dumps({"name": "John", "age": 30}) + event = deepcopy(LOAD_GW_EVENT) + event["httpMethod"] = "POST" + event["path"] = "/" + event["body"] = json.dumps({"name": "John", "age": 30}) # THEN the handler should be invoked and return 422 # THEN the body must be a dict - result = app(LOAD_GW_EVENT, {}) + result = app(event, {}) assert result["statusCode"] == 422 assert "missing" in result["body"] # THEN the handler should be invoked and return 200 # THEN the body must be a dict - LOAD_GW_EVENT["body"] = json.dumps({"user": {"name": "John", "age": 30}}) - result = app(LOAD_GW_EVENT, {}) + event["body"] = json.dumps({"user": {"name": "John", "age": 30}}) + result = app(event, {}) assert result["statusCode"] == 200 @@ -387,13 +403,14 @@ class Model(BaseModel): def handler(user: Model) -> Response[Model]: return Response(body=user, status_code=200, content_type="application/json") - LOAD_GW_EVENT["httpMethod"] = "POST" - LOAD_GW_EVENT["path"] = "/" - LOAD_GW_EVENT["body"] = json.dumps({"name": "John", "age": 30}) + event = deepcopy(LOAD_GW_EVENT) + event["httpMethod"] = "POST" + event["path"] = "/" + event["body"] = json.dumps({"name": "John", "age": 30}) # THEN the handler should be invoked and return 200 # THEN the body must be a dict - result = app(LOAD_GW_EVENT, {}) + result = app(event, {}) assert result["statusCode"] == 200 assert json.loads(result["body"]) == {"name": "John", "age": 30} @@ -411,13 +428,14 @@ class Model(BaseModel): def handler(user: Model) -> Response[Model]: return Response(body=user, status_code=200) - LOAD_GW_EVENT["httpMethod"] = "POST" - LOAD_GW_EVENT["path"] = "/" - LOAD_GW_EVENT["body"] = json.dumps({}) + event = deepcopy(LOAD_GW_EVENT) + event["httpMethod"] = "POST" + event["path"] = "/" + event["body"] = json.dumps({}) # THEN the handler should be invoked and return 422 # THEN the body should have the word missing - result = app(LOAD_GW_EVENT, {}) + result = app(event, {}) assert result["statusCode"] == 422 assert "missing" in result["body"] @@ -435,8 +453,9 @@ def test_validation_query_string_with_api_rest_resolver(handler_func, expected_s # GIVEN a APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) - LOAD_GW_EVENT["httpMethod"] = "GET" - LOAD_GW_EVENT["path"] = "/users" + event = deepcopy(LOAD_GW_EVENT) + event["httpMethod"] = "GET" + event["path"] = "/users" # WHEN a handler is defined with various parameters and routes # Define handler1 with correct params @@ -455,8 +474,8 @@ def handler2(parameter1: Annotated[List[int], Query()], parameter2: str): # Define handler3 without params if handler_func == "handler3_without_query_params": - LOAD_GW_EVENT["queryStringParameters"] = None - LOAD_GW_EVENT["multiValueQueryStringParameters"] = None + event["queryStringParameters"] = None + event["multiValueQueryStringParameters"] = None @app.get("/users") def handler3(): @@ -464,7 +483,7 @@ def handler3(): # THEN the handler should be invoked with the expected result # AND the status code should match the expected_status_code - result = app(LOAD_GW_EVENT, {}) + result = app(event, {}) assert result["statusCode"] == expected_status_code # IF expected_error_text is provided, THEN check for its presence in the response body @@ -484,9 +503,10 @@ def test_validation_query_string_with_api_http_resolver(handler_func, expected_s # GIVEN a APIGatewayHttpResolver with validation enabled app = APIGatewayHttpResolver(enable_validation=True) - LOAD_GW_EVENT_HTTP["rawPath"] = "/users" - LOAD_GW_EVENT_HTTP["requestContext"]["http"]["method"] = "GET" - LOAD_GW_EVENT_HTTP["requestContext"]["http"]["path"] = "/users" + event = deepcopy(LOAD_GW_EVENT_HTTP) + event["rawPath"] = "/users" + event["requestContext"]["http"]["method"] = "GET" + event["requestContext"]["http"]["path"] = "/users" # WHEN a handler is defined with various parameters and routes # Define handler1 with correct params @@ -505,7 +525,7 @@ def handler2(parameter1: Annotated[List[int], Query()], parameter2: str): # Define handler3 without params if handler_func == "handler3_without_query_params": - LOAD_GW_EVENT_HTTP["queryStringParameters"] = None + event["queryStringParameters"] = None @app.get("/users") def handler3(): @@ -513,7 +533,7 @@ def handler3(): # THEN the handler should be invoked with the expected result # AND the status code should match the expected_status_code - result = app(LOAD_GW_EVENT_HTTP, {}) + result = app(event, {}) assert result["statusCode"] == expected_status_code # IF expected_error_text is provided, THEN check for its presence in the response body @@ -533,7 +553,8 @@ def test_validation_query_string_with_alb_resolver(handler_func, expected_status # GIVEN a ALBResolver with validation enabled app = ALBResolver(enable_validation=True) - LOAD_GW_EVENT_ALB["path"] = "/users" + event = deepcopy(LOAD_GW_EVENT_ALB) + event["path"] = "/users" # WHEN a handler is defined with various parameters and routes # Define handler1 with correct params @@ -552,7 +573,7 @@ def handler2(parameter1: Annotated[List[int], Query()], parameter2: str): # Define handler3 without params if handler_func == "handler3_without_query_params": - LOAD_GW_EVENT_HTTP["multiValueQueryStringParameters"] = None + event["multiValueQueryStringParameters"] = None @app.get("/users") def handler3(): @@ -560,7 +581,7 @@ def handler3(): # THEN the handler should be invoked with the expected result # AND the status code should match the expected_status_code - result = app(LOAD_GW_EVENT_ALB, {}) + result = app(event, {}) assert result["statusCode"] == expected_status_code # IF expected_error_text is provided, THEN check for its presence in the response body @@ -580,9 +601,10 @@ def test_validation_query_string_with_lambda_url_resolver(handler_func, expected # GIVEN a LambdaFunctionUrlResolver with validation enabled app = LambdaFunctionUrlResolver(enable_validation=True) - LOAD_GW_EVENT_LAMBDA_URL["rawPath"] = "/users" - LOAD_GW_EVENT_LAMBDA_URL["requestContext"]["http"]["method"] = "GET" - LOAD_GW_EVENT_LAMBDA_URL["requestContext"]["http"]["path"] = "/users" + event = deepcopy(LOAD_GW_EVENT_LAMBDA_URL) + event["rawPath"] = "/users" + event["requestContext"]["http"]["method"] = "GET" + event["requestContext"]["http"]["path"] = "/users" # WHEN a handler is defined with various parameters and routes # Define handler1 with correct params @@ -601,7 +623,7 @@ def handler2(parameter1: Annotated[List[int], Query()], parameter2: str): # Define handler3 without params if handler_func == "handler3_without_query_params": - LOAD_GW_EVENT_LAMBDA_URL["queryStringParameters"] = None + event["queryStringParameters"] = None @app.get("/users") def handler3(): @@ -609,7 +631,7 @@ def handler3(): # THEN the handler should be invoked with the expected result # AND the status code should match the expected_status_code - result = app(LOAD_GW_EVENT_LAMBDA_URL, {}) + result = app(event, {}) assert result["statusCode"] == expected_status_code # IF expected_error_text is provided, THEN check for its presence in the response body @@ -629,7 +651,8 @@ def test_validation_query_string_with_vpc_lattice_resolver(handler_func, expecte # GIVEN a VPCLatticeV2Resolver with validation enabled app = VPCLatticeV2Resolver(enable_validation=True) - LOAD_GW_EVENT_VPC_LATTICE["path"] = "/users" + event = deepcopy(LOAD_GW_EVENT_VPC_LATTICE) + event["path"] = "/users" # WHEN a handler is defined with various parameters and routes @@ -649,7 +672,7 @@ def handler2(parameter1: Annotated[List[int], Query()], parameter2: str): # Define handler3 without params if handler_func == "handler3_without_query_params": - LOAD_GW_EVENT_VPC_LATTICE["queryStringParameters"] = None + event["queryStringParameters"] = None @app.get("/users") def handler3(): @@ -657,7 +680,7 @@ def handler3(): # THEN the handler should be invoked with the expected result # AND the status code should match the expected_status_code - result = app(LOAD_GW_EVENT_VPC_LATTICE, {}) + result = app(event, {}) assert result["statusCode"] == expected_status_code # IF expected_error_text is provided, THEN check for its presence in the response body @@ -679,8 +702,9 @@ def test_validation_header_with_api_rest_resolver(handler_func, expected_status_ # GIVEN a APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) - LOAD_GW_EVENT["httpMethod"] = "GET" - LOAD_GW_EVENT["path"] = "/users" + event = deepcopy(LOAD_GW_EVENT) + event["httpMethod"] = "GET" + event["path"] = "/users" # WHEN a handler is defined with various parameters and routes # Define handler1 with correct params @@ -709,8 +733,8 @@ def handler3( # Define handler4 without params if handler_func == "handler4_without_header_params": - LOAD_GW_EVENT["headers"] = None - LOAD_GW_EVENT["multiValueHeaders"] = None + event["headers"] = None + event["multiValueHeaders"] = None @app.get("/users") def handler4(): @@ -718,7 +742,7 @@ def handler4(): # THEN the handler should be invoked with the expected result # AND the status code should match the expected_status_code - result = app(LOAD_GW_EVENT, {}) + result = app(event, {}) assert result["statusCode"] == expected_status_code # IF expected_error_text is provided, THEN check for its presence in the response body @@ -739,9 +763,10 @@ def test_validation_header_with_http_rest_resolver(handler_func, expected_status # GIVEN a APIGatewayHttpResolver with validation enabled app = APIGatewayHttpResolver(enable_validation=True) - LOAD_GW_EVENT_HTTP["rawPath"] = "/users" - LOAD_GW_EVENT_HTTP["requestContext"]["http"]["method"] = "GET" - LOAD_GW_EVENT_HTTP["requestContext"]["http"]["path"] = "/users" + event = deepcopy(LOAD_GW_EVENT_HTTP) + event["rawPath"] = "/users" + event["requestContext"]["http"]["method"] = "GET" + event["requestContext"]["http"]["path"] = "/users" # WHEN a handler is defined with various parameters and routes # Define handler1 with correct params @@ -770,7 +795,7 @@ def handler3( # Define handler4 without params if handler_func == "handler4_without_header_params": - LOAD_GW_EVENT_HTTP["headers"] = None + event["headers"] = None @app.get("/users") def handler4(): @@ -778,7 +803,7 @@ def handler4(): # THEN the handler should be invoked with the expected result # AND the status code should match the expected_status_code - result = app(LOAD_GW_EVENT_HTTP, {}) + result = app(event, {}) assert result["statusCode"] == expected_status_code # IF expected_error_text is provided, THEN check for its presence in the response body @@ -799,7 +824,8 @@ def test_validation_header_with_alb_resolver(handler_func, expected_status_code, # GIVEN a ALBResolver with validation enabled app = ALBResolver(enable_validation=True) - LOAD_GW_EVENT_ALB["path"] = "/users" + event = deepcopy(LOAD_GW_EVENT_ALB) + event["path"] = "/users" # WHEN a handler is defined with various parameters and routes # Define handler1 with correct params @@ -828,7 +854,7 @@ def handler3( # Define handler4 without params if handler_func == "handler4_without_header_params": - LOAD_GW_EVENT_ALB["multiValueHeaders"] = None + event["multiValueHeaders"] = None @app.get("/users") def handler4(): @@ -836,7 +862,7 @@ def handler4(): # THEN the handler should be invoked with the expected result # AND the status code should match the expected_status_code - result = app(LOAD_GW_EVENT_ALB, {}) + result = app(event, {}) assert result["statusCode"] == expected_status_code # IF expected_error_text is provided, THEN check for its presence in the response body @@ -857,9 +883,10 @@ def test_validation_header_with_lambda_url_resolver(handler_func, expected_statu # GIVEN a LambdaFunctionUrlResolver with validation enabled app = LambdaFunctionUrlResolver(enable_validation=True) - LOAD_GW_EVENT_LAMBDA_URL["rawPath"] = "/users" - LOAD_GW_EVENT_LAMBDA_URL["requestContext"]["http"]["method"] = "GET" - LOAD_GW_EVENT_LAMBDA_URL["requestContext"]["http"]["path"] = "/users" + event = deepcopy(LOAD_GW_EVENT_LAMBDA_URL) + event["rawPath"] = "/users" + event["requestContext"]["http"]["method"] = "GET" + event["requestContext"]["http"]["path"] = "/users" # WHEN a handler is defined with various parameters and routes # Define handler1 with correct params @@ -888,7 +915,7 @@ def handler3( # Define handler4 without params if handler_func == "handler4_without_header_params": - LOAD_GW_EVENT_LAMBDA_URL["headers"] = None + event["headers"] = None @app.get("/users") def handler4(): @@ -896,7 +923,7 @@ def handler4(): # THEN the handler should be invoked with the expected result # AND the status code should match the expected_status_code - result = app(LOAD_GW_EVENT_LAMBDA_URL, {}) + result = app(event, {}) assert result["statusCode"] == expected_status_code # IF expected_error_text is provided, THEN check for its presence in the response body @@ -917,8 +944,9 @@ def test_validation_header_with_vpc_lattice_v1_resolver(handler_func, expected_s # GIVEN a VPCLatticeResolver with validation enabled app = VPCLatticeResolver(enable_validation=True) - LOAD_GW_EVENT_VPC_LATTICE_V1["raw_path"] = "/users" - LOAD_GW_EVENT_VPC_LATTICE_V1["method"] = "GET" + event = deepcopy(LOAD_GW_EVENT_VPC_LATTICE_V1) + event["raw_path"] = "/users" + event["method"] = "GET" # WHEN a handler is defined with various parameters and routes # Define handler1 with correct params @@ -947,7 +975,7 @@ def handler3( # Define handler4 without params if handler_func == "handler4_without_header_params": - LOAD_GW_EVENT_VPC_LATTICE_V1["headers"] = None + event["headers"] = None @app.get("/users") def handler4(): @@ -955,7 +983,7 @@ def handler4(): # THEN the handler should be invoked with the expected result # AND the status code should match the expected_status_code - result = app(LOAD_GW_EVENT_VPC_LATTICE_V1, {}) + result = app(event, {}) assert result["statusCode"] == expected_status_code # IF expected_error_text is provided, THEN check for its presence in the response body @@ -976,8 +1004,9 @@ def test_validation_header_with_vpc_lattice_v2_resolver(handler_func, expected_s # GIVEN a VPCLatticeV2Resolver with validation enabled app = VPCLatticeV2Resolver(enable_validation=True) - LOAD_GW_EVENT_VPC_LATTICE["path"] = "/users" - LOAD_GW_EVENT_VPC_LATTICE["method"] = "GET" + event = deepcopy(LOAD_GW_EVENT_VPC_LATTICE) + event["path"] = "/users" + event["method"] = "GET" # WHEN a handler is defined with various parameters and routes # Define handler1 with correct params @@ -1006,7 +1035,7 @@ def handler3( # Define handler4 without params if handler_func == "handler4_without_header_params": - LOAD_GW_EVENT_VPC_LATTICE["headers"] = None + event["headers"] = None @app.get("/users") def handler3(): @@ -1014,7 +1043,7 @@ def handler3(): # THEN the handler should be invoked with the expected result # AND the status code should match the expected_status_code - result = app(LOAD_GW_EVENT_VPC_LATTICE, {}) + result = app(event, {}) assert result["statusCode"] == expected_status_code # IF expected_error_text is provided, THEN check for its presence in the response body @@ -1023,9 +1052,9 @@ def handler3(): def test_validation_with_alias(): - # GIVEN a Http API V2 proxy type event + # GIVEN a REST API V2 proxy type event app = APIGatewayRestResolver(enable_validation=True) - event = load_event("apiGatewayProxyEvent.json") + event = deepcopy(LOAD_GW_EVENT) class FunkyTown(BaseModel): parameter: str @@ -1040,3 +1069,24 @@ def my_path( result = app(event, {}) assert result["statusCode"] == 200 + + +def test_validation_with_http_single_param(): + # GIVEN a HTTP API V2 proxy type event + app = APIGatewayHttpResolver(enable_validation=True) + event = deepcopy(LOAD_GW_EVENT_HTTP) + + class FunkyTown(BaseModel): + parameter: str + + # WHEN a handler is defined with a single parameter + @app.post("/my/path") + def my_path( + parameter2: str, + ) -> Response[FunkyTown]: + assert parameter2 == "value" + return Response(200, content_types.APPLICATION_JSON, FunkyTown(parameter=parameter2)) + + # THEN the handler should be invoked and return 200 + result = app(event, {}) + assert result["statusCode"] == 200 From 1a3a73b5285b1b515d5f3f9cd68a6d0b95c44ae3 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Mon, 19 Feb 2024 12:50:33 +0100 Subject: [PATCH 2/5] fix: improve typing --- .../event_handler/middlewares/openapi_validation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py index 14b80c83c87..944b3bb14b2 100644 --- a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py +++ b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py @@ -371,7 +371,7 @@ def _get_embed_body( def _normalize_multi_query_string_with_param( query_string: Optional[Dict[str, List[str]]], params: Sequence[ModelField], -): +) -> Optional[Dict[str, Any]]: """ Extract and normalize resolved_query_string_parameters From c2815252cc1e6f8f6763460084caca9f48bd721a Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Mon, 19 Feb 2024 13:43:44 +0100 Subject: [PATCH 3/5] fix: mypy --- .../event_handler/middlewares/openapi_validation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py index 944b3bb14b2..f1975454a92 100644 --- a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py +++ b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py @@ -371,7 +371,7 @@ def _get_embed_body( def _normalize_multi_query_string_with_param( query_string: Optional[Dict[str, List[str]]], params: Sequence[ModelField], -) -> Optional[Dict[str, Any]]: +) -> Dict[str, Any]: """ Extract and normalize resolved_query_string_parameters @@ -387,7 +387,7 @@ def _normalize_multi_query_string_with_param( A dictionary containing the processed multi_query_string_parameters. """ if not query_string: - return None + return {} else: resolved_query_string: Dict[str, Any] = query_string for param in filter(is_scalar_field, params): From fe3252241b4f33f8d3b442ce4f12194bd9e262ee Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Mon, 19 Feb 2024 15:46:08 +0100 Subject: [PATCH 4/5] chore: refactored multi query code --- .../middlewares/openapi_validation.py | 23 ++++++-------- .../utilities/data_classes/alb_event.py | 8 ++--- .../data_classes/api_gateway_proxy_event.py | 16 ++-------- .../data_classes/bedrock_agent_event.py | 8 ----- .../utilities/data_classes/common.py | 4 +-- .../utilities/data_classes/vpc_lattice.py | 31 +++++++------------ 6 files changed, 27 insertions(+), 63 deletions(-) diff --git a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py index f1975454a92..241a9972953 100644 --- a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py +++ b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py @@ -369,7 +369,7 @@ def _get_embed_body( def _normalize_multi_query_string_with_param( - query_string: Optional[Dict[str, List[str]]], + query_string: Dict[str, List[str]], params: Sequence[ModelField], ) -> Dict[str, Any]: """ @@ -386,18 +386,15 @@ def _normalize_multi_query_string_with_param( ------- A dictionary containing the processed multi_query_string_parameters. """ - if not query_string: - return {} - else: - resolved_query_string: Dict[str, Any] = query_string - for param in filter(is_scalar_field, params): - try: - # if the target parameter is a scalar, we keep the first value of the query string - # regardless if there are more in the payload - resolved_query_string[param.alias] = query_string[param.alias][0] - except KeyError: - pass - return resolved_query_string + resolved_query_string: Dict[str, Any] = query_string + for param in filter(is_scalar_field, params): + try: + # if the target parameter is a scalar, we keep the first value of the query string + # regardless if there are more in the payload + resolved_query_string[param.alias] = query_string[param.alias][0] + except KeyError: + pass + return resolved_query_string def _normalize_multi_header_values_with_param(headers: Optional[Dict[str, str]], params: Sequence[ModelField]): diff --git a/aws_lambda_powertools/utilities/data_classes/alb_event.py b/aws_lambda_powertools/utilities/data_classes/alb_event.py index 7c5ae092d97..1ec2535850b 100644 --- a/aws_lambda_powertools/utilities/data_classes/alb_event.py +++ b/aws_lambda_powertools/utilities/data_classes/alb_event.py @@ -36,15 +36,11 @@ def multi_value_query_string_parameters(self) -> Optional[Dict[str, List[str]]]: return self.get("multiValueQueryStringParameters") @property - def resolved_query_string_parameters(self) -> Optional[Dict[str, List[str]]]: + def resolved_query_string_parameters(self) -> Dict[str, List[str]]: if self.multi_value_query_string_parameters: return self.multi_value_query_string_parameters - if self.query_string_parameters: - query_string = {key: value.split(",") for key, value in self.query_string_parameters.items()} - return query_string - - return None + return super().resolved_query_string_parameters @property def resolved_headers_field(self) -> Optional[Dict[str, Any]]: diff --git a/aws_lambda_powertools/utilities/data_classes/api_gateway_proxy_event.py b/aws_lambda_powertools/utilities/data_classes/api_gateway_proxy_event.py index 2837ff1f5fd..ff24e908d1a 100644 --- a/aws_lambda_powertools/utilities/data_classes/api_gateway_proxy_event.py +++ b/aws_lambda_powertools/utilities/data_classes/api_gateway_proxy_event.py @@ -119,15 +119,11 @@ def multi_value_query_string_parameters(self) -> Optional[Dict[str, List[str]]]: return self.get("multiValueQueryStringParameters") @property - def resolved_query_string_parameters(self) -> Optional[Dict[str, List[str]]]: + def resolved_query_string_parameters(self) -> Dict[str, List[str]]: if self.multi_value_query_string_parameters: return self.multi_value_query_string_parameters - if self.query_string_parameters: - query_string = {key: value.split(",") for key, value in self.query_string_parameters.items()} - return query_string - - return None + return super().resolved_query_string_parameters @property def resolved_headers_field(self) -> Optional[Dict[str, Any]]: @@ -322,14 +318,6 @@ def http_method(self) -> str: def header_serializer(self): return HttpApiHeadersSerializer() - @property - def resolved_query_string_parameters(self) -> Optional[Dict[str, List[str]]]: - if self.query_string_parameters is not None: - query_string = {key: value.split(",") for key, value in self.query_string_parameters.items()} - return query_string - - return None - @property def resolved_headers_field(self) -> Optional[Dict[str, Any]]: if self.headers is not None: diff --git a/aws_lambda_powertools/utilities/data_classes/bedrock_agent_event.py b/aws_lambda_powertools/utilities/data_classes/bedrock_agent_event.py index 91c4f4aee76..399c435b3ec 100644 --- a/aws_lambda_powertools/utilities/data_classes/bedrock_agent_event.py +++ b/aws_lambda_powertools/utilities/data_classes/bedrock_agent_event.py @@ -109,14 +109,6 @@ def query_string_parameters(self) -> Optional[Dict[str, str]]: # together with the other parameters. So we just return all parameters here. return {x["name"]: x["value"] for x in self["parameters"]} if self.get("parameters") else None - @property - def resolved_query_string_parameters(self) -> Optional[Dict[str, List[str]]]: - if self.query_string_parameters is not None: - query_string = {key: value.split(",") for key, value in self.query_string_parameters.items()} - return query_string - - return None - @property def resolved_headers_field(self) -> Optional[Dict[str, Any]]: return {} diff --git a/aws_lambda_powertools/utilities/data_classes/common.py b/aws_lambda_powertools/utilities/data_classes/common.py index f2c05bc65c8..067706140fd 100644 --- a/aws_lambda_powertools/utilities/data_classes/common.py +++ b/aws_lambda_powertools/utilities/data_classes/common.py @@ -104,7 +104,7 @@ def query_string_parameters(self) -> Optional[Dict[str, str]]: return self.get("queryStringParameters") @property - def resolved_query_string_parameters(self) -> Optional[Dict[str, List[str]]]: + def resolved_query_string_parameters(self) -> Dict[str, List[str]]: """ This property determines the appropriate query string parameter to be used as a trusted source for validating OpenAPI. @@ -116,7 +116,7 @@ def resolved_query_string_parameters(self) -> Optional[Dict[str, List[str]]]: query_string = {key: value.split(",") for key, value in self.query_string_parameters.items()} return query_string - return None + return {} @property def resolved_headers_field(self) -> Optional[Dict[str, Any]]: diff --git a/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py b/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py index b8cf9690123..f997d4b3f04 100644 --- a/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py +++ b/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py @@ -1,5 +1,5 @@ from functools import cached_property -from typing import Any, Dict, List, Optional, overload +from typing import Any, Dict, Optional, overload from aws_lambda_powertools.shared.headers_serializer import ( BaseHeadersSerializer, @@ -138,13 +138,6 @@ def query_string_parameters(self) -> Dict[str, str]: """The request query string parameters.""" return self["query_string_parameters"] - @property - def resolved_query_string_parameters(self) -> Optional[Dict[str, List[str]]]: - if self.query_string_parameters is not None: - query_string = {key: value.split(",") for key, value in self.query_string_parameters.items()} - return query_string - return None - @property def resolved_headers_field(self) -> Optional[Dict[str, Any]]: if self.headers is not None: @@ -256,23 +249,21 @@ def path(self) -> str: @property def request_context(self) -> vpcLatticeEventV2RequestContext: - """he VPC Lattice v2 Event request context.""" + """The VPC Lattice v2 Event request context.""" return vpcLatticeEventV2RequestContext(self["requestContext"]) @property def query_string_parameters(self) -> Optional[Dict[str, str]]: - """The request query string parameters.""" - return self.get("queryStringParameters") + """The request query string parameters. - @property - def resolved_query_string_parameters(self) -> Optional[Dict[str, List[str]]]: - if self.query_string_parameters is not None: - query_string = { - key: value.split(",") if not isinstance(value, list) else value - for key, value in self.query_string_parameters.items() - } - return query_string - return None + For VPC Lattice V2, the queryStringParameters will contain a Dict[str, List[str]] + so to keep compatibility with existing utilities, we merge all the values with a comma. + """ + params = self.get("queryStringParameters") + if params: + return {key: ",".join(value) for key, value in params.items()} + else: + return None @property def resolved_headers_field(self) -> Optional[Dict[str, str]]: From 6303bdc24f94d7762ec5e268b4d0eed12a191646 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Mon, 19 Feb 2024 16:07:37 +0100 Subject: [PATCH 5/5] chore: refactored load event --- tests/functional/event_handler/conftest.py | 32 ++ .../test_openapi_validation_middleware.py | 388 +++++++++--------- 2 files changed, 235 insertions(+), 185 deletions(-) diff --git a/tests/functional/event_handler/conftest.py b/tests/functional/event_handler/conftest.py index c7a4ac6e500..5c2bdb7729a 100644 --- a/tests/functional/event_handler/conftest.py +++ b/tests/functional/event_handler/conftest.py @@ -2,6 +2,8 @@ import pytest +from tests.functional.utils import load_event + @pytest.fixture def json_dump(): @@ -39,3 +41,33 @@ def validation_schema(): @pytest.fixture def raw_event(): return {"message": "hello hello", "username": "blah blah"} + + +@pytest.fixture +def gw_event(): + return load_event("apiGatewayProxyEvent.json") + + +@pytest.fixture +def gw_event_http(): + return load_event("apiGatewayProxyV2Event.json") + + +@pytest.fixture +def gw_event_alb(): + return load_event("albMultiValueQueryStringEvent.json") + + +@pytest.fixture +def gw_event_lambda_url(): + return load_event("lambdaFunctionUrlEventWithHeaders.json") + + +@pytest.fixture +def gw_event_vpc_lattice(): + return load_event("vpcLatticeV2EventWithHeaders.json") + + +@pytest.fixture +def gw_event_vpc_lattice_v1(): + return load_event("vpcLatticeEvent.json") diff --git a/tests/functional/event_handler/test_openapi_validation_middleware.py b/tests/functional/event_handler/test_openapi_validation_middleware.py index 6bbc1103988..a9396644b98 100644 --- a/tests/functional/event_handler/test_openapi_validation_middleware.py +++ b/tests/functional/event_handler/test_openapi_validation_middleware.py @@ -1,5 +1,4 @@ import json -from copy import deepcopy from dataclasses import dataclass from enum import Enum from pathlib import PurePath @@ -16,22 +15,12 @@ Response, VPCLatticeResolver, VPCLatticeV2Resolver, - content_types, ) from aws_lambda_powertools.event_handler.openapi.params import Body, Header, Query from aws_lambda_powertools.shared.types import Annotated -from aws_lambda_powertools.utilities.data_classes import APIGatewayProxyEvent -from tests.functional.utils import load_event -LOAD_GW_EVENT = load_event("apiGatewayProxyEvent.json") -LOAD_GW_EVENT_HTTP = load_event("apiGatewayProxyV2Event.json") -LOAD_GW_EVENT_ALB = load_event("albMultiValueQueryStringEvent.json") -LOAD_GW_EVENT_LAMBDA_URL = load_event("lambdaFunctionUrlEventWithHeaders.json") -LOAD_GW_EVENT_VPC_LATTICE = load_event("vpcLatticeV2EventWithHeaders.json") -LOAD_GW_EVENT_VPC_LATTICE_V1 = load_event("vpcLatticeEvent.json") - -def test_validate_scalars(): +def test_validate_scalars(gw_event): # GIVEN an APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) @@ -41,23 +30,22 @@ def handler(user_id: int): print(user_id) # sending a number - event = deepcopy(LOAD_GW_EVENT) - event["path"] = "/users/123" + gw_event["path"] = "/users/123" # THEN the handler should be invoked and return 200 - result = app(event, {}) + result = app(gw_event, {}) assert result["statusCode"] == 200 # sending a string - event["path"] = "/users/abc" + gw_event["path"] = "/users/abc" # THEN the handler should be invoked and return 422 - result = app(event, {}) + result = app(gw_event, {}) assert result["statusCode"] == 422 assert any(text in result["body"] for text in ["type_error.integer", "int_parsing"]) -def test_validate_scalars_with_default(): +def test_validate_scalars_with_default(gw_event): # GIVEN an APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) @@ -67,23 +55,22 @@ def handler(user_id: int = 123): print(user_id) # sending a number - event = deepcopy(LOAD_GW_EVENT) - event["path"] = "/users/123" + gw_event["path"] = "/users/123" # THEN the handler should be invoked and return 200 - result = app(event, {}) + result = app(gw_event, {}) assert result["statusCode"] == 200 # sending a string - event["path"] = "/users/abc" + gw_event["path"] = "/users/abc" # THEN the handler should be invoked and return 422 - result = app(event, {}) + result = app(gw_event, {}) assert result["statusCode"] == 422 assert any(text in result["body"] for text in ["type_error.integer", "int_parsing"]) -def test_validate_scalars_with_default_and_optional(): +def test_validate_scalars_with_default_and_optional(gw_event): # GIVEN an APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) @@ -93,23 +80,22 @@ def handler(user_id: int = 123, include_extra: bool = False): print(user_id) # sending a number - event = deepcopy(LOAD_GW_EVENT) - event["path"] = "/users/123" + gw_event["path"] = "/users/123" # THEN the handler should be invoked and return 200 - result = app(event, {}) + result = app(gw_event, {}) assert result["statusCode"] == 200 # sending a string - event["path"] = "/users/abc" + gw_event["path"] = "/users/abc" # THEN the handler should be invoked and return 422 - result = app(event, {}) + result = app(gw_event, {}) assert result["statusCode"] == 422 assert any(text in result["body"] for text in ["type_error.integer", "int_parsing"]) -def test_validate_return_type(): +def test_validate_return_type(gw_event): # GIVEN an APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) @@ -118,17 +104,16 @@ def test_validate_return_type(): def handler() -> int: return 123 - event = deepcopy(LOAD_GW_EVENT) - event["path"] = "/" + gw_event["path"] = "/" # THEN the handler should be invoked and return 200 # THEN the body must be 123 - result = app(event, {}) + result = app(gw_event, {}) assert result["statusCode"] == 200 assert result["body"] == "123" -def test_validate_return_list(): +def test_validate_return_list(gw_event): # GIVEN an APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) @@ -137,17 +122,16 @@ def test_validate_return_list(): def handler() -> List[int]: return [123, 234] - event = deepcopy(LOAD_GW_EVENT) - event["path"] = "/" + gw_event["path"] = "/" # THEN the handler should be invoked and return 200 # THEN the body must be [123, 234] - result = app(event, {}) + result = app(gw_event, {}) assert result["statusCode"] == 200 assert json.loads(result["body"]) == [123, 234] -def test_validate_return_tuple(): +def test_validate_return_tuple(gw_event): # GIVEN an APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) @@ -158,17 +142,16 @@ def test_validate_return_tuple(): def handler() -> Tuple: return sample_tuple - event = deepcopy(LOAD_GW_EVENT) - event["path"] = "/" + gw_event["path"] = "/" # THEN the handler should be invoked and return 200 # THEN the body must be a tuple - result = app(event, {}) + result = app(gw_event, {}) assert result["statusCode"] == 200 assert json.loads(result["body"]) == [1, 2, 3] -def test_validate_return_purepath(): +def test_validate_return_purepath(gw_event): # GIVEN an APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) @@ -180,17 +163,16 @@ def test_validate_return_purepath(): def handler() -> str: return sample_path.as_posix() - event = deepcopy(LOAD_GW_EVENT) - event["path"] = "/" + gw_event["path"] = "/" # THEN the handler should be invoked and return 200 # THEN the body must be a string - result = app(event, {}) + result = app(gw_event, {}) assert result["statusCode"] == 200 assert result["body"] == sample_path.as_posix() -def test_validate_return_enum(): +def test_validate_return_enum(gw_event): # GIVEN an APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) @@ -202,17 +184,16 @@ class Model(Enum): def handler() -> Model: return Model.name.value - event = deepcopy(LOAD_GW_EVENT) - event["path"] = "/" + gw_event["path"] = "/" # THEN the handler should be invoked and return 200 # THEN the body must be a string - result = app(event, {}) + result = app(gw_event, {}) assert result["statusCode"] == 200 assert result["body"] == "powertools" -def test_validate_return_dataclass(): +def test_validate_return_dataclass(gw_event): # GIVEN an APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) @@ -226,17 +207,16 @@ class Model: def handler() -> Model: return Model(name="John", age=30) - event = deepcopy(LOAD_GW_EVENT) - event["path"] = "/" + gw_event["path"] = "/" # THEN the handler should be invoked and return 200 # THEN the body must be a JSON object - result = app(event, {}) + result = app(gw_event, {}) assert result["statusCode"] == 200 assert json.loads(result["body"]) == {"name": "John", "age": 30} -def test_validate_return_model(): +def test_validate_return_model(gw_event): # GIVEN an APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) @@ -249,17 +229,16 @@ class Model(BaseModel): def handler() -> Model: return Model(name="John", age=30) - event = deepcopy(LOAD_GW_EVENT) - event["path"] = "/" + gw_event["path"] = "/" # THEN the handler should be invoked and return 200 # THEN the body must be a JSON object - result = app(event, {}) + result = app(gw_event, {}) assert result["statusCode"] == 200 assert json.loads(result["body"]) == {"name": "John", "age": 30} -def test_validate_invalid_return_model(): +def test_validate_invalid_return_model(gw_event): # GIVEN an APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) @@ -272,17 +251,16 @@ class Model(BaseModel): def handler() -> Model: return {"name": "John"} # type: ignore - event = deepcopy(LOAD_GW_EVENT) - event["path"] = "/" + gw_event["path"] = "/" # THEN the handler should be invoked and return 422 # THEN the body must be a dict - result = app(event, {}) + result = app(gw_event, {}) assert result["statusCode"] == 422 assert "missing" in result["body"] -def test_validate_body_param(): +def test_validate_body_param(gw_event): # GIVEN an APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) @@ -295,19 +273,18 @@ class Model(BaseModel): def handler(user: Model) -> Model: return user - event = deepcopy(LOAD_GW_EVENT) - event["httpMethod"] = "POST" - event["path"] = "/" - event["body"] = json.dumps({"name": "John", "age": 30}) + gw_event["httpMethod"] = "POST" + gw_event["path"] = "/" + gw_event["body"] = json.dumps({"name": "John", "age": 30}) # THEN the handler should be invoked and return 200 # THEN the body must be a JSON object - result = app(event, {}) + result = app(gw_event, {}) assert result["statusCode"] == 200 assert json.loads(result["body"]) == {"name": "John", "age": 30} -def test_validate_body_param_with_stripped_headers(): +def test_validate_body_param_with_stripped_headers(gw_event): # GIVEN an APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) @@ -321,20 +298,19 @@ class Model(BaseModel): def handler(user: Model) -> Model: return user - event = deepcopy(LOAD_GW_EVENT) - event["httpMethod"] = "POST" - event["headers"] = {"Content-type": " application/json "} - event["path"] = "/" - event["body"] = json.dumps({"name": "John", "age": 30}) + gw_event["httpMethod"] = "POST" + gw_event["headers"] = {"Content-type": " application/json "} + gw_event["path"] = "/" + gw_event["body"] = json.dumps({"name": "John", "age": 30}) # THEN the handler should be invoked and return 200 # THEN the body must be a JSON object - result = app(event, {}) + result = app(gw_event, {}) assert result["statusCode"] == 200 assert json.loads(result["body"]) == {"name": "John", "age": 30} -def test_validate_body_param_with_invalid_date(): +def test_validate_body_param_with_invalid_date(gw_event): # GIVEN an APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) @@ -347,19 +323,18 @@ class Model(BaseModel): def handler(user: Model) -> Model: return user - event = deepcopy(LOAD_GW_EVENT) - event["httpMethod"] = "POST" - event["path"] = "/" - event["body"] = "{" # invalid JSON + gw_event["httpMethod"] = "POST" + gw_event["path"] = "/" + gw_event["body"] = "{" # invalid JSON # THEN the handler should be invoked and return 422 # THEN the body must have the "json_invalid" error message - result = app(event, {}) + result = app(gw_event, {}) assert result["statusCode"] == 422 assert "json_invalid" in result["body"] -def test_validate_embed_body_param(): +def test_validate_embed_body_param(gw_event): # GIVEN an APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) @@ -372,25 +347,24 @@ class Model(BaseModel): def handler(user: Annotated[Model, Body(embed=True)]) -> Model: return user - event = deepcopy(LOAD_GW_EVENT) - event["httpMethod"] = "POST" - event["path"] = "/" - event["body"] = json.dumps({"name": "John", "age": 30}) + gw_event["httpMethod"] = "POST" + gw_event["path"] = "/" + gw_event["body"] = json.dumps({"name": "John", "age": 30}) # THEN the handler should be invoked and return 422 # THEN the body must be a dict - result = app(event, {}) + result = app(gw_event, {}) assert result["statusCode"] == 422 assert "missing" in result["body"] # THEN the handler should be invoked and return 200 # THEN the body must be a dict - event["body"] = json.dumps({"user": {"name": "John", "age": 30}}) - result = app(event, {}) + gw_event["body"] = json.dumps({"user": {"name": "John", "age": 30}}) + result = app(gw_event, {}) assert result["statusCode"] == 200 -def test_validate_response_return(): +def test_validate_response_return(gw_event): # GIVEN an APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) @@ -403,19 +377,18 @@ class Model(BaseModel): def handler(user: Model) -> Response[Model]: return Response(body=user, status_code=200, content_type="application/json") - event = deepcopy(LOAD_GW_EVENT) - event["httpMethod"] = "POST" - event["path"] = "/" - event["body"] = json.dumps({"name": "John", "age": 30}) + gw_event["httpMethod"] = "POST" + gw_event["path"] = "/" + gw_event["body"] = json.dumps({"name": "John", "age": 30}) # THEN the handler should be invoked and return 200 # THEN the body must be a dict - result = app(event, {}) + result = app(gw_event, {}) assert result["statusCode"] == 200 assert json.loads(result["body"]) == {"name": "John", "age": 30} -def test_validate_response_invalid_return(): +def test_validate_response_invalid_return(gw_event): # GIVEN an APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) @@ -428,14 +401,13 @@ class Model(BaseModel): def handler(user: Model) -> Response[Model]: return Response(body=user, status_code=200) - event = deepcopy(LOAD_GW_EVENT) - event["httpMethod"] = "POST" - event["path"] = "/" - event["body"] = json.dumps({}) + gw_event["httpMethod"] = "POST" + gw_event["path"] = "/" + gw_event["body"] = json.dumps({}) # THEN the handler should be invoked and return 422 # THEN the body should have the word missing - result = app(event, {}) + result = app(gw_event, {}) assert result["statusCode"] == 422 assert "missing" in result["body"] @@ -449,13 +421,17 @@ def handler(user: Model) -> Response[Model]: ("handler3_without_query_params", 200, None), ], ) -def test_validation_query_string_with_api_rest_resolver(handler_func, expected_status_code, expected_error_text): +def test_validation_query_string_with_api_rest_resolver( + handler_func, + expected_status_code, + expected_error_text, + gw_event, +): # GIVEN a APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) - event = deepcopy(LOAD_GW_EVENT) - event["httpMethod"] = "GET" - event["path"] = "/users" + gw_event["httpMethod"] = "GET" + gw_event["path"] = "/users" # WHEN a handler is defined with various parameters and routes # Define handler1 with correct params @@ -474,8 +450,8 @@ def handler2(parameter1: Annotated[List[int], Query()], parameter2: str): # Define handler3 without params if handler_func == "handler3_without_query_params": - event["queryStringParameters"] = None - event["multiValueQueryStringParameters"] = None + gw_event["queryStringParameters"] = None + gw_event["multiValueQueryStringParameters"] = None @app.get("/users") def handler3(): @@ -483,7 +459,7 @@ def handler3(): # THEN the handler should be invoked with the expected result # AND the status code should match the expected_status_code - result = app(event, {}) + result = app(gw_event, {}) assert result["statusCode"] == expected_status_code # IF expected_error_text is provided, THEN check for its presence in the response body @@ -499,14 +475,18 @@ def handler3(): ("handler3_without_query_params", 200, None), ], ) -def test_validation_query_string_with_api_http_resolver(handler_func, expected_status_code, expected_error_text): +def test_validation_query_string_with_api_http_resolver( + handler_func, + expected_status_code, + expected_error_text, + gw_event_http, +): # GIVEN a APIGatewayHttpResolver with validation enabled app = APIGatewayHttpResolver(enable_validation=True) - event = deepcopy(LOAD_GW_EVENT_HTTP) - event["rawPath"] = "/users" - event["requestContext"]["http"]["method"] = "GET" - event["requestContext"]["http"]["path"] = "/users" + gw_event_http["rawPath"] = "/users" + gw_event_http["requestContext"]["http"]["method"] = "GET" + gw_event_http["requestContext"]["http"]["path"] = "/users" # WHEN a handler is defined with various parameters and routes # Define handler1 with correct params @@ -525,7 +505,7 @@ def handler2(parameter1: Annotated[List[int], Query()], parameter2: str): # Define handler3 without params if handler_func == "handler3_without_query_params": - event["queryStringParameters"] = None + gw_event_http["queryStringParameters"] = None @app.get("/users") def handler3(): @@ -533,7 +513,7 @@ def handler3(): # THEN the handler should be invoked with the expected result # AND the status code should match the expected_status_code - result = app(event, {}) + result = app(gw_event_http, {}) assert result["statusCode"] == expected_status_code # IF expected_error_text is provided, THEN check for its presence in the response body @@ -549,14 +529,18 @@ def handler3(): ("handler3_without_query_params", 200, None), ], ) -def test_validation_query_string_with_alb_resolver(handler_func, expected_status_code, expected_error_text): +def test_validation_query_string_with_alb_resolver( + handler_func, + expected_status_code, + expected_error_text, + gw_event_alb, +): # GIVEN a ALBResolver with validation enabled app = ALBResolver(enable_validation=True) - event = deepcopy(LOAD_GW_EVENT_ALB) - event["path"] = "/users" - # WHEN a handler is defined with various parameters and routes + gw_event_alb["path"] = "/users" + # WHEN a handler is defined with various parameters and routes # Define handler1 with correct params if handler_func == "handler1_with_correct_params": @@ -573,7 +557,7 @@ def handler2(parameter1: Annotated[List[int], Query()], parameter2: str): # Define handler3 without params if handler_func == "handler3_without_query_params": - event["multiValueQueryStringParameters"] = None + gw_event_alb["multiValueQueryStringParameters"] = None @app.get("/users") def handler3(): @@ -581,7 +565,7 @@ def handler3(): # THEN the handler should be invoked with the expected result # AND the status code should match the expected_status_code - result = app(event, {}) + result = app(gw_event_alb, {}) assert result["statusCode"] == expected_status_code # IF expected_error_text is provided, THEN check for its presence in the response body @@ -597,14 +581,18 @@ def handler3(): ("handler3_without_query_params", 200, None), ], ) -def test_validation_query_string_with_lambda_url_resolver(handler_func, expected_status_code, expected_error_text): +def test_validation_query_string_with_lambda_url_resolver( + handler_func, + expected_status_code, + expected_error_text, + gw_event_lambda_url, +): # GIVEN a LambdaFunctionUrlResolver with validation enabled app = LambdaFunctionUrlResolver(enable_validation=True) - event = deepcopy(LOAD_GW_EVENT_LAMBDA_URL) - event["rawPath"] = "/users" - event["requestContext"]["http"]["method"] = "GET" - event["requestContext"]["http"]["path"] = "/users" + gw_event_lambda_url["rawPath"] = "/users" + gw_event_lambda_url["requestContext"]["http"]["method"] = "GET" + gw_event_lambda_url["requestContext"]["http"]["path"] = "/users" # WHEN a handler is defined with various parameters and routes # Define handler1 with correct params @@ -623,7 +611,7 @@ def handler2(parameter1: Annotated[List[int], Query()], parameter2: str): # Define handler3 without params if handler_func == "handler3_without_query_params": - event["queryStringParameters"] = None + gw_event_lambda_url["queryStringParameters"] = None @app.get("/users") def handler3(): @@ -631,7 +619,7 @@ def handler3(): # THEN the handler should be invoked with the expected result # AND the status code should match the expected_status_code - result = app(event, {}) + result = app(gw_event_lambda_url, {}) assert result["statusCode"] == expected_status_code # IF expected_error_text is provided, THEN check for its presence in the response body @@ -647,12 +635,16 @@ def handler3(): ("handler3_without_query_params", 200, None), ], ) -def test_validation_query_string_with_vpc_lattice_resolver(handler_func, expected_status_code, expected_error_text): +def test_validation_query_string_with_vpc_lattice_resolver( + handler_func, + expected_status_code, + expected_error_text, + gw_event_vpc_lattice, +): # GIVEN a VPCLatticeV2Resolver with validation enabled app = VPCLatticeV2Resolver(enable_validation=True) - event = deepcopy(LOAD_GW_EVENT_VPC_LATTICE) - event["path"] = "/users" + gw_event_vpc_lattice["path"] = "/users" # WHEN a handler is defined with various parameters and routes @@ -672,7 +664,7 @@ def handler2(parameter1: Annotated[List[int], Query()], parameter2: str): # Define handler3 without params if handler_func == "handler3_without_query_params": - event["queryStringParameters"] = None + gw_event_vpc_lattice["queryStringParameters"] = None @app.get("/users") def handler3(): @@ -680,7 +672,7 @@ def handler3(): # THEN the handler should be invoked with the expected result # AND the status code should match the expected_status_code - result = app(event, {}) + result = app(gw_event_vpc_lattice, {}) assert result["statusCode"] == expected_status_code # IF expected_error_text is provided, THEN check for its presence in the response body @@ -698,13 +690,17 @@ def handler3(): ("handler4_without_header_params", 200, None), ], ) -def test_validation_header_with_api_rest_resolver(handler_func, expected_status_code, expected_error_text): +def test_validation_header_with_api_rest_resolver( + handler_func, + expected_status_code, + expected_error_text, + gw_event, +): # GIVEN a APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) - event = deepcopy(LOAD_GW_EVENT) - event["httpMethod"] = "GET" - event["path"] = "/users" + gw_event["httpMethod"] = "GET" + gw_event["path"] = "/users" # WHEN a handler is defined with various parameters and routes # Define handler1 with correct params @@ -733,8 +729,8 @@ def handler3( # Define handler4 without params if handler_func == "handler4_without_header_params": - event["headers"] = None - event["multiValueHeaders"] = None + gw_event["headers"] = None + gw_event["multiValueHeaders"] = None @app.get("/users") def handler4(): @@ -742,7 +738,7 @@ def handler4(): # THEN the handler should be invoked with the expected result # AND the status code should match the expected_status_code - result = app(event, {}) + result = app(gw_event, {}) assert result["statusCode"] == expected_status_code # IF expected_error_text is provided, THEN check for its presence in the response body @@ -759,14 +755,18 @@ def handler4(): ("handler4_without_header_params", 200, None), ], ) -def test_validation_header_with_http_rest_resolver(handler_func, expected_status_code, expected_error_text): +def test_validation_header_with_http_rest_resolver( + handler_func, + expected_status_code, + expected_error_text, + gw_event_http, +): # GIVEN a APIGatewayHttpResolver with validation enabled app = APIGatewayHttpResolver(enable_validation=True) - event = deepcopy(LOAD_GW_EVENT_HTTP) - event["rawPath"] = "/users" - event["requestContext"]["http"]["method"] = "GET" - event["requestContext"]["http"]["path"] = "/users" + gw_event_http["rawPath"] = "/users" + gw_event_http["requestContext"]["http"]["method"] = "GET" + gw_event_http["requestContext"]["http"]["path"] = "/users" # WHEN a handler is defined with various parameters and routes # Define handler1 with correct params @@ -795,7 +795,7 @@ def handler3( # Define handler4 without params if handler_func == "handler4_without_header_params": - event["headers"] = None + gw_event_http["headers"] = None @app.get("/users") def handler4(): @@ -803,7 +803,7 @@ def handler4(): # THEN the handler should be invoked with the expected result # AND the status code should match the expected_status_code - result = app(event, {}) + result = app(gw_event_http, {}) assert result["statusCode"] == expected_status_code # IF expected_error_text is provided, THEN check for its presence in the response body @@ -820,12 +820,16 @@ def handler4(): ("handler4_without_header_params", 200, None), ], ) -def test_validation_header_with_alb_resolver(handler_func, expected_status_code, expected_error_text): +def test_validation_header_with_alb_resolver( + handler_func, + expected_status_code, + expected_error_text, + gw_event_alb, +): # GIVEN a ALBResolver with validation enabled app = ALBResolver(enable_validation=True) - event = deepcopy(LOAD_GW_EVENT_ALB) - event["path"] = "/users" + gw_event_alb["path"] = "/users" # WHEN a handler is defined with various parameters and routes # Define handler1 with correct params @@ -854,7 +858,7 @@ def handler3( # Define handler4 without params if handler_func == "handler4_without_header_params": - event["multiValueHeaders"] = None + gw_event_alb["multiValueHeaders"] = None @app.get("/users") def handler4(): @@ -862,7 +866,7 @@ def handler4(): # THEN the handler should be invoked with the expected result # AND the status code should match the expected_status_code - result = app(event, {}) + result = app(gw_event_alb, {}) assert result["statusCode"] == expected_status_code # IF expected_error_text is provided, THEN check for its presence in the response body @@ -879,14 +883,18 @@ def handler4(): ("handler4_without_header_params", 200, None), ], ) -def test_validation_header_with_lambda_url_resolver(handler_func, expected_status_code, expected_error_text): +def test_validation_header_with_lambda_url_resolver( + handler_func, + expected_status_code, + expected_error_text, + gw_event_lambda_url, +): # GIVEN a LambdaFunctionUrlResolver with validation enabled app = LambdaFunctionUrlResolver(enable_validation=True) - event = deepcopy(LOAD_GW_EVENT_LAMBDA_URL) - event["rawPath"] = "/users" - event["requestContext"]["http"]["method"] = "GET" - event["requestContext"]["http"]["path"] = "/users" + gw_event_lambda_url["rawPath"] = "/users" + gw_event_lambda_url["requestContext"]["http"]["method"] = "GET" + gw_event_lambda_url["requestContext"]["http"]["path"] = "/users" # WHEN a handler is defined with various parameters and routes # Define handler1 with correct params @@ -915,7 +923,7 @@ def handler3( # Define handler4 without params if handler_func == "handler4_without_header_params": - event["headers"] = None + gw_event_lambda_url["headers"] = None @app.get("/users") def handler4(): @@ -923,7 +931,7 @@ def handler4(): # THEN the handler should be invoked with the expected result # AND the status code should match the expected_status_code - result = app(event, {}) + result = app(gw_event_lambda_url, {}) assert result["statusCode"] == expected_status_code # IF expected_error_text is provided, THEN check for its presence in the response body @@ -940,13 +948,17 @@ def handler4(): ("handler4_without_header_params", 200, None), ], ) -def test_validation_header_with_vpc_lattice_v1_resolver(handler_func, expected_status_code, expected_error_text): +def test_validation_header_with_vpc_lattice_v1_resolver( + handler_func, + expected_status_code, + expected_error_text, + gw_event_vpc_lattice_v1, +): # GIVEN a VPCLatticeResolver with validation enabled app = VPCLatticeResolver(enable_validation=True) - event = deepcopy(LOAD_GW_EVENT_VPC_LATTICE_V1) - event["raw_path"] = "/users" - event["method"] = "GET" + gw_event_vpc_lattice_v1["raw_path"] = "/users" + gw_event_vpc_lattice_v1["method"] = "GET" # WHEN a handler is defined with various parameters and routes # Define handler1 with correct params @@ -975,7 +987,7 @@ def handler3( # Define handler4 without params if handler_func == "handler4_without_header_params": - event["headers"] = None + gw_event_vpc_lattice_v1["headers"] = None @app.get("/users") def handler4(): @@ -983,7 +995,7 @@ def handler4(): # THEN the handler should be invoked with the expected result # AND the status code should match the expected_status_code - result = app(event, {}) + result = app(gw_event_vpc_lattice_v1, {}) assert result["statusCode"] == expected_status_code # IF expected_error_text is provided, THEN check for its presence in the response body @@ -1000,13 +1012,17 @@ def handler4(): ("handler4_without_header_params", 200, None), ], ) -def test_validation_header_with_vpc_lattice_v2_resolver(handler_func, expected_status_code, expected_error_text): +def test_validation_header_with_vpc_lattice_v2_resolver( + handler_func, + expected_status_code, + expected_error_text, + gw_event_vpc_lattice, +): # GIVEN a VPCLatticeV2Resolver with validation enabled app = VPCLatticeV2Resolver(enable_validation=True) - event = deepcopy(LOAD_GW_EVENT_VPC_LATTICE) - event["path"] = "/users" - event["method"] = "GET" + gw_event_vpc_lattice["path"] = "/users" + gw_event_vpc_lattice["method"] = "GET" # WHEN a handler is defined with various parameters and routes # Define handler1 with correct params @@ -1035,7 +1051,7 @@ def handler3( # Define handler4 without params if handler_func == "handler4_without_header_params": - event["headers"] = None + gw_event_vpc_lattice["headers"] = None @app.get("/users") def handler3(): @@ -1043,7 +1059,7 @@ def handler3(): # THEN the handler should be invoked with the expected result # AND the status code should match the expected_status_code - result = app(event, {}) + result = app(gw_event_vpc_lattice, {}) assert result["statusCode"] == expected_status_code # IF expected_error_text is provided, THEN check for its presence in the response body @@ -1051,42 +1067,44 @@ def handler3(): assert any(text in result["body"] for text in expected_error_text) -def test_validation_with_alias(): +def test_validation_with_alias(gw_event): # GIVEN a REST API V2 proxy type event app = APIGatewayRestResolver(enable_validation=True) - event = deepcopy(LOAD_GW_EVENT) - class FunkyTown(BaseModel): - parameter: str + # GIVEN that it has a multiple parameters called "parameter1" + gw_event["queryStringParameters"] = { + "parameter1": "value1,value2", + } @app.get("/my/path") def my_path( parameter: Annotated[Optional[str], Query(alias="parameter1")] = None, - ) -> Response[FunkyTown]: - assert isinstance(app.current_event, APIGatewayProxyEvent) + ) -> str: assert parameter == "value1" - return Response(200, content_types.APPLICATION_JSON, FunkyTown(parameter=parameter)) + return parameter - result = app(event, {}) + result = app(gw_event, {}) assert result["statusCode"] == 200 -def test_validation_with_http_single_param(): +def test_validation_with_http_single_param(gw_event_http): # GIVEN a HTTP API V2 proxy type event app = APIGatewayHttpResolver(enable_validation=True) - event = deepcopy(LOAD_GW_EVENT_HTTP) - class FunkyTown(BaseModel): - parameter: str + # GIVEN that it has a single parameter called "parameter2" + gw_event_http["queryStringParameters"] = { + "parameter1": "value1,value2", + "parameter2": "value", + } # WHEN a handler is defined with a single parameter @app.post("/my/path") def my_path( parameter2: str, - ) -> Response[FunkyTown]: + ) -> str: assert parameter2 == "value" - return Response(200, content_types.APPLICATION_JSON, FunkyTown(parameter=parameter2)) + return parameter2 # THEN the handler should be invoked and return 200 - result = app(event, {}) + result = app(gw_event_http, {}) assert result["statusCode"] == 200