Skip to content

Commit 50a477c

Browse files
committed
feat: generate OpenAPI spec from event handler
1 parent 04cf87f commit 50a477c

File tree

7 files changed

+1369
-22
lines changed

7 files changed

+1369
-22
lines changed

Diff for: aws_lambda_powertools/event_handler/api_gateway.py

+139-22
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import re
55
import traceback
66
import warnings
7-
import zlib
87
from abc import ABC, abstractmethod
98
from enum import Enum
109
from functools import partial
@@ -17,14 +16,23 @@
1716
Match,
1817
Optional,
1918
Pattern,
19+
Sequence,
2020
Set,
2121
Tuple,
2222
Type,
2323
Union,
2424
)
2525

26+
import zlib
27+
from pydantic.fields import ModelField
28+
from pydantic.schema import get_flat_models_from_fields, get_model_name_map, model_process_schema
29+
2630
from aws_lambda_powertools.event_handler import content_types
2731
from aws_lambda_powertools.event_handler.exceptions import NotFoundError, ServiceError
32+
from aws_lambda_powertools.event_handler.openapi.dependant import get_dependant
33+
from aws_lambda_powertools.event_handler.openapi.models import Contact, License, OpenAPI, Server, Tag
34+
from aws_lambda_powertools.event_handler.openapi.utils import get_flat_params
35+
from aws_lambda_powertools.event_handler.route import Route
2836
from aws_lambda_powertools.shared.cookies import Cookie
2937
from aws_lambda_powertools.shared.functions import powertools_dev_is_set
3038
from aws_lambda_powertools.shared.json_encoder import Encoder
@@ -207,26 +215,6 @@ def __init__(
207215
self.headers.setdefault("Content-Type", content_type)
208216

209217

210-
class Route:
211-
"""Internally used Route Configuration"""
212-
213-
def __init__(
214-
self,
215-
method: str,
216-
rule: Pattern,
217-
func: Callable,
218-
cors: bool,
219-
compress: bool,
220-
cache_control: Optional[str],
221-
):
222-
self.method = method.upper()
223-
self.rule = rule
224-
self.func = func
225-
self.cors = cors
226-
self.compress = compress
227-
self.cache_control = cache_control
228-
229-
230218
class ResponseBuilder:
231219
"""Internally used Response builder"""
232220

@@ -554,6 +542,119 @@ def __init__(
554542
# Allow for a custom serializer or a concise json serialization
555543
self._serializer = serializer or partial(json.dumps, separators=(",", ":"), cls=Encoder)
556544

545+
def get_openapi_schema(
546+
self,
547+
*,
548+
title: str,
549+
version: str,
550+
openapi_version: str = "3.1.0",
551+
summary: Optional[str] = None,
552+
description: Optional[str] = None,
553+
tags: Optional[List[Tag]] = None,
554+
servers: Optional[List[Server]] = None,
555+
terms_of_service: Optional[str] = None,
556+
contact: Optional[Contact] = None,
557+
license_info: Optional[License] = None,
558+
) -> OpenAPI:
559+
info: Dict[str, Any] = {"title": title, "version": version}
560+
if summary:
561+
info["summary"] = summary
562+
if description:
563+
info["description"] = description
564+
if terms_of_service:
565+
info["termsOfService"] = terms_of_service
566+
if contact:
567+
info["contact"] = contact
568+
if license_info:
569+
info["license"] = license_info
570+
571+
output: Dict[str, Any] = {"openapi": openapi_version, "info": info}
572+
if servers:
573+
output["servers"] = servers
574+
else:
575+
# If the servers property is not provided, or is an empty array, the default value would be a Server Object
576+
# with a url value of /.
577+
output["servers"] = [Server(url="/")]
578+
579+
components: Dict[str, Dict[str, Any]] = {}
580+
paths: Dict[str, Dict[str, Any]] = {}
581+
operation_ids: Set[str] = set()
582+
583+
all_routes = self._dynamic_routes + self._static_routes
584+
all_fields = self._get_fields_from_routes(all_routes)
585+
models = get_flat_models_from_fields(all_fields, known_models=set())
586+
model_name_map = get_model_name_map(models)
587+
588+
definitions: Dict[str, Dict[str, Any]] = {}
589+
for model in models:
590+
m_schema, m_definitions, _ = model_process_schema(
591+
model,
592+
model_name_map=model_name_map,
593+
ref_prefix="#/components/schemas/",
594+
)
595+
definitions.update(m_definitions)
596+
model_name = model_name_map[model]
597+
if "description" in m_schema:
598+
m_schema["description"] = m_schema["description"].split("\f")[0]
599+
definitions[model_name] = m_schema
600+
601+
for route in all_routes:
602+
dependant = get_dependant(
603+
path=route.func.__name__,
604+
call=route.func,
605+
)
606+
607+
result = route._openapi_path(
608+
dependant=dependant,
609+
operation_ids=operation_ids,
610+
model_name_map=model_name_map,
611+
)
612+
if result:
613+
path, path_definitions = result
614+
if path:
615+
paths.setdefault(route.path, {}).update(path)
616+
if path_definitions:
617+
definitions.update(path_definitions)
618+
619+
if definitions:
620+
components["schemas"] = {k: definitions[k] for k in sorted(definitions)}
621+
if components:
622+
output["components"] = components
623+
if tags:
624+
output["tags"] = tags
625+
626+
output["paths"] = paths
627+
628+
return OpenAPI(**output) # .dict(by_alias=True, exclude_none=True)
629+
630+
def get_openapi_json_schema(
631+
self,
632+
*,
633+
title: str,
634+
version: str,
635+
openapi_version: str = "3.1.0",
636+
summary: Optional[str] = None,
637+
description: Optional[str] = None,
638+
tags: Optional[List[Tag]] = None,
639+
servers: Optional[List[Server]] = None,
640+
terms_of_service: Optional[str] = None,
641+
contact: Optional[Contact] = None,
642+
license_info: Optional[License] = None,
643+
) -> str:
644+
"""Returns the OpenAPI schema as a JSON serializable dict"""
645+
return self.get_openapi_schema(
646+
title=title,
647+
version=version,
648+
openapi_version=openapi_version,
649+
summary=summary,
650+
description=description,
651+
tags=tags,
652+
servers=servers,
653+
terms_of_service=terms_of_service,
654+
contact=contact,
655+
license_info=license_info,
656+
).json(by_alias=True, exclude_none=True, indent=2)
657+
557658
def route(
558659
self,
559660
rule: str,
@@ -573,7 +674,7 @@ def register_resolver(func: Callable):
573674
cors_enabled = cors
574675

575676
for item in methods:
576-
_route = Route(item, self._compile_regex(rule), func, cors_enabled, compress, cache_control)
677+
_route = Route(item, rule, self._compile_regex(rule), func, cors_enabled, compress, cache_control)
577678

578679
# The more specific route wins.
579680
# We store dynamic (/studies/{studyid}) and static routes (/studies/fetch) separately.
@@ -889,6 +990,22 @@ def include_router(self, router: "Router", prefix: Optional[str] = None) -> None
889990

890991
self.route(*new_route)(func)
891992

993+
@staticmethod
994+
def _get_fields_from_routes(routes: Sequence[Route]) -> List[ModelField]:
995+
responses_from_routes: List[ModelField] = []
996+
request_fields_from_routes: List[ModelField] = []
997+
998+
for route in routes:
999+
dependant = get_dependant(path=route.path, call=route.func)
1000+
params = get_flat_params(dependant)
1001+
request_fields_from_routes.extend(params)
1002+
1003+
if dependant.return_param:
1004+
responses_from_routes.append(dependant.return_param)
1005+
1006+
flat_models = list(responses_from_routes + request_fields_from_routes)
1007+
return flat_models
1008+
8921009

8931010
class Router(BaseRouter):
8941011
"""Router helper class to allow splitting ApiGatewayResolver into multiple files"""
+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from aws_lambda_powertools.event_handler.openapi.models import (
2+
Example,
3+
Info,
4+
MediaType,
5+
Operation,
6+
Reference,
7+
Response,
8+
Schema,
9+
)
10+
11+
__all__ = ["Info", "Operation", "Response", "MediaType", "Reference", "Schema", "Example"]
+100
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import inspect
2+
import re
3+
from typing import Any, Callable, Dict, ForwardRef, Optional, Set, cast
4+
5+
from pydantic.fields import ModelField
6+
from pydantic.typing import evaluate_forwardref
7+
8+
from aws_lambda_powertools.event_handler.openapi.params import Dependant, Param, ParamTypes, analyze_param
9+
10+
11+
def add_param_to_fields(
12+
*,
13+
field: ModelField,
14+
dependant: Dependant,
15+
) -> None:
16+
field_info = cast(Param, field.field_info)
17+
if field_info.in_ == ParamTypes.path:
18+
dependant.path_params.append(field)
19+
elif field_info.in_ == ParamTypes.query:
20+
dependant.query_params.append(field)
21+
elif field_info.in_ == ParamTypes.header:
22+
dependant.header_params.append(field)
23+
else:
24+
assert field_info.in_ == ParamTypes.cookie
25+
dependant.cookie_params.append(field)
26+
27+
28+
def get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any:
29+
if isinstance(annotation, str):
30+
annotation = ForwardRef(annotation)
31+
annotation = evaluate_forwardref(annotation, globalns, globalns)
32+
return annotation
33+
34+
35+
def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
36+
signature = inspect.signature(call)
37+
globalns = getattr(call, "__global__", {})
38+
typed_params = [
39+
inspect.Parameter(
40+
name=param.name,
41+
kind=param.kind,
42+
default=param.default,
43+
annotation=get_typed_annotation(param.annotation, globalns),
44+
)
45+
for param in signature.parameters.values()
46+
]
47+
48+
if signature.return_annotation is not inspect.Signature.empty:
49+
return_param = inspect.Parameter(
50+
name="Return",
51+
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
52+
default=None,
53+
annotation=get_typed_annotation(signature.return_annotation, globalns),
54+
)
55+
return inspect.Signature(typed_params, return_annotation=return_param.annotation)
56+
else:
57+
return inspect.Signature(typed_params)
58+
59+
60+
def get_path_param_names(path: str) -> Set[str]:
61+
return set(re.findall("{(.*?)}", path))
62+
63+
64+
def get_dependant(
65+
*,
66+
path: str,
67+
call: Callable[..., Any],
68+
name: Optional[str] = None,
69+
) -> Dependant:
70+
path_param_names = get_path_param_names(path)
71+
endpoint_signature = get_typed_signature(call)
72+
signature_params = endpoint_signature.parameters
73+
dependant = Dependant(
74+
call=call,
75+
name=name,
76+
path=path,
77+
)
78+
79+
for param_name, param in signature_params.items():
80+
is_path_param = param_name in path_param_names
81+
type_annotation, param_field = analyze_param(
82+
param_name=param_name,
83+
annotation=param.annotation,
84+
value=param.default,
85+
is_path_param=is_path_param,
86+
)
87+
assert param_field is not None
88+
89+
add_param_to_fields(field=param_field, dependant=dependant)
90+
91+
return_annotation = endpoint_signature.return_annotation
92+
if return_annotation is not inspect.Signature.empty:
93+
type_annotation, param_field = analyze_param(
94+
param_name="Return", annotation=return_annotation, value=None, is_path_param=False,
95+
)
96+
assert param_field is not None
97+
98+
dependant.return_param = param_field
99+
100+
return dependant

0 commit comments

Comments
 (0)