From b1acfb26374f7d943e9dec7dc94593385505ff3d Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Mon, 22 Jan 2024 23:06:15 +0000 Subject: [PATCH 1/9] Initial code for multivalue querystring --- .../middlewares/openapi_validation.py | 106 +++++++++++++++++- 1 file changed, 104 insertions(+), 2 deletions(-) diff --git a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py index 34011b64384..609093697c3 100644 --- a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py +++ b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py @@ -6,7 +6,14 @@ from pydantic import BaseModel -from aws_lambda_powertools.event_handler import Response +from aws_lambda_powertools.event_handler import ( + ALBResolver, + APIGatewayHttpResolver, + APIGatewayRestResolver, + LambdaFunctionUrlResolver, + Response, + VPCLatticeV2Resolver, +) from aws_lambda_powertools.event_handler.api_gateway import Route from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler, NextMiddleware from aws_lambda_powertools.event_handler.openapi.compat import ( @@ -16,6 +23,7 @@ _regenerate_error_with_loc, get_missing_field_error, ) +from aws_lambda_powertools.event_handler.openapi.dependant import is_scalar_field from aws_lambda_powertools.event_handler.openapi.encoders import jsonable_encoder from aws_lambda_powertools.event_handler.openapi.exceptions import RequestValidationError from aws_lambda_powertools.event_handler.openapi.params import Param @@ -54,6 +62,98 @@ def get_todos(): List[Todo]: ``` """ + def _extract_multi_query_string_with_param(self, query_string, params: Sequence[ModelField]): + """ + Extract and process multi_query_string_parameters for VPCLatticeV2Resolver and APIGatewayRestResolver. + + Parameters + ---------- + query_string: Dict + A dictionary containing the initial query string parameters. + params: Sequence[ModelField] + A sequence of ModelField objects representing parameters. + + Returns + ------- + A dictionary containing the processed multi_query_string_parameters. + + Comments + -------- + - This method is specifically designed for VPCLatticeV2Resolver and APIGatewayRestResolver. + - It processes multi_query_string_parameters based on the given params. + """ + for param in filter(is_scalar_field, params): + try: + # If the field is a scalar, it implies it's not a multi-query string. + # And we keep the first value for this field + + # We Attempt to retain only the first element if the parameter is a scalar field + query_string[param.name] = query_string[param.name][0] + except KeyError: + pass + return query_string + + def _extract_query_string(self, app: EventHandlerInstance, params: Sequence[ModelField]): + """ + Extract and process query string parameters based on the resolver type. + Payloads are different and we need to identify when it is using multiValueQueryStringParameters. + + References + https://docs.aws.amazon.com/apigateway/latest/developerguide/http-api-develop-integrations-lambda.html#http-api-develop-integrations-lambda.proxy-format + https://docs.aws.amazon.com/elasticloadbalancing/latest/application/lambda-functions.html#multi-value-headers + https://docs.aws.amazon.com/vpc-lattice/latest/ug/lambda-functions.html#multi-value-headers + https://docs.aws.amazon.com/lambda/latest/dg/urls-invocation.html#urls-payloads + + Parameters + ---------- + app: EventHandlerInstance. + Instance of a Resolver + params: Sequence[ModelField] + A sequence of ModelField objects representing parameters. + + Returns + ------- + A dictionary containing the processed query string parameters. + + Comments + -------- + - The initial query is obtained from app.current_event.query_string_parameters. + + - In the case of using ALBResolver, we attempt to retrieve multi_value_query_string_parameters; otherwise, + we retain the original query. + + - In the case of using LambdaFunctionUrlResolver or APIGatewayHttpResolver, multi-query strings consistently + reside in the same field, separated by commas. Consequently, we split these strings into lists. + + - When using a VPCLatticeV2Resolver, the Payload consistently sends query strings as arrays. To enhance + compatibility, we attempt to identify scalar types within the arrays and convert them to single elements. + + - In the case of using APIGatewayRestResolver, the payload includes both query string and multi-query string + fields. We apply a similar logic as used in VPCLatticeV2Resolver to handle these query strings effectively. + + - VPCLatticeResolver (v1) and BedrockAgentResolver doesn't support multi-query strings + and we retain original query + """ + + query = app.current_event.query_string_parameters or {} + + if isinstance(app, ALBResolver): + query = app.current_event.multi_value_query_string_parameters or query + + if isinstance(app, (LambdaFunctionUrlResolver, APIGatewayHttpResolver)): + query = {key: value.split(",") if "," in value else value for key, value in query.items()} + + if isinstance(app, VPCLatticeV2Resolver): + query = self._extract_multi_query_string_with_param(query, params) + + if isinstance(app, APIGatewayRestResolver) and app.current_event.multi_value_query_string_parameters: + query = self._extract_multi_query_string_with_param( + app.current_event.multi_value_query_string_parameters, + params, + ) + + return query + def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> Response: logger.debug("OpenAPIValidationMiddleware handler") @@ -68,10 +168,12 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> app.context["_route_args"], ) + query_string = self._extract_query_string(app, route.dependant.query_params) + # Process query values query_values, query_errors = _request_params_to_args( route.dependant.query_params, - app.current_event.query_string_parameters or {}, + query_string, ) values.update(path_values) From 38da25d6998839751fd232f12c75a70e12bc0522 Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Tue, 23 Jan 2024 00:07:08 +0000 Subject: [PATCH 2/9] Adding tests and improving code --- .../middlewares/openapi_validation.py | 16 +- .../events/albMultiValueQueryStringEvent.json | 38 ++++ .../lambdaFunctionUrlEventWithHeaders.json | 51 +++++ .../events/vpcLatticeV2EventWithHeaders.json | 36 ++++ .../event_handler/test_openapi_params.py | 14 ++ .../test_openapi_validation_middleware.py | 190 +++++++++++++++++- 6 files changed, 334 insertions(+), 11 deletions(-) create mode 100644 tests/events/albMultiValueQueryStringEvent.json create mode 100644 tests/events/lambdaFunctionUrlEventWithHeaders.json create mode 100644 tests/events/vpcLatticeV2EventWithHeaders.json diff --git a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py index 609093697c3..be688152b86 100644 --- a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py +++ b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py @@ -119,17 +119,15 @@ def _extract_query_string(self, app: EventHandlerInstance, params: Sequence[Mode -------- - The initial query is obtained from app.current_event.query_string_parameters. - - In the case of using ALBResolver, we attempt to retrieve multi_value_query_string_parameters; otherwise, - we retain the original query. - - In the case of using LambdaFunctionUrlResolver or APIGatewayHttpResolver, multi-query strings consistently reside in the same field, separated by commas. Consequently, we split these strings into lists. - When using a VPCLatticeV2Resolver, the Payload consistently sends query strings as arrays. To enhance compatibility, we attempt to identify scalar types within the arrays and convert them to single elements. - - In the case of using APIGatewayRestResolver, the payload includes both query string and multi-query string - fields. We apply a similar logic as used in VPCLatticeV2Resolver to handle these query strings effectively. + - In the case of using APIGatewayRestResolver or ALBResolver, the payload may includes both query string and + multi-query string fields. We apply a similar logic as used in VPCLatticeV2Resolver + to handle these query strings effectively. - VPCLatticeResolver (v1) and BedrockAgentResolver doesn't support multi-query strings and we retain original query @@ -137,16 +135,16 @@ def _extract_query_string(self, app: EventHandlerInstance, params: Sequence[Mode query = app.current_event.query_string_parameters or {} - if isinstance(app, ALBResolver): - query = app.current_event.multi_value_query_string_parameters or query - if isinstance(app, (LambdaFunctionUrlResolver, APIGatewayHttpResolver)): query = {key: value.split(",") if "," in value else value for key, value in query.items()} if isinstance(app, VPCLatticeV2Resolver): query = self._extract_multi_query_string_with_param(query, params) - if isinstance(app, APIGatewayRestResolver) and app.current_event.multi_value_query_string_parameters: + if ( + isinstance(app, (ALBResolver, APIGatewayRestResolver)) + and app.current_event.multi_value_query_string_parameters + ): query = self._extract_multi_query_string_with_param( app.current_event.multi_value_query_string_parameters, params, diff --git a/tests/events/albMultiValueQueryStringEvent.json b/tests/events/albMultiValueQueryStringEvent.json new file mode 100644 index 00000000000..4584ba7c477 --- /dev/null +++ b/tests/events/albMultiValueQueryStringEvent.json @@ -0,0 +1,38 @@ +{ + "requestContext": { + "elb": { + "targetGroupArn": "arn:aws:elasticloadbalancing:eu-central-1:1234567890:targetgroup/alb-c-Targe-11GDXTPQ7663S/804a67588bfdc10f" + } + }, + "httpMethod": "GET", + "path": "/todos", + "multiValueQueryStringParameters": { + "parameter1": ["value1","value2"], + "parameter2": ["value"] + }, + "multiValueHeaders": { + "accept": [ + "*/*" + ], + "host": [ + "alb-c-LoadB-14POFKYCLBNSF-1815800096.eu-central-1.elb.amazonaws.com" + ], + "user-agent": [ + "curl/7.79.1" + ], + "x-amzn-trace-id": [ + "Root=1-62fa9327-21cdd4da4c6db451490a5fb7" + ], + "x-forwarded-for": [ + "123.123.123.123" + ], + "x-forwarded-port": [ + "80" + ], + "x-forwarded-proto": [ + "http" + ] + }, + "body": "", + "isBase64Encoded": false +} diff --git a/tests/events/lambdaFunctionUrlEventWithHeaders.json b/tests/events/lambdaFunctionUrlEventWithHeaders.json new file mode 100644 index 00000000000..e453690d9b3 --- /dev/null +++ b/tests/events/lambdaFunctionUrlEventWithHeaders.json @@ -0,0 +1,51 @@ +{ + "version":"2.0", + "routeKey":"$default", + "rawPath":"/", + "rawQueryString":"", + "headers":{ + "sec-fetch-mode":"navigate", + "x-amzn-tls-version":"TLSv1.2", + "sec-fetch-site":"cross-site", + "accept-language":"pt-BR,pt;q=0.9", + "x-forwarded-proto":"https", + "x-forwarded-port":"443", + "x-forwarded-for":"123.123.123.123", + "sec-fetch-user":"?1", + "accept":"text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9", + "x-amzn-tls-cipher-suite":"ECDHE-RSA-AES128-GCM-SHA256", + "sec-ch-ua":"\" Not A;Brand\";v=\"99\", \"Chromium\";v=\"102\", \"Google Chrome\";v=\"102\"", + "sec-ch-ua-mobile":"?0", + "x-amzn-trace-id":"Root=1-62ecd163-5f302e550dcde3b12402207d", + "sec-ch-ua-platform":"\"Linux\"", + "host":".lambda-url.us-east-1.on.aws", + "upgrade-insecure-requests":"1", + "cache-control":"max-age=0", + "accept-encoding":"gzip, deflate, br", + "sec-fetch-dest":"document", + "user-agent":"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/102.0.0.0 Safari/537.36" + }, + "queryStringParameters": { + "parameter1": "value1,value2", + "parameter2": "value" + }, + "requestContext":{ + "accountId":"anonymous", + "apiId":"", + "domainName":".lambda-url.us-east-1.on.aws", + "domainPrefix":"", + "http":{ + "method":"GET", + "path":"/", + "protocol":"HTTP/1.1", + "sourceIp":"123.123.123.123", + "userAgent":"agent" + }, + "requestId":"id", + "routeKey":"$default", + "stage":"$default", + "time":"05/Aug/2022:08:14:39 +0000", + "timeEpoch":1659687279885 + }, + "isBase64Encoded":false +} diff --git a/tests/events/vpcLatticeV2EventWithHeaders.json b/tests/events/vpcLatticeV2EventWithHeaders.json new file mode 100644 index 00000000000..11b36ef118b --- /dev/null +++ b/tests/events/vpcLatticeV2EventWithHeaders.json @@ -0,0 +1,36 @@ +{ + "version": "2.0", + "path": "/newpath", + "method": "GET", + "headers": { + "user_agent": "curl/7.64.1", + "x-forwarded-for": "10.213.229.10", + "host": "test-lambda-service-3908sdf9u3u.dkfjd93.vpc-lattice-svcs.us-east-2.on.aws", + "accept": "*/*" + }, + "queryStringParameters": { + "parameter1": [ + "value1", + "value2" + ], + "parameter2": [ + "value" + ] + }, + "body": "{\"message\": \"Hello from Lambda!\"}", + "isBase64Encoded": false, + "requestContext": { + "serviceNetworkArn": "arn:aws:vpc-lattice:us-east-2:123456789012:servicenetwork/sn-0bf3f2882e9cc805a", + "serviceArn": "arn:aws:vpc-lattice:us-east-2:123456789012:service/svc-0a40eebed65f8d69c", + "targetGroupArn": "arn:aws:vpc-lattice:us-east-2:123456789012:targetgroup/tg-6d0ecf831eec9f09", + "identity": { + "sourceVpcArn": "arn:aws:ec2:region:123456789012:vpc/vpc-0b8276c84697e7339", + "type" : "AWS_IAM", + "principal": "arn:aws:sts::123456789012:assumed-role/example-role/057d00f8b51257ba3c853a0f248943cf", + "sessionName": "057d00f8b51257ba3c853a0f248943cf", + "x509SanDns": "example.com" + }, + "region": "us-east-2", + "timeEpoch": "1696331543569073" + } +} diff --git a/tests/functional/event_handler/test_openapi_params.py b/tests/functional/event_handler/test_openapi_params.py index 0f06524ea6d..2f48f5aa534 100644 --- a/tests/functional/event_handler/test_openapi_params.py +++ b/tests/functional/event_handler/test_openapi_params.py @@ -184,6 +184,20 @@ def handler(page: Annotated[str, Query(include_in_schema=False)]): assert get.parameters is None +def test_openapi_with_list_param(): + app = APIGatewayRestResolver() + + @app.get("/") + def handler(page: Annotated[List[str], Query()]): + return page + + schema = app.get_openapi_schema() + assert len(schema.paths.keys()) == 1 + + get = schema.paths["/"].get + assert get.parameters[0].schema_.type == "array" + + def test_openapi_with_description(): app = APIGatewayRestResolver() diff --git a/tests/functional/event_handler/test_openapi_validation_middleware.py b/tests/functional/event_handler/test_openapi_validation_middleware.py index f558bd23ced..45f9810e6bb 100644 --- a/tests/functional/event_handler/test_openapi_validation_middleware.py +++ b/tests/functional/event_handler/test_openapi_validation_middleware.py @@ -6,12 +6,23 @@ from pydantic import BaseModel -from aws_lambda_powertools.event_handler import APIGatewayRestResolver, Response -from aws_lambda_powertools.event_handler.openapi.params import Body +from aws_lambda_powertools.event_handler import ( + ALBResolver, + APIGatewayHttpResolver, + APIGatewayRestResolver, + LambdaFunctionUrlResolver, + Response, + VPCLatticeV2Resolver, +) +from aws_lambda_powertools.event_handler.openapi.params import Body, Query from aws_lambda_powertools.shared.types import Annotated from tests.functional.utils import load_event LOAD_GW_EVENT = load_event("apiGatewayProxyEvent.json") +LOAD_GW_EVENT_HTTP = load_event("apiGatewayProxyV2Event.json") +LOAD_GW_EVENT_ALB = load_event("albMultiValueQueryStringEvent.json") +LOAD_GW_EVENT_LAMBDA_URL = load_event("lambdaFunctionUrlEventWithHeaders.json") +LOAD_GW_EVENT_VPC_LATTICE = load_event("vpcLatticeV2EventWithHeaders.json") def test_validate_scalars(): @@ -378,3 +389,178 @@ def handler(user: Model) -> Response[Model]: result = app(LOAD_GW_EVENT, {}) assert result["statusCode"] == 422 assert "missing" in result["body"] + + +def test_validate_rest_api_resolver_with_multi_query_values(): + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + # WHEN a handler is defined with a default scalar parameter and a list + @app.get("/users") + def handler(parameter1: Annotated[List[str], Query()], parameter2: str): + print(parameter2) + + LOAD_GW_EVENT["httpMethod"] = "GET" + LOAD_GW_EVENT["path"] = "/users" + + # THEN the handler should be invoked and return 200 + result = app(LOAD_GW_EVENT, {}) + assert result["statusCode"] == 200 + + +def test_validate_rest_api_resolver_with_multi_query_values_fail(): + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + # WHEN a handler is defined with a default scalar parameter and a list with wrong type + @app.get("/users") + def handler(parameter1: Annotated[List[int], Query()], parameter2: str): + print(parameter2) + + LOAD_GW_EVENT["httpMethod"] = "GET" + LOAD_GW_EVENT["path"] = "/users" + + # THEN the handler should be invoked and return 422 + result = app(LOAD_GW_EVENT, {}) + assert result["statusCode"] == 422 + assert any(text in result["body"] for text in ["type_error.integer"]) + + +def test_validate_http_resolver_with_multi_query_values(): + # GIVEN an APIGatewayHttpResolver with validation enabled + app = APIGatewayHttpResolver(enable_validation=True) + + # WHEN a handler is defined with a default scalar parameter and a list + @app.get("/users") + def handler(parameter1: Annotated[List[str], Query()], parameter2: str): + print(parameter2) + + LOAD_GW_EVENT_HTTP["rawPath"] = "/users" + LOAD_GW_EVENT_HTTP["requestContext"]["http"]["method"] = "GET" + LOAD_GW_EVENT_HTTP["requestContext"]["http"]["path"] = "/users" + + # THEN the handler should be invoked and return 200 + result = app(LOAD_GW_EVENT_HTTP, {}) + assert result["statusCode"] == 200 + + +def test_validate_http_resolver_with_multi_query_values_fail(): + # GIVEN an APIGatewayHttpResolver with validation enabled + app = APIGatewayHttpResolver(enable_validation=True) + + # WHEN a handler is defined with a default scalar parameter and a list with wrong type + @app.get("/users") + def handler(parameter1: Annotated[List[int], Query()], parameter2: str): + print(parameter2) + + LOAD_GW_EVENT_HTTP["rawPath"] = "/users" + LOAD_GW_EVENT_HTTP["requestContext"]["http"]["method"] = "GET" + LOAD_GW_EVENT_HTTP["requestContext"]["http"]["path"] = "/users" + + # THEN the handler should be invoked and return 422 + result = app(LOAD_GW_EVENT_HTTP, {}) + assert result["statusCode"] == 422 + assert any(text in result["body"] for text in ["type_error.integer"]) + + +def test_validate_alb_resolver_with_multi_query_values(): + # GIVEN an ALBResolver with validation enabled + app = ALBResolver(enable_validation=True) + + # WHEN a handler is defined with a default scalar parameter and a list + @app.get("/users") + def handler(parameter1: Annotated[List[str], Query()], parameter2: str): + print(parameter2) + + LOAD_GW_EVENT_ALB["path"] = "/users" + + # THEN the handler should be invoked and return 200 + result = app(LOAD_GW_EVENT_ALB, {}) + assert result["statusCode"] == 200 + + +def test_validate_alb_resolver_with_multi_query_values_fail(): + # GIVEN an ALBResolver with validation enabled + app = ALBResolver(enable_validation=True) + + # WHEN a handler is defined with a default scalar parameter and a list with wrong type + @app.get("/users") + def handler(parameter1: Annotated[List[int], Query()], parameter2: str): + print(parameter2) + + LOAD_GW_EVENT_ALB["path"] = "/users" + + # THEN the handler should be invoked and return 422 + result = app(LOAD_GW_EVENT_ALB, {}) + assert result["statusCode"] == 422 + assert any(text in result["body"] for text in ["type_error.integer"]) + + +def test_validate_lambda_url_resolver_with_multi_query_values(): + # GIVEN an LambdaFunctionUrlResolver with validation enabled + app = LambdaFunctionUrlResolver(enable_validation=True) + + # WHEN a handler is defined with a default scalar parameter and a list + @app.get("/users") + def handler(parameter1: Annotated[List[str], Query()], parameter2: str): + print(parameter2) + + LOAD_GW_EVENT_LAMBDA_URL["rawPath"] = "/users" + LOAD_GW_EVENT_LAMBDA_URL["requestContext"]["http"]["method"] = "GET" + LOAD_GW_EVENT_LAMBDA_URL["requestContext"]["http"]["path"] = "/users" + + # THEN the handler should be invoked and return 200 + result = app(LOAD_GW_EVENT_LAMBDA_URL, {}) + assert result["statusCode"] == 200 + + +def test_validate__lambda_url_resolver_with_multi_query_values_fail(): + # GIVEN an LambdaFunctionUrlResolver with validation enabled + app = LambdaFunctionUrlResolver(enable_validation=True) + + # WHEN a handler is defined with a default scalar parameter and a list with wrong type + @app.get("/users") + def handler(parameter1: Annotated[List[int], Query()], parameter2: str): + print(parameter2) + + LOAD_GW_EVENT_LAMBDA_URL["rawPath"] = "/users" + LOAD_GW_EVENT_LAMBDA_URL["requestContext"]["http"]["method"] = "GET" + LOAD_GW_EVENT_LAMBDA_URL["requestContext"]["http"]["path"] = "/users" + + # THEN the handler should be invoked and return 422 + result = app(LOAD_GW_EVENT_LAMBDA_URL, {}) + assert result["statusCode"] == 422 + assert any(text in result["body"] for text in ["type_error.integer"]) + + +def test_validate_vpc_lattice_resolver_with_multi_query_values(): + # GIVEN an VPCLatticeV2Resolver with validation enabled + app = VPCLatticeV2Resolver(enable_validation=True) + + # WHEN a handler is defined with a default scalar parameter and a list + @app.get("/users") + def handler(parameter1: Annotated[List[str], Query()], parameter2: str): + print(parameter2) + + LOAD_GW_EVENT_VPC_LATTICE["path"] = "/users" + + # THEN the handler should be invoked and return 200 + result = app(LOAD_GW_EVENT_VPC_LATTICE, {}) + assert result["statusCode"] == 200 + + +def test_validate_vpc_lattice_resolver_with_multi_query_values_fail(): + # GIVEN an VPCLatticeV2Resolver with validation enabled + app = VPCLatticeV2Resolver(enable_validation=True) + + # WHEN a handler is defined with a default scalar parameter and a list with wrong type + @app.get("/users") + def handler(parameter1: Annotated[List[int], Query()], parameter2: str): + print(parameter2) + + LOAD_GW_EVENT_VPC_LATTICE["path"] = "/users" + + # THEN the handler should be invoked and return 422 + result = app(LOAD_GW_EVENT_VPC_LATTICE, {}) + assert result["statusCode"] == 422 + assert any(text in result["body"] for text in ["type_error.integer"]) From ad44f5fd626605b478aaeff90c8766f81c2506cd Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Tue, 23 Jan 2024 00:10:49 +0000 Subject: [PATCH 3/9] Adding tests and improving code --- .../event_handler/middlewares/openapi_validation.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py index be688152b86..42d6d6d5a0e 100644 --- a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py +++ b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py @@ -130,27 +130,27 @@ def _extract_query_string(self, app: EventHandlerInstance, params: Sequence[Mode to handle these query strings effectively. - VPCLatticeResolver (v1) and BedrockAgentResolver doesn't support multi-query strings - and we retain original query + and we retain original query_string field """ - query = app.current_event.query_string_parameters or {} + query_string = app.current_event.query_string_parameters or {} if isinstance(app, (LambdaFunctionUrlResolver, APIGatewayHttpResolver)): - query = {key: value.split(",") if "," in value else value for key, value in query.items()} + query_string = {key: value.split(",") if "," in value else value for key, value in query_string.items()} if isinstance(app, VPCLatticeV2Resolver): - query = self._extract_multi_query_string_with_param(query, params) + query_string = self._extract_multi_query_string_with_param(query_string, params) if ( isinstance(app, (ALBResolver, APIGatewayRestResolver)) and app.current_event.multi_value_query_string_parameters ): - query = self._extract_multi_query_string_with_param( + query_string = self._extract_multi_query_string_with_param( app.current_event.multi_value_query_string_parameters, params, ) - return query + return query_string def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> Response: logger.debug("OpenAPIValidationMiddleware handler") From 522ea9a3e08836e88175cb274b903d1a4c748db0 Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Tue, 23 Jan 2024 12:59:55 +0000 Subject: [PATCH 4/9] Refactoging to avoid abstraction leaky --- .../middlewares/openapi_validation.py | 149 ++++++------------ .../utilities/data_classes/alb_event.py | 13 +- .../data_classes/api_gateway_proxy_event.py | 25 +++ .../data_classes/bedrock_agent_event.py | 8 + .../utilities/data_classes/common.py | 8 + .../utilities/data_classes/vpc_lattice.py | 16 ++ .../test_openapi_validation_middleware.py | 105 +++++++++++- 7 files changed, 217 insertions(+), 107 deletions(-) diff --git a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py index 42d6d6d5a0e..0d210947223 100644 --- a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py +++ b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py @@ -6,14 +6,7 @@ from pydantic import BaseModel -from aws_lambda_powertools.event_handler import ( - ALBResolver, - APIGatewayHttpResolver, - APIGatewayRestResolver, - LambdaFunctionUrlResolver, - Response, - VPCLatticeV2Resolver, -) +from aws_lambda_powertools.event_handler import Response from aws_lambda_powertools.event_handler.api_gateway import Route from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler, NextMiddleware from aws_lambda_powertools.event_handler.openapi.compat import ( @@ -62,96 +55,6 @@ def get_todos(): List[Todo]: ``` """ - def _extract_multi_query_string_with_param(self, query_string, params: Sequence[ModelField]): - """ - Extract and process multi_query_string_parameters for VPCLatticeV2Resolver and APIGatewayRestResolver. - - Parameters - ---------- - query_string: Dict - A dictionary containing the initial query string parameters. - params: Sequence[ModelField] - A sequence of ModelField objects representing parameters. - - Returns - ------- - A dictionary containing the processed multi_query_string_parameters. - - Comments - -------- - - This method is specifically designed for VPCLatticeV2Resolver and APIGatewayRestResolver. - - It processes multi_query_string_parameters based on the given params. - """ - for param in filter(is_scalar_field, params): - try: - # If the field is a scalar, it implies it's not a multi-query string. - # And we keep the first value for this field - - # We Attempt to retain only the first element if the parameter is a scalar field - query_string[param.name] = query_string[param.name][0] - except KeyError: - pass - return query_string - - def _extract_query_string(self, app: EventHandlerInstance, params: Sequence[ModelField]): - """ - Extract and process query string parameters based on the resolver type. - Payloads are different and we need to identify when it is using multiValueQueryStringParameters. - - References - https://docs.aws.amazon.com/apigateway/latest/developerguide/http-api-develop-integrations-lambda.html#http-api-develop-integrations-lambda.proxy-format - https://docs.aws.amazon.com/elasticloadbalancing/latest/application/lambda-functions.html#multi-value-headers - https://docs.aws.amazon.com/vpc-lattice/latest/ug/lambda-functions.html#multi-value-headers - https://docs.aws.amazon.com/lambda/latest/dg/urls-invocation.html#urls-payloads - - Parameters - ---------- - app: EventHandlerInstance. - Instance of a Resolver - params: Sequence[ModelField] - A sequence of ModelField objects representing parameters. - - Returns - ------- - A dictionary containing the processed query string parameters. - - Comments - -------- - - The initial query is obtained from app.current_event.query_string_parameters. - - - In the case of using LambdaFunctionUrlResolver or APIGatewayHttpResolver, multi-query strings consistently - reside in the same field, separated by commas. Consequently, we split these strings into lists. - - - When using a VPCLatticeV2Resolver, the Payload consistently sends query strings as arrays. To enhance - compatibility, we attempt to identify scalar types within the arrays and convert them to single elements. - - - In the case of using APIGatewayRestResolver or ALBResolver, the payload may includes both query string and - multi-query string fields. We apply a similar logic as used in VPCLatticeV2Resolver - to handle these query strings effectively. - - - VPCLatticeResolver (v1) and BedrockAgentResolver doesn't support multi-query strings - and we retain original query_string field - """ - - query_string = app.current_event.query_string_parameters or {} - - if isinstance(app, (LambdaFunctionUrlResolver, APIGatewayHttpResolver)): - query_string = {key: value.split(",") if "," in value else value for key, value in query_string.items()} - - if isinstance(app, VPCLatticeV2Resolver): - query_string = self._extract_multi_query_string_with_param(query_string, params) - - if ( - isinstance(app, (ALBResolver, APIGatewayRestResolver)) - and app.current_event.multi_value_query_string_parameters - ): - query_string = self._extract_multi_query_string_with_param( - app.current_event.multi_value_query_string_parameters, - params, - ) - - return query_string - def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> Response: logger.debug("OpenAPIValidationMiddleware handler") @@ -166,7 +69,11 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> app.context["_route_args"], ) - query_string = self._extract_query_string(app, route.dependant.query_params) + # Normalize query values before validate this + query_string = _normalize_multi_query_string_with_param( + app.current_event.resolved_query_string_parameters, + route.dependant.query_params, + ) # Process query values query_values, query_errors = _request_params_to_args( @@ -444,3 +351,47 @@ def _get_embed_body( received_body = {field.alias: received_body} return received_body, field_alias_omitted + + +def _normalize_multi_query_string_with_param(query_string, params: Sequence[ModelField]): + """ + Extract and normalize resolved_query_string_parameters + + Parameters + ---------- + query_string: Dict + A dictionary containing the initial query string parameters. + params: Sequence[ModelField] + A sequence of ModelField objects representing parameters. + + Returns + ------- + A dictionary containing the processed multi_query_string_parameters. + + Comments + -------- + - These comments are to explain the decision to create this method + + - In the case of using LambdaFunctionUrlResolver or APIGatewayHttpResolver, multi-query strings consistently + reside in the same field, separated by commas. + + - When using a VPCLatticeV2Resolver, the Payload consistently sends query strings as arrays. To enhance + compatibility, we attempt to identify scalar types within the arrays and convert them to single elements. + + - In the case of using APIGatewayRestResolver or ALBResolver, the payload may includes both query string and + multi-query string fields. We apply a similar logic as used in VPCLatticeV2Resolver + to handle these query strings effectively. + + - VPCLatticeResolver (v1) and BedrockAgentResolver doesn't support multi-query strings + and we retain original query + """ + for param in filter(is_scalar_field, params): + try: + # If the field is a scalar, it implies it's not a multi-query string. + # And we keep the first value for this field + + # We Attempt to retain only the first element if the parameter is a scalar field + query_string[param.name] = query_string[param.name][0] + except KeyError: + pass + return query_string diff --git a/aws_lambda_powertools/utilities/data_classes/alb_event.py b/aws_lambda_powertools/utilities/data_classes/alb_event.py index 51a6f61f368..b9f319f2b7b 100644 --- a/aws_lambda_powertools/utilities/data_classes/alb_event.py +++ b/aws_lambda_powertools/utilities/data_classes/alb_event.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional from aws_lambda_powertools.shared.headers_serializer import ( BaseHeadersSerializer, @@ -35,6 +35,17 @@ def request_context(self) -> ALBEventRequestContext: def multi_value_query_string_parameters(self) -> Optional[Dict[str, List[str]]]: return self.get("multiValueQueryStringParameters") + @property + def resolved_query_string_parameters(self) -> Optional[Dict[str, Any]]: + """ + This property determines the appropriate query string parameter to be used + as a trusted source for validating OpenAPI. + """ + if self.multi_value_query_string_parameters: + return self.multi_value_query_string_parameters + + return self.query_string_parameters + @property def multi_value_headers(self) -> Optional[Dict[str, List[str]]]: return self.get("multiValueHeaders") diff --git a/aws_lambda_powertools/utilities/data_classes/api_gateway_proxy_event.py b/aws_lambda_powertools/utilities/data_classes/api_gateway_proxy_event.py index 5c2ef12e62c..e82aba89ec7 100644 --- a/aws_lambda_powertools/utilities/data_classes/api_gateway_proxy_event.py +++ b/aws_lambda_powertools/utilities/data_classes/api_gateway_proxy_event.py @@ -118,6 +118,17 @@ def multi_value_headers(self) -> Dict[str, List[str]]: def multi_value_query_string_parameters(self) -> Optional[Dict[str, List[str]]]: return self.get("multiValueQueryStringParameters") + @property + def resolved_query_string_parameters(self) -> Optional[Dict[str, Any]]: + """ + This property determines the appropriate query string parameter to be used + as a trusted source for validating OpenAPI. + """ + if self.multi_value_query_string_parameters: + return self.multi_value_query_string_parameters + + return self.query_string_parameters + @property def request_context(self) -> APIGatewayEventRequestContext: return APIGatewayEventRequestContext(self._data) @@ -299,3 +310,17 @@ def http_method(self) -> str: def header_serializer(self): return HttpApiHeadersSerializer() + + @property + def resolved_query_string_parameters(self) -> Optional[Dict[str, Any]]: + """ + This property determines the appropriate query string parameter to be used + as a trusted source for validating OpenAPI. + """ + if self.query_string_parameters is not None: + query_string = { + key: value.split(",") if "," in value else value for key, value in self.query_string_parameters.items() + } + return query_string + + return {} diff --git a/aws_lambda_powertools/utilities/data_classes/bedrock_agent_event.py b/aws_lambda_powertools/utilities/data_classes/bedrock_agent_event.py index 9534af0e7f6..31ce4454857 100644 --- a/aws_lambda_powertools/utilities/data_classes/bedrock_agent_event.py +++ b/aws_lambda_powertools/utilities/data_classes/bedrock_agent_event.py @@ -108,3 +108,11 @@ def query_string_parameters(self) -> Optional[Dict[str, str]]: # In Bedrock Agent events, query string parameters are passed as undifferentiated parameters, # together with the other parameters. So we just return all parameters here. return {x["name"]: x["value"] for x in self["parameters"]} if self.get("parameters") else None + + @property + def resolved_query_string_parameters(self) -> Optional[Dict[str, str]]: + """ + This property determines the appropriate query string parameter to be used + as a trusted source for validating OpenAPI. + """ + return self.query_string_parameters diff --git a/aws_lambda_powertools/utilities/data_classes/common.py b/aws_lambda_powertools/utilities/data_classes/common.py index 28229c21a62..1e715c730f4 100644 --- a/aws_lambda_powertools/utilities/data_classes/common.py +++ b/aws_lambda_powertools/utilities/data_classes/common.py @@ -103,6 +103,14 @@ def headers(self) -> Dict[str, str]: def query_string_parameters(self) -> Optional[Dict[str, str]]: return self.get("queryStringParameters") + @property + def resolved_query_string_parameters(self) -> Optional[Dict[str, str]]: + """ + This property determines the appropriate query string parameter to be used + as a trusted source for validating OpenAPI. + """ + return self.query_string_parameters + @property def is_base64_encoded(self) -> Optional[bool]: return self.get("isBase64Encoded") diff --git a/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py b/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py index 00ba5136eec..3cba6440ba3 100644 --- a/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py +++ b/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py @@ -141,6 +141,14 @@ def query_string_parameters(self) -> Dict[str, str]: """The request query string parameters.""" return self["query_string_parameters"] + @property + def resolved_query_string_parameters(self) -> Optional[Dict[str, str]]: + """ + This property determines the appropriate query string parameter to be used + as a trusted source for validating OpenAPI. + """ + return self.query_string_parameters + class vpcLatticeEventV2Identity(DictWrapper): @property @@ -251,3 +259,11 @@ def request_context(self) -> vpcLatticeEventV2RequestContext: def query_string_parameters(self) -> Optional[Dict[str, str]]: """The request query string parameters.""" return self.get("queryStringParameters") + + @property + def resolved_query_string_parameters(self) -> Optional[Dict[str, str]]: + """ + This property determines the appropriate query string parameter to be used + as a trusted source for validating OpenAPI. + """ + return self.query_string_parameters diff --git a/tests/functional/event_handler/test_openapi_validation_middleware.py b/tests/functional/event_handler/test_openapi_validation_middleware.py index 45f9810e6bb..cc6056f92b3 100644 --- a/tests/functional/event_handler/test_openapi_validation_middleware.py +++ b/tests/functional/event_handler/test_openapi_validation_middleware.py @@ -391,7 +391,7 @@ def handler(user: Model) -> Response[Model]: assert "missing" in result["body"] -def test_validate_rest_api_resolver_with_multi_query_values(): +def test_validate_rest_api_resolver_with_multi_query_params(): # GIVEN an APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) @@ -408,7 +408,7 @@ def handler(parameter1: Annotated[List[str], Query()], parameter2: str): assert result["statusCode"] == 200 -def test_validate_rest_api_resolver_with_multi_query_values_fail(): +def test_validate_rest_api_resolver_with_multi_query_params_fail(): # GIVEN an APIGatewayRestResolver with validation enabled app = APIGatewayRestResolver(enable_validation=True) @@ -426,7 +426,26 @@ def handler(parameter1: Annotated[List[int], Query()], parameter2: str): assert any(text in result["body"] for text in ["type_error.integer"]) -def test_validate_http_resolver_with_multi_query_values(): +def test_validate_rest_api_resolver_without_query_params(): + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + # WHEN a handler is defined with a default scalar parameter and a list with wrong type + @app.get("/users") + def handler(): + return None + + LOAD_GW_EVENT["httpMethod"] = "GET" + LOAD_GW_EVENT["path"] = "/users" + LOAD_GW_EVENT["queryStringParameters"] = None + LOAD_GW_EVENT["multiValueQueryStringParameters"] = None + + # THEN the handler should be invoked and return 422 + result = app(LOAD_GW_EVENT, {}) + assert result["statusCode"] == 200 + + +def test_validate_http_resolver_with_multi_query_params(): # GIVEN an APIGatewayHttpResolver with validation enabled app = APIGatewayHttpResolver(enable_validation=True) @@ -463,6 +482,25 @@ def handler(parameter1: Annotated[List[int], Query()], parameter2: str): assert any(text in result["body"] for text in ["type_error.integer"]) +def test_validate_http_resolver_without_query_params(): + # GIVEN an APIGatewayHttpResolver with validation enabled + app = APIGatewayHttpResolver(enable_validation=True) + + # WHEN a handler is defined without any query params + @app.get("/users") + def handler(): + return None + + LOAD_GW_EVENT_HTTP["rawPath"] = "/users" + LOAD_GW_EVENT_HTTP["requestContext"]["http"]["method"] = "GET" + LOAD_GW_EVENT_HTTP["requestContext"]["http"]["path"] = "/users" + LOAD_GW_EVENT_HTTP["queryStringParameters"] = None + + # THEN the handler should be invoked and return 200 + result = app(LOAD_GW_EVENT_HTTP, {}) + assert result["statusCode"] == 200 + + def test_validate_alb_resolver_with_multi_query_values(): # GIVEN an ALBResolver with validation enabled app = ALBResolver(enable_validation=True) @@ -496,7 +534,24 @@ def handler(parameter1: Annotated[List[int], Query()], parameter2: str): assert any(text in result["body"] for text in ["type_error.integer"]) -def test_validate_lambda_url_resolver_with_multi_query_values(): +def test_validate_alb_resolver_without_query_params(): + # GIVEN an ALBResolver with validation enabled + app = ALBResolver(enable_validation=True) + + # WHEN a handler is defined without any query params + @app.get("/users") + def handler(parameter1: Annotated[List[str], Query()], parameter2: str): + print(parameter2) + + LOAD_GW_EVENT_ALB["path"] = "/users" + LOAD_GW_EVENT_HTTP["multiValueQueryStringParameters"] = None + + # THEN the handler should be invoked and return 200 + result = app(LOAD_GW_EVENT_ALB, {}) + assert result["statusCode"] == 200 + + +def test_validate_lambda_url_resolver_with_multi_query_params(): # GIVEN an LambdaFunctionUrlResolver with validation enabled app = LambdaFunctionUrlResolver(enable_validation=True) @@ -514,7 +569,7 @@ def handler(parameter1: Annotated[List[str], Query()], parameter2: str): assert result["statusCode"] == 200 -def test_validate__lambda_url_resolver_with_multi_query_values_fail(): +def test_validate_lambda_url_resolver_with_multi_query_params_fail(): # GIVEN an LambdaFunctionUrlResolver with validation enabled app = LambdaFunctionUrlResolver(enable_validation=True) @@ -533,7 +588,26 @@ def handler(parameter1: Annotated[List[int], Query()], parameter2: str): assert any(text in result["body"] for text in ["type_error.integer"]) -def test_validate_vpc_lattice_resolver_with_multi_query_values(): +def test_validate_lambda_url_resolver_without_query_params(): + # GIVEN an LambdaFunctionUrlResolver with validation enabled + app = LambdaFunctionUrlResolver(enable_validation=True) + + # WHEN a handler is defined without any query params + @app.get("/users") + def handler(): + return None + + LOAD_GW_EVENT_LAMBDA_URL["rawPath"] = "/users" + LOAD_GW_EVENT_LAMBDA_URL["requestContext"]["http"]["method"] = "GET" + LOAD_GW_EVENT_LAMBDA_URL["requestContext"]["http"]["path"] = "/users" + LOAD_GW_EVENT_LAMBDA_URL["queryStringParameters"] = None + + # THEN the handler should be invoked and return 200 + result = app(LOAD_GW_EVENT_LAMBDA_URL, {}) + assert result["statusCode"] == 200 + + +def test_validate_vpc_lattice_resolver_with_multi_params_values(): # GIVEN an VPCLatticeV2Resolver with validation enabled app = VPCLatticeV2Resolver(enable_validation=True) @@ -549,7 +623,7 @@ def handler(parameter1: Annotated[List[str], Query()], parameter2: str): assert result["statusCode"] == 200 -def test_validate_vpc_lattice_resolver_with_multi_query_values_fail(): +def test_validate_vpc_lattice_resolver_with_multi_query_params_fail(): # GIVEN an VPCLatticeV2Resolver with validation enabled app = VPCLatticeV2Resolver(enable_validation=True) @@ -564,3 +638,20 @@ def handler(parameter1: Annotated[List[int], Query()], parameter2: str): result = app(LOAD_GW_EVENT_VPC_LATTICE, {}) assert result["statusCode"] == 422 assert any(text in result["body"] for text in ["type_error.integer"]) + + +def test_validate_vpc_lattice_resolver_without_query_params(): + # GIVEN an VPCLatticeV2Resolver with validation enabled + app = VPCLatticeV2Resolver(enable_validation=True) + + # WHEN a handler is defined without any query params + @app.get("/users") + def handler(): + return None + + LOAD_GW_EVENT_VPC_LATTICE["path"] = "/users" + LOAD_GW_EVENT_VPC_LATTICE["queryStringParameters"] = None + + # THEN the handler should be invoked and return 200 + result = app(LOAD_GW_EVENT_VPC_LATTICE, {}) + assert result["statusCode"] == 200 From 40d80aed872c685743278162538f8428a9f9a853 Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Tue, 23 Jan 2024 13:20:40 +0000 Subject: [PATCH 5/9] Making Pydanticv2 happy --- .../test_openapi_validation_middleware.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/functional/event_handler/test_openapi_validation_middleware.py b/tests/functional/event_handler/test_openapi_validation_middleware.py index cc6056f92b3..ea4305257d4 100644 --- a/tests/functional/event_handler/test_openapi_validation_middleware.py +++ b/tests/functional/event_handler/test_openapi_validation_middleware.py @@ -423,7 +423,7 @@ def handler(parameter1: Annotated[List[int], Query()], parameter2: str): # THEN the handler should be invoked and return 422 result = app(LOAD_GW_EVENT, {}) assert result["statusCode"] == 422 - assert any(text in result["body"] for text in ["type_error.integer"]) + assert any(text in result["body"] for text in ["type_error.integer", "int_parsing"]) def test_validate_rest_api_resolver_without_query_params(): @@ -479,7 +479,7 @@ def handler(parameter1: Annotated[List[int], Query()], parameter2: str): # THEN the handler should be invoked and return 422 result = app(LOAD_GW_EVENT_HTTP, {}) assert result["statusCode"] == 422 - assert any(text in result["body"] for text in ["type_error.integer"]) + assert any(text in result["body"] for text in ["type_error.integer", "int_parsing"]) def test_validate_http_resolver_without_query_params(): @@ -531,7 +531,7 @@ def handler(parameter1: Annotated[List[int], Query()], parameter2: str): # THEN the handler should be invoked and return 422 result = app(LOAD_GW_EVENT_ALB, {}) assert result["statusCode"] == 422 - assert any(text in result["body"] for text in ["type_error.integer"]) + assert any(text in result["body"] for text in ["type_error.integer", "int_parsing"]) def test_validate_alb_resolver_without_query_params(): @@ -585,7 +585,7 @@ def handler(parameter1: Annotated[List[int], Query()], parameter2: str): # THEN the handler should be invoked and return 422 result = app(LOAD_GW_EVENT_LAMBDA_URL, {}) assert result["statusCode"] == 422 - assert any(text in result["body"] for text in ["type_error.integer"]) + assert any(text in result["body"] for text in ["type_error.integer", "int_parsing"]) def test_validate_lambda_url_resolver_without_query_params(): @@ -637,7 +637,7 @@ def handler(parameter1: Annotated[List[int], Query()], parameter2: str): # THEN the handler should be invoked and return 422 result = app(LOAD_GW_EVENT_VPC_LATTICE, {}) assert result["statusCode"] == 422 - assert any(text in result["body"] for text in ["type_error.integer"]) + assert any(text in result["body"] for text in ["type_error.integer", "int_parsing"]) def test_validate_vpc_lattice_resolver_without_query_params(): From 112429315dd220fcd0c3a270d834c3ecf975ebb7 Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Tue, 23 Jan 2024 13:34:56 +0000 Subject: [PATCH 6/9] Adding documentation --- docs/core/event_handler/api_gateway.md | 10 ++++++ .../src/working_with_multi_query_values.py | 35 +++++++++++++++++++ 2 files changed, 45 insertions(+) create mode 100644 examples/event_handler_rest/src/working_with_multi_query_values.py diff --git a/docs/core/event_handler/api_gateway.md b/docs/core/event_handler/api_gateway.md index a34a94975bc..86b97c87e4b 100644 --- a/docs/core/event_handler/api_gateway.md +++ b/docs/core/event_handler/api_gateway.md @@ -400,6 +400,16 @@ In the following example, we use a new `Query` OpenAPI type to add [one out of m 1. `completed` is still the same query string as before, except we simply state it's an string. No `Query` or `Annotated` to validate it. +=== "working_with_multi_query_values.py" + + If you need to handle multi-value query parameters, you can create a list of the desired type. + + ```python hl_lines="23" + --8<-- "examples/event_handler_rest/src/working_with_multi_query_values.py" + ``` + + 1. `example_multi_value_param` is a list containing values from the `ExampleEnum` enumeration. + #### Validating path parameters diff --git a/examples/event_handler_rest/src/working_with_multi_query_values.py b/examples/event_handler_rest/src/working_with_multi_query_values.py new file mode 100644 index 00000000000..4ae77f96849 --- /dev/null +++ b/examples/event_handler_rest/src/working_with_multi_query_values.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +from enum import Enum + +from aws_lambda_powertools.event_handler import APIGatewayRestResolver +from aws_lambda_powertools.event_handler.openapi.params import Query +from aws_lambda_powertools.shared.types import Annotated +from aws_lambda_powertools.utilities.typing import LambdaContext + +app = APIGatewayRestResolver(enable_validation=True) + + +class ExampleEnum(Enum): + """Example of an Enum class.""" + + ONE = "value_one" + TWO = "value_two" + THREE = "value_three" + + +@app.get("/todos") +def get( + example_multi_value_param: Annotated[ + list[ExampleEnum], # (1)! + Query( + description="This is multi value query parameter.", + ), + ], +): + """Return validated multi-value param values.""" + return example_multi_value_param + + +def lambda_handler(event: dict, context: LambdaContext) -> dict: + return app.resolve(event, context) From 673b7e1f21743cfbdb3f718483c872e0f9f1580c Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Tue, 23 Jan 2024 14:41:39 +0000 Subject: [PATCH 7/9] Addressing Ruben's feedback --- .../middlewares/openapi_validation.py | 25 +++---------------- .../utilities/data_classes/alb_event.py | 4 --- .../data_classes/api_gateway_proxy_event.py | 8 ------ .../data_classes/bedrock_agent_event.py | 4 --- .../utilities/data_classes/common.py | 2 ++ .../utilities/data_classes/vpc_lattice.py | 8 ------ .../src/working_with_multi_query_values.py | 5 ++-- 7 files changed, 7 insertions(+), 49 deletions(-) diff --git a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py index 0d210947223..182dc059bea 100644 --- a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py +++ b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py @@ -353,7 +353,7 @@ def _get_embed_body( return received_body, field_alias_omitted -def _normalize_multi_query_string_with_param(query_string, params: Sequence[ModelField]): +def _normalize_multi_query_string_with_param(query_string: Dict[str, Any], params: Sequence[ModelField]): """ Extract and normalize resolved_query_string_parameters @@ -367,30 +367,11 @@ def _normalize_multi_query_string_with_param(query_string, params: Sequence[Mode Returns ------- A dictionary containing the processed multi_query_string_parameters. - - Comments - -------- - - These comments are to explain the decision to create this method - - - In the case of using LambdaFunctionUrlResolver or APIGatewayHttpResolver, multi-query strings consistently - reside in the same field, separated by commas. - - - When using a VPCLatticeV2Resolver, the Payload consistently sends query strings as arrays. To enhance - compatibility, we attempt to identify scalar types within the arrays and convert them to single elements. - - - In the case of using APIGatewayRestResolver or ALBResolver, the payload may includes both query string and - multi-query string fields. We apply a similar logic as used in VPCLatticeV2Resolver - to handle these query strings effectively. - - - VPCLatticeResolver (v1) and BedrockAgentResolver doesn't support multi-query strings - and we retain original query """ for param in filter(is_scalar_field, params): try: - # If the field is a scalar, it implies it's not a multi-query string. - # And we keep the first value for this field - - # We Attempt to retain only the first element if the parameter is a scalar field + # if the target parameter is a scalar, we keep the first value of the query string + # regardless if there are more in the payload query_string[param.name] = query_string[param.name][0] except KeyError: pass diff --git a/aws_lambda_powertools/utilities/data_classes/alb_event.py b/aws_lambda_powertools/utilities/data_classes/alb_event.py index b9f319f2b7b..688c9567efa 100644 --- a/aws_lambda_powertools/utilities/data_classes/alb_event.py +++ b/aws_lambda_powertools/utilities/data_classes/alb_event.py @@ -37,10 +37,6 @@ def multi_value_query_string_parameters(self) -> Optional[Dict[str, List[str]]]: @property def resolved_query_string_parameters(self) -> Optional[Dict[str, Any]]: - """ - This property determines the appropriate query string parameter to be used - as a trusted source for validating OpenAPI. - """ if self.multi_value_query_string_parameters: return self.multi_value_query_string_parameters diff --git a/aws_lambda_powertools/utilities/data_classes/api_gateway_proxy_event.py b/aws_lambda_powertools/utilities/data_classes/api_gateway_proxy_event.py index e82aba89ec7..9e013eac038 100644 --- a/aws_lambda_powertools/utilities/data_classes/api_gateway_proxy_event.py +++ b/aws_lambda_powertools/utilities/data_classes/api_gateway_proxy_event.py @@ -120,10 +120,6 @@ def multi_value_query_string_parameters(self) -> Optional[Dict[str, List[str]]]: @property def resolved_query_string_parameters(self) -> Optional[Dict[str, Any]]: - """ - This property determines the appropriate query string parameter to be used - as a trusted source for validating OpenAPI. - """ if self.multi_value_query_string_parameters: return self.multi_value_query_string_parameters @@ -313,10 +309,6 @@ def header_serializer(self): @property def resolved_query_string_parameters(self) -> Optional[Dict[str, Any]]: - """ - This property determines the appropriate query string parameter to be used - as a trusted source for validating OpenAPI. - """ if self.query_string_parameters is not None: query_string = { key: value.split(",") if "," in value else value for key, value in self.query_string_parameters.items() diff --git a/aws_lambda_powertools/utilities/data_classes/bedrock_agent_event.py b/aws_lambda_powertools/utilities/data_classes/bedrock_agent_event.py index 31ce4454857..d9b45242376 100644 --- a/aws_lambda_powertools/utilities/data_classes/bedrock_agent_event.py +++ b/aws_lambda_powertools/utilities/data_classes/bedrock_agent_event.py @@ -111,8 +111,4 @@ def query_string_parameters(self) -> Optional[Dict[str, str]]: @property def resolved_query_string_parameters(self) -> Optional[Dict[str, str]]: - """ - This property determines the appropriate query string parameter to be used - as a trusted source for validating OpenAPI. - """ return self.query_string_parameters diff --git a/aws_lambda_powertools/utilities/data_classes/common.py b/aws_lambda_powertools/utilities/data_classes/common.py index 1e715c730f4..96fd9d3f2dc 100644 --- a/aws_lambda_powertools/utilities/data_classes/common.py +++ b/aws_lambda_powertools/utilities/data_classes/common.py @@ -108,6 +108,8 @@ def resolved_query_string_parameters(self) -> Optional[Dict[str, str]]: """ This property determines the appropriate query string parameter to be used as a trusted source for validating OpenAPI. + + This is necessary because different resolvers use different formats to encode multi query string parameters. """ return self.query_string_parameters diff --git a/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py b/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py index 3cba6440ba3..633ce068f6e 100644 --- a/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py +++ b/aws_lambda_powertools/utilities/data_classes/vpc_lattice.py @@ -143,10 +143,6 @@ def query_string_parameters(self) -> Dict[str, str]: @property def resolved_query_string_parameters(self) -> Optional[Dict[str, str]]: - """ - This property determines the appropriate query string parameter to be used - as a trusted source for validating OpenAPI. - """ return self.query_string_parameters @@ -262,8 +258,4 @@ def query_string_parameters(self) -> Optional[Dict[str, str]]: @property def resolved_query_string_parameters(self) -> Optional[Dict[str, str]]: - """ - This property determines the appropriate query string parameter to be used - as a trusted source for validating OpenAPI. - """ return self.query_string_parameters diff --git a/examples/event_handler_rest/src/working_with_multi_query_values.py b/examples/event_handler_rest/src/working_with_multi_query_values.py index 4ae77f96849..7f6049dad46 100644 --- a/examples/event_handler_rest/src/working_with_multi_query_values.py +++ b/examples/event_handler_rest/src/working_with_multi_query_values.py @@ -1,6 +1,5 @@ -from __future__ import annotations - from enum import Enum +from typing import List from aws_lambda_powertools.event_handler import APIGatewayRestResolver from aws_lambda_powertools.event_handler.openapi.params import Query @@ -21,7 +20,7 @@ class ExampleEnum(Enum): @app.get("/todos") def get( example_multi_value_param: Annotated[ - list[ExampleEnum], # (1)! + List[ExampleEnum], # (1)! Query( description="This is multi value query parameter.", ), From ef8f4ba231f0e603b84035a64eaaaee5746d2858 Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Tue, 23 Jan 2024 14:44:40 +0000 Subject: [PATCH 8/9] Addressing Ruben's feedback --- aws_lambda_powertools/utilities/data_classes/common.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/aws_lambda_powertools/utilities/data_classes/common.py b/aws_lambda_powertools/utilities/data_classes/common.py index 96fd9d3f2dc..d2cf57d4af5 100644 --- a/aws_lambda_powertools/utilities/data_classes/common.py +++ b/aws_lambda_powertools/utilities/data_classes/common.py @@ -109,7 +109,8 @@ def resolved_query_string_parameters(self) -> Optional[Dict[str, str]]: This property determines the appropriate query string parameter to be used as a trusted source for validating OpenAPI. - This is necessary because different resolvers use different formats to encode multi query string parameters. + This is necessary because different resolvers use different formats to encode + multi query string parameters. """ return self.query_string_parameters From ee83b1bfb9f20366e25df7ff04adca04be19dd5e Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Tue, 23 Jan 2024 14:56:08 +0000 Subject: [PATCH 9/9] Mypy.... --- .../middlewares/openapi_validation.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py index 182dc059bea..e819947b147 100644 --- a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py +++ b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py @@ -353,7 +353,7 @@ def _get_embed_body( return received_body, field_alias_omitted -def _normalize_multi_query_string_with_param(query_string: Dict[str, Any], params: Sequence[ModelField]): +def _normalize_multi_query_string_with_param(query_string: Optional[Dict[str, str]], params: Sequence[ModelField]): """ Extract and normalize resolved_query_string_parameters @@ -368,11 +368,12 @@ def _normalize_multi_query_string_with_param(query_string: Dict[str, Any], param ------- A dictionary containing the processed multi_query_string_parameters. """ - for param in filter(is_scalar_field, params): - try: - # if the target parameter is a scalar, we keep the first value of the query string - # regardless if there are more in the payload - query_string[param.name] = query_string[param.name][0] - except KeyError: - pass + if query_string: + for param in filter(is_scalar_field, params): + try: + # if the target parameter is a scalar, we keep the first value of the query string + # regardless if there are more in the payload + query_string[param.name] = query_string[param.name][0] + except KeyError: + pass return query_string