diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 29601247b48..87433b020d5 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -34,7 +34,6 @@ from aws_lambda_powertools.event_handler.exceptions import NotFoundError, ServiceError from aws_lambda_powertools.event_handler.openapi.constants import DEFAULT_API_VERSION, DEFAULT_OPENAPI_VERSION from aws_lambda_powertools.event_handler.openapi.exceptions import RequestValidationError -from aws_lambda_powertools.event_handler.openapi.swagger_ui.html import generate_swagger_html from aws_lambda_powertools.event_handler.openapi.types import ( COMPONENT_REF_PREFIX, METHODS_WITH_BODY, @@ -83,10 +82,14 @@ Contact, License, OpenAPI, + SecurityScheme, Server, Tag, ) from aws_lambda_powertools.event_handler.openapi.params import Dependant + from aws_lambda_powertools.event_handler.openapi.swagger_ui.oauth2 import ( + OAuth2Config, + ) from aws_lambda_powertools.event_handler.openapi.types import ( TypeModelOrEnum, ) @@ -282,6 +285,7 @@ def __init__( tags: Optional[List[str]], operation_id: Optional[str], include_in_schema: bool, + security: Optional[List[Dict[str, List[str]]]], middlewares: Optional[List[Callable[..., Response]]], ): """ @@ -317,6 +321,8 @@ def __init__( The OpenAPI operationId for this route include_in_schema: bool Whether or not to include this route in the OpenAPI schema + security: List[Dict[str, List[str]]], optional + The OpenAPI security for this route middlewares: Optional[List[Callable[..., Response]]] The list of route middlewares to be called in order. """ @@ -339,6 +345,7 @@ def __init__( self.response_description = response_description self.tags = tags or [] self.include_in_schema = include_in_schema + self.security = security self.middlewares = middlewares or [] self.operation_id = operation_id or self._generate_operation_id() @@ -486,6 +493,10 @@ def _get_openapi_path( ) parameters.extend(operation_params) + # Add security if present + if self.security: + operation["security"] = self.security + # Add the parameters to the OpenAPI operation if parameters: all_parameters = {(param["in"], param["name"]): param for param in parameters} @@ -885,6 +896,7 @@ def route( tags: Optional[List[str]] = None, operation_id: Optional[str] = None, include_in_schema: bool = True, + security: Optional[List[Dict[str, List[str]]]] = None, middlewares: Optional[List[Callable[..., Any]]] = None, ): raise NotImplementedError() @@ -943,6 +955,7 @@ def get( tags: Optional[List[str]] = None, operation_id: Optional[str] = None, include_in_schema: bool = True, + security: Optional[List[Dict[str, List[str]]]] = None, middlewares: Optional[List[Callable[..., Any]]] = None, ): """Get route decorator with GET `method` @@ -980,6 +993,7 @@ def lambda_handler(event, context): tags, operation_id, include_in_schema, + security, middlewares, ) @@ -996,6 +1010,7 @@ def post( tags: Optional[List[str]] = None, operation_id: Optional[str] = None, include_in_schema: bool = True, + security: Optional[List[Dict[str, List[str]]]] = None, middlewares: Optional[List[Callable[..., Any]]] = None, ): """Post route decorator with POST `method` @@ -1034,6 +1049,7 @@ def lambda_handler(event, context): tags, operation_id, include_in_schema, + security, middlewares, ) @@ -1050,6 +1066,7 @@ def put( tags: Optional[List[str]] = None, operation_id: Optional[str] = None, include_in_schema: bool = True, + security: Optional[List[Dict[str, List[str]]]] = None, middlewares: Optional[List[Callable[..., Any]]] = None, ): """Put route decorator with PUT `method` @@ -1088,6 +1105,7 @@ def lambda_handler(event, context): tags, operation_id, include_in_schema, + security, middlewares, ) @@ -1104,6 +1122,7 @@ def delete( tags: Optional[List[str]] = None, operation_id: Optional[str] = None, include_in_schema: bool = True, + security: Optional[List[Dict[str, List[str]]]] = None, middlewares: Optional[List[Callable[..., Any]]] = None, ): """Delete route decorator with DELETE `method` @@ -1141,6 +1160,7 @@ def lambda_handler(event, context): tags, operation_id, include_in_schema, + security, middlewares, ) @@ -1157,6 +1177,7 @@ def patch( tags: Optional[List[str]] = None, operation_id: Optional[str] = None, include_in_schema: bool = True, + security: Optional[List[Dict[str, List[str]]]] = None, middlewares: Optional[List[Callable]] = None, ): """Patch route decorator with PATCH `method` @@ -1197,6 +1218,7 @@ def lambda_handler(event, context): tags, operation_id, include_in_schema, + security, middlewares, ) @@ -1419,6 +1441,8 @@ def get_openapi_schema( terms_of_service: Optional[str] = None, contact: Optional["Contact"] = None, license_info: Optional["License"] = None, + security_schemes: Optional[Dict[str, "SecurityScheme"]] = None, + security: Optional[List[Dict[str, List[str]]]] = None, ) -> "OpenAPI": """ Returns the OpenAPI schema as a pydantic model. @@ -1445,6 +1469,10 @@ def get_openapi_schema( The contact information for the exposed API. license_info: License, optional The license information for the exposed API. + security_schemes: Dict[str, "SecurityScheme"]], optional + A declaration of the security schemes available to be used in the specification. + security: List[Dict[str, List[str]]], optional + A declaration of which security mechanisms are applied globally across the API. Returns ------- @@ -1457,25 +1485,12 @@ def get_openapi_schema( get_compat_model_name_map, get_definitions, ) - from aws_lambda_powertools.event_handler.openapi.models import OpenAPI, PathItem, Server, Tag - from aws_lambda_powertools.event_handler.openapi.pydantic_loader import PYDANTIC_V2 + from aws_lambda_powertools.event_handler.openapi.models import OpenAPI, PathItem, Tag from aws_lambda_powertools.event_handler.openapi.types import ( COMPONENT_REF_TEMPLATE, ) - # Pydantic V2 has no support for OpenAPI schema 3.0 - if PYDANTIC_V2 and not openapi_version.startswith("3.1"): - warnings.warn( - "You are using Pydantic v2, which is incompatible with OpenAPI schema 3.0. Forcing OpenAPI 3.1", - stacklevel=2, - ) - openapi_version = "3.1.0" - elif not PYDANTIC_V2 and not openapi_version.startswith("3.0"): - warnings.warn( - "You are using Pydantic v1, which is incompatible with OpenAPI schema 3.1. Forcing OpenAPI 3.0", - stacklevel=2, - ) - openapi_version = "3.0.3" + openapi_version = self._determine_openapi_version(openapi_version) # Start with the bare minimum required for a valid OpenAPI schema info: Dict[str, Any] = {"title": title, "version": version} @@ -1490,13 +1505,12 @@ def get_openapi_schema( info.update({field: value for field, value in optional_fields.items() if value}) - output: Dict[str, Any] = {"openapi": openapi_version, "info": info} - if servers: - output["servers"] = servers - else: - # If the servers property is not provided, or is an empty array, the default value would be a Server Object - # with an url value of /. - output["servers"] = [Server(url="/")] + output: Dict[str, Any] = { + "openapi": openapi_version, + "info": info, + "servers": self._get_openapi_servers(servers), + "security": self._get_openapi_security(security, security_schemes), + } components: Dict[str, Dict[str, Any]] = {} paths: Dict[str, Dict[str, Any]] = {} @@ -1534,6 +1548,8 @@ def get_openapi_schema( if definitions: components["schemas"] = {k: definitions[k] for k in sorted(definitions)} + if security_schemes: + components["securitySchemes"] = security_schemes if components: output["components"] = components if tags: @@ -1543,6 +1559,50 @@ def get_openapi_schema( return OpenAPI(**output) + @staticmethod + def _get_openapi_servers(servers: Optional[List["Server"]]) -> List["Server"]: + from aws_lambda_powertools.event_handler.openapi.models import Server + + # If the 'servers' property is not provided or is an empty array, + # the default behavior is to return a Server Object with a URL value of "/". + return servers if servers else [Server(url="/")] + + @staticmethod + def _get_openapi_security( + security: Optional[List[Dict[str, List[str]]]], + security_schemes: Optional[Dict[str, "SecurityScheme"]], + ) -> Optional[List[Dict[str, List[str]]]]: + if not security: + return None + + if not security_schemes: + raise ValueError("security_schemes must be provided if security is provided") + + # Check if all keys in security are present in the security_schemes + if any(key not in security_schemes for sec in security for key in sec): + raise ValueError("Some security schemes not found in security_schemes") + + return security + + @staticmethod + def _determine_openapi_version(openapi_version): + from aws_lambda_powertools.event_handler.openapi.pydantic_loader import PYDANTIC_V2 + + # Pydantic V2 has no support for OpenAPI schema 3.0 + if PYDANTIC_V2 and not openapi_version.startswith("3.1"): + warnings.warn( + "You are using Pydantic v2, which is incompatible with OpenAPI schema 3.0. Forcing OpenAPI 3.1", + stacklevel=2, + ) + openapi_version = "3.1.0" + elif not PYDANTIC_V2 and not openapi_version.startswith("3.0"): + warnings.warn( + "You are using Pydantic v1, which is incompatible with OpenAPI schema 3.1. Forcing OpenAPI 3.0", + stacklevel=2, + ) + openapi_version = "3.0.3" + return openapi_version + def get_openapi_json_schema( self, *, @@ -1556,6 +1616,8 @@ def get_openapi_json_schema( terms_of_service: Optional[str] = None, contact: Optional["Contact"] = None, license_info: Optional["License"] = None, + security_schemes: Optional[Dict[str, "SecurityScheme"]] = None, + security: Optional[List[Dict[str, List[str]]]] = None, ) -> str: """ Returns the OpenAPI schema as a JSON serializable dict @@ -1582,6 +1644,10 @@ def get_openapi_json_schema( The contact information for the exposed API. license_info: License, optional The license information for the exposed API. + security_schemes: Dict[str, "SecurityScheme"]], optional + A declaration of the security schemes available to be used in the specification. + security: List[Dict[str, List[str]]], optional + A declaration of which security mechanisms are applied globally across the API. Returns ------- @@ -1602,6 +1668,8 @@ def get_openapi_json_schema( terms_of_service=terms_of_service, contact=contact, license_info=license_info, + security_schemes=security_schemes, + security=security, ), by_alias=True, exclude_none=True, @@ -1625,6 +1693,9 @@ def enable_swagger( swagger_base_url: Optional[str] = None, middlewares: Optional[List[Callable[..., Response]]] = None, compress: bool = False, + security_schemes: Optional[Dict[str, "SecurityScheme"]] = None, + security: Optional[List[Dict[str, List[str]]]] = None, + oauth2_config: Optional["OAuth2Config"] = None, ): """ Returns the OpenAPI schema as a JSON serializable dict @@ -1659,12 +1730,34 @@ def enable_swagger( List of middlewares to be used for the swagger route. compress: bool, default = False Whether or not to enable gzip compression swagger route. + security_schemes: Dict[str, "SecurityScheme"], optional + A declaration of the security schemes available to be used in the specification. + security: List[Dict[str, List[str]]], optional + A declaration of which security mechanisms are applied globally across the API. + oauth2_config: OAuth2Config, optional + The OAuth2 configuration for the Swagger UI. """ from aws_lambda_powertools.event_handler.openapi.compat import model_json from aws_lambda_powertools.event_handler.openapi.models import Server + from aws_lambda_powertools.event_handler.openapi.swagger_ui import ( + generate_oauth2_redirect_html, + generate_swagger_html, + ) @self.get(path, middlewares=middlewares, include_in_schema=False, compress=compress) def swagger_handler(): + query_params = self.current_event.query_string_parameters or {} + + # Check for query parameters; if "format" is specified as "oauth2-redirect", + # send the oauth2-redirect HTML stanza so OAuth2 can be used + # Source: https://github.com/swagger-api/swagger-ui/blob/master/dist/oauth2-redirect.html + if query_params.get("format") == "oauth2-redirect": + return Response( + status_code=200, + content_type="text/html", + body=generate_oauth2_redirect_html(), + ) + base_path = self._get_base_path() if swagger_base_url: @@ -1690,6 +1783,8 @@ def swagger_handler(): terms_of_service=terms_of_service, contact=contact, license_info=license_info, + security_schemes=security_schemes, + security=security, ) # The .replace('', '<\\/') part is necessary to prevent a potential issue where the JSON string contains @@ -1705,7 +1800,6 @@ def swagger_handler(): # Check for query parameters; if "format" is specified as "json", # respond with the JSON used in the OpenAPI spec # Example: https://www.example.com/swagger?format=json - query_params = self.current_event.query_string_parameters or {} if query_params.get("format") == "json": return Response( status_code=200, @@ -1719,6 +1813,7 @@ def swagger_handler(): swagger_js, swagger_css, swagger_base_url, + oauth2_config, ) return Response( @@ -1741,6 +1836,7 @@ def route( tags: Optional[List[str]] = None, operation_id: Optional[str] = None, include_in_schema: bool = True, + security: Optional[List[Dict[str, List[str]]]] = None, middlewares: Optional[List[Callable[..., Any]]] = None, ): """Route decorator includes parameter `method`""" @@ -1767,6 +1863,7 @@ def register_resolver(func: Callable): tags, operation_id, include_in_schema, + security, middlewares, ) @@ -2218,6 +2315,7 @@ def route( tags: Optional[List[str]] = None, operation_id: Optional[str] = None, include_in_schema: bool = True, + security: Optional[List[Dict[str, List[str]]]] = None, middlewares: Optional[List[Callable[..., Any]]] = None, ): def register_route(func: Callable): @@ -2239,6 +2337,7 @@ def register_route(func: Callable): frozen_tags, operation_id, include_in_schema, + security, ) # Collate Middleware for routes @@ -2318,6 +2417,7 @@ def route( tags: Optional[List[str]] = None, operation_id: Optional[str] = None, include_in_schema: bool = True, + security: Optional[List[Dict[str, List[str]]]] = None, middlewares: Optional[List[Callable[..., Any]]] = None, ): # NOTE: see #1552 for more context. @@ -2334,6 +2434,7 @@ def route( tags, operation_id, include_in_schema, + security, middlewares, ) diff --git a/aws_lambda_powertools/event_handler/bedrock_agent.py b/aws_lambda_powertools/event_handler/bedrock_agent.py index 0ce0f3ff725..4d1a6096f32 100644 --- a/aws_lambda_powertools/event_handler/bedrock_agent.py +++ b/aws_lambda_powertools/event_handler/bedrock_agent.py @@ -102,6 +102,8 @@ def get( # type: ignore[override] include_in_schema: bool = True, middlewares: Optional[List[Callable[..., Any]]] = None, ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + security = None + return super(BedrockAgentResolver, self).get( rule, cors, @@ -114,6 +116,7 @@ def get( # type: ignore[override] tags, operation_id, include_in_schema, + security, middlewares, ) @@ -134,6 +137,8 @@ def post( # type: ignore[override] include_in_schema: bool = True, middlewares: Optional[List[Callable[..., Any]]] = None, ): + security = None + return super().post( rule, cors, @@ -146,6 +151,7 @@ def post( # type: ignore[override] tags, operation_id, include_in_schema, + security, middlewares, ) @@ -166,6 +172,8 @@ def put( # type: ignore[override] include_in_schema: bool = True, middlewares: Optional[List[Callable[..., Any]]] = None, ): + security = None + return super().put( rule, cors, @@ -178,6 +186,7 @@ def put( # type: ignore[override] tags, operation_id, include_in_schema, + security, middlewares, ) @@ -198,6 +207,8 @@ def patch( # type: ignore[override] include_in_schema: bool = True, middlewares: Optional[List[Callable]] = None, ): + security = None + return super().patch( rule, cors, @@ -210,6 +221,7 @@ def patch( # type: ignore[override] tags, operation_id, include_in_schema, + security, middlewares, ) @@ -230,6 +242,8 @@ def delete( # type: ignore[override] include_in_schema: bool = True, middlewares: Optional[List[Callable[..., Any]]] = None, ): + security = None + return super().delete( rule, cors, @@ -242,6 +256,7 @@ def delete( # type: ignore[override] tags, operation_id, include_in_schema, + security, middlewares, ) diff --git a/aws_lambda_powertools/event_handler/openapi/models.py b/aws_lambda_powertools/event_handler/openapi/models.py index ace398ec532..04345ddaad7 100644 --- a/aws_lambda_powertools/event_handler/openapi/models.py +++ b/aws_lambda_powertools/event_handler/openapi/models.py @@ -441,12 +441,13 @@ class SecurityBase(BaseModel): description: Optional[str] = None if PYDANTIC_V2: - model_config = {"extra": "allow"} + model_config = {"extra": "allow", "populate_by_name": True} else: class Config: extra = "allow" + allow_population_by_field_name = True class APIKeyIn(Enum): diff --git a/aws_lambda_powertools/event_handler/openapi/swagger_ui/__init__.py b/aws_lambda_powertools/event_handler/openapi/swagger_ui/__init__.py index e69de29bb2d..bc6eda8abb3 100644 --- a/aws_lambda_powertools/event_handler/openapi/swagger_ui/__init__.py +++ b/aws_lambda_powertools/event_handler/openapi/swagger_ui/__init__.py @@ -0,0 +1,13 @@ +from aws_lambda_powertools.event_handler.openapi.swagger_ui.html import ( + generate_swagger_html, +) +from aws_lambda_powertools.event_handler.openapi.swagger_ui.oauth2 import ( + OAuth2Config, + generate_oauth2_redirect_html, +) + +__all__ = [ + "generate_swagger_html", + "generate_oauth2_redirect_html", + "OAuth2Config", +] diff --git a/aws_lambda_powertools/event_handler/openapi/swagger_ui/html.py b/aws_lambda_powertools/event_handler/openapi/swagger_ui/html.py index fcd644f39f5..8b748d9338a 100644 --- a/aws_lambda_powertools/event_handler/openapi/swagger_ui/html.py +++ b/aws_lambda_powertools/event_handler/openapi/swagger_ui/html.py @@ -1,4 +1,16 @@ -def generate_swagger_html(spec: str, path: str, swagger_js: str, swagger_css: str, swagger_base_url: str) -> str: +from typing import Optional + +from aws_lambda_powertools.event_handler.openapi.swagger_ui.oauth2 import OAuth2Config + + +def generate_swagger_html( + spec: str, + path: str, + swagger_js: str, + swagger_css: str, + swagger_base_url: str, + oauth2_config: Optional[OAuth2Config], +) -> str: """ Generate Swagger UI HTML page @@ -8,10 +20,14 @@ def generate_swagger_html(spec: str, path: str, swagger_js: str, swagger_css: st The OpenAPI spec path: str The path to the Swagger documentation - js_url: str - The URL to the Swagger UI JavaScript file - css_url: str - The URL to the Swagger UI CSS file + swagger_js: str + Swagger UI JavaScript source code or URL + swagger_css: str + Swagger UI CSS source code or URL + swagger_base_url: str + The base URL for Swagger UI + oauth2_config: OAuth2Config, optional + The OAuth2 configuration. """ # If Swagger base URL is present, generate HTML content with linked CSS and JavaScript files @@ -23,6 +39,11 @@ def generate_swagger_html(spec: str, path: str, swagger_js: str, swagger_css: st swagger_css_content = f"" swagger_js_content = f"" + # Prepare oauth2 config + oauth2_content = ( + f"ui.initOAuth({oauth2_config.json(exclude_none=True, exclude_unset=True)});" if oauth2_config else "" + ) + return f""" @@ -45,6 +66,9 @@ def generate_swagger_html(spec: str, path: str, swagger_js: str, swagger_css: st {swagger_js_content} """.strip() diff --git a/aws_lambda_powertools/event_handler/openapi/swagger_ui/oauth2.py b/aws_lambda_powertools/event_handler/openapi/swagger_ui/oauth2.py new file mode 100644 index 00000000000..29250ae0056 --- /dev/null +++ b/aws_lambda_powertools/event_handler/openapi/swagger_ui/oauth2.py @@ -0,0 +1,158 @@ +# ruff: noqa: E501 +import warnings +from typing import Dict, Optional, Sequence + +from pydantic import BaseModel, Field, validator + +from aws_lambda_powertools.event_handler.openapi.pydantic_loader import PYDANTIC_V2 +from aws_lambda_powertools.shared.functions import powertools_dev_is_set + + +# Based on https://swagger.io/docs/open-source-tools/swagger-ui/usage/oauth2/ +class OAuth2Config(BaseModel): + """ + OAuth2 configuration for Swagger UI + """ + + # The client ID for the OAuth2 application + clientId: Optional[str] = Field(alias="client_id", default=None) + + # The client secret for the OAuth2 application. This is sensitive information and requires the explicit presence + # of the POWERTOOLS_DEV environment variable. + clientSecret: Optional[str] = Field(alias="client_secret", default=None) + + # The realm in which the OAuth2 application is registered. Optional. + realm: Optional[str] = Field(default=None) + + # The name of the OAuth2 application + appName: str = Field(alias="app_name") + + # The scopes that the OAuth2 application requires. Defaults to an empty list. + scopes: Sequence[str] = Field(default=[]) + + # Additional query string parameters to be included in the OAuth2 request. Defaults to an empty dictionary. + additionalQueryStringParams: Dict[str, str] = Field(alias="additional_query_string_params", default={}) + + # Whether to use basic authentication with the access code grant type. Defaults to False. + useBasicAuthenticationWithAccessCodeGrant: bool = Field( + alias="use_basic_authentication_with_access_code_grant", + default=False, + ) + + # Whether to use PKCE with the authorization code grant type. Defaults to False. + usePkceWithAuthorizationCodeGrant: bool = Field(alias="use_pkce_with_authorization_code_grant", default=False) + + if PYDANTIC_V2: + model_config = {"extra": "allow"} + else: + + class Config: + extra = "allow" + allow_population_by_field_name = True + + @validator("clientSecret", always=True) + def client_secret_only_on_dev(cls, v: Optional[str]) -> Optional[str]: + if not v: + return None + + if not powertools_dev_is_set(): + raise ValueError( + "cannot use client_secret without POWERTOOLS_DEV mode. See " + "https://docs.powertools.aws.dev/lambda/python/latest/#optimizing-for-non-production-environments", + ) + else: + warnings.warn( + "OAuth2Config is using client_secret and POWERTOOLS_DEV is set. This reveals sensitive information. " + "DO NOT USE THIS OUTSIDE LOCAL DEVELOPMENT", + stacklevel=2, + ) + return v + + +def generate_oauth2_redirect_html() -> str: + """ + Generates the HTML content for the OAuth2 redirect page. + + Source: https://github.com/swagger-api/swagger-ui/blob/master/dist/oauth2-redirect.html + """ + return """ + + +
+