diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 2c829789e8c..8e84a74ef60 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -323,6 +323,7 @@ def __init__( operation_id: Optional[str] = None, include_in_schema: bool = True, security: Optional[List[Dict[str, List[str]]]] = None, + openapi_extensions: Optional[Dict[str, Any]] = None, middlewares: Optional[List[Callable[..., Response]]] = None, ): """ @@ -360,6 +361,8 @@ def __init__( Whether or not to include this route in the OpenAPI schema security: List[Dict[str, List[str]]], optional The OpenAPI security for this route + openapi_extensions: Dict[str, Any], optional + Additional OpenAPI extensions as a dictionary. middlewares: Optional[List[Callable[..., Response]]] The list of route middlewares to be called in order. """ @@ -383,6 +386,7 @@ def __init__( self.tags = tags or [] self.include_in_schema = include_in_schema self.security = security + self.openapi_extensions = openapi_extensions self.middlewares = middlewares or [] self.operation_id = operation_id or self._generate_operation_id() @@ -534,6 +538,10 @@ def _get_openapi_path( if self.security: operation["security"] = self.security + # Add OpenAPI extensions if present + if self.openapi_extensions: + operation.update(self.openapi_extensions) + # Add the parameters to the OpenAPI operation if parameters: all_parameters = {(param["in"], param["name"]): param for param in parameters} @@ -939,6 +947,7 @@ def route( operation_id: Optional[str] = None, include_in_schema: bool = True, security: Optional[List[Dict[str, List[str]]]] = None, + openapi_extensions: Optional[Dict[str, Any]] = None, middlewares: Optional[List[Callable[..., Any]]] = None, ): raise NotImplementedError() @@ -998,6 +1007,7 @@ def get( operation_id: Optional[str] = None, include_in_schema: bool = True, security: Optional[List[Dict[str, List[str]]]] = None, + openapi_extensions: Optional[Dict[str, Any]] = None, middlewares: Optional[List[Callable[..., Any]]] = None, ): """Get route decorator with GET `method` @@ -1036,6 +1046,7 @@ def lambda_handler(event, context): operation_id, include_in_schema, security, + openapi_extensions, middlewares, ) @@ -1053,6 +1064,7 @@ def post( operation_id: Optional[str] = None, include_in_schema: bool = True, security: Optional[List[Dict[str, List[str]]]] = None, + openapi_extensions: Optional[Dict[str, Any]] = None, middlewares: Optional[List[Callable[..., Any]]] = None, ): """Post route decorator with POST `method` @@ -1092,6 +1104,7 @@ def lambda_handler(event, context): operation_id, include_in_schema, security, + openapi_extensions, middlewares, ) @@ -1109,6 +1122,7 @@ def put( operation_id: Optional[str] = None, include_in_schema: bool = True, security: Optional[List[Dict[str, List[str]]]] = None, + openapi_extensions: Optional[Dict[str, Any]] = None, middlewares: Optional[List[Callable[..., Any]]] = None, ): """Put route decorator with PUT `method` @@ -1148,6 +1162,7 @@ def lambda_handler(event, context): operation_id, include_in_schema, security, + openapi_extensions, middlewares, ) @@ -1165,6 +1180,7 @@ def delete( operation_id: Optional[str] = None, include_in_schema: bool = True, security: Optional[List[Dict[str, List[str]]]] = None, + openapi_extensions: Optional[Dict[str, Any]] = None, middlewares: Optional[List[Callable[..., Any]]] = None, ): """Delete route decorator with DELETE `method` @@ -1203,6 +1219,7 @@ def lambda_handler(event, context): operation_id, include_in_schema, security, + openapi_extensions, middlewares, ) @@ -1220,6 +1237,7 @@ def patch( operation_id: Optional[str] = None, include_in_schema: bool = True, security: Optional[List[Dict[str, List[str]]]] = None, + openapi_extensions: Optional[Dict[str, Any]] = None, middlewares: Optional[List[Callable]] = None, ): """Patch route decorator with PATCH `method` @@ -1261,6 +1279,7 @@ def lambda_handler(event, context): operation_id, include_in_schema, security, + openapi_extensions, middlewares, ) @@ -1278,6 +1297,7 @@ def head( operation_id: Optional[str] = None, include_in_schema: bool = True, security: Optional[List[Dict[str, List[str]]]] = None, + openapi_extensions: Optional[Dict[str, Any]] = None, middlewares: Optional[List[Callable]] = None, ): """Head route decorator with HEAD `method` @@ -1318,6 +1338,7 @@ def lambda_handler(event, context): operation_id, include_in_schema, security, + openapi_extensions, middlewares, ) @@ -1541,6 +1562,7 @@ def get_openapi_schema( license_info: Optional["License"] = None, security_schemes: Optional[Dict[str, "SecurityScheme"]] = None, security: Optional[List[Dict[str, List[str]]]] = None, + openapi_extensions: Optional[Dict[str, Any]] = None, ) -> "OpenAPI": """ Returns the OpenAPI schema as a pydantic model. @@ -1571,6 +1593,8 @@ def get_openapi_schema( A declaration of the security schemes available to be used in the specification. security: List[Dict[str, List[str]]], optional A declaration of which security mechanisms are applied globally across the API. + openapi_extensions: Dict[str, Any], optional + Additional OpenAPI extensions as a dictionary. Returns ------- @@ -1603,11 +1627,15 @@ def get_openapi_schema( info.update({field: value for field, value in optional_fields.items() if value}) + if not isinstance(openapi_extensions, Dict): + openapi_extensions = {} + output: Dict[str, Any] = { "openapi": openapi_version, "info": info, "servers": self._get_openapi_servers(servers), "security": self._get_openapi_security(security, security_schemes), + **openapi_extensions, } components: Dict[str, Dict[str, Any]] = {} @@ -1726,6 +1754,7 @@ def get_openapi_json_schema( license_info: Optional["License"] = None, security_schemes: Optional[Dict[str, "SecurityScheme"]] = None, security: Optional[List[Dict[str, List[str]]]] = None, + openapi_extensions: Optional[Dict[str, Any]] = None, ) -> str: """ Returns the OpenAPI schema as a JSON serializable dict @@ -1756,6 +1785,8 @@ def get_openapi_json_schema( A declaration of the security schemes available to be used in the specification. security: List[Dict[str, List[str]]], optional A declaration of which security mechanisms are applied globally across the API. + openapi_extensions: Dict[str, Any], optional + Additional OpenAPI extensions as a dictionary. Returns ------- @@ -1778,6 +1809,7 @@ def get_openapi_json_schema( license_info=license_info, security_schemes=security_schemes, security=security, + openapi_extensions=openapi_extensions, ), by_alias=True, exclude_none=True, @@ -1805,6 +1837,7 @@ def enable_swagger( security: Optional[List[Dict[str, List[str]]]] = None, oauth2_config: Optional["OAuth2Config"] = None, persist_authorization: bool = False, + openapi_extensions: Optional[Dict[str, Any]] = None, ): """ Returns the OpenAPI schema as a JSON serializable dict @@ -1847,6 +1880,8 @@ def enable_swagger( The OAuth2 configuration for the Swagger UI. persist_authorization: bool, optional Whether to persist authorization data on browser close/refresh. + openapi_extensions: Dict[str, Any], optional + Additional OpenAPI extensions as a dictionary. """ from aws_lambda_powertools.event_handler.openapi.compat import model_json from aws_lambda_powertools.event_handler.openapi.models import Server @@ -1896,6 +1931,7 @@ def swagger_handler(): license_info=license_info, security_schemes=security_schemes, security=security, + openapi_extensions=openapi_extensions, ) # The .replace(' Callable[[Callable[..., Any]], Callable[..., Any]]: + + openapi_extensions = None security = None return super(BedrockAgentResolver, self).get( @@ -117,6 +119,7 @@ def get( # type: ignore[override] operation_id, include_in_schema, security, + openapi_extensions, middlewares, ) @@ -137,6 +140,7 @@ def post( # type: ignore[override] include_in_schema: bool = True, middlewares: Optional[List[Callable[..., Any]]] = None, ): + openapi_extensions = None security = None return super().post( @@ -152,6 +156,7 @@ def post( # type: ignore[override] operation_id, include_in_schema, security, + openapi_extensions, middlewares, ) @@ -172,6 +177,7 @@ def put( # type: ignore[override] include_in_schema: bool = True, middlewares: Optional[List[Callable[..., Any]]] = None, ): + openapi_extensions = None security = None return super().put( @@ -187,6 +193,7 @@ def put( # type: ignore[override] operation_id, include_in_schema, security, + openapi_extensions, middlewares, ) @@ -207,6 +214,7 @@ def patch( # type: ignore[override] include_in_schema: bool = True, middlewares: Optional[List[Callable]] = None, ): + openapi_extensions = None security = None return super().patch( @@ -222,6 +230,7 @@ def patch( # type: ignore[override] operation_id, include_in_schema, security, + openapi_extensions, middlewares, ) @@ -242,6 +251,7 @@ def delete( # type: ignore[override] include_in_schema: bool = True, middlewares: Optional[List[Callable[..., Any]]] = None, ): + openapi_extensions = None security = None return super().delete( @@ -257,6 +267,7 @@ def delete( # type: ignore[override] operation_id, include_in_schema, security, + openapi_extensions, middlewares, ) diff --git a/aws_lambda_powertools/event_handler/openapi/compat.py b/aws_lambda_powertools/event_handler/openapi/compat.py index 060886605ec..df0f8aabab7 100644 --- a/aws_lambda_powertools/event_handler/openapi/compat.py +++ b/aws_lambda_powertools/event_handler/openapi/compat.py @@ -41,7 +41,7 @@ RequestErrorModel: Type[BaseModel] = create_model("Request") if PYDANTIC_V2: # pragma: no cover # false positive; dropping in v3 - from pydantic import TypeAdapter, ValidationError + from pydantic import TypeAdapter, ValidationError, model_validator as parser_openapi_extension from pydantic._internal._typing_extra import eval_type_lenient from pydantic.fields import FieldInfo from pydantic._internal._utils import lenient_issubclass @@ -217,7 +217,7 @@ def model_json(model: BaseModel, **kwargs: Any) -> Any: return model.model_dump_json(**kwargs) else: - from pydantic import BaseModel, ValidationError + from pydantic import BaseModel, ValidationError, root_validator as parser_openapi_extension from pydantic.fields import ( ModelField, Required, diff --git a/aws_lambda_powertools/event_handler/openapi/models.py b/aws_lambda_powertools/event_handler/openapi/models.py index 04345ddaad7..cac3266d254 100644 --- a/aws_lambda_powertools/event_handler/openapi/models.py +++ b/aws_lambda_powertools/event_handler/openapi/models.py @@ -3,7 +3,8 @@ from pydantic import AnyUrl, BaseModel, Field -from aws_lambda_powertools.event_handler.openapi.compat import model_rebuild +from aws_lambda_powertools.event_handler.openapi.compat import model_rebuild, parser_openapi_extension +from aws_lambda_powertools.event_handler.openapi.exceptions import SchemaValidationError from aws_lambda_powertools.event_handler.openapi.pydantic_loader import PYDANTIC_V2 from aws_lambda_powertools.shared.types import Annotated, Literal @@ -13,6 +14,62 @@ """ +class OpenAPIExtensions(BaseModel): + """ + This class serves as a Pydantic proxy model to add OpenAPI extensions. + + OpenAPI extensions are arbitrary fields, so we remove openapi_extensions when dumping + and add only the provided value in the schema. + """ + + openapi_extensions: Optional[Dict[str, Any]] = None + + # This rule is valid for Pydantic v1 and v2 + # If the 'openapi_extensions' field is present in the 'values' dictionary, + # And if the extension starts with x- (must respect the RFC) + # update the 'values' dictionary with the contents of 'openapi_extensions', + # and then remove the 'openapi_extensions' field from the 'values' dictionary + + if PYDANTIC_V2: + + model_config = {"extra": "allow"} + + @parser_openapi_extension(mode="before") + def serialize_openapi_extension_v2(self): + if isinstance(self, dict) and self.get("openapi_extensions"): + + openapi_extension_value = self.get("openapi_extensions") + + for extension_key in openapi_extension_value: + if not str(extension_key).startswith("x-"): + raise SchemaValidationError("An OpenAPI extension key must start with x-") + + self.update(openapi_extension_value) + self.pop("openapi_extensions", None) + + return self + + else: + + @parser_openapi_extension(pre=False, allow_reuse=True) + def serialize_openapi_extension_v1(cls, values): + openapi_extension_value = values.get("openapi_extensions") + + if openapi_extension_value: + + for extension_key in openapi_extension_value: + if not str(extension_key).startswith("x-"): + raise SchemaValidationError("An OpenAPI extension key must start with x-") + + values.update(values["openapi_extensions"]) + del values["openapi_extensions"] + + return values + + class Config: + extra = "allow" + + # https://swagger.io/specification/#contact-object class Contact(BaseModel): name: Optional[str] = None @@ -77,7 +134,7 @@ class Config: # https://swagger.io/specification/#server-object -class Server(BaseModel): +class Server(OpenAPIExtensions): url: Union[AnyUrl, str] description: Optional[str] = None variables: Optional[Dict[str, ServerVariable]] = None @@ -379,7 +436,7 @@ class Config: # https://swagger.io/specification/#operation-object -class Operation(BaseModel): +class Operation(OpenAPIExtensions): tags: Optional[List[str]] = None summary: Optional[str] = None description: Optional[str] = None @@ -436,12 +493,12 @@ class SecuritySchemeType(Enum): openIdConnect = "openIdConnect" -class SecurityBase(BaseModel): +class SecurityBase(OpenAPIExtensions): type_: SecuritySchemeType = Field(alias="type") description: Optional[str] = None if PYDANTIC_V2: - model_config = {"extra": "allow", "populate_by_name": True} + model_config = {"extra": "allow", "populate_by_name": True} # type: ignore else: @@ -557,7 +614,7 @@ class Config: # https://swagger.io/specification/#openapi-object -class OpenAPI(BaseModel): +class OpenAPI(OpenAPIExtensions): openapi: str info: Info jsonSchemaDialect: Optional[str] = None diff --git a/docs/core/event_handler/api_gateway.md b/docs/core/event_handler/api_gateway.md index b3c046e243c..d5fcbd7619e 100644 --- a/docs/core/event_handler/api_gateway.md +++ b/docs/core/event_handler/api_gateway.md @@ -1115,6 +1115,22 @@ OpenAPI 3 lets you describe APIs protected using the following security schemes: --8<-- "examples/event_handler_rest/src/swagger_with_oauth2.py" ``` +#### OpenAPI extensions + +For a better experience when working with Lambda and Amazon API Gateway, customers can define extensions using the `openapi_extensions` parameter. We support defining OpenAPI extensions at the following levels of the OpenAPI JSON Schema: Root, Servers, Operation, and Security Schemes. + +???+ warning + We do not support the `x-amazon-apigateway-any-method` and `x-amazon-apigateway-integrations` extensions. + +```python hl_lines="9 15 25 28" title="Adding OpenAPI extensions" +--8<-- "examples/event_handler_rest/src/working_with_openapi_extensions.py" +``` + +1. Server level +2. Operation level +3. Security scheme level +4. Root level + ### Custom serializer You can instruct event handler to use a custom serializer to best suit your needs, for example take into account Enums when serializing. diff --git a/examples/event_handler_rest/src/working_with_openapi_extensions.py b/examples/event_handler_rest/src/working_with_openapi_extensions.py new file mode 100644 index 00000000000..03489c6f7b8 --- /dev/null +++ b/examples/event_handler_rest/src/working_with_openapi_extensions.py @@ -0,0 +1,33 @@ +from aws_lambda_powertools.event_handler import APIGatewayRestResolver +from aws_lambda_powertools.event_handler.openapi.models import APIKey, APIKeyIn, Server + +app = APIGatewayRestResolver(enable_validation=True) + +servers = Server( + url="http://example.com", + description="Example server", + openapi_extensions={"x-amazon-apigateway-endpoint-configuration": {"vpcEndpoint": "myendpointid"}}, # (1)! +) + + +@app.get( + "/hello", + openapi_extensions={"x-amazon-apigateway-integration": {"type": "aws", "uri": "my_lambda_arn"}}, # (2)! +) +def hello(): + return app.get_openapi_json_schema( + servers=[servers], + security_schemes={ + "apikey": APIKey( + name="X-API-KEY", + description="API KeY", + in_=APIKeyIn.header, + openapi_extensions={"x-amazon-apigateway-authorizer": "custom"}, # (3)! + ), + }, + openapi_extensions={"x-amazon-apigateway-gateway-responses": {"DEFAULT_4XX"}}, # (4)! + ) + + +def lambda_handler(event, context): + return app.resolve(event, context) diff --git a/noxfile.py b/noxfile.py index 7023f45a2b7..4e53bcb816a 100644 --- a/noxfile.py +++ b/noxfile.py @@ -163,6 +163,7 @@ def test_with_pydantic_required_package(session: nox.Session, pydantic: str): f"{PREFIX_TESTS_FUNCTIONAL}/event_handler/_pydantic/", f"{PREFIX_TESTS_FUNCTIONAL}/batch/_pydantic/", f"{PREFIX_TESTS_UNIT}/parser/_pydantic/", + f"{PREFIX_TESTS_UNIT}/event_handler/_pydantic/", ], ) diff --git a/tests/e2e/event_handler/handlers/openapi_handler.py b/tests/e2e/event_handler/handlers/openapi_handler.py new file mode 100644 index 00000000000..13cfb69f016 --- /dev/null +++ b/tests/e2e/event_handler/handlers/openapi_handler.py @@ -0,0 +1,19 @@ +from aws_lambda_powertools.event_handler import ( + APIGatewayRestResolver, +) + +app = APIGatewayRestResolver(enable_validation=True) + + +@app.get("/openapi_schema") +def openapi_schema(): + return app.get_openapi_json_schema( + title="Powertools e2e API", + version="1.0.0", + description="This is a sample Powertools e2e API", + openapi_extensions={"x-amazon-apigateway-gateway-responses": {"DEFAULT_4XX"}}, + ) + + +def lambda_handler(event, context): + return app.resolve(event, context) diff --git a/tests/e2e/event_handler/infrastructure.py b/tests/e2e/event_handler/infrastructure.py index 5fd78896a34..b607e32caf8 100644 --- a/tests/e2e/event_handler/infrastructure.py +++ b/tests/e2e/event_handler/infrastructure.py @@ -18,7 +18,7 @@ def create_resources(self): functions = self.create_lambda_functions() self._create_alb(function=[functions["AlbHandler"], functions["AlbHandlerWithBodyNone"]]) - self._create_api_gateway_rest(function=functions["ApiGatewayRestHandler"]) + self._create_api_gateway_rest(function=[functions["ApiGatewayRestHandler"], functions["OpenapiHandler"]]) self._create_api_gateway_http(function=functions["ApiGatewayHttpHandler"]) self._create_lambda_function_url(function=functions["LambdaFunctionUrlHandler"]) @@ -76,7 +76,7 @@ def _create_api_gateway_http(self, function: Function): CfnOutput(self.stack, "APIGatewayHTTPUrl", value=(apigw.url or "")) - def _create_api_gateway_rest(self, function: Function): + def _create_api_gateway_rest(self, function: List[Function]): apigw = apigwv1.RestApi( self.stack, "APIGatewayRest", @@ -87,7 +87,10 @@ def _create_api_gateway_rest(self, function: Function): ) todos = apigw.root.add_resource("todos") - todos.add_method("POST", apigwv1.LambdaIntegration(function, proxy=True)) + todos.add_method("POST", apigwv1.LambdaIntegration(function[0], proxy=True)) + + openapi_schema = apigw.root.add_resource("openapi_schema") + openapi_schema.add_method("GET", apigwv1.LambdaIntegration(function[1], proxy=True)) CfnOutput(self.stack, "APIGatewayRestUrl", value=apigw.url) diff --git a/tests/e2e/event_handler/test_openapi.py b/tests/e2e/event_handler/test_openapi.py new file mode 100644 index 00000000000..d69c3b142b2 --- /dev/null +++ b/tests/e2e/event_handler/test_openapi.py @@ -0,0 +1,27 @@ +import pytest +from requests import Request + +from tests.e2e.utils import data_fetcher + + +@pytest.fixture +def apigw_rest_endpoint(infrastructure: dict) -> str: + return infrastructure.get("APIGatewayRestUrl", "") + + +@pytest.mark.xdist_group(name="event_handler") +def test_get_openapi_schema(apigw_rest_endpoint): + # GIVEN + url = f"{apigw_rest_endpoint}openapi_schema" + + # WHEN + response = data_fetcher.get_http_response( + Request( + method="GET", + url=url, + ), + ) + + assert "Powertools e2e API" in response.text + assert "x-amazon-apigateway-gateway-responses" in response.text + assert response.status_code == 200 diff --git a/tests/functional/event_handler/_pydantic/conftest.py b/tests/functional/event_handler/_pydantic/conftest.py index a099ae4cea5..1d38e2e26b1 100644 --- a/tests/functional/event_handler/_pydantic/conftest.py +++ b/tests/functional/event_handler/_pydantic/conftest.py @@ -120,3 +120,20 @@ def openapi31_schema(): @pytest.fixture def security_scheme(): return {"apiKey": APIKey(name="X-API-KEY", description="API Key", in_=APIKeyIn.header)} + + +@pytest.fixture +def openapi_extension_integration_detail(): + return { + "type": "aws", + "httpMethod": "POST", + "uri": "arn:aws:apigateway:us-east-1:lambda:path/2015-03-31/functions/..integration/invocations", + "responses": {"default": {"statusCode": "200"}}, + "passthroughBehavior": "when_no_match", + "contentHandling": "CONVERT_TO_TEXT", + } + + +@pytest.fixture +def openapi_extension_validator_detail(): + return "Validate body, query string parameters, and headers" diff --git a/tests/functional/event_handler/_pydantic/test_openapi_extensions.py b/tests/functional/event_handler/_pydantic/test_openapi_extensions.py new file mode 100644 index 00000000000..2f0552ffc4c --- /dev/null +++ b/tests/functional/event_handler/_pydantic/test_openapi_extensions.py @@ -0,0 +1,266 @@ +import json + +from aws_lambda_powertools.event_handler.api_gateway import APIGatewayRestResolver, Router +from aws_lambda_powertools.event_handler.openapi.models import ( + APIKey, + APIKeyIn, + OAuth2, + OAuthFlowImplicit, + OAuthFlows, + Server, +) + + +def test_openapi_extension_root_level(): + # GIVEN an APIGatewayRestResolver instance + app = APIGatewayRestResolver() + + cors_config = { + "maxAge": 0, + "allowCredentials": False, + } + + # WHEN we get the OpenAPI JSON schema with CORS extension in the Root Level + schema = json.loads( + app.get_openapi_json_schema( + openapi_extensions={"x-amazon-apigateway-cors": cors_config}, + ), + ) + + # THEN the OpenAPI schema must contain the "x-amazon-apigateway-cors" extension + assert "x-amazon-apigateway-cors" in schema + assert schema["x-amazon-apigateway-cors"] == cors_config + + +def test_openapi_extension_server_level(): + # GIVEN an APIGatewayRestResolver instance + app = APIGatewayRestResolver() + + endpoint_config = { + "disableExecuteApiEndpoint": True, + "vpcEndpointIds": ["vpce-0df8e77555fca0000"], + } + + server_config = { + "url": "https://example.org/", + "description": "Example website", + } + + # WHEN we get the OpenAPI JSON schema with a server-level openapi extension + schema = json.loads( + app.get_openapi_json_schema( + title="Hello API", + version="1.0.0", + servers=[ + Server( + **server_config, + openapi_extensions={ + "x-amazon-apigateway-endpoint-configuration": endpoint_config, + }, + ), + ], + ), + ) + + # THEN the OpenAPI schema must contain the "x-amazon-apigateway-endpoint-configuration" at the server level + assert "x-amazon-apigateway-endpoint-configuration" in schema["servers"][0] + assert schema["servers"][0]["x-amazon-apigateway-endpoint-configuration"] == endpoint_config + assert schema["servers"][0]["url"] == server_config["url"] + assert schema["servers"][0]["description"] == server_config["description"] + + +def test_openapi_extension_security_scheme_level_with_api_key(): + # GIVEN an APIGatewayRestResolver instance + app = APIGatewayRestResolver() + + authorizer_config = { + "authorizerUri": "arn:aws:apigateway:us-east-1:...:function:authorizer/invocations", + "authorizerResultTtlInSeconds": 300, + "type": "token", + } + + api_key_config = { + "name": "X-API-KEY", + "description": "API Key", + "in_": APIKeyIn.header, + } + + # WHEN we get the OpenAPI JSON schema with a security scheme-level extension for a custom auth + schema = json.loads( + app.get_openapi_json_schema( + security_schemes={ + "apiKey": APIKey( + **api_key_config, + openapi_extensions={ + "x-amazon-apigateway-authtype": "custom", + "x-amazon-apigateway-authorizer": authorizer_config, + }, + ), + }, + ), + ) + + # THEN the OpenAPI schema must contain the "x-amazon-apigateway-authtype" extension at the security scheme level + assert "x-amazon-apigateway-authtype" in schema["components"]["securitySchemes"]["apiKey"] + assert schema["components"]["securitySchemes"]["apiKey"]["x-amazon-apigateway-authtype"] == "custom" + assert schema["components"]["securitySchemes"]["apiKey"]["x-amazon-apigateway-authorizer"] == authorizer_config + assert schema["components"]["securitySchemes"]["apiKey"]["name"] == api_key_config["name"] + assert schema["components"]["securitySchemes"]["apiKey"]["description"] == api_key_config["description"] + assert schema["components"]["securitySchemes"]["apiKey"]["in"] == "header" + + +def test_openapi_extension_security_scheme_level_with_oauth2(): + # GIVEN an APIGatewayRestResolver instance + app = APIGatewayRestResolver() + + authorizer_config = { + "identitySource": "$request.header.Authorization", + "jwtConfiguration": { + "audience": ["test"], + "issuer": "https://cognito-idp.us-east-1.amazonaws.com/us-east-1_xxxxx/", + }, + "type": "jwt", + } + + oauth2_config = { + "flows": OAuthFlows( + implicit=OAuthFlowImplicit( + authorizationUrl="https://example.com/oauth2/authorize", + ), + ), + } + + # WHEN we get the OpenAPI JSON schema with a security scheme-level extension for a custom auth + schema = json.loads( + app.get_openapi_json_schema( + security_schemes={ + "oauth2": OAuth2( + **oauth2_config, + openapi_extensions={ + "x-amazon-apigateway-authorizer": authorizer_config, + }, + ), + }, + ), + ) + + # THEN the OpenAPI schema must contain the "x-amazon-apigateway-authorizer" extension at the security scheme level + assert "x-amazon-apigateway-authorizer" in schema["components"]["securitySchemes"]["oauth2"] + assert schema["components"]["securitySchemes"]["oauth2"]["x-amazon-apigateway-authorizer"] == authorizer_config + assert ( + schema["components"]["securitySchemes"]["oauth2"]["x-amazon-apigateway-authorizer"]["identitySource"] + == "$request.header.Authorization" + ) + assert schema["components"]["securitySchemes"]["oauth2"]["x-amazon-apigateway-authorizer"]["jwtConfiguration"][ + "audience" + ] == ["test"] + assert ( + schema["components"]["securitySchemes"]["oauth2"]["x-amazon-apigateway-authorizer"]["jwtConfiguration"][ + "issuer" + ] + == "https://cognito-idp.us-east-1.amazonaws.com/us-east-1_xxxxx/" + ) + + +def test_openapi_extension_operation_level(openapi_extension_integration_detail): + # GIVEN an APIGatewayRestResolver instance + app = APIGatewayRestResolver() + + # WHEN we define an integration extension at operation level + # AND get the schema + @app.get("/test", openapi_extensions={"x-amazon-apigateway-integration": openapi_extension_integration_detail}) + def lambda_handler(): + pass + + schema = json.loads(app.get_openapi_json_schema()) + + # THEN the OpenAPI schema must contain the "x-amazon-apigateway-integration" extension at the operation level + assert "x-amazon-apigateway-integration" in schema["paths"]["/test"]["get"] + assert schema["paths"]["/test"]["get"]["x-amazon-apigateway-integration"] == openapi_extension_integration_detail + assert schema["paths"]["/test"]["get"]["operationId"] == "lambda_handler_test_get" + + +def test_openapi_extension_operation_level_multiple_paths( + openapi_extension_integration_detail, + openapi_extension_validator_detail, +): + # GIVEN an APIGatewayRestResolver instance + app = APIGatewayRestResolver() + + # WHEN we define multiple routes with integration extension at operation level + # AND get the schema + @app.get("/test", openapi_extensions={"x-amazon-apigateway-integration": openapi_extension_integration_detail}) + def lambda_handler_get(): + pass + + @app.post("/test", openapi_extensions={"x-amazon-apigateway-request-validator": openapi_extension_validator_detail}) + def lambda_handler_post(): + pass + + schema = json.loads(app.get_openapi_json_schema()) + + # THEN each route must contain only your extension + assert "x-amazon-apigateway-integration" in schema["paths"]["/test"]["get"] + assert schema["paths"]["/test"]["get"]["x-amazon-apigateway-integration"] == openapi_extension_integration_detail + + assert "x-amazon-apigateway-integration" not in schema["paths"]["/test"]["post"] + assert "x-amazon-apigateway-request-validator" in schema["paths"]["/test"]["post"] + assert ( + schema["paths"]["/test"]["post"]["x-amazon-apigateway-request-validator"] == openapi_extension_validator_detail + ) + + +def test_openapi_extension_operation_level_with_router(openapi_extension_integration_detail): + # GIVEN an APIGatewayRestResolver and Router instance + app = APIGatewayRestResolver() + router = Router() + + # WHEN we define an integration extension at operation level using Router + # AND get the schema + @router.get("/test", openapi_extensions={"x-amazon-apigateway-integration": openapi_extension_integration_detail}) + def lambda_handler(): + pass + + app.include_router(router) + + schema = json.loads(app.get_openapi_json_schema()) + + # THEN the OpenAPI schema must contain the "x-amazon-apigateway-integration" extension at the operation level + assert "x-amazon-apigateway-integration" in schema["paths"]["/test"]["get"] + assert schema["paths"]["/test"]["get"]["x-amazon-apigateway-integration"] == openapi_extension_integration_detail + + +def test_openapi_extension_operation_level_multiple_paths_with_router( + openapi_extension_integration_detail, + openapi_extension_validator_detail, +): + # GIVEN an APIGatewayRestResolver and Router instance + app = APIGatewayRestResolver() + router = Router() + + # WHEN we define multiple routes using extensions at operation level using Router + # AND get the schema + @router.get("/test", openapi_extensions={"x-amazon-apigateway-integration": openapi_extension_integration_detail}) + def lambda_handler_get(): + pass + + @router.post( + "/test", + openapi_extensions={"x-amazon-apigateway-request-validator": openapi_extension_validator_detail}, + ) + def lambda_handler_post(): + pass + + app.include_router(router) + + schema = json.loads(app.get_openapi_json_schema()) + + # THEN each route must contain only your extension + assert "x-amazon-apigateway-integration" in schema["paths"]["/test"]["get"] + assert schema["paths"]["/test"]["get"]["x-amazon-apigateway-integration"] == openapi_extension_integration_detail + + assert "x-amazon-apigateway-integration" not in schema["paths"]["/test"]["post"] + assert "x-amazon-apigateway-request-validator" in schema["paths"]["/test"]["post"] + assert ( + schema["paths"]["/test"]["post"]["x-amazon-apigateway-request-validator"] == openapi_extension_validator_detail + ) diff --git a/tests/unit/event_handler/__init__.py b/tests/unit/event_handler/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/unit/event_handler/_pydantic/__init__.py b/tests/unit/event_handler/_pydantic/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/unit/event_handler/_pydantic/conftest.py b/tests/unit/event_handler/_pydantic/conftest.py new file mode 100644 index 00000000000..d50d4e483ef --- /dev/null +++ b/tests/unit/event_handler/_pydantic/conftest.py @@ -0,0 +1,18 @@ +import pytest +from pydantic import __version__ + + +@pytest.fixture(scope="session") +def pydanticv1_only(): + + version = __version__.split(".") + if version[0] != "1": + pytest.skip("pydanticv1 test only") + + +@pytest.fixture(scope="session") +def pydanticv2_only(): + + version = __version__.split(".") + if version[0] != "2": + pytest.skip("pydanticv2 test only") diff --git a/tests/unit/event_handler/_pydantic/test_openapi_models_pydantic_v1.py b/tests/unit/event_handler/_pydantic/test_openapi_models_pydantic_v1.py new file mode 100644 index 00000000000..49ee4920f3d --- /dev/null +++ b/tests/unit/event_handler/_pydantic/test_openapi_models_pydantic_v1.py @@ -0,0 +1,44 @@ +import pytest + +from aws_lambda_powertools.event_handler.openapi.exceptions import SchemaValidationError +from aws_lambda_powertools.event_handler.openapi.models import OpenAPIExtensions + + +@pytest.mark.usefixtures("pydanticv1_only") +def test_openapi_extensions_with_dict(): + # GIVEN we create an OpenAPIExtensions object with a dict + extensions = OpenAPIExtensions(openapi_extensions={"x-amazon-apigateway": {"foo": "bar"}}) + + # THEN we get a dict with the extension + assert extensions.dict(exclude_none=True) == {"x-amazon-apigateway": {"foo": "bar"}} + + +@pytest.mark.usefixtures("pydanticv1_only") +def test_openapi_extensions_with_invalid_key(): + # GIVEN we create an OpenAPIExtensions object with an invalid value + with pytest.raises(SchemaValidationError): + # THEN must raise an exception + OpenAPIExtensions(openapi_extensions={"amazon-apigateway-invalid": {"foo": "bar"}}) + + +@pytest.mark.usefixtures("pydanticv1_only") +def test_openapi_extensions_with_proxy_models(): + + # GIVEN we create an models using OpenAPIExtensions as a "Proxy" Model + class MyModelFoo(OpenAPIExtensions): + foo: str + + class MyModelBar(OpenAPIExtensions): + bar: str + foo: MyModelFoo + + value_to_serialize = MyModelBar( + bar="bar", + foo=MyModelFoo(foo="foo"), + openapi_extensions={"x-amazon-apigateway": {"foo": "bar"}}, + ) + + value_to_return = value_to_serialize.dict(exclude_none=True) + + # THEN we get a dict with the value serialized + assert value_to_return == {"bar": "bar", "foo": {"foo": "foo"}, "x-amazon-apigateway": {"foo": "bar"}} diff --git a/tests/unit/event_handler/_pydantic/test_openapi_models_pydantic_v2.py b/tests/unit/event_handler/_pydantic/test_openapi_models_pydantic_v2.py new file mode 100644 index 00000000000..5191b4f7520 --- /dev/null +++ b/tests/unit/event_handler/_pydantic/test_openapi_models_pydantic_v2.py @@ -0,0 +1,44 @@ +import pytest + +from aws_lambda_powertools.event_handler.openapi.exceptions import SchemaValidationError +from aws_lambda_powertools.event_handler.openapi.models import OpenAPIExtensions + + +@pytest.mark.usefixtures("pydanticv2_only") +def test_openapi_extensions_with_dict(): + # GIVEN we create an OpenAPIExtensions object with a dict + extensions = OpenAPIExtensions(openapi_extensions={"x-amazon-apigateway": {"foo": "bar"}}) + + # THEN we get a dict with the extension + assert extensions.model_dump(exclude_none=True) == {"x-amazon-apigateway": {"foo": "bar"}} + + +@pytest.mark.usefixtures("pydanticv2_only") +def test_openapi_extensions_with_invalid_key(): + # GIVEN we create an OpenAPIExtensions object with an invalid value + with pytest.raises(SchemaValidationError): + # THEN must raise an exception + OpenAPIExtensions(openapi_extensions={"amazon-apigateway-invalid": {"foo": "bar"}}) + + +@pytest.mark.usefixtures("pydanticv2_only") +def test_openapi_extensions_with_proxy_models(): + + # GIVEN we create an models using OpenAPIExtensions as a "Proxy" Model + class MyModelFoo(OpenAPIExtensions): + foo: str + + class MyModelBar(OpenAPIExtensions): + bar: str + foo: MyModelFoo + + value_to_serialize = MyModelBar( + bar="bar", + foo=MyModelFoo(foo="foo"), + openapi_extensions={"x-amazon-apigateway": {"foo": "bar"}}, + ) + + value_to_return = value_to_serialize.model_dump(exclude_none=True) + + # THEN we get a dict with the value serialized + assert value_to_return == {"bar": "bar", "foo": {"foo": "foo"}, "x-amazon-apigateway": {"foo": "bar"}}