diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 1383b74ada0..37e2265ea8a 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -9,6 +9,7 @@ from enum import Enum from functools import partial from http import HTTPStatus +from pathlib import Path from typing import ( TYPE_CHECKING, Any, @@ -28,6 +29,8 @@ from aws_lambda_powertools.event_handler import content_types from aws_lambda_powertools.event_handler.exceptions import NotFoundError, ServiceError +from aws_lambda_powertools.event_handler.openapi.constants import DEFAULT_API_VERSION, DEFAULT_OPENAPI_VERSION +from aws_lambda_powertools.event_handler.openapi.swagger_ui.html import generate_swagger_html from aws_lambda_powertools.event_handler.openapi.types import ( COMPONENT_REF_PREFIX, METHODS_WITH_BODY, @@ -69,7 +72,6 @@ License, OpenAPI, Server, - Tag, ) from aws_lambda_powertools.event_handler.openapi.params import Dependant from aws_lambda_powertools.event_handler.openapi.types import ( @@ -236,6 +238,15 @@ def __init__( if content_type: self.headers.setdefault("Content-Type", content_type) + def is_json(self) -> bool: + """ + Returns True if the response is JSON, based on the Content-Type. + """ + content_type = self.headers.get("Content-Type", "") + if isinstance(content_type, list): + content_type = content_type[0] + return content_type.startswith("application/json") + class Route: """Internally used Route Configuration""" @@ -253,8 +264,9 @@ def __init__( description: Optional[str], responses: Optional[Dict[int, Dict[str, Any]]], response_description: Optional[str], - tags: Optional[List["Tag"]], + tags: Optional[List[str]], operation_id: Optional[str], + include_in_schema: bool, middlewares: Optional[List[Callable[..., Response]]], ): """ @@ -284,10 +296,12 @@ def __init__( The OpenAPI responses for this route response_description: Optional[str] The OpenAPI response description for this route - tags: Optional[List[Tag]] + tags: Optional[List[str]] The list of OpenAPI tags to be used for this route operation_id: Optional[str] The OpenAPI operationId for this route + include_in_schema: bool + Whether or not to include this route in the OpenAPI schema middlewares: Optional[List[Callable[..., Response]]] The list of route middlewares to be called in order. """ @@ -304,6 +318,7 @@ def __init__( self.responses = responses self.response_description = response_description self.tags = tags or [] + self.include_in_schema = include_in_schema self.middlewares = middlewares or [] self.operation_id = operation_id or self._generate_operation_id() @@ -483,7 +498,6 @@ def _get_openapi_path( # Add the response schema to the OpenAPI 200 response json_response.update( self._openapi_operation_return( - operation_id=self.operation_id, param=dependant.return_param, model_name_map=model_name_map, field_mapping=field_mapping, @@ -530,7 +544,7 @@ def _openapi_operation_metadata(self, operation_ids: Set[str]) -> Dict[str, Any] # Ensure tags is added to the operation if self.tags: - operation["tags"] = self.tags + operation["tags"] = [{"name": tag for tag in self.tags}] # Ensure summary is added to the operation operation["summary"] = self._openapi_operation_summary() @@ -643,7 +657,6 @@ def _openapi_operation_parameters( @staticmethod def _openapi_operation_return( *, - operation_id: str, param: Optional["ModelField"], model_name_map: Dict["TypeModelOrEnum", str], field_mapping: Dict[ @@ -667,7 +680,7 @@ def _openapi_operation_return( field_mapping=field_mapping, ) - return {"name": f"Return {operation_id}", "schema": return_schema} + return {"schema": return_schema} def _generate_operation_id(self) -> str: operation_id = self.func.__name__ + self.path @@ -790,8 +803,9 @@ def route( description: Optional[str] = None, responses: Optional[Dict[int, Dict[str, Any]]] = None, response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, - tags: Optional[List["Tag"]] = None, + tags: Optional[List[str]] = None, operation_id: Optional[str] = None, + include_in_schema: bool = True, middlewares: Optional[List[Callable[..., Any]]] = None, ): raise NotImplementedError() @@ -847,8 +861,9 @@ def get( description: Optional[str] = None, responses: Optional[Dict[int, Dict[str, Any]]] = None, response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, - tags: Optional[List["Tag"]] = None, + tags: Optional[List[str]] = None, operation_id: Optional[str] = None, + include_in_schema: bool = True, middlewares: Optional[List[Callable[..., Any]]] = None, ): """Get route decorator with GET `method` @@ -885,6 +900,7 @@ def lambda_handler(event, context): response_description, tags, operation_id, + include_in_schema, middlewares, ) @@ -898,8 +914,9 @@ def post( description: Optional[str] = None, responses: Optional[Dict[int, Dict[str, Any]]] = None, response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, - tags: Optional[List["Tag"]] = None, + tags: Optional[List[str]] = None, operation_id: Optional[str] = None, + include_in_schema: bool = True, middlewares: Optional[List[Callable[..., Any]]] = None, ): """Post route decorator with POST `method` @@ -937,6 +954,7 @@ def lambda_handler(event, context): response_description, tags, operation_id, + include_in_schema, middlewares, ) @@ -950,8 +968,9 @@ def put( description: Optional[str] = None, responses: Optional[Dict[int, Dict[str, Any]]] = None, response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, - tags: Optional[List["Tag"]] = None, + tags: Optional[List[str]] = None, operation_id: Optional[str] = None, + include_in_schema: bool = True, middlewares: Optional[List[Callable[..., Any]]] = None, ): """Put route decorator with PUT `method` @@ -989,6 +1008,7 @@ def lambda_handler(event, context): response_description, tags, operation_id, + include_in_schema, middlewares, ) @@ -1002,8 +1022,9 @@ def delete( description: Optional[str] = None, responses: Optional[Dict[int, Dict[str, Any]]] = None, response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, - tags: Optional[List["Tag"]] = None, + tags: Optional[List[str]] = None, operation_id: Optional[str] = None, + include_in_schema: bool = True, middlewares: Optional[List[Callable[..., Any]]] = None, ): """Delete route decorator with DELETE `method` @@ -1040,6 +1061,7 @@ def lambda_handler(event, context): response_description, tags, operation_id, + include_in_schema, middlewares, ) @@ -1053,8 +1075,9 @@ def patch( description: Optional[str] = None, responses: Optional[Dict[int, Dict[str, Any]]] = None, response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, - tags: Optional[List["Tag"]] = None, + tags: Optional[List[str]] = None, operation_id: Optional[str] = None, + include_in_schema: bool = True, middlewares: Optional[List[Callable]] = None, ): """Patch route decorator with PATCH `method` @@ -1094,6 +1117,7 @@ def lambda_handler(event, context): response_description, tags, operation_id, + include_in_schema, middlewares, ) @@ -1312,11 +1336,11 @@ def get_openapi_schema( self, *, title: str = "Powertools API", - version: str = "1.0.0", - openapi_version: str = "3.1.0", + version: str = DEFAULT_API_VERSION, + openapi_version: str = DEFAULT_OPENAPI_VERSION, summary: Optional[str] = None, description: Optional[str] = None, - tags: Optional[List["Tag"]] = None, + tags: Optional[List[str]] = None, servers: Optional[List["Server"]] = None, terms_of_service: Optional[str] = None, contact: Optional["Contact"] = None, @@ -1337,7 +1361,7 @@ def get_openapi_schema( A short summary of what the application does. description: str, optional A verbose explanation of the application behavior. - tags: List[Tag], optional + tags: List[str], optional A list of tags used by the specification with additional metadata. servers: List[Server], optional An array of Server Objects, which provide connectivity information to a target server. @@ -1345,7 +1369,7 @@ def get_openapi_schema( A URL to the Terms of Service for the API. MUST be in the format of a URL. contact: Contact, optional The contact information for the exposed API. - license_info: + license_info: License, optional The license information for the exposed API. Returns @@ -1403,6 +1427,9 @@ def get_openapi_schema( # Add routes to the OpenAPI schema for route in all_routes: + if not route.include_in_schema: + continue + result = route._get_openapi_path( dependant=route.dependant, operation_ids=operation_ids, @@ -1421,7 +1448,7 @@ def get_openapi_schema( if components: output["components"] = components if tags: - output["tags"] = tags + output["tags"] = [{"name": tag} for tag in tags] output["paths"] = {k: PathItem(**v) for k, v in paths.items()} @@ -1431,11 +1458,11 @@ def get_openapi_json_schema( self, *, title: str = "Powertools API", - version: str = "1.0.0", - openapi_version: str = "3.1.0", + version: str = DEFAULT_API_VERSION, + openapi_version: str = DEFAULT_OPENAPI_VERSION, summary: Optional[str] = None, description: Optional[str] = None, - tags: Optional[List["Tag"]] = None, + tags: Optional[List[str]] = None, servers: Optional[List["Server"]] = None, terms_of_service: Optional[str] = None, contact: Optional["Contact"] = None, @@ -1456,7 +1483,7 @@ def get_openapi_json_schema( A short summary of what the application does. description: str, optional A verbose explanation of the application behavior. - tags: List[Tag], optional + tags: List[str], optional A list of tags used by the specification with additional metadata. servers: List[Server], optional An array of Server Objects, which provide connectivity information to a target server. @@ -1464,7 +1491,7 @@ def get_openapi_json_schema( A URL to the Terms of Service for the API. MUST be in the format of a URL. contact: Contact, optional The contact information for the exposed API. - license_info: + license_info: License, optional The license information for the exposed API. Returns @@ -1492,6 +1519,111 @@ def get_openapi_json_schema( indent=2, ) + def enable_swagger( + self, + *, + path: str = "/swagger", + title: str = "Powertools for AWS Lambda (Python) API", + version: str = DEFAULT_API_VERSION, + openapi_version: str = DEFAULT_OPENAPI_VERSION, + summary: Optional[str] = None, + description: Optional[str] = None, + tags: Optional[List[str]] = None, + servers: Optional[List["Server"]] = None, + terms_of_service: Optional[str] = None, + contact: Optional["Contact"] = None, + license_info: Optional["License"] = None, + swagger_base_url: Optional[str] = None, + middlewares: Optional[List[Callable[..., Response]]] = None, + ): + """ + Returns the OpenAPI schema as a JSON serializable dict + + Parameters + ---------- + path: str, default = "/swagger" + The path to the swagger UI. + title: str + The title of the application. + version: str + The version of the OpenAPI document (which is distinct from the OpenAPI Specification version or the API + openapi_version: str, default = "3.1.0" + The version of the OpenAPI Specification (which the document uses). + summary: str, optional + A short summary of what the application does. + description: str, optional + A verbose explanation of the application behavior. + tags: List[str], optional + A list of tags used by the specification with additional metadata. + servers: List[Server], optional + An array of Server Objects, which provide connectivity information to a target server. + terms_of_service: str, optional + A URL to the Terms of Service for the API. MUST be in the format of a URL. + contact: Contact, optional + The contact information for the exposed API. + license_info: License, optional + The license information for the exposed API. + swagger_base_url: str, optional + The base url for the swagger UI. If not provided, we will serve a recent version of the Swagger UI. + middlewares: List[Callable[..., Response]], optional + List of middlewares to be used for the swagger route. + """ + from aws_lambda_powertools.event_handler.openapi.models import Server + + if not swagger_base_url: + + @self.get("/swagger.js", include_in_schema=False) + def swagger_js(): + body = Path.open(Path(__file__).parent / "openapi" / "swagger_ui" / "swagger-ui-bundle.min.js").read() + return Response( + status_code=200, + content_type="text/javascript", + body=body, + ) + + @self.get("/swagger.css", include_in_schema=False) + def swagger_css(): + body = Path.open(Path(__file__).parent / "openapi" / "swagger_ui" / "swagger-ui.min.css").read() + return Response( + status_code=200, + content_type="text/css", + body=body, + ) + + @self.get(path, middlewares=middlewares, include_in_schema=False) + def swagger_handler(): + base_path = self._get_base_path() + + if swagger_base_url: + swagger_js = f"{swagger_base_url}/swagger-ui-bundle.min.js" + swagger_css = f"{swagger_base_url}/swagger-ui.min.css" + else: + swagger_js = f"{base_path}/swagger.js" + swagger_css = f"{base_path}/swagger.css" + + openapi_servers = servers or [Server(url=(base_path or "/"))] + + spec = self.get_openapi_json_schema( + title=title, + version=version, + openapi_version=openapi_version, + summary=summary, + description=description, + tags=tags, + servers=openapi_servers, + terms_of_service=terms_of_service, + contact=contact, + license_info=license_info, + ) + + body = generate_swagger_html(spec, swagger_js, swagger_css) + + return Response( + status_code=200, + content_type="text/html", + body=body, + ) + def route( self, rule: str, @@ -1503,8 +1635,9 @@ def route( description: Optional[str] = None, responses: Optional[Dict[int, Dict[str, Any]]] = None, response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, - tags: Optional[List["Tag"]] = None, + tags: Optional[List[str]] = None, operation_id: Optional[str] = None, + include_in_schema: bool = True, middlewares: Optional[List[Callable[..., Any]]] = None, ): """Route decorator includes parameter `method`""" @@ -1530,6 +1663,7 @@ def register_resolver(func: Callable): response_description, tags, operation_id, + include_in_schema, middlewares, ) @@ -1606,6 +1740,9 @@ def _create_route_key(self, item: str, rule: str): ) self._route_keys.append(route_key) + def _get_base_path(self) -> str: + raise NotImplementedError() + @staticmethod def _has_debug(debug: Optional[bool] = None) -> bool: # It might have been explicitly switched off (debug=False) @@ -1940,8 +2077,9 @@ def route( description: Optional[str] = None, responses: Optional[Dict[int, Dict[str, Any]]] = None, response_description: Optional[str] = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, - tags: Optional[List["Tag"]] = None, + tags: Optional[List[str]] = None, operation_id: Optional[str] = None, + include_in_schema: bool = True, middlewares: Optional[List[Callable[..., Any]]] = None, ): def register_route(func: Callable): @@ -1960,6 +2098,7 @@ def register_route(func: Callable): response_description, tags, operation_id, + include_in_schema, ) # Collate Middleware for routes @@ -2000,6 +2139,19 @@ def __init__( enable_validation, ) + def _get_base_path(self) -> str: + # 3 different scenarios: + # + # 1. SAM local: even though a stage variable is sent to the Lambda function, it's not used in the path + # 2. API Gateway REST API: stage variable is used in the path + # 3. API Gateway REST Custom Domain: stage variable is not used in the path + # + # To solve the 3 scenarios, we try to match the beginning of the path with the stage variable + stage = self.current_event.request_context.stage + if stage and stage != "$default" and self.current_event.request_context.path.startswith(f"/{stage}"): + return f"/{stage}" + return "" + # override route to ignore trailing "/" in routes for REST API def route( self, @@ -2012,8 +2164,9 @@ def route( description: Optional[str] = None, responses: Optional[Dict[int, Dict[str, Any]]] = None, response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, - tags: Optional[List["Tag"]] = None, + tags: Optional[List[str]] = None, operation_id: Optional[str] = None, + include_in_schema: bool = True, middlewares: Optional[List[Callable[..., Any]]] = None, ): # NOTE: see #1552 for more context. @@ -2029,6 +2182,7 @@ def route( response_description, tags, operation_id, + include_in_schema, middlewares, ) @@ -2059,6 +2213,19 @@ def __init__( enable_validation, ) + def _get_base_path(self) -> str: + # 3 different scenarios: + # + # 1. SAM local: even though a stage variable is sent to the Lambda function, it's not used in the path + # 2. API Gateway HTTP API: stage variable is used in the path + # 3. API Gateway HTTP Custom Domain: stage variable is not used in the path + # + # To solve the 3 scenarios, we try to match the beginning of the path with the stage variable + stage = self.current_event.request_context.stage + if stage and stage != "$default" and self.current_event.request_context.http.path.startswith(f"/{stage}"): + return f"/{stage}" + return "" + class ALBResolver(ApiGatewayResolver): current_event: ALBEvent @@ -2073,3 +2240,7 @@ def __init__( ): """Amazon Application Load Balancer (ALB) resolver""" super().__init__(ProxyEventType.ALBEvent, cors, debug, serializer, strip_prefixes, enable_validation) + + def _get_base_path(self) -> str: + # ALB doesn't have a stage variable, so we just return an empty string + return "" diff --git a/aws_lambda_powertools/event_handler/lambda_function_url.py b/aws_lambda_powertools/event_handler/lambda_function_url.py index bacdc8549c7..b69c8fc8087 100644 --- a/aws_lambda_powertools/event_handler/lambda_function_url.py +++ b/aws_lambda_powertools/event_handler/lambda_function_url.py @@ -62,3 +62,9 @@ def __init__( strip_prefixes, enable_validation, ) + + def _get_base_path(self) -> str: + stage = self.current_event.request_context.stage + if stage and stage != "$default" and self.current_event.request_context.http.method.startswith(f"/{stage}"): + return f"/{stage}" + return "" diff --git a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py index ea7b303bfa5..c162eeb4ce1 100644 --- a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py +++ b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py @@ -94,13 +94,12 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> else: # Re-write the route_args with the validated values, and call the next middleware app.context["_route_args"] = values - response = next_middleware(app) - # Process the response body if it exists - raw_response = jsonable_encoder(response.body) + # Call the handler by calling the next middleware + response = next_middleware(app) - # Validate and serialize the response - return self._serialize_response(field=route.dependant.return_param, response_content=raw_response) + # Process the response + return self._handle_response(route=route, response=response) except RequestValidationError as e: return Response( status_code=422, @@ -108,6 +107,18 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> body=json.dumps({"detail": e.errors()}), ) + def _handle_response(self, *, route: Route, response: Response): + # Process the response body if it exists + if response.body: + # Validate and serialize the response, if it's JSON + if response.is_json(): + response.body = json.dumps( + self._serialize_response(field=route.dependant.return_param, response_content=response.body), + sort_keys=True, + ) + + return response + def _serialize_response( self, *, diff --git a/aws_lambda_powertools/event_handler/openapi/constants.py b/aws_lambda_powertools/event_handler/openapi/constants.py new file mode 100644 index 00000000000..f5d72d47f7e --- /dev/null +++ b/aws_lambda_powertools/event_handler/openapi/constants.py @@ -0,0 +1,2 @@ +DEFAULT_API_VERSION = "1.0.0" +DEFAULT_OPENAPI_VERSION = "3.1.0" diff --git a/aws_lambda_powertools/event_handler/openapi/models.py b/aws_lambda_powertools/event_handler/openapi/models.py index 80818315f18..bbbc160f1e6 100644 --- a/aws_lambda_powertools/event_handler/openapi/models.py +++ b/aws_lambda_powertools/event_handler/openapi/models.py @@ -363,9 +363,24 @@ class Config: extra = "allow" +# https://swagger.io/specification/#tag-object +class Tag(BaseModel): + name: str + description: Optional[str] = None + externalDocs: Optional[ExternalDocumentation] = None + + if PYDANTIC_V2: + model_config = {"extra": "allow"} + + else: + + class Config: + extra = "allow" + + # https://swagger.io/specification/#operation-object class Operation(BaseModel): - tags: Optional[List[str]] = None + tags: Optional[List[Tag]] = None summary: Optional[str] = None description: Optional[str] = None externalDocs: Optional[ExternalDocumentation] = None @@ -540,21 +555,6 @@ class Config: extra = "allow" -# https://swagger.io/specification/#tag-object -class Tag(BaseModel): - name: str - description: Optional[str] = None - externalDocs: Optional[ExternalDocumentation] = None - - if PYDANTIC_V2: - model_config = {"extra": "allow"} - - else: - - class Config: - extra = "allow" - - # https://swagger.io/specification/#openapi-object class OpenAPI(BaseModel): openapi: str diff --git a/aws_lambda_powertools/event_handler/openapi/swagger_ui/__init__.py b/aws_lambda_powertools/event_handler/openapi/swagger_ui/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/aws_lambda_powertools/event_handler/openapi/swagger_ui/html.py b/aws_lambda_powertools/event_handler/openapi/swagger_ui/html.py new file mode 100644 index 00000000000..fdb38599f30 --- /dev/null +++ b/aws_lambda_powertools/event_handler/openapi/swagger_ui/html.py @@ -0,0 +1,52 @@ +def generate_swagger_html(spec: str, js_url: str, css_url: str) -> str: + """ + Generate Swagger UI HTML page + + Parameters + ---------- + spec: str + The OpenAPI spec in the JSON format + js_url: str + The URL to the Swagger UI JavaScript file + css_url: str + The URL to the Swagger UI CSS file + """ + return f""" + + +
+ +