From e97f6ee1e3a9168716349c6feecac72190431d75 Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Mon, 8 Jul 2024 01:14:25 +0100 Subject: [PATCH 01/17] Initial commit OpenAPI Extensions --- .../event_handler/api_gateway.py | 35 +++ .../event_handler/bedrock_agent.py | 11 + .../event_handler/openapi/compat.py | 4 +- .../event_handler/openapi/models.py | 34 ++- .../event_handler/_pydantic/conftest.py | 17 ++ .../_pydantic/test_openapi_extensions.py | 219 ++++++++++++++++++ 6 files changed, 312 insertions(+), 8 deletions(-) create mode 100644 tests/functional/event_handler/_pydantic/test_openapi_extensions.py diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 2c829789e8c..682e3617219 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -323,6 +323,7 @@ def __init__( operation_id: Optional[str] = None, include_in_schema: bool = True, security: Optional[List[Dict[str, List[str]]]] = None, + openapi_extensions: Optional[Dict[str, Any]] = None, middlewares: Optional[List[Callable[..., Response]]] = None, ): """ @@ -383,6 +384,7 @@ def __init__( self.tags = tags or [] self.include_in_schema = include_in_schema self.security = security + self.openapi_extensions = openapi_extensions self.middlewares = middlewares or [] self.operation_id = operation_id or self._generate_operation_id() @@ -534,6 +536,10 @@ def _get_openapi_path( if self.security: operation["security"] = self.security + # Add OpenAPI extensions if present + if self.openapi_extensions: + operation.update(self.openapi_extensions) + # Add the parameters to the OpenAPI operation if parameters: all_parameters = {(param["in"], param["name"]): param for param in parameters} @@ -939,6 +945,7 @@ def route( operation_id: Optional[str] = None, include_in_schema: bool = True, security: Optional[List[Dict[str, List[str]]]] = None, + openapi_extensions: Optional[Dict[str, Any]] = None, middlewares: Optional[List[Callable[..., Any]]] = None, ): raise NotImplementedError() @@ -998,6 +1005,7 @@ def get( operation_id: Optional[str] = None, include_in_schema: bool = True, security: Optional[List[Dict[str, List[str]]]] = None, + openapi_extensions: Optional[Dict[str, Any]] = None, middlewares: Optional[List[Callable[..., Any]]] = None, ): """Get route decorator with GET `method` @@ -1036,6 +1044,7 @@ def lambda_handler(event, context): operation_id, include_in_schema, security, + openapi_extensions, middlewares, ) @@ -1053,6 +1062,7 @@ def post( operation_id: Optional[str] = None, include_in_schema: bool = True, security: Optional[List[Dict[str, List[str]]]] = None, + openapi_extensions: Optional[Dict[str, Any]] = None, middlewares: Optional[List[Callable[..., Any]]] = None, ): """Post route decorator with POST `method` @@ -1092,6 +1102,7 @@ def lambda_handler(event, context): operation_id, include_in_schema, security, + openapi_extensions, middlewares, ) @@ -1109,6 +1120,7 @@ def put( operation_id: Optional[str] = None, include_in_schema: bool = True, security: Optional[List[Dict[str, List[str]]]] = None, + openapi_extensions: Optional[Dict[str, Any]] = None, middlewares: Optional[List[Callable[..., Any]]] = None, ): """Put route decorator with PUT `method` @@ -1148,6 +1160,7 @@ def lambda_handler(event, context): operation_id, include_in_schema, security, + openapi_extensions, middlewares, ) @@ -1165,6 +1178,7 @@ def delete( operation_id: Optional[str] = None, include_in_schema: bool = True, security: Optional[List[Dict[str, List[str]]]] = None, + openapi_extensions: Optional[Dict[str, Any]] = None, middlewares: Optional[List[Callable[..., Any]]] = None, ): """Delete route decorator with DELETE `method` @@ -1203,6 +1217,7 @@ def lambda_handler(event, context): operation_id, include_in_schema, security, + openapi_extensions, middlewares, ) @@ -1220,6 +1235,7 @@ def patch( operation_id: Optional[str] = None, include_in_schema: bool = True, security: Optional[List[Dict[str, List[str]]]] = None, + openapi_extensions: Optional[Dict[str, Any]] = None, middlewares: Optional[List[Callable]] = None, ): """Patch route decorator with PATCH `method` @@ -1261,6 +1277,7 @@ def lambda_handler(event, context): operation_id, include_in_schema, security, + openapi_extensions, middlewares, ) @@ -1278,6 +1295,7 @@ def head( operation_id: Optional[str] = None, include_in_schema: bool = True, security: Optional[List[Dict[str, List[str]]]] = None, + openapi_extensions: Optional[Dict[str, Any]] = None, middlewares: Optional[List[Callable]] = None, ): """Head route decorator with HEAD `method` @@ -1318,6 +1336,7 @@ def lambda_handler(event, context): operation_id, include_in_schema, security, + openapi_extensions, middlewares, ) @@ -1541,6 +1560,7 @@ def get_openapi_schema( license_info: Optional["License"] = None, security_schemes: Optional[Dict[str, "SecurityScheme"]] = None, security: Optional[List[Dict[str, List[str]]]] = None, + openapi_extensions: Optional[Dict[str, Any]] = None, ) -> "OpenAPI": """ Returns the OpenAPI schema as a pydantic model. @@ -1603,11 +1623,15 @@ def get_openapi_schema( info.update({field: value for field, value in optional_fields.items() if value}) + if not openapi_extensions: + openapi_extensions = {} + output: Dict[str, Any] = { "openapi": openapi_version, "info": info, "servers": self._get_openapi_servers(servers), "security": self._get_openapi_security(security, security_schemes), + **openapi_extensions, } components: Dict[str, Dict[str, Any]] = {} @@ -1726,6 +1750,7 @@ def get_openapi_json_schema( license_info: Optional["License"] = None, security_schemes: Optional[Dict[str, "SecurityScheme"]] = None, security: Optional[List[Dict[str, List[str]]]] = None, + openapi_extensions: Optional[Dict[str, Any]] = None, ) -> str: """ Returns the OpenAPI schema as a JSON serializable dict @@ -1778,6 +1803,7 @@ def get_openapi_json_schema( license_info=license_info, security_schemes=security_schemes, security=security, + openapi_extensions=openapi_extensions, ), by_alias=True, exclude_none=True, @@ -1805,6 +1831,7 @@ def enable_swagger( security: Optional[List[Dict[str, List[str]]]] = None, oauth2_config: Optional["OAuth2Config"] = None, persist_authorization: bool = False, + openapi_extensions: Optional[Dict[str, Any]] = None, ): """ Returns the OpenAPI schema as a JSON serializable dict @@ -1896,6 +1923,7 @@ def swagger_handler(): license_info=license_info, security_schemes=security_schemes, security=security, + openapi_extensions=openapi_extensions, ) # The .replace(' Callable[[Callable[..., Any]], Callable[..., Any]]: + + openapi_extensions = None security = None return super(BedrockAgentResolver, self).get( @@ -117,6 +119,7 @@ def get( # type: ignore[override] operation_id, include_in_schema, security, + openapi_extensions, middlewares, ) @@ -137,6 +140,7 @@ def post( # type: ignore[override] include_in_schema: bool = True, middlewares: Optional[List[Callable[..., Any]]] = None, ): + openapi_extensions = None security = None return super().post( @@ -152,6 +156,7 @@ def post( # type: ignore[override] operation_id, include_in_schema, security, + openapi_extensions, middlewares, ) @@ -172,6 +177,7 @@ def put( # type: ignore[override] include_in_schema: bool = True, middlewares: Optional[List[Callable[..., Any]]] = None, ): + openapi_extensions = None security = None return super().put( @@ -187,6 +193,7 @@ def put( # type: ignore[override] operation_id, include_in_schema, security, + openapi_extensions, middlewares, ) @@ -207,6 +214,7 @@ def patch( # type: ignore[override] include_in_schema: bool = True, middlewares: Optional[List[Callable]] = None, ): + openapi_extensions = None security = None return super().patch( @@ -222,6 +230,7 @@ def patch( # type: ignore[override] operation_id, include_in_schema, security, + openapi_extensions, middlewares, ) @@ -242,6 +251,7 @@ def delete( # type: ignore[override] include_in_schema: bool = True, middlewares: Optional[List[Callable[..., Any]]] = None, ): + openapi_extensions = None security = None return super().delete( @@ -257,6 +267,7 @@ def delete( # type: ignore[override] operation_id, include_in_schema, security, + openapi_extensions, middlewares, ) diff --git a/aws_lambda_powertools/event_handler/openapi/compat.py b/aws_lambda_powertools/event_handler/openapi/compat.py index 060886605ec..11bdd17f3b1 100644 --- a/aws_lambda_powertools/event_handler/openapi/compat.py +++ b/aws_lambda_powertools/event_handler/openapi/compat.py @@ -41,7 +41,7 @@ RequestErrorModel: Type[BaseModel] = create_model("Request") if PYDANTIC_V2: # pragma: no cover # false positive; dropping in v3 - from pydantic import TypeAdapter, ValidationError + from pydantic import TypeAdapter, ValidationError, model_serializer as parser_openapi_extension from pydantic._internal._typing_extra import eval_type_lenient from pydantic.fields import FieldInfo from pydantic._internal._utils import lenient_issubclass @@ -217,7 +217,7 @@ def model_json(model: BaseModel, **kwargs: Any) -> Any: return model.model_dump_json(**kwargs) else: - from pydantic import BaseModel, ValidationError + from pydantic import BaseModel, ValidationError, root_validator as parser_openapi_extension from pydantic.fields import ( ModelField, Required, diff --git a/aws_lambda_powertools/event_handler/openapi/models.py b/aws_lambda_powertools/event_handler/openapi/models.py index 04345ddaad7..a04a1af0c8b 100644 --- a/aws_lambda_powertools/event_handler/openapi/models.py +++ b/aws_lambda_powertools/event_handler/openapi/models.py @@ -3,7 +3,7 @@ from pydantic import AnyUrl, BaseModel, Field -from aws_lambda_powertools.event_handler.openapi.compat import model_rebuild +from aws_lambda_powertools.event_handler.openapi.compat import model_rebuild, parser_openapi_extension from aws_lambda_powertools.event_handler.openapi.pydantic_loader import PYDANTIC_V2 from aws_lambda_powertools.shared.types import Annotated, Literal @@ -13,6 +13,28 @@ """ +class OpenapiExtensions(BaseModel): + """OpenAPI extensions, see https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.1.0.md#specification-extensions""" + + openapi_extensions: Optional[Dict[str, Any]] = None + + if PYDANTIC_V2: + + @parser_openapi_extension() + def serialize(self): + if self.openapi_extensions: + return self.openapi_extensions + + else: + + @parser_openapi_extension(pre=False, allow_reuse=True) + def check_json(cls, values): + if values.get("openapi_extensions"): + values.update(values["openapi_extensions"]) + del values["openapi_extensions"] + return values + + # https://swagger.io/specification/#contact-object class Contact(BaseModel): name: Optional[str] = None @@ -77,7 +99,7 @@ class Config: # https://swagger.io/specification/#server-object -class Server(BaseModel): +class Server(OpenapiExtensions): url: Union[AnyUrl, str] description: Optional[str] = None variables: Optional[Dict[str, ServerVariable]] = None @@ -379,7 +401,7 @@ class Config: # https://swagger.io/specification/#operation-object -class Operation(BaseModel): +class Operation(OpenapiExtensions): tags: Optional[List[str]] = None summary: Optional[str] = None description: Optional[str] = None @@ -436,7 +458,7 @@ class SecuritySchemeType(Enum): openIdConnect = "openIdConnect" -class SecurityBase(BaseModel): +class SecurityBase(OpenapiExtensions): type_: SecuritySchemeType = Field(alias="type") description: Optional[str] = None @@ -534,7 +556,7 @@ class OpenIdConnect(SecurityBase): # https://swagger.io/specification/#components-object -class Components(BaseModel): +class Components(OpenapiExtensions): schemas: Optional[Dict[str, Union[Schema, Reference]]] = None responses: Optional[Dict[str, Union[Response, Reference]]] = None parameters: Optional[Dict[str, Union[Parameter, Reference]]] = None @@ -557,7 +579,7 @@ class Config: # https://swagger.io/specification/#openapi-object -class OpenAPI(BaseModel): +class OpenAPI(OpenapiExtensions): openapi: str info: Info jsonSchemaDialect: Optional[str] = None diff --git a/tests/functional/event_handler/_pydantic/conftest.py b/tests/functional/event_handler/_pydantic/conftest.py index a099ae4cea5..1d38e2e26b1 100644 --- a/tests/functional/event_handler/_pydantic/conftest.py +++ b/tests/functional/event_handler/_pydantic/conftest.py @@ -120,3 +120,20 @@ def openapi31_schema(): @pytest.fixture def security_scheme(): return {"apiKey": APIKey(name="X-API-KEY", description="API Key", in_=APIKeyIn.header)} + + +@pytest.fixture +def openapi_extension_integration_detail(): + return { + "type": "aws", + "httpMethod": "POST", + "uri": "arn:aws:apigateway:us-east-1:lambda:path/2015-03-31/functions/..integration/invocations", + "responses": {"default": {"statusCode": "200"}}, + "passthroughBehavior": "when_no_match", + "contentHandling": "CONVERT_TO_TEXT", + } + + +@pytest.fixture +def openapi_extension_validator_detail(): + return "Validate body, query string parameters, and headers" diff --git a/tests/functional/event_handler/_pydantic/test_openapi_extensions.py b/tests/functional/event_handler/_pydantic/test_openapi_extensions.py new file mode 100644 index 00000000000..cd1744672c6 --- /dev/null +++ b/tests/functional/event_handler/_pydantic/test_openapi_extensions.py @@ -0,0 +1,219 @@ +import json + +from aws_lambda_powertools.event_handler.api_gateway import APIGatewayRestResolver, Router +from aws_lambda_powertools.event_handler.openapi.models import ( + APIKey, + APIKeyIn, + OAuth2, + OAuthFlowImplicit, + OAuthFlows, + Server, +) + + +def test_openapi_extension_root_level(): + app = APIGatewayRestResolver() + + cors_config = { + "maxAge": 0, + "allowCredentials": False, + } + + schema = json.loads( + app.get_openapi_json_schema( + openapi_extensions={"x-amazon-apigateway-cors": cors_config}, + ), + ) + + assert "x-amazon-apigateway-cors" in schema + assert schema["x-amazon-apigateway-cors"] == cors_config + + +def test_openapi_extension_server_level(): + app = APIGatewayRestResolver() + + endpoint_config = { + "disableExecuteApiEndpoint": True, + "vpcEndpointIds": ["vpce-0df8e77555fca0000"], + } + + server_config = { + "url": "https://example.org/", + "description": "Example website", + } + + schema = json.loads( + app.get_openapi_json_schema( + title="Hello API", + version="1.0.0", + servers=[ + Server( + **server_config, + openapi_extensions={ + "x-amazon-apigateway-endpoint-configuration": endpoint_config, + }, + ), + ], + ), + ) + + assert "x-amazon-apigateway-endpoint-configuration" in schema["servers"][0] + assert schema["servers"][0]["x-amazon-apigateway-endpoint-configuration"] == endpoint_config + + +def test_openapi_extension_security_scheme_level_with_api_key(): + app = APIGatewayRestResolver() + + authorizer_config = { + "authorizerUri": "arn:aws:apigateway:us-east-1:...:function:authorizer/invocations", + "authorizerResultTtlInSeconds": 300, + "type": "token", + } + + api_key_config = { + "name": "X-API-KEY", + "description": "API Key", + "in_": APIKeyIn.header, + } + + schema = json.loads( + app.get_openapi_json_schema( + security_schemes={ + "apiKey": APIKey( + **api_key_config, + openapi_extensions={ + "x-amazon-apigateway-authtype": "custom", + "x-amazon-apigateway-authorizer": authorizer_config, + }, + ), + }, + ), + ) + + assert "x-amazon-apigateway-authtype" in schema["components"]["securitySchemes"]["apiKey"] + assert schema["components"]["securitySchemes"]["apiKey"]["x-amazon-apigateway-authtype"] == "custom" + assert schema["components"]["securitySchemes"]["apiKey"]["x-amazon-apigateway-authorizer"] == authorizer_config + + +def test_openapi_extension_security_scheme_level_with_oauth2(): + app = APIGatewayRestResolver() + + authorizer_config = { + "identitySource": "$request.header.Authorization", + "jwtConfiguration": { + "audience": ["test"], + "issuer": "https://cognito-idp.us-east-1.amazonaws.com/us-east-1_xxxxx/", + }, + "type": "jwt", + } + + oauth2_config = { + "flows": OAuthFlows( + implicit=OAuthFlowImplicit( + authorizationUrl="https://example.com/oauth2/authorize", + ), + ), + } + + schema = json.loads( + app.get_openapi_json_schema( + security_schemes={ + "oauth2": OAuth2( + **oauth2_config, + openapi_extensions={ + "x-amazon-apigateway-authorizer": authorizer_config, + }, + ), + }, + ), + ) + + assert "x-amazon-apigateway-authorizer" in schema["components"]["securitySchemes"]["oauth2"] + assert schema["components"]["securitySchemes"]["oauth2"]["x-amazon-apigateway-authorizer"] == authorizer_config + + +def test_openapi_extension_operation_level(openapi_extension_integration_detail): + app = APIGatewayRestResolver() + + @app.get("/test", openapi_extensions={"x-amazon-apigateway-integration": openapi_extension_integration_detail}) + def lambda_handler(): + pass + + schema = json.loads(app.get_openapi_json_schema()) + + assert "x-amazon-apigateway-integration" in schema["paths"]["/test"]["get"] + assert schema["paths"]["/test"]["get"]["x-amazon-apigateway-integration"] == openapi_extension_integration_detail + + +def test_openapi_extension_operation_level_multiple_paths( + openapi_extension_integration_detail, + openapi_extension_validator_detail, +): + app = APIGatewayRestResolver() + + @app.get("/test", openapi_extensions={"x-amazon-apigateway-integration": openapi_extension_integration_detail}) + def lambda_handler_get(): + pass + + @app.post("/test", openapi_extensions={"x-amazon-apigateway-request-validator": openapi_extension_validator_detail}) + def lambda_handler_post(): + pass + + schema = json.loads(app.get_openapi_json_schema()) + + assert "x-amazon-apigateway-integration" in schema["paths"]["/test"]["get"] + assert schema["paths"]["/test"]["get"]["x-amazon-apigateway-integration"] == openapi_extension_integration_detail + + assert "x-amazon-apigateway-integration" not in schema["paths"]["/test"]["post"] + assert "x-amazon-apigateway-request-validator" in schema["paths"]["/test"]["post"] + assert ( + schema["paths"]["/test"]["post"]["x-amazon-apigateway-request-validator"] == openapi_extension_validator_detail + ) + + +def test_openapi_extension_operation_level_with_router(openapi_extension_integration_detail): + app = APIGatewayRestResolver() + router = Router() + + @router.get("/test", openapi_extensions={"x-amazon-apigateway-integration": openapi_extension_integration_detail}) + def lambda_handler(): + pass + + app.include_router(router) + + schema = json.loads(app.get_openapi_json_schema()) + + assert "x-amazon-apigateway-integration" in schema["paths"]["/test"]["get"] + assert schema["paths"]["/test"]["get"]["x-amazon-apigateway-integration"] == openapi_extension_integration_detail + + +def test_openapi_extension_operation_level_multiple_paths_with_router( + openapi_extension_integration_detail, + openapi_extension_validator_detail, +): + app = APIGatewayRestResolver() + router = Router() + + @router.get("/test", openapi_extensions={"x-amazon-apigateway-integration": openapi_extension_integration_detail}) + def lambda_handler_get(): + pass + + @router.post( + "/test", + openapi_extensions={"x-amazon-apigateway-request-validator": openapi_extension_validator_detail}, + ) + def lambda_handler_post(): + pass + + app.include_router(router) + + schema = json.loads(app.get_openapi_json_schema()) + + assert "x-amazon-apigateway-integration" in schema["paths"]["/test"]["get"] + assert schema["paths"]["/test"]["get"]["x-amazon-apigateway-integration"] == openapi_extension_integration_detail + + assert "x-amazon-apigateway-integration" not in schema["paths"]["/test"]["post"] + assert "x-amazon-apigateway-request-validator" in schema["paths"]["/test"]["post"] + assert ( + schema["paths"]["/test"]["post"]["x-amazon-apigateway-request-validator"] == openapi_extension_validator_detail + ) From 3fceac7e8996c243e0115a6e9503d311738e3ed1 Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Mon, 8 Jul 2024 01:16:55 +0100 Subject: [PATCH 02/17] Polishing the PR with best practicies - Comments --- aws_lambda_powertools/event_handler/openapi/models.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/aws_lambda_powertools/event_handler/openapi/models.py b/aws_lambda_powertools/event_handler/openapi/models.py index a04a1af0c8b..3e33063cf61 100644 --- a/aws_lambda_powertools/event_handler/openapi/models.py +++ b/aws_lambda_powertools/event_handler/openapi/models.py @@ -22,11 +22,15 @@ class OpenapiExtensions(BaseModel): @parser_openapi_extension() def serialize(self): + # If the 'openapi_extensions' field is not None, return it if self.openapi_extensions: return self.openapi_extensions else: + # If the 'openapi_extensions' field is present in the 'values' dictionary, + # update the 'values' dictionary with the contents of 'openapi_extensions', + # and then remove the 'openapi_extensions' field from the 'values' dictionary @parser_openapi_extension(pre=False, allow_reuse=True) def check_json(cls, values): if values.get("openapi_extensions"): From 84d035a9b815b71d9f259b34b3897e219a12f85c Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Mon, 8 Jul 2024 01:25:56 +0100 Subject: [PATCH 03/17] Polishing the PR with best practicies - Tests --- .../_pydantic/test_openapi_extensions.py | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/tests/functional/event_handler/_pydantic/test_openapi_extensions.py b/tests/functional/event_handler/_pydantic/test_openapi_extensions.py index cd1744672c6..19e6d8c71e3 100644 --- a/tests/functional/event_handler/_pydantic/test_openapi_extensions.py +++ b/tests/functional/event_handler/_pydantic/test_openapi_extensions.py @@ -12,6 +12,7 @@ def test_openapi_extension_root_level(): + # GIVEN an APIGatewayRestResolver instance app = APIGatewayRestResolver() cors_config = { @@ -19,17 +20,20 @@ def test_openapi_extension_root_level(): "allowCredentials": False, } + # WHEN we get the OpenAPI JSON schema with CORS extension in the Root Level schema = json.loads( app.get_openapi_json_schema( openapi_extensions={"x-amazon-apigateway-cors": cors_config}, ), ) + # THEN the OpenAPI schema must contain the "x-amazon-apigateway-cors" extension assert "x-amazon-apigateway-cors" in schema assert schema["x-amazon-apigateway-cors"] == cors_config def test_openapi_extension_server_level(): + # GIVEN an APIGatewayRestResolver instance app = APIGatewayRestResolver() endpoint_config = { @@ -42,6 +46,7 @@ def test_openapi_extension_server_level(): "description": "Example website", } + # WHEN we get the OpenAPI JSON schema with a server-level openapi extension schema = json.loads( app.get_openapi_json_schema( title="Hello API", @@ -57,11 +62,13 @@ def test_openapi_extension_server_level(): ), ) + # THEN the OpenAPI schema must contain the "x-amazon-apigateway-endpoint-configuration" at the server level assert "x-amazon-apigateway-endpoint-configuration" in schema["servers"][0] assert schema["servers"][0]["x-amazon-apigateway-endpoint-configuration"] == endpoint_config def test_openapi_extension_security_scheme_level_with_api_key(): + # GIVEN an APIGatewayRestResolver instance app = APIGatewayRestResolver() authorizer_config = { @@ -76,6 +83,7 @@ def test_openapi_extension_security_scheme_level_with_api_key(): "in_": APIKeyIn.header, } + # WHEN we get the OpenAPI JSON schema with a security scheme-level extension for a custom auth schema = json.loads( app.get_openapi_json_schema( security_schemes={ @@ -90,12 +98,14 @@ def test_openapi_extension_security_scheme_level_with_api_key(): ), ) + # THEN the OpenAPI schema must contain the "x-amazon-apigateway-authtype" extension at the security scheme level assert "x-amazon-apigateway-authtype" in schema["components"]["securitySchemes"]["apiKey"] assert schema["components"]["securitySchemes"]["apiKey"]["x-amazon-apigateway-authtype"] == "custom" assert schema["components"]["securitySchemes"]["apiKey"]["x-amazon-apigateway-authorizer"] == authorizer_config def test_openapi_extension_security_scheme_level_with_oauth2(): + # GIVEN an APIGatewayRestResolver instance app = APIGatewayRestResolver() authorizer_config = { @@ -115,6 +125,7 @@ def test_openapi_extension_security_scheme_level_with_oauth2(): ), } + # WHEN we get the OpenAPI JSON schema with a security scheme-level extension for a custom auth schema = json.loads( app.get_openapi_json_schema( security_schemes={ @@ -128,19 +139,24 @@ def test_openapi_extension_security_scheme_level_with_oauth2(): ), ) + # THEN the OpenAPI schema must contain the "x-amazon-apigateway-authorizer" extension at the security scheme level assert "x-amazon-apigateway-authorizer" in schema["components"]["securitySchemes"]["oauth2"] assert schema["components"]["securitySchemes"]["oauth2"]["x-amazon-apigateway-authorizer"] == authorizer_config def test_openapi_extension_operation_level(openapi_extension_integration_detail): + # GIVEN an APIGatewayRestResolver instance app = APIGatewayRestResolver() + # WHEN we define an integration extension at operation level + # AND get the schema @app.get("/test", openapi_extensions={"x-amazon-apigateway-integration": openapi_extension_integration_detail}) def lambda_handler(): pass schema = json.loads(app.get_openapi_json_schema()) + # THEN the OpenAPI schema must contain the "x-amazon-apigateway-integration" extension at the operation level assert "x-amazon-apigateway-integration" in schema["paths"]["/test"]["get"] assert schema["paths"]["/test"]["get"]["x-amazon-apigateway-integration"] == openapi_extension_integration_detail @@ -149,8 +165,11 @@ def test_openapi_extension_operation_level_multiple_paths( openapi_extension_integration_detail, openapi_extension_validator_detail, ): + # GIVEN an APIGatewayRestResolver instance app = APIGatewayRestResolver() + # WHEN we define multiple routes with integration extension at operation level + # AND get the schema @app.get("/test", openapi_extensions={"x-amazon-apigateway-integration": openapi_extension_integration_detail}) def lambda_handler_get(): pass @@ -161,6 +180,7 @@ def lambda_handler_post(): schema = json.loads(app.get_openapi_json_schema()) + # THEN each route must contain only your extension assert "x-amazon-apigateway-integration" in schema["paths"]["/test"]["get"] assert schema["paths"]["/test"]["get"]["x-amazon-apigateway-integration"] == openapi_extension_integration_detail @@ -172,9 +192,12 @@ def lambda_handler_post(): def test_openapi_extension_operation_level_with_router(openapi_extension_integration_detail): + # GIVEN an APIGatewayRestResolver and Router instance app = APIGatewayRestResolver() router = Router() + # WHEN we define an integration extension at operation level using Router + # AND get the schema @router.get("/test", openapi_extensions={"x-amazon-apigateway-integration": openapi_extension_integration_detail}) def lambda_handler(): pass @@ -183,6 +206,7 @@ def lambda_handler(): schema = json.loads(app.get_openapi_json_schema()) + # THEN the OpenAPI schema must contain the "x-amazon-apigateway-integration" extension at the operation level assert "x-amazon-apigateway-integration" in schema["paths"]["/test"]["get"] assert schema["paths"]["/test"]["get"]["x-amazon-apigateway-integration"] == openapi_extension_integration_detail @@ -191,9 +215,12 @@ def test_openapi_extension_operation_level_multiple_paths_with_router( openapi_extension_integration_detail, openapi_extension_validator_detail, ): + # GIVEN an APIGatewayRestResolver and Router instance app = APIGatewayRestResolver() router = Router() + # WHEN we define multiple routes using extensions at operation level using Router + # AND get the schema @router.get("/test", openapi_extensions={"x-amazon-apigateway-integration": openapi_extension_integration_detail}) def lambda_handler_get(): pass @@ -209,6 +236,7 @@ def lambda_handler_post(): schema = json.loads(app.get_openapi_json_schema()) + # THEN each route must contain only your extension assert "x-amazon-apigateway-integration" in schema["paths"]["/test"]["get"] assert schema["paths"]["/test"]["get"]["x-amazon-apigateway-integration"] == openapi_extension_integration_detail From 8a111d2438e4cef3b239adb06febc2f39c14a8f7 Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Mon, 8 Jul 2024 08:50:50 +0100 Subject: [PATCH 04/17] Polishing the PR with best practicies - make pydanticv2 happy --- .../event_handler/openapi/models.py | 92 ++++++++++++------- 1 file changed, 61 insertions(+), 31 deletions(-) diff --git a/aws_lambda_powertools/event_handler/openapi/models.py b/aws_lambda_powertools/event_handler/openapi/models.py index 3e33063cf61..1f710306639 100644 --- a/aws_lambda_powertools/event_handler/openapi/models.py +++ b/aws_lambda_powertools/event_handler/openapi/models.py @@ -13,32 +13,6 @@ """ -class OpenapiExtensions(BaseModel): - """OpenAPI extensions, see https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.1.0.md#specification-extensions""" - - openapi_extensions: Optional[Dict[str, Any]] = None - - if PYDANTIC_V2: - - @parser_openapi_extension() - def serialize(self): - # If the 'openapi_extensions' field is not None, return it - if self.openapi_extensions: - return self.openapi_extensions - - else: - - # If the 'openapi_extensions' field is present in the 'values' dictionary, - # update the 'values' dictionary with the contents of 'openapi_extensions', - # and then remove the 'openapi_extensions' field from the 'values' dictionary - @parser_openapi_extension(pre=False, allow_reuse=True) - def check_json(cls, values): - if values.get("openapi_extensions"): - values.update(values["openapi_extensions"]) - del values["openapi_extensions"] - return values - - # https://swagger.io/specification/#contact-object class Contact(BaseModel): name: Optional[str] = None @@ -103,16 +77,33 @@ class Config: # https://swagger.io/specification/#server-object -class Server(OpenapiExtensions): +class Server(BaseModel): url: Union[AnyUrl, str] description: Optional[str] = None variables: Optional[Dict[str, ServerVariable]] = None + openapi_extensions: Optional[Dict[str, Any]] = None if PYDANTIC_V2: model_config = {"extra": "allow"} + @parser_openapi_extension() + def serialize(self): + # If the 'openapi_extensions' field is not None, return it + if self.openapi_extensions: + return self.openapi_extensions + else: + # If the 'openapi_extensions' field is present in the 'values' dictionary, + # update the 'values' dictionary with the contents of 'openapi_extensions', + # and then remove the 'openapi_extensions' field from the 'values' dictionary + @parser_openapi_extension(pre=False, allow_reuse=True) + def check_json(cls, values): + if values.get("openapi_extensions"): + values.update(values["openapi_extensions"]) + del values["openapi_extensions"] + return values + class Config: extra = "allow" @@ -405,7 +396,7 @@ class Config: # https://swagger.io/specification/#operation-object -class Operation(OpenapiExtensions): +class Operation(BaseModel): tags: Optional[List[str]] = None summary: Optional[str] = None description: Optional[str] = None @@ -419,12 +410,23 @@ class Operation(OpenapiExtensions): deprecated: Optional[bool] = None security: Optional[List[Dict[str, List[str]]]] = None servers: Optional[List[Server]] = None + openapi_extensions: Optional[Dict[str, Any]] = None if PYDANTIC_V2: model_config = {"extra": "allow"} else: + # If the 'openapi_extensions' field is present in the 'values' dictionary, + # update the 'values' dictionary with the contents of 'openapi_extensions', + # and then remove the 'openapi_extensions' field from the 'values' dictionary + @parser_openapi_extension(pre=False, allow_reuse=True) + def check_json(cls, values): + if values.get("openapi_extensions"): + values.update(values["openapi_extensions"]) + del values["openapi_extensions"] + return values + class Config: extra = "allow" @@ -462,15 +464,32 @@ class SecuritySchemeType(Enum): openIdConnect = "openIdConnect" -class SecurityBase(OpenapiExtensions): +class SecurityBase(BaseModel): type_: SecuritySchemeType = Field(alias="type") description: Optional[str] = None + openapi_extensions: Optional[Dict[str, Any]] = None if PYDANTIC_V2: model_config = {"extra": "allow", "populate_by_name": True} + @parser_openapi_extension() + def serialize(self): + # If the 'openapi_extensions' field is not None, return it + if self.openapi_extensions: + return self.openapi_extensions + else: + # If the 'openapi_extensions' field is present in the 'values' dictionary, + # update the 'values' dictionary with the contents of 'openapi_extensions', + # and then remove the 'openapi_extensions' field from the 'values' dictionary + @parser_openapi_extension(pre=False, allow_reuse=True) + def check_json(cls, values): + if values.get("openapi_extensions"): + values.update(values["openapi_extensions"]) + del values["openapi_extensions"] + return values + class Config: extra = "allow" allow_population_by_field_name = True @@ -560,7 +579,7 @@ class OpenIdConnect(SecurityBase): # https://swagger.io/specification/#components-object -class Components(OpenapiExtensions): +class Components(BaseModel): schemas: Optional[Dict[str, Union[Schema, Reference]]] = None responses: Optional[Dict[str, Union[Response, Reference]]] = None parameters: Optional[Dict[str, Union[Parameter, Reference]]] = None @@ -583,7 +602,7 @@ class Config: # https://swagger.io/specification/#openapi-object -class OpenAPI(OpenapiExtensions): +class OpenAPI(BaseModel): openapi: str info: Info jsonSchemaDialect: Optional[str] = None @@ -595,12 +614,23 @@ class OpenAPI(OpenapiExtensions): security: Optional[List[Dict[str, List[str]]]] = None tags: Optional[List[Tag]] = None externalDocs: Optional[ExternalDocumentation] = None + openapi_extensions: Optional[Dict[str, Any]] = None if PYDANTIC_V2: model_config = {"extra": "allow"} else: + # If the 'openapi_extensions' field is present in the 'values' dictionary, + # update the 'values' dictionary with the contents of 'openapi_extensions', + # and then remove the 'openapi_extensions' field from the 'values' dictionary + @parser_openapi_extension(pre=False, allow_reuse=True) + def check_json(cls, values): + if values.get("openapi_extensions"): + values.update(values["openapi_extensions"]) + del values["openapi_extensions"] + return values + class Config: extra = "allow" From 270516fa642e9ed7ad7d34dac3b4bb3ad6a66c59 Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Mon, 8 Jul 2024 18:14:42 +0100 Subject: [PATCH 05/17] Polishing the PR with best practicies - using model_validator to be more specific --- .../event_handler/api_gateway.py | 2 +- .../event_handler/openapi/compat.py | 2 +- .../event_handler/openapi/models.py | 96 +++++++------------ .../_pydantic/test_openapi_extensions.py | 19 ++++ 4 files changed, 57 insertions(+), 62 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 682e3617219..95e4575695a 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -1623,7 +1623,7 @@ def get_openapi_schema( info.update({field: value for field, value in optional_fields.items() if value}) - if not openapi_extensions: + if not isinstance(openapi_extensions, Dict): openapi_extensions = {} output: Dict[str, Any] = { diff --git a/aws_lambda_powertools/event_handler/openapi/compat.py b/aws_lambda_powertools/event_handler/openapi/compat.py index 11bdd17f3b1..df0f8aabab7 100644 --- a/aws_lambda_powertools/event_handler/openapi/compat.py +++ b/aws_lambda_powertools/event_handler/openapi/compat.py @@ -41,7 +41,7 @@ RequestErrorModel: Type[BaseModel] = create_model("Request") if PYDANTIC_V2: # pragma: no cover # false positive; dropping in v3 - from pydantic import TypeAdapter, ValidationError, model_serializer as parser_openapi_extension + from pydantic import TypeAdapter, ValidationError, model_validator as parser_openapi_extension from pydantic._internal._typing_extra import eval_type_lenient from pydantic.fields import FieldInfo from pydantic._internal._utils import lenient_issubclass diff --git a/aws_lambda_powertools/event_handler/openapi/models.py b/aws_lambda_powertools/event_handler/openapi/models.py index 1f710306639..7fbc9516e6c 100644 --- a/aws_lambda_powertools/event_handler/openapi/models.py +++ b/aws_lambda_powertools/event_handler/openapi/models.py @@ -13,6 +13,38 @@ """ +class OpenapiExtensions(BaseModel): + openapi_extensions: Optional[Dict[str, Any]] = None + + # This rule is valid for Pydantic v1 and v2 + # If the 'openapi_extensions' field is present in the 'values' dictionary, + # update the 'values' dictionary with the contents of 'openapi_extensions', + # and then remove the 'openapi_extensions' field from the 'values' dictionary + + if PYDANTIC_V2: + model_config = {"extra": "allow"} + + @parser_openapi_extension(mode="before") + def serialize_openapi_extension(self): + if isinstance(self, dict) and self.get("openapi_extensions"): + self.update(self.get("openapi_extensions")) + self.pop("openapi_extensions", None) + + return self + + else: + + @parser_openapi_extension(pre=False, allow_reuse=True) + def serialize_openapi_extension(cls, values): + if values.get("openapi_extensions"): + values.update(values["openapi_extensions"]) + del values["openapi_extensions"] + return values + + class Config: + extra = "allow" + + # https://swagger.io/specification/#contact-object class Contact(BaseModel): name: Optional[str] = None @@ -77,33 +109,16 @@ class Config: # https://swagger.io/specification/#server-object -class Server(BaseModel): +class Server(OpenapiExtensions): url: Union[AnyUrl, str] description: Optional[str] = None variables: Optional[Dict[str, ServerVariable]] = None - openapi_extensions: Optional[Dict[str, Any]] = None if PYDANTIC_V2: model_config = {"extra": "allow"} - @parser_openapi_extension() - def serialize(self): - # If the 'openapi_extensions' field is not None, return it - if self.openapi_extensions: - return self.openapi_extensions - else: - # If the 'openapi_extensions' field is present in the 'values' dictionary, - # update the 'values' dictionary with the contents of 'openapi_extensions', - # and then remove the 'openapi_extensions' field from the 'values' dictionary - @parser_openapi_extension(pre=False, allow_reuse=True) - def check_json(cls, values): - if values.get("openapi_extensions"): - values.update(values["openapi_extensions"]) - del values["openapi_extensions"] - return values - class Config: extra = "allow" @@ -396,7 +411,7 @@ class Config: # https://swagger.io/specification/#operation-object -class Operation(BaseModel): +class Operation(OpenapiExtensions): tags: Optional[List[str]] = None summary: Optional[str] = None description: Optional[str] = None @@ -410,23 +425,12 @@ class Operation(BaseModel): deprecated: Optional[bool] = None security: Optional[List[Dict[str, List[str]]]] = None servers: Optional[List[Server]] = None - openapi_extensions: Optional[Dict[str, Any]] = None if PYDANTIC_V2: model_config = {"extra": "allow"} else: - # If the 'openapi_extensions' field is present in the 'values' dictionary, - # update the 'values' dictionary with the contents of 'openapi_extensions', - # and then remove the 'openapi_extensions' field from the 'values' dictionary - @parser_openapi_extension(pre=False, allow_reuse=True) - def check_json(cls, values): - if values.get("openapi_extensions"): - values.update(values["openapi_extensions"]) - del values["openapi_extensions"] - return values - class Config: extra = "allow" @@ -464,32 +468,15 @@ class SecuritySchemeType(Enum): openIdConnect = "openIdConnect" -class SecurityBase(BaseModel): +class SecurityBase(OpenapiExtensions): type_: SecuritySchemeType = Field(alias="type") description: Optional[str] = None - openapi_extensions: Optional[Dict[str, Any]] = None if PYDANTIC_V2: model_config = {"extra": "allow", "populate_by_name": True} - @parser_openapi_extension() - def serialize(self): - # If the 'openapi_extensions' field is not None, return it - if self.openapi_extensions: - return self.openapi_extensions - else: - # If the 'openapi_extensions' field is present in the 'values' dictionary, - # update the 'values' dictionary with the contents of 'openapi_extensions', - # and then remove the 'openapi_extensions' field from the 'values' dictionary - @parser_openapi_extension(pre=False, allow_reuse=True) - def check_json(cls, values): - if values.get("openapi_extensions"): - values.update(values["openapi_extensions"]) - del values["openapi_extensions"] - return values - class Config: extra = "allow" allow_population_by_field_name = True @@ -602,7 +589,7 @@ class Config: # https://swagger.io/specification/#openapi-object -class OpenAPI(BaseModel): +class OpenAPI(OpenapiExtensions): openapi: str info: Info jsonSchemaDialect: Optional[str] = None @@ -614,23 +601,12 @@ class OpenAPI(BaseModel): security: Optional[List[Dict[str, List[str]]]] = None tags: Optional[List[Tag]] = None externalDocs: Optional[ExternalDocumentation] = None - openapi_extensions: Optional[Dict[str, Any]] = None if PYDANTIC_V2: model_config = {"extra": "allow"} else: - # If the 'openapi_extensions' field is present in the 'values' dictionary, - # update the 'values' dictionary with the contents of 'openapi_extensions', - # and then remove the 'openapi_extensions' field from the 'values' dictionary - @parser_openapi_extension(pre=False, allow_reuse=True) - def check_json(cls, values): - if values.get("openapi_extensions"): - values.update(values["openapi_extensions"]) - del values["openapi_extensions"] - return values - class Config: extra = "allow" diff --git a/tests/functional/event_handler/_pydantic/test_openapi_extensions.py b/tests/functional/event_handler/_pydantic/test_openapi_extensions.py index 19e6d8c71e3..2f0552ffc4c 100644 --- a/tests/functional/event_handler/_pydantic/test_openapi_extensions.py +++ b/tests/functional/event_handler/_pydantic/test_openapi_extensions.py @@ -65,6 +65,8 @@ def test_openapi_extension_server_level(): # THEN the OpenAPI schema must contain the "x-amazon-apigateway-endpoint-configuration" at the server level assert "x-amazon-apigateway-endpoint-configuration" in schema["servers"][0] assert schema["servers"][0]["x-amazon-apigateway-endpoint-configuration"] == endpoint_config + assert schema["servers"][0]["url"] == server_config["url"] + assert schema["servers"][0]["description"] == server_config["description"] def test_openapi_extension_security_scheme_level_with_api_key(): @@ -102,6 +104,9 @@ def test_openapi_extension_security_scheme_level_with_api_key(): assert "x-amazon-apigateway-authtype" in schema["components"]["securitySchemes"]["apiKey"] assert schema["components"]["securitySchemes"]["apiKey"]["x-amazon-apigateway-authtype"] == "custom" assert schema["components"]["securitySchemes"]["apiKey"]["x-amazon-apigateway-authorizer"] == authorizer_config + assert schema["components"]["securitySchemes"]["apiKey"]["name"] == api_key_config["name"] + assert schema["components"]["securitySchemes"]["apiKey"]["description"] == api_key_config["description"] + assert schema["components"]["securitySchemes"]["apiKey"]["in"] == "header" def test_openapi_extension_security_scheme_level_with_oauth2(): @@ -142,6 +147,19 @@ def test_openapi_extension_security_scheme_level_with_oauth2(): # THEN the OpenAPI schema must contain the "x-amazon-apigateway-authorizer" extension at the security scheme level assert "x-amazon-apigateway-authorizer" in schema["components"]["securitySchemes"]["oauth2"] assert schema["components"]["securitySchemes"]["oauth2"]["x-amazon-apigateway-authorizer"] == authorizer_config + assert ( + schema["components"]["securitySchemes"]["oauth2"]["x-amazon-apigateway-authorizer"]["identitySource"] + == "$request.header.Authorization" + ) + assert schema["components"]["securitySchemes"]["oauth2"]["x-amazon-apigateway-authorizer"]["jwtConfiguration"][ + "audience" + ] == ["test"] + assert ( + schema["components"]["securitySchemes"]["oauth2"]["x-amazon-apigateway-authorizer"]["jwtConfiguration"][ + "issuer" + ] + == "https://cognito-idp.us-east-1.amazonaws.com/us-east-1_xxxxx/" + ) def test_openapi_extension_operation_level(openapi_extension_integration_detail): @@ -159,6 +177,7 @@ def lambda_handler(): # THEN the OpenAPI schema must contain the "x-amazon-apigateway-integration" extension at the operation level assert "x-amazon-apigateway-integration" in schema["paths"]["/test"]["get"] assert schema["paths"]["/test"]["get"]["x-amazon-apigateway-integration"] == openapi_extension_integration_detail + assert schema["paths"]["/test"]["get"]["operationId"] == "lambda_handler_test_get" def test_openapi_extension_operation_level_multiple_paths( From 771c44ebc542dc1e18f49abc896a16bbb1c2441d Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Mon, 8 Jul 2024 18:26:28 +0100 Subject: [PATCH 06/17] Temporary mypy disabling --- aws_lambda_powertools/event_handler/openapi/models.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/aws_lambda_powertools/event_handler/openapi/models.py b/aws_lambda_powertools/event_handler/openapi/models.py index 7fbc9516e6c..87c14848d7f 100644 --- a/aws_lambda_powertools/event_handler/openapi/models.py +++ b/aws_lambda_powertools/event_handler/openapi/models.py @@ -1,3 +1,5 @@ +# mypy: ignore-errors + from enum import Enum from typing import Any, Dict, List, Optional, Set, Union @@ -22,10 +24,11 @@ class OpenapiExtensions(BaseModel): # and then remove the 'openapi_extensions' field from the 'values' dictionary if PYDANTIC_V2: + model_config = {"extra": "allow"} @parser_openapi_extension(mode="before") - def serialize_openapi_extension(self): + def serialize_openapi_extension_v2(self): if isinstance(self, dict) and self.get("openapi_extensions"): self.update(self.get("openapi_extensions")) self.pop("openapi_extensions", None) @@ -35,7 +38,7 @@ def serialize_openapi_extension(self): else: @parser_openapi_extension(pre=False, allow_reuse=True) - def serialize_openapi_extension(cls, values): + def serialize_openapi_extension_v1(cls, values): if values.get("openapi_extensions"): values.update(values["openapi_extensions"]) del values["openapi_extensions"] From 8195bda88c23aa7eef80d24300f540f3c99f5c61 Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Mon, 8 Jul 2024 19:30:34 +0100 Subject: [PATCH 07/17] Make mypy happy? --- aws_lambda_powertools/event_handler/openapi/models.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/aws_lambda_powertools/event_handler/openapi/models.py b/aws_lambda_powertools/event_handler/openapi/models.py index 87c14848d7f..b5a0a26fd5d 100644 --- a/aws_lambda_powertools/event_handler/openapi/models.py +++ b/aws_lambda_powertools/event_handler/openapi/models.py @@ -1,5 +1,3 @@ -# mypy: ignore-errors - from enum import Enum from typing import Any, Dict, List, Optional, Set, Union From 657fed3a8b8691a8493dda198e6360995dba08fb Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Mon, 8 Jul 2024 19:37:10 +0100 Subject: [PATCH 08/17] Make mypy happy? --- aws_lambda_powertools/event_handler/openapi/models.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/aws_lambda_powertools/event_handler/openapi/models.py b/aws_lambda_powertools/event_handler/openapi/models.py index b5a0a26fd5d..48c4c2e8194 100644 --- a/aws_lambda_powertools/event_handler/openapi/models.py +++ b/aws_lambda_powertools/event_handler/openapi/models.py @@ -14,6 +14,13 @@ class OpenapiExtensions(BaseModel): + """ + This class serves as a Pydantic proxy model to add OpenAPI extensions. + + OpenAPI extensions are arbitrary fields, so we remove openapi_extensions when dumping + and add only the provided value in the schema. + """ + openapi_extensions: Optional[Dict[str, Any]] = None # This rule is valid for Pydantic v1 and v2 @@ -474,7 +481,7 @@ class SecurityBase(OpenapiExtensions): description: Optional[str] = None if PYDANTIC_V2: - model_config = {"extra": "allow", "populate_by_name": True} + model_config = {"extra": "allow", "populate_by_name": True} # type: ignore else: From 06c8e235ca6a7079b7fd64540fa7484e2adcd412 Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Mon, 8 Jul 2024 22:18:11 +0100 Subject: [PATCH 09/17] Polishing the PR with best practicies - adding e2e tests --- .../event_handler/handlers/openapi_handler.py | 19 +++++++++++++ tests/e2e/event_handler/infrastructure.py | 9 ++++--- tests/e2e/event_handler/test_openapi.py | 27 +++++++++++++++++++ 3 files changed, 52 insertions(+), 3 deletions(-) create mode 100644 tests/e2e/event_handler/handlers/openapi_handler.py create mode 100644 tests/e2e/event_handler/test_openapi.py diff --git a/tests/e2e/event_handler/handlers/openapi_handler.py b/tests/e2e/event_handler/handlers/openapi_handler.py new file mode 100644 index 00000000000..13cfb69f016 --- /dev/null +++ b/tests/e2e/event_handler/handlers/openapi_handler.py @@ -0,0 +1,19 @@ +from aws_lambda_powertools.event_handler import ( + APIGatewayRestResolver, +) + +app = APIGatewayRestResolver(enable_validation=True) + + +@app.get("/openapi_schema") +def openapi_schema(): + return app.get_openapi_json_schema( + title="Powertools e2e API", + version="1.0.0", + description="This is a sample Powertools e2e API", + openapi_extensions={"x-amazon-apigateway-gateway-responses": {"DEFAULT_4XX"}}, + ) + + +def lambda_handler(event, context): + return app.resolve(event, context) diff --git a/tests/e2e/event_handler/infrastructure.py b/tests/e2e/event_handler/infrastructure.py index 5fd78896a34..b607e32caf8 100644 --- a/tests/e2e/event_handler/infrastructure.py +++ b/tests/e2e/event_handler/infrastructure.py @@ -18,7 +18,7 @@ def create_resources(self): functions = self.create_lambda_functions() self._create_alb(function=[functions["AlbHandler"], functions["AlbHandlerWithBodyNone"]]) - self._create_api_gateway_rest(function=functions["ApiGatewayRestHandler"]) + self._create_api_gateway_rest(function=[functions["ApiGatewayRestHandler"], functions["OpenapiHandler"]]) self._create_api_gateway_http(function=functions["ApiGatewayHttpHandler"]) self._create_lambda_function_url(function=functions["LambdaFunctionUrlHandler"]) @@ -76,7 +76,7 @@ def _create_api_gateway_http(self, function: Function): CfnOutput(self.stack, "APIGatewayHTTPUrl", value=(apigw.url or "")) - def _create_api_gateway_rest(self, function: Function): + def _create_api_gateway_rest(self, function: List[Function]): apigw = apigwv1.RestApi( self.stack, "APIGatewayRest", @@ -87,7 +87,10 @@ def _create_api_gateway_rest(self, function: Function): ) todos = apigw.root.add_resource("todos") - todos.add_method("POST", apigwv1.LambdaIntegration(function, proxy=True)) + todos.add_method("POST", apigwv1.LambdaIntegration(function[0], proxy=True)) + + openapi_schema = apigw.root.add_resource("openapi_schema") + openapi_schema.add_method("GET", apigwv1.LambdaIntegration(function[1], proxy=True)) CfnOutput(self.stack, "APIGatewayRestUrl", value=apigw.url) diff --git a/tests/e2e/event_handler/test_openapi.py b/tests/e2e/event_handler/test_openapi.py new file mode 100644 index 00000000000..d69c3b142b2 --- /dev/null +++ b/tests/e2e/event_handler/test_openapi.py @@ -0,0 +1,27 @@ +import pytest +from requests import Request + +from tests.e2e.utils import data_fetcher + + +@pytest.fixture +def apigw_rest_endpoint(infrastructure: dict) -> str: + return infrastructure.get("APIGatewayRestUrl", "") + + +@pytest.mark.xdist_group(name="event_handler") +def test_get_openapi_schema(apigw_rest_endpoint): + # GIVEN + url = f"{apigw_rest_endpoint}openapi_schema" + + # WHEN + response = data_fetcher.get_http_response( + Request( + method="GET", + url=url, + ), + ) + + assert "Powertools e2e API" in response.text + assert "x-amazon-apigateway-gateway-responses" in response.text + assert response.status_code == 200 From aa0e0d39ef7368f8e09ad90889eb01acc927a20c Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Tue, 9 Jul 2024 19:16:30 +0100 Subject: [PATCH 10/17] Adding docstring --- aws_lambda_powertools/event_handler/api_gateway.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 95e4575695a..8e84a74ef60 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -361,6 +361,8 @@ def __init__( Whether or not to include this route in the OpenAPI schema security: List[Dict[str, List[str]]], optional The OpenAPI security for this route + openapi_extensions: Dict[str, Any], optional + Additional OpenAPI extensions as a dictionary. middlewares: Optional[List[Callable[..., Response]]] The list of route middlewares to be called in order. """ @@ -1591,6 +1593,8 @@ def get_openapi_schema( A declaration of the security schemes available to be used in the specification. security: List[Dict[str, List[str]]], optional A declaration of which security mechanisms are applied globally across the API. + openapi_extensions: Dict[str, Any], optional + Additional OpenAPI extensions as a dictionary. Returns ------- @@ -1781,6 +1785,8 @@ def get_openapi_json_schema( A declaration of the security schemes available to be used in the specification. security: List[Dict[str, List[str]]], optional A declaration of which security mechanisms are applied globally across the API. + openapi_extensions: Dict[str, Any], optional + Additional OpenAPI extensions as a dictionary. Returns ------- @@ -1874,6 +1880,8 @@ def enable_swagger( The OAuth2 configuration for the Swagger UI. persist_authorization: bool, optional Whether to persist authorization data on browser close/refresh. + openapi_extensions: Dict[str, Any], optional + Additional OpenAPI extensions as a dictionary. """ from aws_lambda_powertools.event_handler.openapi.compat import model_json from aws_lambda_powertools.event_handler.openapi.models import Server From c0a098b772e4fe64a5d7eecc2b26f5003bb58917 Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Tue, 9 Jul 2024 19:45:39 +0100 Subject: [PATCH 11/17] Adding documentation --- docs/core/event_handler/api_gateway.md | 16 +++++++++ .../src/working_with_openapi_extensions.py | 33 +++++++++++++++++++ 2 files changed, 49 insertions(+) create mode 100644 examples/event_handler_rest/src/working_with_openapi_extensions.py diff --git a/docs/core/event_handler/api_gateway.md b/docs/core/event_handler/api_gateway.md index b3c046e243c..d5fcbd7619e 100644 --- a/docs/core/event_handler/api_gateway.md +++ b/docs/core/event_handler/api_gateway.md @@ -1115,6 +1115,22 @@ OpenAPI 3 lets you describe APIs protected using the following security schemes: --8<-- "examples/event_handler_rest/src/swagger_with_oauth2.py" ``` +#### OpenAPI extensions + +For a better experience when working with Lambda and Amazon API Gateway, customers can define extensions using the `openapi_extensions` parameter. We support defining OpenAPI extensions at the following levels of the OpenAPI JSON Schema: Root, Servers, Operation, and Security Schemes. + +???+ warning + We do not support the `x-amazon-apigateway-any-method` and `x-amazon-apigateway-integrations` extensions. + +```python hl_lines="9 15 25 28" title="Adding OpenAPI extensions" +--8<-- "examples/event_handler_rest/src/working_with_openapi_extensions.py" +``` + +1. Server level +2. Operation level +3. Security scheme level +4. Root level + ### Custom serializer You can instruct event handler to use a custom serializer to best suit your needs, for example take into account Enums when serializing. diff --git a/examples/event_handler_rest/src/working_with_openapi_extensions.py b/examples/event_handler_rest/src/working_with_openapi_extensions.py new file mode 100644 index 00000000000..03489c6f7b8 --- /dev/null +++ b/examples/event_handler_rest/src/working_with_openapi_extensions.py @@ -0,0 +1,33 @@ +from aws_lambda_powertools.event_handler import APIGatewayRestResolver +from aws_lambda_powertools.event_handler.openapi.models import APIKey, APIKeyIn, Server + +app = APIGatewayRestResolver(enable_validation=True) + +servers = Server( + url="http://example.com", + description="Example server", + openapi_extensions={"x-amazon-apigateway-endpoint-configuration": {"vpcEndpoint": "myendpointid"}}, # (1)! +) + + +@app.get( + "/hello", + openapi_extensions={"x-amazon-apigateway-integration": {"type": "aws", "uri": "my_lambda_arn"}}, # (2)! +) +def hello(): + return app.get_openapi_json_schema( + servers=[servers], + security_schemes={ + "apikey": APIKey( + name="X-API-KEY", + description="API KeY", + in_=APIKeyIn.header, + openapi_extensions={"x-amazon-apigateway-authorizer": "custom"}, # (3)! + ), + }, + openapi_extensions={"x-amazon-apigateway-gateway-responses": {"DEFAULT_4XX"}}, # (4)! + ) + + +def lambda_handler(event, context): + return app.resolve(event, context) From 83ecc370b58ab7d71ad49dc6ae96d3aea4d4e04e Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Wed, 10 Jul 2024 16:32:36 +0100 Subject: [PATCH 12/17] Addressing Simon's feedback --- aws_lambda_powertools/event_handler/openapi/models.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/aws_lambda_powertools/event_handler/openapi/models.py b/aws_lambda_powertools/event_handler/openapi/models.py index 48c4c2e8194..806bc46595f 100644 --- a/aws_lambda_powertools/event_handler/openapi/models.py +++ b/aws_lambda_powertools/event_handler/openapi/models.py @@ -13,7 +13,7 @@ """ -class OpenapiExtensions(BaseModel): +class OpenAPIExtensions(BaseModel): """ This class serves as a Pydantic proxy model to add OpenAPI extensions. @@ -117,7 +117,7 @@ class Config: # https://swagger.io/specification/#server-object -class Server(OpenapiExtensions): +class Server(OpenAPIExtensions): url: Union[AnyUrl, str] description: Optional[str] = None variables: Optional[Dict[str, ServerVariable]] = None @@ -419,7 +419,7 @@ class Config: # https://swagger.io/specification/#operation-object -class Operation(OpenapiExtensions): +class Operation(OpenAPIExtensions): tags: Optional[List[str]] = None summary: Optional[str] = None description: Optional[str] = None @@ -476,7 +476,7 @@ class SecuritySchemeType(Enum): openIdConnect = "openIdConnect" -class SecurityBase(OpenapiExtensions): +class SecurityBase(OpenAPIExtensions): type_: SecuritySchemeType = Field(alias="type") description: Optional[str] = None @@ -597,7 +597,7 @@ class Config: # https://swagger.io/specification/#openapi-object -class OpenAPI(OpenapiExtensions): +class OpenAPI(OpenAPIExtensions): openapi: str info: Info jsonSchemaDialect: Optional[str] = None From 56497f94dec9c00fdb3fe3cf719dd6a428727aa1 Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Wed, 10 Jul 2024 17:04:44 +0100 Subject: [PATCH 13/17] Addressing Simon's feedback --- tests/unit/event_handler/__init__.py | 0 .../unit/event_handler/_pydantic/__init__.py | 0 .../unit/event_handler/_pydantic/conftest.py | 18 ++++++++++ .../test_openapi_models_pydantic_v1.py | 35 +++++++++++++++++++ .../test_openapi_models_pydantic_v2.py | 35 +++++++++++++++++++ 5 files changed, 88 insertions(+) create mode 100644 tests/unit/event_handler/__init__.py create mode 100644 tests/unit/event_handler/_pydantic/__init__.py create mode 100644 tests/unit/event_handler/_pydantic/conftest.py create mode 100644 tests/unit/event_handler/_pydantic/test_openapi_models_pydantic_v1.py create mode 100644 tests/unit/event_handler/_pydantic/test_openapi_models_pydantic_v2.py diff --git a/tests/unit/event_handler/__init__.py b/tests/unit/event_handler/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/unit/event_handler/_pydantic/__init__.py b/tests/unit/event_handler/_pydantic/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/unit/event_handler/_pydantic/conftest.py b/tests/unit/event_handler/_pydantic/conftest.py new file mode 100644 index 00000000000..d50d4e483ef --- /dev/null +++ b/tests/unit/event_handler/_pydantic/conftest.py @@ -0,0 +1,18 @@ +import pytest +from pydantic import __version__ + + +@pytest.fixture(scope="session") +def pydanticv1_only(): + + version = __version__.split(".") + if version[0] != "1": + pytest.skip("pydanticv1 test only") + + +@pytest.fixture(scope="session") +def pydanticv2_only(): + + version = __version__.split(".") + if version[0] != "2": + pytest.skip("pydanticv2 test only") diff --git a/tests/unit/event_handler/_pydantic/test_openapi_models_pydantic_v1.py b/tests/unit/event_handler/_pydantic/test_openapi_models_pydantic_v1.py new file mode 100644 index 00000000000..730213697c7 --- /dev/null +++ b/tests/unit/event_handler/_pydantic/test_openapi_models_pydantic_v1.py @@ -0,0 +1,35 @@ +import pytest + +from aws_lambda_powertools.event_handler.openapi.models import OpenAPIExtensions + + +@pytest.mark.usefixtures("pydanticv1_only") +def test_openapi_extensions_with_dict(): + # GIVEN we create an OpenAPIExtensions object with a dict + extensions = OpenAPIExtensions(openapi_extensions={"x-amazon-apigateway": {"foo": "bar"}}) + + # THEN we get a dict with the extension + assert extensions.dict(exclude_none=True) == {"x-amazon-apigateway": {"foo": "bar"}} + + +@pytest.mark.usefixtures("pydanticv1_only") +def test_openapi_extensions_with_proxy_models(): + + # GIVEN we create an models using OpenAPIExtensions as a "Proxy" Model + class MyModelFoo(OpenAPIExtensions): + foo: str + + class MyModelBar(OpenAPIExtensions): + bar: str + foo: MyModelFoo + + value_to_serialize = MyModelBar( + bar="bar", + foo=MyModelFoo(foo="foo"), + openapi_extensions={"x-amazon-apigateway": {"foo": "bar"}}, + ) + + value_to_return = value_to_serialize.dict(exclude_none=True) + + # THEN we get a dict with the value serialized + assert value_to_return == {"bar": "bar", "foo": {"foo": "foo"}, "x-amazon-apigateway": {"foo": "bar"}} diff --git a/tests/unit/event_handler/_pydantic/test_openapi_models_pydantic_v2.py b/tests/unit/event_handler/_pydantic/test_openapi_models_pydantic_v2.py new file mode 100644 index 00000000000..7058c49699c --- /dev/null +++ b/tests/unit/event_handler/_pydantic/test_openapi_models_pydantic_v2.py @@ -0,0 +1,35 @@ +import pytest + +from aws_lambda_powertools.event_handler.openapi.models import OpenAPIExtensions + + +@pytest.mark.usefixtures("pydanticv2_only") +def test_openapi_extensions_with_dict(): + # GIVEN we create an OpenAPIExtensions object with a dict + extensions = OpenAPIExtensions(openapi_extensions={"x-amazon-apigateway": {"foo": "bar"}}) + + # THEN we get a dict with the extension + assert extensions.model_dump(exclude_none=True) == {"x-amazon-apigateway": {"foo": "bar"}} + + +@pytest.mark.usefixtures("pydanticv2_only") +def test_openapi_extensions_with_proxy_models(): + + # GIVEN we create an models using OpenAPIExtensions as a "Proxy" Model + class MyModelFoo(OpenAPIExtensions): + foo: str + + class MyModelBar(OpenAPIExtensions): + bar: str + foo: MyModelFoo + + value_to_serialize = MyModelBar( + bar="bar", + foo=MyModelFoo(foo="foo"), + openapi_extensions={"x-amazon-apigateway": {"foo": "bar"}}, + ) + + value_to_return = value_to_serialize.model_dump(exclude_none=True) + + # THEN we get a dict with the value serialized + assert value_to_return == {"bar": "bar", "foo": {"foo": "foo"}, "x-amazon-apigateway": {"foo": "bar"}} From c9178775f4f3b925fefbb4df2bd409d79a715c0d Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Wed, 10 Jul 2024 17:05:41 +0100 Subject: [PATCH 14/17] Addressing Simon's feedback --- noxfile.py | 1 + 1 file changed, 1 insertion(+) diff --git a/noxfile.py b/noxfile.py index 7023f45a2b7..4e53bcb816a 100644 --- a/noxfile.py +++ b/noxfile.py @@ -163,6 +163,7 @@ def test_with_pydantic_required_package(session: nox.Session, pydantic: str): f"{PREFIX_TESTS_FUNCTIONAL}/event_handler/_pydantic/", f"{PREFIX_TESTS_FUNCTIONAL}/batch/_pydantic/", f"{PREFIX_TESTS_UNIT}/parser/_pydantic/", + f"{PREFIX_TESTS_UNIT}/event_handler/_pydantic/", ], ) From b8390b034571046b125b1c5b73932a568d20d640 Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Wed, 10 Jul 2024 22:40:25 +0100 Subject: [PATCH 15/17] Adding more tests --- .../event_handler/openapi/models.py | 23 ++++++++++++++++--- .../test_openapi_models_pydantic_v1.py | 9 ++++++++ .../test_openapi_models_pydantic_v2.py | 9 ++++++++ 3 files changed, 38 insertions(+), 3 deletions(-) diff --git a/aws_lambda_powertools/event_handler/openapi/models.py b/aws_lambda_powertools/event_handler/openapi/models.py index 806bc46595f..1f32752f102 100644 --- a/aws_lambda_powertools/event_handler/openapi/models.py +++ b/aws_lambda_powertools/event_handler/openapi/models.py @@ -4,6 +4,7 @@ from pydantic import AnyUrl, BaseModel, Field from aws_lambda_powertools.event_handler.openapi.compat import model_rebuild, parser_openapi_extension +from aws_lambda_powertools.event_handler.openapi.exceptions import SchemaValidationError from aws_lambda_powertools.event_handler.openapi.pydantic_loader import PYDANTIC_V2 from aws_lambda_powertools.shared.types import Annotated, Literal @@ -25,6 +26,7 @@ class OpenAPIExtensions(BaseModel): # This rule is valid for Pydantic v1 and v2 # If the 'openapi_extensions' field is present in the 'values' dictionary, + # And if the extension starts with x- # update the 'values' dictionary with the contents of 'openapi_extensions', # and then remove the 'openapi_extensions' field from the 'values' dictionary @@ -34,8 +36,15 @@ class OpenAPIExtensions(BaseModel): @parser_openapi_extension(mode="before") def serialize_openapi_extension_v2(self): - if isinstance(self, dict) and self.get("openapi_extensions"): - self.update(self.get("openapi_extensions")) + openapi_extension_value = self.get("openapi_extensions") + + if isinstance(self, dict) and openapi_extension_value: + + for extension_key in openapi_extension_value: + if not str(extension_key).startswith("x-"): + raise SchemaValidationError("An OpenAPI extension key must start with x-") + + self.update(openapi_extension_value) self.pop("openapi_extensions", None) return self @@ -44,9 +53,17 @@ def serialize_openapi_extension_v2(self): @parser_openapi_extension(pre=False, allow_reuse=True) def serialize_openapi_extension_v1(cls, values): - if values.get("openapi_extensions"): + openapi_extension_value = values.get("openapi_extensions") + + if openapi_extension_value: + + for extension_key in openapi_extension_value: + if not str(extension_key).startswith("x-"): + raise SchemaValidationError("An OpenAPI extension key must start with x-") + values.update(values["openapi_extensions"]) del values["openapi_extensions"] + return values class Config: diff --git a/tests/unit/event_handler/_pydantic/test_openapi_models_pydantic_v1.py b/tests/unit/event_handler/_pydantic/test_openapi_models_pydantic_v1.py index 730213697c7..49ee4920f3d 100644 --- a/tests/unit/event_handler/_pydantic/test_openapi_models_pydantic_v1.py +++ b/tests/unit/event_handler/_pydantic/test_openapi_models_pydantic_v1.py @@ -1,5 +1,6 @@ import pytest +from aws_lambda_powertools.event_handler.openapi.exceptions import SchemaValidationError from aws_lambda_powertools.event_handler.openapi.models import OpenAPIExtensions @@ -12,6 +13,14 @@ def test_openapi_extensions_with_dict(): assert extensions.dict(exclude_none=True) == {"x-amazon-apigateway": {"foo": "bar"}} +@pytest.mark.usefixtures("pydanticv1_only") +def test_openapi_extensions_with_invalid_key(): + # GIVEN we create an OpenAPIExtensions object with an invalid value + with pytest.raises(SchemaValidationError): + # THEN must raise an exception + OpenAPIExtensions(openapi_extensions={"amazon-apigateway-invalid": {"foo": "bar"}}) + + @pytest.mark.usefixtures("pydanticv1_only") def test_openapi_extensions_with_proxy_models(): diff --git a/tests/unit/event_handler/_pydantic/test_openapi_models_pydantic_v2.py b/tests/unit/event_handler/_pydantic/test_openapi_models_pydantic_v2.py index 7058c49699c..5191b4f7520 100644 --- a/tests/unit/event_handler/_pydantic/test_openapi_models_pydantic_v2.py +++ b/tests/unit/event_handler/_pydantic/test_openapi_models_pydantic_v2.py @@ -1,5 +1,6 @@ import pytest +from aws_lambda_powertools.event_handler.openapi.exceptions import SchemaValidationError from aws_lambda_powertools.event_handler.openapi.models import OpenAPIExtensions @@ -12,6 +13,14 @@ def test_openapi_extensions_with_dict(): assert extensions.model_dump(exclude_none=True) == {"x-amazon-apigateway": {"foo": "bar"}} +@pytest.mark.usefixtures("pydanticv2_only") +def test_openapi_extensions_with_invalid_key(): + # GIVEN we create an OpenAPIExtensions object with an invalid value + with pytest.raises(SchemaValidationError): + # THEN must raise an exception + OpenAPIExtensions(openapi_extensions={"amazon-apigateway-invalid": {"foo": "bar"}}) + + @pytest.mark.usefixtures("pydanticv2_only") def test_openapi_extensions_with_proxy_models(): From 319b3d47c66a4a1d4ae954fccde8afab572f48b2 Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Wed, 10 Jul 2024 22:46:01 +0100 Subject: [PATCH 16/17] Adding more tests --- aws_lambda_powertools/event_handler/openapi/models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aws_lambda_powertools/event_handler/openapi/models.py b/aws_lambda_powertools/event_handler/openapi/models.py index 1f32752f102..530960a7961 100644 --- a/aws_lambda_powertools/event_handler/openapi/models.py +++ b/aws_lambda_powertools/event_handler/openapi/models.py @@ -36,9 +36,9 @@ class OpenAPIExtensions(BaseModel): @parser_openapi_extension(mode="before") def serialize_openapi_extension_v2(self): - openapi_extension_value = self.get("openapi_extensions") + if isinstance(self, dict) and self.get("openapi_extensions"): - if isinstance(self, dict) and openapi_extension_value: + openapi_extension_value = self.get("openapi_extensions") for extension_key in openapi_extension_value: if not str(extension_key).startswith("x-"): From 8d3f9c30a5e7cb1d2df74432cbc33a345a4d3d42 Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Wed, 10 Jul 2024 22:49:02 +0100 Subject: [PATCH 17/17] Adding more tests --- aws_lambda_powertools/event_handler/openapi/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aws_lambda_powertools/event_handler/openapi/models.py b/aws_lambda_powertools/event_handler/openapi/models.py index 530960a7961..cac3266d254 100644 --- a/aws_lambda_powertools/event_handler/openapi/models.py +++ b/aws_lambda_powertools/event_handler/openapi/models.py @@ -26,7 +26,7 @@ class OpenAPIExtensions(BaseModel): # This rule is valid for Pydantic v1 and v2 # If the 'openapi_extensions' field is present in the 'values' dictionary, - # And if the extension starts with x- + # And if the extension starts with x- (must respect the RFC) # update the 'values' dictionary with the contents of 'openapi_extensions', # and then remove the 'openapi_extensions' field from the 'values' dictionary