diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 29601247b48..87433b020d5 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -34,7 +34,6 @@ from aws_lambda_powertools.event_handler.exceptions import NotFoundError, ServiceError from aws_lambda_powertools.event_handler.openapi.constants import DEFAULT_API_VERSION, DEFAULT_OPENAPI_VERSION from aws_lambda_powertools.event_handler.openapi.exceptions import RequestValidationError -from aws_lambda_powertools.event_handler.openapi.swagger_ui.html import generate_swagger_html from aws_lambda_powertools.event_handler.openapi.types import ( COMPONENT_REF_PREFIX, METHODS_WITH_BODY, @@ -83,10 +82,14 @@ Contact, License, OpenAPI, + SecurityScheme, Server, Tag, ) from aws_lambda_powertools.event_handler.openapi.params import Dependant + from aws_lambda_powertools.event_handler.openapi.swagger_ui.oauth2 import ( + OAuth2Config, + ) from aws_lambda_powertools.event_handler.openapi.types import ( TypeModelOrEnum, ) @@ -282,6 +285,7 @@ def __init__( tags: Optional[List[str]], operation_id: Optional[str], include_in_schema: bool, + security: Optional[List[Dict[str, List[str]]]], middlewares: Optional[List[Callable[..., Response]]], ): """ @@ -317,6 +321,8 @@ def __init__( The OpenAPI operationId for this route include_in_schema: bool Whether or not to include this route in the OpenAPI schema + security: List[Dict[str, List[str]]], optional + The OpenAPI security for this route middlewares: Optional[List[Callable[..., Response]]] The list of route middlewares to be called in order. """ @@ -339,6 +345,7 @@ def __init__( self.response_description = response_description self.tags = tags or [] self.include_in_schema = include_in_schema + self.security = security self.middlewares = middlewares or [] self.operation_id = operation_id or self._generate_operation_id() @@ -486,6 +493,10 @@ def _get_openapi_path( ) parameters.extend(operation_params) + # Add security if present + if self.security: + operation["security"] = self.security + # Add the parameters to the OpenAPI operation if parameters: all_parameters = {(param["in"], param["name"]): param for param in parameters} @@ -885,6 +896,7 @@ def route( tags: Optional[List[str]] = None, operation_id: Optional[str] = None, include_in_schema: bool = True, + security: Optional[List[Dict[str, List[str]]]] = None, middlewares: Optional[List[Callable[..., Any]]] = None, ): raise NotImplementedError() @@ -943,6 +955,7 @@ def get( tags: Optional[List[str]] = None, operation_id: Optional[str] = None, include_in_schema: bool = True, + security: Optional[List[Dict[str, List[str]]]] = None, middlewares: Optional[List[Callable[..., Any]]] = None, ): """Get route decorator with GET `method` @@ -980,6 +993,7 @@ def lambda_handler(event, context): tags, operation_id, include_in_schema, + security, middlewares, ) @@ -996,6 +1010,7 @@ def post( tags: Optional[List[str]] = None, operation_id: Optional[str] = None, include_in_schema: bool = True, + security: Optional[List[Dict[str, List[str]]]] = None, middlewares: Optional[List[Callable[..., Any]]] = None, ): """Post route decorator with POST `method` @@ -1034,6 +1049,7 @@ def lambda_handler(event, context): tags, operation_id, include_in_schema, + security, middlewares, ) @@ -1050,6 +1066,7 @@ def put( tags: Optional[List[str]] = None, operation_id: Optional[str] = None, include_in_schema: bool = True, + security: Optional[List[Dict[str, List[str]]]] = None, middlewares: Optional[List[Callable[..., Any]]] = None, ): """Put route decorator with PUT `method` @@ -1088,6 +1105,7 @@ def lambda_handler(event, context): tags, operation_id, include_in_schema, + security, middlewares, ) @@ -1104,6 +1122,7 @@ def delete( tags: Optional[List[str]] = None, operation_id: Optional[str] = None, include_in_schema: bool = True, + security: Optional[List[Dict[str, List[str]]]] = None, middlewares: Optional[List[Callable[..., Any]]] = None, ): """Delete route decorator with DELETE `method` @@ -1141,6 +1160,7 @@ def lambda_handler(event, context): tags, operation_id, include_in_schema, + security, middlewares, ) @@ -1157,6 +1177,7 @@ def patch( tags: Optional[List[str]] = None, operation_id: Optional[str] = None, include_in_schema: bool = True, + security: Optional[List[Dict[str, List[str]]]] = None, middlewares: Optional[List[Callable]] = None, ): """Patch route decorator with PATCH `method` @@ -1197,6 +1218,7 @@ def lambda_handler(event, context): tags, operation_id, include_in_schema, + security, middlewares, ) @@ -1419,6 +1441,8 @@ def get_openapi_schema( terms_of_service: Optional[str] = None, contact: Optional["Contact"] = None, license_info: Optional["License"] = None, + security_schemes: Optional[Dict[str, "SecurityScheme"]] = None, + security: Optional[List[Dict[str, List[str]]]] = None, ) -> "OpenAPI": """ Returns the OpenAPI schema as a pydantic model. @@ -1445,6 +1469,10 @@ def get_openapi_schema( The contact information for the exposed API. license_info: License, optional The license information for the exposed API. + security_schemes: Dict[str, "SecurityScheme"]], optional + A declaration of the security schemes available to be used in the specification. + security: List[Dict[str, List[str]]], optional + A declaration of which security mechanisms are applied globally across the API. Returns ------- @@ -1457,25 +1485,12 @@ def get_openapi_schema( get_compat_model_name_map, get_definitions, ) - from aws_lambda_powertools.event_handler.openapi.models import OpenAPI, PathItem, Server, Tag - from aws_lambda_powertools.event_handler.openapi.pydantic_loader import PYDANTIC_V2 + from aws_lambda_powertools.event_handler.openapi.models import OpenAPI, PathItem, Tag from aws_lambda_powertools.event_handler.openapi.types import ( COMPONENT_REF_TEMPLATE, ) - # Pydantic V2 has no support for OpenAPI schema 3.0 - if PYDANTIC_V2 and not openapi_version.startswith("3.1"): - warnings.warn( - "You are using Pydantic v2, which is incompatible with OpenAPI schema 3.0. Forcing OpenAPI 3.1", - stacklevel=2, - ) - openapi_version = "3.1.0" - elif not PYDANTIC_V2 and not openapi_version.startswith("3.0"): - warnings.warn( - "You are using Pydantic v1, which is incompatible with OpenAPI schema 3.1. Forcing OpenAPI 3.0", - stacklevel=2, - ) - openapi_version = "3.0.3" + openapi_version = self._determine_openapi_version(openapi_version) # Start with the bare minimum required for a valid OpenAPI schema info: Dict[str, Any] = {"title": title, "version": version} @@ -1490,13 +1505,12 @@ def get_openapi_schema( info.update({field: value for field, value in optional_fields.items() if value}) - output: Dict[str, Any] = {"openapi": openapi_version, "info": info} - if servers: - output["servers"] = servers - else: - # If the servers property is not provided, or is an empty array, the default value would be a Server Object - # with an url value of /. - output["servers"] = [Server(url="/")] + output: Dict[str, Any] = { + "openapi": openapi_version, + "info": info, + "servers": self._get_openapi_servers(servers), + "security": self._get_openapi_security(security, security_schemes), + } components: Dict[str, Dict[str, Any]] = {} paths: Dict[str, Dict[str, Any]] = {} @@ -1534,6 +1548,8 @@ def get_openapi_schema( if definitions: components["schemas"] = {k: definitions[k] for k in sorted(definitions)} + if security_schemes: + components["securitySchemes"] = security_schemes if components: output["components"] = components if tags: @@ -1543,6 +1559,50 @@ def get_openapi_schema( return OpenAPI(**output) + @staticmethod + def _get_openapi_servers(servers: Optional[List["Server"]]) -> List["Server"]: + from aws_lambda_powertools.event_handler.openapi.models import Server + + # If the 'servers' property is not provided or is an empty array, + # the default behavior is to return a Server Object with a URL value of "/". + return servers if servers else [Server(url="/")] + + @staticmethod + def _get_openapi_security( + security: Optional[List[Dict[str, List[str]]]], + security_schemes: Optional[Dict[str, "SecurityScheme"]], + ) -> Optional[List[Dict[str, List[str]]]]: + if not security: + return None + + if not security_schemes: + raise ValueError("security_schemes must be provided if security is provided") + + # Check if all keys in security are present in the security_schemes + if any(key not in security_schemes for sec in security for key in sec): + raise ValueError("Some security schemes not found in security_schemes") + + return security + + @staticmethod + def _determine_openapi_version(openapi_version): + from aws_lambda_powertools.event_handler.openapi.pydantic_loader import PYDANTIC_V2 + + # Pydantic V2 has no support for OpenAPI schema 3.0 + if PYDANTIC_V2 and not openapi_version.startswith("3.1"): + warnings.warn( + "You are using Pydantic v2, which is incompatible with OpenAPI schema 3.0. Forcing OpenAPI 3.1", + stacklevel=2, + ) + openapi_version = "3.1.0" + elif not PYDANTIC_V2 and not openapi_version.startswith("3.0"): + warnings.warn( + "You are using Pydantic v1, which is incompatible with OpenAPI schema 3.1. Forcing OpenAPI 3.0", + stacklevel=2, + ) + openapi_version = "3.0.3" + return openapi_version + def get_openapi_json_schema( self, *, @@ -1556,6 +1616,8 @@ def get_openapi_json_schema( terms_of_service: Optional[str] = None, contact: Optional["Contact"] = None, license_info: Optional["License"] = None, + security_schemes: Optional[Dict[str, "SecurityScheme"]] = None, + security: Optional[List[Dict[str, List[str]]]] = None, ) -> str: """ Returns the OpenAPI schema as a JSON serializable dict @@ -1582,6 +1644,10 @@ def get_openapi_json_schema( The contact information for the exposed API. license_info: License, optional The license information for the exposed API. + security_schemes: Dict[str, "SecurityScheme"]], optional + A declaration of the security schemes available to be used in the specification. + security: List[Dict[str, List[str]]], optional + A declaration of which security mechanisms are applied globally across the API. Returns ------- @@ -1602,6 +1668,8 @@ def get_openapi_json_schema( terms_of_service=terms_of_service, contact=contact, license_info=license_info, + security_schemes=security_schemes, + security=security, ), by_alias=True, exclude_none=True, @@ -1625,6 +1693,9 @@ def enable_swagger( swagger_base_url: Optional[str] = None, middlewares: Optional[List[Callable[..., Response]]] = None, compress: bool = False, + security_schemes: Optional[Dict[str, "SecurityScheme"]] = None, + security: Optional[List[Dict[str, List[str]]]] = None, + oauth2_config: Optional["OAuth2Config"] = None, ): """ Returns the OpenAPI schema as a JSON serializable dict @@ -1659,12 +1730,34 @@ def enable_swagger( List of middlewares to be used for the swagger route. compress: bool, default = False Whether or not to enable gzip compression swagger route. + security_schemes: Dict[str, "SecurityScheme"], optional + A declaration of the security schemes available to be used in the specification. + security: List[Dict[str, List[str]]], optional + A declaration of which security mechanisms are applied globally across the API. + oauth2_config: OAuth2Config, optional + The OAuth2 configuration for the Swagger UI. """ from aws_lambda_powertools.event_handler.openapi.compat import model_json from aws_lambda_powertools.event_handler.openapi.models import Server + from aws_lambda_powertools.event_handler.openapi.swagger_ui import ( + generate_oauth2_redirect_html, + generate_swagger_html, + ) @self.get(path, middlewares=middlewares, include_in_schema=False, compress=compress) def swagger_handler(): + query_params = self.current_event.query_string_parameters or {} + + # Check for query parameters; if "format" is specified as "oauth2-redirect", + # send the oauth2-redirect HTML stanza so OAuth2 can be used + # Source: https://github.com/swagger-api/swagger-ui/blob/master/dist/oauth2-redirect.html + if query_params.get("format") == "oauth2-redirect": + return Response( + status_code=200, + content_type="text/html", + body=generate_oauth2_redirect_html(), + ) + base_path = self._get_base_path() if swagger_base_url: @@ -1690,6 +1783,8 @@ def swagger_handler(): terms_of_service=terms_of_service, contact=contact, license_info=license_info, + security_schemes=security_schemes, + security=security, ) # The .replace(' Callable[[Callable[..., Any]], Callable[..., Any]]: + security = None + return super(BedrockAgentResolver, self).get( rule, cors, @@ -114,6 +116,7 @@ def get( # type: ignore[override] tags, operation_id, include_in_schema, + security, middlewares, ) @@ -134,6 +137,8 @@ def post( # type: ignore[override] include_in_schema: bool = True, middlewares: Optional[List[Callable[..., Any]]] = None, ): + security = None + return super().post( rule, cors, @@ -146,6 +151,7 @@ def post( # type: ignore[override] tags, operation_id, include_in_schema, + security, middlewares, ) @@ -166,6 +172,8 @@ def put( # type: ignore[override] include_in_schema: bool = True, middlewares: Optional[List[Callable[..., Any]]] = None, ): + security = None + return super().put( rule, cors, @@ -178,6 +186,7 @@ def put( # type: ignore[override] tags, operation_id, include_in_schema, + security, middlewares, ) @@ -198,6 +207,8 @@ def patch( # type: ignore[override] include_in_schema: bool = True, middlewares: Optional[List[Callable]] = None, ): + security = None + return super().patch( rule, cors, @@ -210,6 +221,7 @@ def patch( # type: ignore[override] tags, operation_id, include_in_schema, + security, middlewares, ) @@ -230,6 +242,8 @@ def delete( # type: ignore[override] include_in_schema: bool = True, middlewares: Optional[List[Callable[..., Any]]] = None, ): + security = None + return super().delete( rule, cors, @@ -242,6 +256,7 @@ def delete( # type: ignore[override] tags, operation_id, include_in_schema, + security, middlewares, ) diff --git a/aws_lambda_powertools/event_handler/openapi/models.py b/aws_lambda_powertools/event_handler/openapi/models.py index ace398ec532..04345ddaad7 100644 --- a/aws_lambda_powertools/event_handler/openapi/models.py +++ b/aws_lambda_powertools/event_handler/openapi/models.py @@ -441,12 +441,13 @@ class SecurityBase(BaseModel): description: Optional[str] = None if PYDANTIC_V2: - model_config = {"extra": "allow"} + model_config = {"extra": "allow", "populate_by_name": True} else: class Config: extra = "allow" + allow_population_by_field_name = True class APIKeyIn(Enum): diff --git a/aws_lambda_powertools/event_handler/openapi/swagger_ui/__init__.py b/aws_lambda_powertools/event_handler/openapi/swagger_ui/__init__.py index e69de29bb2d..bc6eda8abb3 100644 --- a/aws_lambda_powertools/event_handler/openapi/swagger_ui/__init__.py +++ b/aws_lambda_powertools/event_handler/openapi/swagger_ui/__init__.py @@ -0,0 +1,13 @@ +from aws_lambda_powertools.event_handler.openapi.swagger_ui.html import ( + generate_swagger_html, +) +from aws_lambda_powertools.event_handler.openapi.swagger_ui.oauth2 import ( + OAuth2Config, + generate_oauth2_redirect_html, +) + +__all__ = [ + "generate_swagger_html", + "generate_oauth2_redirect_html", + "OAuth2Config", +] diff --git a/aws_lambda_powertools/event_handler/openapi/swagger_ui/html.py b/aws_lambda_powertools/event_handler/openapi/swagger_ui/html.py index fcd644f39f5..8b748d9338a 100644 --- a/aws_lambda_powertools/event_handler/openapi/swagger_ui/html.py +++ b/aws_lambda_powertools/event_handler/openapi/swagger_ui/html.py @@ -1,4 +1,16 @@ -def generate_swagger_html(spec: str, path: str, swagger_js: str, swagger_css: str, swagger_base_url: str) -> str: +from typing import Optional + +from aws_lambda_powertools.event_handler.openapi.swagger_ui.oauth2 import OAuth2Config + + +def generate_swagger_html( + spec: str, + path: str, + swagger_js: str, + swagger_css: str, + swagger_base_url: str, + oauth2_config: Optional[OAuth2Config], +) -> str: """ Generate Swagger UI HTML page @@ -8,10 +20,14 @@ def generate_swagger_html(spec: str, path: str, swagger_js: str, swagger_css: st The OpenAPI spec path: str The path to the Swagger documentation - js_url: str - The URL to the Swagger UI JavaScript file - css_url: str - The URL to the Swagger UI CSS file + swagger_js: str + Swagger UI JavaScript source code or URL + swagger_css: str + Swagger UI CSS source code or URL + swagger_base_url: str + The base URL for Swagger UI + oauth2_config: OAuth2Config, optional + The OAuth2 configuration. """ # If Swagger base URL is present, generate HTML content with linked CSS and JavaScript files @@ -23,6 +39,11 @@ def generate_swagger_html(spec: str, path: str, swagger_js: str, swagger_css: st swagger_css_content = f"" swagger_js_content = f"" + # Prepare oauth2 config + oauth2_content = ( + f"ui.initOAuth({oauth2_config.json(exclude_none=True, exclude_unset=True)});" if oauth2_config else "" + ) + return f""" @@ -45,6 +66,9 @@ def generate_swagger_html(spec: str, path: str, swagger_js: str, swagger_css: st {swagger_js_content} """.strip() diff --git a/aws_lambda_powertools/event_handler/openapi/swagger_ui/oauth2.py b/aws_lambda_powertools/event_handler/openapi/swagger_ui/oauth2.py new file mode 100644 index 00000000000..29250ae0056 --- /dev/null +++ b/aws_lambda_powertools/event_handler/openapi/swagger_ui/oauth2.py @@ -0,0 +1,158 @@ +# ruff: noqa: E501 +import warnings +from typing import Dict, Optional, Sequence + +from pydantic import BaseModel, Field, validator + +from aws_lambda_powertools.event_handler.openapi.pydantic_loader import PYDANTIC_V2 +from aws_lambda_powertools.shared.functions import powertools_dev_is_set + + +# Based on https://swagger.io/docs/open-source-tools/swagger-ui/usage/oauth2/ +class OAuth2Config(BaseModel): + """ + OAuth2 configuration for Swagger UI + """ + + # The client ID for the OAuth2 application + clientId: Optional[str] = Field(alias="client_id", default=None) + + # The client secret for the OAuth2 application. This is sensitive information and requires the explicit presence + # of the POWERTOOLS_DEV environment variable. + clientSecret: Optional[str] = Field(alias="client_secret", default=None) + + # The realm in which the OAuth2 application is registered. Optional. + realm: Optional[str] = Field(default=None) + + # The name of the OAuth2 application + appName: str = Field(alias="app_name") + + # The scopes that the OAuth2 application requires. Defaults to an empty list. + scopes: Sequence[str] = Field(default=[]) + + # Additional query string parameters to be included in the OAuth2 request. Defaults to an empty dictionary. + additionalQueryStringParams: Dict[str, str] = Field(alias="additional_query_string_params", default={}) + + # Whether to use basic authentication with the access code grant type. Defaults to False. + useBasicAuthenticationWithAccessCodeGrant: bool = Field( + alias="use_basic_authentication_with_access_code_grant", + default=False, + ) + + # Whether to use PKCE with the authorization code grant type. Defaults to False. + usePkceWithAuthorizationCodeGrant: bool = Field(alias="use_pkce_with_authorization_code_grant", default=False) + + if PYDANTIC_V2: + model_config = {"extra": "allow"} + else: + + class Config: + extra = "allow" + allow_population_by_field_name = True + + @validator("clientSecret", always=True) + def client_secret_only_on_dev(cls, v: Optional[str]) -> Optional[str]: + if not v: + return None + + if not powertools_dev_is_set(): + raise ValueError( + "cannot use client_secret without POWERTOOLS_DEV mode. See " + "https://docs.powertools.aws.dev/lambda/python/latest/#optimizing-for-non-production-environments", + ) + else: + warnings.warn( + "OAuth2Config is using client_secret and POWERTOOLS_DEV is set. This reveals sensitive information. " + "DO NOT USE THIS OUTSIDE LOCAL DEVELOPMENT", + stacklevel=2, + ) + return v + + +def generate_oauth2_redirect_html() -> str: + """ + Generates the HTML content for the OAuth2 redirect page. + + Source: https://github.com/swagger-api/swagger-ui/blob/master/dist/oauth2-redirect.html + """ + return """ + + + + Swagger UI: OAuth2 Redirect + + + + + + """.strip() diff --git a/docs/core/event_handler/api_gateway.md b/docs/core/event_handler/api_gateway.md index 1ca54ac57b4..aaf9352ebc0 100644 --- a/docs/core/event_handler/api_gateway.md +++ b/docs/core/event_handler/api_gateway.md @@ -991,6 +991,18 @@ To implement these customizations, include extra parameters when defining your r --8<-- "examples/event_handler_rest/src/customizing_api_operations.py" ``` +#### Customizing OpenAPI metadata + +--8<-- "docs/core/event_handler/_openapi_customization_metadata.md" + +Include extra parameters when exporting your OpenAPI specification to apply these customizations: + +=== "customizing_api_metadata.py" + + ```python hl_lines="25-31" + --8<-- "examples/event_handler_rest/src/customizing_api_metadata.py" + ``` + #### Customizing Swagger UI ???+note "Customizing the Swagger metadata" @@ -1014,16 +1026,44 @@ Below is an example configuration for serving Swagger UI from a custom path or C --8<-- "examples/event_handler_rest/src/customizing_swagger_middlewares.py" ``` -#### Customizing OpenAPI metadata +#### Security schemes ---8<-- "docs/core/event_handler/_openapi_customization_metadata.md" +???-info "Does Powertools implement any of the security schemes?" + No. Powertools adds support for generating OpenAPI documentation with [security schemes](https://swagger.io/docs/specification/authentication/), but it doesn't implement any of the security schemes itself, so you must implement the security mechanisms separately. -Include extra parameters when exporting your OpenAPI specification to apply these customizations: +OpenAPI uses the term security scheme for [authentication and authorization schemes](https://swagger.io/docs/specification/authentication/){target="_blank"}. +When you're describing your API, declare security schemes at the top level, and reference them globally or per operation. -=== "customizing_api_metadata.py" +=== "Global OpenAPI security schemes" - ```python hl_lines="25-31" - --8<-- "examples/event_handler_rest/src/customizing_api_metadata.py" + ```python title="security_schemes_global.py" hl_lines="32-42" + --8<-- "examples/event_handler_rest/src/security_schemes_global.py" + ``` + + 1. Using the oauth security scheme defined earlier, scoped to the "admin" role. + +=== "Per Operation security" + + ```python title="security_schemes_per_operation.py" hl_lines="17 32-41" + --8<-- "examples/event_handler_rest/src/security_schemes_per_operation.py" + ``` + + 1. Using the oauth security scheme defined bellow, scoped to the "admin" role. + +OpenAPI 3 lets you describe APIs protected using the following security schemes: + +| Security Scheme | Type | Description | +|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| [HTTP auth](https://www.iana.org/assignments/http-authschemes/http-authschemes.xhtml){target="_blank"} | `HTTPBase` | HTTP authentication schemes using the Authorization header (e.g: [Basic auth](https://swagger.io/docs/specification/authentication/basic-authentication/){target="_blank"}, [Bearer](https://swagger.io/docs/specification/authentication/bearer-authentication/){target="_blank"}) | +| [API keys](https://swagger.io/docs/specification/authentication/api-keys/https://swagger.io/docs/specification/authentication/api-keys/){target="_blank"} (e.g: query strings, cookies) | `APIKey` | API keys in headers, query strings or [cookies](https://swagger.io/docs/specification/authentication/cookie-authentication/){target="_blank"}. | +| [OAuth 2](https://swagger.io/docs/specification/authentication/oauth2/){target="_blank"} | `OAuth2` | Authorization protocol that gives an API client limited access to user data on a web server. | +| [OpenID Connect Discovery](https://swagger.io/docs/specification/authentication/openid-connect-discovery/){target="_blank"} | `OpenIdConnect` | Identity layer built [on top of the OAuth 2.0 protocol](https://openid.net/developers/how-connect-works/){target="_blank"} and supported by some OAuth 2.0. | + +???-note "Using OAuth2 with the Swagger UI?" + You can use the `OAuth2Config` option to configure a default OAuth2 app on the generated Swagger UI. + + ```python hl_lines="10 15-18 22" + --8<-- "examples/event_handler_rest/src/swagger_with_oauth2.py" ``` ### Custom serializer diff --git a/examples/event_handler_rest/sam/swagger_ui_oauth2_template.yaml b/examples/event_handler_rest/sam/swagger_ui_oauth2_template.yaml new file mode 100644 index 00000000000..629fa02f88b --- /dev/null +++ b/examples/event_handler_rest/sam/swagger_ui_oauth2_template.yaml @@ -0,0 +1,84 @@ +AWSTemplateFormatVersion: "2010-09-09" +Transform: AWS::Serverless-2016-10-31 +Description: Sample SAM Template for Oauth2 Cognito User Pool + Swagger UI + +Globals: + Function: + Timeout: 5 + Runtime: python3.12 + Tracing: Active + Environment: + Variables: + LOG_LEVEL: INFO + POWERTOOLS_LOGGER_SAMPLE_RATE: 0.1 + POWERTOOLS_LOGGER_LOG_EVENT: true + POWERTOOLS_SERVICE_NAME: example + +Resources: + HelloWorldFunction: + Type: AWS::Serverless::Function + Properties: + CodeUri: ../src + Handler: swagger_ui_oauth2.lambda_handler + Environment: + Variables: + COGNITO_USER_POOL_DOMAIN: !Ref UserPoolDomain + Events: + AnyApiEvent: + Type: Api + Properties: + Path: /{proxy+} # Send requests on any path to the lambda function + Method: ANY # Send requests using any http method to the lambda function + + CognitoUserPool: + Type: AWS::Cognito::UserPool + Properties: + UserPoolName: PowertoolsUserPool + Policies: + PasswordPolicy: + MinimumLength: 8 + RequireLowercase: true + RequireNumbers: true + RequireSymbols: true + RequireUppercase: true + + CognitoUserPoolClient: + Type: AWS::Cognito::UserPoolClient + Properties: + ClientName: PowertoolsClient + UserPoolId: !Ref CognitoUserPool + GenerateSecret: true + RefreshTokenValidity: 30 + ExplicitAuthFlows: + - ALLOW_USER_PASSWORD_AUTH + - ALLOW_REFRESH_TOKEN_AUTH + SupportedIdentityProviders: + - COGNITO + CallbackURLs: + # NOTE: for this to work, your OAuth2 redirect url needs to precisely follow this format: + # https://.execute-api..amazonaws.com//swagger?format=oauth2-redirect + - !Sub "https://${ServerlessRestApi}.execute-api.${AWS::Region}.amazonaws.com/${ServerlessRestApi.Stage}/swagger?format=oauth2-redirect" + AllowedOAuthFlows: + - code + AllowedOAuthScopes: + - openid + - email + - profile + - aws.cognito.signin.user.admin + AllowedOAuthFlowsUserPoolClient: true + + UserPoolDomain: + Type: AWS::Cognito::UserPoolDomain + Properties: + Domain: powertools-swagger-oauth2 + UserPoolId: !Ref CognitoUserPool + +Outputs: + HelloWorldApiUrl: + Value: !Sub "https://${ServerlessRestApi}.execute-api.${AWS::Region}.amazonaws.com/${ServerlessRestApi.Stage}/swagger" + + CognitoOAuthClientId: + Value: !GetAtt CognitoUserPoolClient.ClientId + + CognitoDomain: + Value: !Ref UserPoolDomain diff --git a/examples/event_handler_rest/src/security_schemes_global.py b/examples/event_handler_rest/src/security_schemes_global.py new file mode 100644 index 00000000000..3a3ef5ce6f4 --- /dev/null +++ b/examples/event_handler_rest/src/security_schemes_global.py @@ -0,0 +1,44 @@ +from aws_lambda_powertools import Logger, Tracer +from aws_lambda_powertools.event_handler import ( + APIGatewayRestResolver, +) +from aws_lambda_powertools.event_handler.openapi.models import ( + OAuth2, + OAuthFlowAuthorizationCode, + OAuthFlows, +) + +tracer = Tracer() +logger = Logger() + +app = APIGatewayRestResolver(enable_validation=True) + + +@app.get("/") +def helloworld() -> dict: + return {"hello": "world"} + + +@logger.inject_lambda_context +@tracer.capture_lambda_handler +def lambda_handler(event, context): + return app.resolve(event, context) + + +if __name__ == "__main__": + print( + app.get_openapi_json_schema( + title="My API", + security_schemes={ + "oauth": OAuth2( + flows=OAuthFlows( + authorizationCode=OAuthFlowAuthorizationCode( + authorizationUrl="https://xxx.amazoncognito.com/oauth2/authorize", + tokenUrl="https://xxx.amazoncognito.com/oauth2/token", + ), + ), + ), + }, + security=[{"oauth": ["admin"]}], # (1)! + ), + ) diff --git a/examples/event_handler_rest/src/security_schemes_per_operation.py b/examples/event_handler_rest/src/security_schemes_per_operation.py new file mode 100644 index 00000000000..66770a787c7 --- /dev/null +++ b/examples/event_handler_rest/src/security_schemes_per_operation.py @@ -0,0 +1,43 @@ +from aws_lambda_powertools import Logger, Tracer +from aws_lambda_powertools.event_handler import ( + APIGatewayRestResolver, +) +from aws_lambda_powertools.event_handler.openapi.models import ( + OAuth2, + OAuthFlowAuthorizationCode, + OAuthFlows, +) + +tracer = Tracer() +logger = Logger() + +app = APIGatewayRestResolver(enable_validation=True) + + +@app.get("/", security=[{"oauth": ["admin"]}]) # (1)! +def helloworld() -> dict: + return {"hello": "world"} + + +@logger.inject_lambda_context +@tracer.capture_lambda_handler +def lambda_handler(event, context): + return app.resolve(event, context) + + +if __name__ == "__main__": + print( + app.get_openapi_json_schema( + title="My API", + security_schemes={ + "oauth": OAuth2( + flows=OAuthFlows( + authorizationCode=OAuthFlowAuthorizationCode( + authorizationUrl="https://xxx.amazoncognito.com/oauth2/authorize", + tokenUrl="https://xxx.amazoncognito.com/oauth2/token", + ), + ), + ), + }, + ), + ) diff --git a/examples/event_handler_rest/src/swagger_ui_oauth2.py b/examples/event_handler_rest/src/swagger_ui_oauth2.py new file mode 100644 index 00000000000..1dc7f173735 --- /dev/null +++ b/examples/event_handler_rest/src/swagger_ui_oauth2.py @@ -0,0 +1,53 @@ +import os + +from aws_lambda_powertools import Logger, Tracer +from aws_lambda_powertools.event_handler import ( + APIGatewayRestResolver, + Response, +) +from aws_lambda_powertools.event_handler.openapi.models import ( + OAuth2, + OAuthFlowAuthorizationCode, + OAuthFlows, +) +from aws_lambda_powertools.event_handler.openapi.swagger_ui import OAuth2Config + +tracer = Tracer() +logger = Logger() + +region = os.getenv("AWS_REGION") +cognito_domain = os.getenv("COGNITO_USER_POOL_DOMAIN") + +app = APIGatewayRestResolver(enable_validation=True) +app.enable_swagger( + # NOTE: for this to work, your OAuth2 redirect url needs to precisely follow this format: + # https://.execute-api..amazonaws.com//swagger?format=oauth2-redirect + oauth2_config=OAuth2Config(app_name="OAuth2 Test"), + security_schemes={ + "oauth": OAuth2( + flows=OAuthFlows( + authorizationCode=OAuthFlowAuthorizationCode( + authorizationUrl=f"https://{cognito_domain}.auth.{region}.amazoncognito.com/oauth2/authorize", + tokenUrl=f"https://{cognito_domain}.auth.{region}.amazoncognito.com/oauth2/token", + ), + ), + ), + }, + security=[{"oauth": []}], +) + + +@app.get("/") +def helloworld() -> Response[dict]: + logger.info("Hello, World!") + return Response( + status_code=200, + body={"message": "Hello, World"}, + content_type="application/json", + ) + + +@logger.inject_lambda_context(log_event=True) +@tracer.capture_lambda_handler +def lambda_handler(event, context): + return app.resolve(event, context) diff --git a/examples/event_handler_rest/src/swagger_with_oauth2.py b/examples/event_handler_rest/src/swagger_with_oauth2.py new file mode 100644 index 00000000000..4a2a86cdd40 --- /dev/null +++ b/examples/event_handler_rest/src/swagger_with_oauth2.py @@ -0,0 +1,45 @@ +from aws_lambda_powertools import Logger, Tracer +from aws_lambda_powertools.event_handler import ( + APIGatewayRestResolver, +) +from aws_lambda_powertools.event_handler.openapi.models import ( + OAuth2, + OAuthFlowAuthorizationCode, + OAuthFlows, +) +from aws_lambda_powertools.event_handler.openapi.swagger_ui import OAuth2Config + +tracer = Tracer() +logger = Logger() + +oauth2 = OAuth2Config( + client_id="xxxxxxxxxxxxxxxxxxxxxxxxxxxx", + app_name="OAuth2 app", +) + +app = APIGatewayRestResolver(enable_validation=True) +app.enable_swagger( + oauth2_config=oauth2, + security_schemes={ + "oauth": OAuth2( + flows=OAuthFlows( + authorizationCode=OAuthFlowAuthorizationCode( + authorizationUrl="https://xxx.amazoncognito.com/oauth2/authorize", + tokenUrl="https://xxx.amazoncognito.com/oauth2/token", + ), + ), + ), + }, + security=[{"oauth": []}], +) + + +@app.get("/") +def hello() -> str: + return "world" + + +@logger.inject_lambda_context +@tracer.capture_lambda_handler +def lambda_handler(event, context): + return app.resolve(event, context) diff --git a/tests/functional/event_handler/test_openapi_security.py b/tests/functional/event_handler/test_openapi_security.py new file mode 100644 index 00000000000..7120a815edd --- /dev/null +++ b/tests/functional/event_handler/test_openapi_security.py @@ -0,0 +1,62 @@ +import pytest + +from aws_lambda_powertools.event_handler import APIGatewayRestResolver +from aws_lambda_powertools.event_handler.openapi.models import APIKey, APIKeyIn + + +def test_openapi_top_level_security(): + app = APIGatewayRestResolver() + + @app.get("/") + def handler(): + raise NotImplementedError() + + schema = app.get_openapi_schema( + security_schemes={ + "apiKey": APIKey(name="X-API-KEY", description="API Key", in_=APIKeyIn.header), + }, + security=[{"apiKey": []}], + ) + + security = schema.security + assert security is not None + + assert len(security) == 1 + assert security[0] == {"apiKey": []} + + +def test_openapi_top_level_security_missing(): + app = APIGatewayRestResolver() + + @app.get("/") + def handler(): + raise NotImplementedError() + + with pytest.raises(ValueError): + app.get_openapi_schema( + security=[{"apiKey": []}], + ) + + +def test_openapi_operation_security(): + app = APIGatewayRestResolver() + + @app.get("/", security=[{"apiKey": []}]) + def handler(): + raise NotImplementedError() + + schema = app.get_openapi_schema( + security_schemes={ + "apiKey": APIKey(name="X-API-KEY", description="API Key", in_=APIKeyIn.header), + }, + ) + + security = schema.security + assert security is None + + operation = schema.paths["/"].get + security = operation.security + assert security is not None + + assert len(security) == 1 + assert security[0] == {"apiKey": []} diff --git a/tests/functional/event_handler/test_openapi_security_schemes.py b/tests/functional/event_handler/test_openapi_security_schemes.py new file mode 100644 index 00000000000..dc785ba56d0 --- /dev/null +++ b/tests/functional/event_handler/test_openapi_security_schemes.py @@ -0,0 +1,112 @@ +from aws_lambda_powertools.event_handler import APIGatewayRestResolver +from aws_lambda_powertools.event_handler.openapi.models import ( + APIKey, + APIKeyIn, + HTTPBearer, + OAuth2, + OAuthFlowImplicit, + OAuthFlows, + OpenIdConnect, +) + + +def test_openapi_security_scheme_api_key(): + app = APIGatewayRestResolver() + + @app.get("/") + def handler(): + raise NotImplementedError() + + schema = app.get_openapi_schema( + security_schemes={ + "apiKey": APIKey(name="X-API-KEY", description="API Key", in_=APIKeyIn.header), + }, + ) + + security_schemes = schema.components.securitySchemes + assert security_schemes is not None + + assert "apiKey" in security_schemes + api_key_scheme = security_schemes["apiKey"] + assert api_key_scheme.type_.value == "apiKey" + assert api_key_scheme.name == "X-API-KEY" + assert api_key_scheme.description == "API Key" + assert api_key_scheme.in_.value == "header" + + +def test_openapi_security_scheme_http(): + app = APIGatewayRestResolver() + + @app.get("/") + def handler(): + raise NotImplementedError() + + schema = app.get_openapi_schema( + security_schemes={ + "bearerAuth": HTTPBearer( + description="JWT Token", + bearerFormat="JWT", + ), + }, + ) + + security_schemes = schema.components.securitySchemes + assert security_schemes is not None + + assert "bearerAuth" in security_schemes + http_scheme = security_schemes["bearerAuth"] + assert http_scheme.type_.value == "http" + assert http_scheme.scheme == "bearer" + assert http_scheme.bearerFormat == "JWT" + + +def test_openapi_security_scheme_oauth2(): + app = APIGatewayRestResolver() + + @app.get("/") + def handler(): + raise NotImplementedError() + + schema = app.get_openapi_schema( + security_schemes={ + "oauth2": OAuth2( + flows=OAuthFlows( + implicit=OAuthFlowImplicit( + authorizationUrl="https://example.com/oauth2/authorize", + ), + ), + ), + }, + ) + + security_schemes = schema.components.securitySchemes + assert security_schemes is not None + + assert "oauth2" in security_schemes + oauth2_scheme = security_schemes["oauth2"] + assert oauth2_scheme.type_.value == "oauth2" + assert oauth2_scheme.flows.implicit.authorizationUrl == "https://example.com/oauth2/authorize" + + +def test_openapi_security_scheme_open_id_connect(): + app = APIGatewayRestResolver() + + @app.get("/") + def handler(): + raise NotImplementedError() + + schema = app.get_openapi_schema( + security_schemes={ + "openIdConnect": OpenIdConnect( + openIdConnectUrl="https://example.com/oauth2/authorize", + ), + }, + ) + + security_schemes = schema.components.securitySchemes + assert security_schemes is not None + + assert "openIdConnect" in security_schemes + open_id_connect_scheme = security_schemes["openIdConnect"] + assert open_id_connect_scheme.type_.value == "openIdConnect" + assert open_id_connect_scheme.openIdConnectUrl == "https://example.com/oauth2/authorize" diff --git a/tests/functional/event_handler/test_openapi_swagger.py b/tests/functional/event_handler/test_openapi_swagger.py index 82c9b4874d0..11ec0cf24da 100644 --- a/tests/functional/event_handler/test_openapi_swagger.py +++ b/tests/functional/event_handler/test_openapi_swagger.py @@ -1,7 +1,11 @@ import json +import warnings from typing import Dict +import pytest + from aws_lambda_powertools.event_handler import APIGatewayRestResolver +from aws_lambda_powertools.event_handler.openapi.swagger_ui import OAuth2Config from tests.functional.utils import load_event LOAD_GW_EVENT = load_event("apiGatewayProxyEvent.json") @@ -112,3 +116,26 @@ def test_openapi_swagger_with_rest_api_stage(): result = app(event, {}) assert result["statusCode"] == 200 assert "ui.specActions.updateUrl('/prod/swagger?format=json')" in result["body"] + + +def test_openapi_swagger_oauth2_without_powertools_dev(): + with pytest.raises(ValueError) as exc: + OAuth2Config(app_name="OAuth2 app", client_id="client_id", client_secret="verysecret") + + assert "cannot use client_secret without POWERTOOLS_DEV mode" in str(exc.value) + + +def test_openapi_swagger_oauth2_with_powertools_dev(monkeypatch): + monkeypatch.setenv("POWERTOOLS_DEV", "1") + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("default") + + OAuth2Config(app_name="OAuth2 app", client_id="client_id", client_secret="verysecret") + + assert str(w[-1].message) == ( + "OAuth2Config is using client_secret and POWERTOOLS_DEV is set. This reveals sensitive information. " + "DO NOT USE THIS OUTSIDE LOCAL DEVELOPMENT" + ) + + monkeypatch.delenv("POWERTOOLS_DEV")