diff --git a/Makefile b/Makefile index 4bb1e78390c..65df3162f0e 100644 --- a/Makefile +++ b/Makefile @@ -84,7 +84,7 @@ complexity-baseline: $(info Maintenability index) poetry run radon mi aws_lambda_powertools $(info Cyclomatic complexity index) - poetry run xenon --max-absolute C --max-modules A --max-average A aws_lambda_powertools --exclude aws_lambda_powertools/shared/json_encoder.py,aws_lambda_powertools/utilities/validation/base.py + poetry run xenon --max-absolute C --max-modules A --max-average A aws_lambda_powertools --exclude aws_lambda_powertools/shared/json_encoder.py,aws_lambda_powertools/utilities/validation/base.py,aws_lambda_powertools/event_handler/api_gateway.py # # Use `poetry version /` for version bump diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index d4cff69423d..c8e4248fda4 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -18,7 +18,12 @@ from aws_lambda_powertools.event_handler import content_types from aws_lambda_powertools.event_handler.exceptions import NotFoundError, ServiceError -from aws_lambda_powertools.event_handler.openapi.constants import DEFAULT_API_VERSION, DEFAULT_OPENAPI_VERSION +from aws_lambda_powertools.event_handler.openapi.config import OpenAPIConfig +from aws_lambda_powertools.event_handler.openapi.constants import ( + DEFAULT_API_VERSION, + DEFAULT_OPENAPI_TITLE, + DEFAULT_OPENAPI_VERSION, +) from aws_lambda_powertools.event_handler.openapi.exceptions import ( RequestValidationError, ResponseValidationError, @@ -1537,6 +1542,7 @@ def __init__( self.context: dict = {} # early init as customers might add context before event resolution self.processed_stack_frames = [] self._response_builder_class = ResponseBuilder[BaseProxyEvent] + self.openapi_config = OpenAPIConfig() # starting an empty dataclass self._has_response_validation_error = response_validation_error_http_code is not None self._response_validation_error_http_code = self._validate_response_validation_error_http_code( response_validation_error_http_code, @@ -1580,16 +1586,12 @@ def _validate_response_validation_error_http_code( msg = f"'{response_validation_error_http_code}' must be an integer representing an HTTP status code." raise ValueError(msg) from None - return ( - response_validation_error_http_code - if response_validation_error_http_code - else HTTPStatus.UNPROCESSABLE_ENTITY - ) + return response_validation_error_http_code or HTTPStatus.UNPROCESSABLE_ENTITY def get_openapi_schema( self, *, - title: str = "Powertools API", + title: str = DEFAULT_OPENAPI_TITLE, version: str = DEFAULT_API_VERSION, openapi_version: str = DEFAULT_OPENAPI_VERSION, summary: str | None = None, @@ -1641,6 +1643,29 @@ def get_openapi_schema( The OpenAPI schema as a pydantic model. """ + # DEPRECATION: Will be removed in v4.0.0. Use configure_api() instead. + # Maintained for backwards compatibility. + # See: https://github.com/aws-powertools/powertools-lambda-python/issues/6122 + if title == DEFAULT_OPENAPI_TITLE and self.openapi_config.title: + title = self.openapi_config.title + + if version == DEFAULT_API_VERSION and self.openapi_config.version: + version = self.openapi_config.version + + if openapi_version == DEFAULT_OPENAPI_VERSION and self.openapi_config.openapi_version: + openapi_version = self.openapi_config.openapi_version + + summary = summary or self.openapi_config.summary + description = description or self.openapi_config.description + tags = tags or self.openapi_config.tags + servers = servers or self.openapi_config.servers + terms_of_service = terms_of_service or self.openapi_config.terms_of_service + contact = contact or self.openapi_config.contact + license_info = license_info or self.openapi_config.license_info + security_schemes = security_schemes or self.openapi_config.security_schemes + security = security or self.openapi_config.security + openapi_extensions = openapi_extensions or self.openapi_config.openapi_extensions + from aws_lambda_powertools.event_handler.openapi.compat import ( GenerateJsonSchema, get_compat_model_name_map, @@ -1739,7 +1764,7 @@ def _get_openapi_servers(servers: list[Server] | None) -> list[Server]: # If the 'servers' property is not provided or is an empty array, # the default behavior is to return a Server Object with a URL value of "/". - return servers if servers else [Server(url="/")] + return servers or [Server(url="/")] @staticmethod def _get_openapi_security( @@ -1771,7 +1796,7 @@ def _determine_openapi_version(openapi_version: str): def get_openapi_json_schema( self, *, - title: str = "Powertools API", + title: str = DEFAULT_OPENAPI_TITLE, version: str = DEFAULT_API_VERSION, openapi_version: str = DEFAULT_OPENAPI_VERSION, summary: str | None = None, @@ -1822,6 +1847,7 @@ def get_openapi_json_schema( str The OpenAPI schema as a JSON serializable dict. """ + from aws_lambda_powertools.event_handler.openapi.compat import model_json return model_json( @@ -1845,11 +1871,94 @@ def get_openapi_json_schema( indent=2, ) + def configure_openapi( + self, + title: str = DEFAULT_OPENAPI_TITLE, + version: str = DEFAULT_API_VERSION, + openapi_version: str = DEFAULT_OPENAPI_VERSION, + summary: str | None = None, + description: str | None = None, + tags: list[Tag | str] | None = None, + servers: list[Server] | None = None, + terms_of_service: str | None = None, + contact: Contact | None = None, + license_info: License | None = None, + security_schemes: dict[str, SecurityScheme] | None = None, + security: list[dict[str, list[str]]] | None = None, + openapi_extensions: dict[str, Any] | None = None, + ): + """Configure OpenAPI specification settings for the API. + + Sets up the OpenAPI documentation configuration that can be later used + when enabling Swagger UI or generating OpenAPI specifications. + + Parameters + ---------- + title: str + The title of the application. + version: str + The version of the OpenAPI document (which is distinct from the OpenAPI Specification version or the API + openapi_version: str, default = "3.0.0" + The version of the OpenAPI Specification (which the document uses). + summary: str, optional + A short summary of what the application does. + description: str, optional + A verbose explanation of the application behavior. + tags: list[Tag, str], optional + A list of tags used by the specification with additional metadata. + servers: list[Server], optional + An array of Server Objects, which provide connectivity information to a target server. + terms_of_service: str, optional + A URL to the Terms of Service for the API. MUST be in the format of a URL. + contact: Contact, optional + The contact information for the exposed API. + license_info: License, optional + The license information for the exposed API. + security_schemes: dict[str, SecurityScheme]], optional + A declaration of the security schemes available to be used in the specification. + security: list[dict[str, list[str]]], optional + A declaration of which security mechanisms are applied globally across the API. + openapi_extensions: Dict[str, Any], optional + Additional OpenAPI extensions as a dictionary. + + Example + -------- + >>> api.configure_openapi( + ... title="My API", + ... version="1.0.0", + ... description="API for managing resources", + ... contact=Contact( + ... name="API Support", + ... email="support@example.com" + ... ) + ... ) + + See Also + -------- + enable_swagger : Method to enable Swagger UI using these configurations + OpenAPIConfig : Data class containing all OpenAPI configuration options + """ + self.openapi_config = OpenAPIConfig( + title=title, + version=version, + openapi_version=openapi_version, + summary=summary, + description=description, + tags=tags, + servers=servers, + terms_of_service=terms_of_service, + contact=contact, + license_info=license_info, + security_schemes=security_schemes, + security=security, + openapi_extensions=openapi_extensions, + ) + def enable_swagger( self, *, path: str = "/swagger", - title: str = "Powertools for AWS Lambda (Python) API", + title: str = DEFAULT_OPENAPI_TITLE, version: str = DEFAULT_API_VERSION, openapi_version: str = DEFAULT_OPENAPI_VERSION, summary: str | None = None, @@ -1912,6 +2021,7 @@ def enable_swagger( openapi_extensions: dict[str, Any], optional Additional OpenAPI extensions as a dictionary. """ + from aws_lambda_powertools.event_handler.openapi.compat import model_json from aws_lambda_powertools.event_handler.openapi.models import Server from aws_lambda_powertools.event_handler.openapi.swagger_ui import ( @@ -2156,10 +2266,7 @@ def _get_base_path(self) -> str: @staticmethod def _has_debug(debug: bool | None = None) -> bool: # It might have been explicitly switched off (debug=False) - if debug is not None: - return debug - - return powertools_dev_is_set() + return debug if debug is not None else powertools_dev_is_set() @staticmethod def _compile_regex(rule: str, base_regex: str = _ROUTE_REGEX): @@ -2272,7 +2379,7 @@ def _path_starts_with(path: str, prefix: str): if not isinstance(prefix, str) or prefix == "": return False - return path.startswith(prefix + "/") + return path.startswith(f"{prefix}/") def _handle_not_found(self, method: str, path: str) -> ResponseBuilder: """Called when no matching route was found and includes support for the cors preflight response""" @@ -2543,8 +2650,9 @@ def _get_fields_from_routes(routes: Sequence[Route]) -> list[ModelField]: if route.dependant.response_extra_models: responses_from_routes.extend(route.dependant.response_extra_models) - flat_models = list(responses_from_routes + request_fields_from_routes + body_fields_from_routes) - return flat_models + return list( + responses_from_routes + request_fields_from_routes + body_fields_from_routes, + ) class Router(BaseRouter): diff --git a/aws_lambda_powertools/event_handler/openapi/config.py b/aws_lambda_powertools/event_handler/openapi/config.py new file mode 100644 index 00000000000..597362d1ef9 --- /dev/null +++ b/aws_lambda_powertools/event_handler/openapi/config.py @@ -0,0 +1,80 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +from aws_lambda_powertools.event_handler.openapi.constants import ( + DEFAULT_API_VERSION, + DEFAULT_OPENAPI_TITLE, + DEFAULT_OPENAPI_VERSION, +) + +if TYPE_CHECKING: + from aws_lambda_powertools.event_handler.openapi.models import ( + Contact, + License, + SecurityScheme, + Server, + Tag, + ) + + +@dataclass +class OpenAPIConfig: + """Configuration class for OpenAPI specification. + + This class holds all the necessary configuration parameters to generate an OpenAPI specification. + + Parameters + ---------- + title: str + The title of the application. + version: str + The version of the OpenAPI document (which is distinct from the OpenAPI Specification version or the API + openapi_version: str, default = "3.0.0" + The version of the OpenAPI Specification (which the document uses). + summary: str, optional + A short summary of what the application does. + description: str, optional + A verbose explanation of the application behavior. + tags: list[Tag, str], optional + A list of tags used by the specification with additional metadata. + servers: list[Server], optional + An array of Server Objects, which provide connectivity information to a target server. + terms_of_service: str, optional + A URL to the Terms of Service for the API. MUST be in the format of a URL. + contact: Contact, optional + The contact information for the exposed API. + license_info: License, optional + The license information for the exposed API. + security_schemes: dict[str, SecurityScheme]], optional + A declaration of the security schemes available to be used in the specification. + security: list[dict[str, list[str]]], optional + A declaration of which security mechanisms are applied globally across the API. + openapi_extensions: Dict[str, Any], optional + Additional OpenAPI extensions as a dictionary. + + Example + -------- + >>> config = OpenAPIConfig( + ... title="My API", + ... version="1.0.0", + ... description="This is my API description", + ... contact=Contact(name="API Support", email="support@example.com"), + ... servers=[Server(url="https://api.example.com/v1")] + ... ) + """ + + title: str = DEFAULT_OPENAPI_TITLE + version: str = DEFAULT_API_VERSION + openapi_version: str = DEFAULT_OPENAPI_VERSION + summary: str | None = None + description: str | None = None + tags: list[Tag | str] | None = None + servers: list[Server] | None = None + terms_of_service: str | None = None + contact: Contact | None = None + license_info: License | None = None + security_schemes: dict[str, SecurityScheme] | None = None + security: list[dict[str, list[str]]] | None = None + openapi_extensions: dict[str, Any] | None = None diff --git a/aws_lambda_powertools/event_handler/openapi/constants.py b/aws_lambda_powertools/event_handler/openapi/constants.py index f5d72d47f7e..debe1d56736 100644 --- a/aws_lambda_powertools/event_handler/openapi/constants.py +++ b/aws_lambda_powertools/event_handler/openapi/constants.py @@ -1,2 +1,3 @@ DEFAULT_API_VERSION = "1.0.0" DEFAULT_OPENAPI_VERSION = "3.1.0" +DEFAULT_OPENAPI_TITLE = "Powertools for AWS Lambda (Python) API" diff --git a/docs/core/event_handler/_openapi_customization_metadata.md b/docs/core/event_handler/_openapi_customization_metadata.md index 5a96db582cb..a69f53cd84d 100644 --- a/docs/core/event_handler/_openapi_customization_metadata.md +++ b/docs/core/event_handler/_openapi_customization_metadata.md @@ -1,6 +1,6 @@ -Defining and customizing OpenAPI metadata gives detailed, top-level information about your API. Here's the method to set and tailor this metadata: +Defining and customizing OpenAPI metadata gives detailed, top-level information about your API. Use the method `app.configure_openapi` to set and tailor this metadata: | Field Name | Type | Description | | ------------------ | -------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | diff --git a/docs/core/event_handler/api_gateway.md b/docs/core/event_handler/api_gateway.md index 70eef0a2b86..4919598b3ec 100644 --- a/docs/core/event_handler/api_gateway.md +++ b/docs/core/event_handler/api_gateway.md @@ -1072,7 +1072,7 @@ Include extra parameters when exporting your OpenAPI specification to apply thes === "customizing_api_metadata.py" - ```python hl_lines="25-31" + ```python hl_lines="8-16" --8<-- "examples/event_handler_rest/src/customizing_api_metadata.py" ``` @@ -1108,7 +1108,7 @@ Security schemes are declared at the top-level first. You can reference them glo === "Global OpenAPI security schemes" - ```python title="security_schemes_global.py" hl_lines="32-42" + ```python title="security_schemes_global.py" hl_lines="17-27" --8<-- "examples/event_handler_rest/src/security_schemes_global.py" ``` @@ -1116,7 +1116,7 @@ Security schemes are declared at the top-level first. You can reference them glo === "Per Operation security" - ```python title="security_schemes_per_operation.py" hl_lines="17 32-41" + ```python title="security_schemes_per_operation.py" hl_lines="17-26 30" --8<-- "examples/event_handler_rest/src/security_schemes_per_operation.py" ``` @@ -1124,7 +1124,7 @@ Security schemes are declared at the top-level first. You can reference them glo === "Global security schemes and optional security per route" - ```python title="security_schemes_global_and_optional.py" hl_lines="22 37-46" + ```python title="security_schemes_global_and_optional.py" hl_lines="17-26 35" --8<-- "examples/event_handler_rest/src/security_schemes_global_and_optional.py" ``` diff --git a/examples/event_handler_rest/src/customizing_api_metadata.py b/examples/event_handler_rest/src/customizing_api_metadata.py index cd9ced455d2..9297045ea1a 100644 --- a/examples/event_handler_rest/src/customizing_api_metadata.py +++ b/examples/event_handler_rest/src/customizing_api_metadata.py @@ -5,6 +5,15 @@ from aws_lambda_powertools.utilities.typing import LambdaContext app = APIGatewayRestResolver(enable_validation=True) +app.configure_openapi( + title="TODO's API", + version="1.21.3", + summary="API to manage TODOs", + description="This API implements all the CRUD operations for the TODO app", + tags=["todos"], + servers=[Server(url="https://stg.example.org/orders", description="Staging server")], + contact=Contact(name="John Smith", email="john@smith.com"), +) @app.get("/todos/") @@ -20,14 +29,4 @@ def lambda_handler(event: dict, context: LambdaContext) -> dict: if __name__ == "__main__": - print( - app.get_openapi_json_schema( - title="TODO's API", - version="1.21.3", - summary="API to manage TODOs", - description="This API implements all the CRUD operations for the TODO app", - tags=["todos"], - servers=[Server(url="https://stg.example.org/orders", description="Staging server")], - contact=Contact(name="John Smith", email="john@smith.com"), - ), - ) + print(app.get_openapi_json_schema()) diff --git a/examples/event_handler_rest/src/security_schemes_global.py b/examples/event_handler_rest/src/security_schemes_global.py index 3a3ef5ce6f4..762bc077596 100644 --- a/examples/event_handler_rest/src/security_schemes_global.py +++ b/examples/event_handler_rest/src/security_schemes_global.py @@ -12,6 +12,20 @@ logger = Logger() app = APIGatewayRestResolver(enable_validation=True) +app.configure_openapi( + title="My API", + security_schemes={ + "oauth": OAuth2( + flows=OAuthFlows( + authorizationCode=OAuthFlowAuthorizationCode( + authorizationUrl="https://xxx.amazoncognito.com/oauth2/authorize", + tokenUrl="https://xxx.amazoncognito.com/oauth2/token", + ), + ), + ), + }, + security=[{"oauth": ["admin"]}], # (1)!) +) @app.get("/") @@ -26,19 +40,4 @@ def lambda_handler(event, context): if __name__ == "__main__": - print( - app.get_openapi_json_schema( - title="My API", - security_schemes={ - "oauth": OAuth2( - flows=OAuthFlows( - authorizationCode=OAuthFlowAuthorizationCode( - authorizationUrl="https://xxx.amazoncognito.com/oauth2/authorize", - tokenUrl="https://xxx.amazoncognito.com/oauth2/token", - ), - ), - ), - }, - security=[{"oauth": ["admin"]}], # (1)! - ), - ) + print(app.get_openapi_json_schema()) diff --git a/examples/event_handler_rest/src/security_schemes_global_and_optional.py b/examples/event_handler_rest/src/security_schemes_global_and_optional.py index 2a890efd5e4..84e5b0fdfcd 100644 --- a/examples/event_handler_rest/src/security_schemes_global_and_optional.py +++ b/examples/event_handler_rest/src/security_schemes_global_and_optional.py @@ -12,6 +12,19 @@ logger = Logger() app = APIGatewayRestResolver(enable_validation=True) +app.configure_openapi( + title="My API", + security_schemes={ + "oauth": OAuth2( + flows=OAuthFlows( + authorizationCode=OAuthFlowAuthorizationCode( + authorizationUrl="https://xxx.amazoncognito.com/oauth2/authorize", + tokenUrl="https://xxx.amazoncognito.com/oauth2/token", + ), + ), + ), + }, +) @app.get("/protected", security=[{"oauth": ["admin"]}]) @@ -31,18 +44,4 @@ def lambda_handler(event, context): if __name__ == "__main__": - print( - app.get_openapi_json_schema( - title="My API", - security_schemes={ - "oauth": OAuth2( - flows=OAuthFlows( - authorizationCode=OAuthFlowAuthorizationCode( - authorizationUrl="https://xxx.amazoncognito.com/oauth2/authorize", - tokenUrl="https://xxx.amazoncognito.com/oauth2/token", - ), - ), - ), - }, - ), - ) + print(app.get_openapi_json_schema()) diff --git a/examples/event_handler_rest/src/security_schemes_per_operation.py b/examples/event_handler_rest/src/security_schemes_per_operation.py index 66770a787c7..04b5a4ba830 100644 --- a/examples/event_handler_rest/src/security_schemes_per_operation.py +++ b/examples/event_handler_rest/src/security_schemes_per_operation.py @@ -12,6 +12,19 @@ logger = Logger() app = APIGatewayRestResolver(enable_validation=True) +app.configure_openapi( + title="My API", + security_schemes={ + "oauth": OAuth2( + flows=OAuthFlows( + authorizationCode=OAuthFlowAuthorizationCode( + authorizationUrl="https://xxx.amazoncognito.com/oauth2/authorize", + tokenUrl="https://xxx.amazoncognito.com/oauth2/token", + ), + ), + ), + }, +) @app.get("/", security=[{"oauth": ["admin"]}]) # (1)! @@ -26,18 +39,4 @@ def lambda_handler(event, context): if __name__ == "__main__": - print( - app.get_openapi_json_schema( - title="My API", - security_schemes={ - "oauth": OAuth2( - flows=OAuthFlows( - authorizationCode=OAuthFlowAuthorizationCode( - authorizationUrl="https://xxx.amazoncognito.com/oauth2/authorize", - tokenUrl="https://xxx.amazoncognito.com/oauth2/token", - ), - ), - ), - }, - ), - ) + print(app.get_openapi_json_schema()) diff --git a/tests/functional/event_handler/_pydantic/test_openapi_config.py b/tests/functional/event_handler/_pydantic/test_openapi_config.py new file mode 100644 index 00000000000..9fc2dd1ce7b --- /dev/null +++ b/tests/functional/event_handler/_pydantic/test_openapi_config.py @@ -0,0 +1,87 @@ +import json + +from aws_lambda_powertools.event_handler import APIGatewayRestResolver + + +def test_export_openapi_schema_with_custom_configuration(): + # GIVEN an API Gateway resolver with OpenAPI validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + # GIVEN custom OpenAPI configuration + openapi_title = "My API" + openapi_myapi_version = "1.1.1-alpha" + app.configure_openapi(title=openapi_title, version=openapi_myapi_version) + + # WHEN we have a simple handler + @app.get("/") + def handler(): + pass + + # WHEN we get the schema + schema = app.get_openapi_schema() + + # THEN the schema should contain our custom configuration + assert schema.info.title == openapi_title + assert schema.info.version == openapi_myapi_version + + +def test_prioritize_direct_parameters_over_stored_configuration(): + + # GIVEN + stored_config = { + "title": "Stored API Title", + "version": "1.0.0", + } + + direct_params = { + "title": "Direct API Title", + "version": "2.0.0", + } + + # GIVEN an API Gateway resolver with OpenAPI validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + app.configure_openapi(**stored_config) + + # WHEN we have a simple handler + @app.get("/") + def handler(): + pass + + # WHEN we get the schema with direct params + schema = app.get_openapi_schema(**direct_params) + + # THEN direct parameters must override stored configuration + assert schema.info.title == direct_params["title"] + assert schema.info.version == direct_params["version"] + + +def test_export_openapi_schema_with_custom_configuration_and_json_export(): + # GIVEN an API Gateway resolver with OpenAPI validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + # GIVEN custom OpenAPI configuration + openapi_title = "My API" + openapi_myapi_version = "1.1.1-alpha" + openapi_version = "3.1.2" + openapi_description = "My descrition" + app.configure_openapi( + title=openapi_title, + version=openapi_myapi_version, + openapi_version=openapi_version, + description=openapi_description, + ) + + # WHEN we have a simple handler + @app.get("/") + def handler(): + pass + + # WHEN we get the schema + schema = json.loads(app.get_openapi_json_schema()) + + # THEN the schema should contain our custom configuration + assert schema["info"]["title"] == openapi_title + assert schema["info"]["version"] == openapi_myapi_version + assert schema["openapi"] == openapi_version + assert schema["info"]["description"] == openapi_description diff --git a/tests/functional/event_handler/_pydantic/test_openapi_params.py b/tests/functional/event_handler/_pydantic/test_openapi_params.py index 2cf77a7de08..5bcda896858 100644 --- a/tests/functional/event_handler/_pydantic/test_openapi_params.py +++ b/tests/functional/event_handler/_pydantic/test_openapi_params.py @@ -32,7 +32,7 @@ def handler(): raise NotImplementedError() schema = app.get_openapi_schema() - assert schema.info.title == "Powertools API" + assert schema.info.title == "Powertools for AWS Lambda (Python) API" assert schema.info.version == "1.0.0" assert len(schema.paths.keys()) == 1