Skip to content

Commit e97f6ee

Browse files
Initial commit OpenAPI Extensions
1 parent abcb350 commit e97f6ee

File tree

6 files changed

+312
-8
lines changed

6 files changed

+312
-8
lines changed

aws_lambda_powertools/event_handler/api_gateway.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,7 @@ def __init__(
323323
operation_id: Optional[str] = None,
324324
include_in_schema: bool = True,
325325
security: Optional[List[Dict[str, List[str]]]] = None,
326+
openapi_extensions: Optional[Dict[str, Any]] = None,
326327
middlewares: Optional[List[Callable[..., Response]]] = None,
327328
):
328329
"""
@@ -383,6 +384,7 @@ def __init__(
383384
self.tags = tags or []
384385
self.include_in_schema = include_in_schema
385386
self.security = security
387+
self.openapi_extensions = openapi_extensions
386388
self.middlewares = middlewares or []
387389
self.operation_id = operation_id or self._generate_operation_id()
388390

@@ -534,6 +536,10 @@ def _get_openapi_path(
534536
if self.security:
535537
operation["security"] = self.security
536538

539+
# Add OpenAPI extensions if present
540+
if self.openapi_extensions:
541+
operation.update(self.openapi_extensions)
542+
537543
# Add the parameters to the OpenAPI operation
538544
if parameters:
539545
all_parameters = {(param["in"], param["name"]): param for param in parameters}
@@ -939,6 +945,7 @@ def route(
939945
operation_id: Optional[str] = None,
940946
include_in_schema: bool = True,
941947
security: Optional[List[Dict[str, List[str]]]] = None,
948+
openapi_extensions: Optional[Dict[str, Any]] = None,
942949
middlewares: Optional[List[Callable[..., Any]]] = None,
943950
):
944951
raise NotImplementedError()
@@ -998,6 +1005,7 @@ def get(
9981005
operation_id: Optional[str] = None,
9991006
include_in_schema: bool = True,
10001007
security: Optional[List[Dict[str, List[str]]]] = None,
1008+
openapi_extensions: Optional[Dict[str, Any]] = None,
10011009
middlewares: Optional[List[Callable[..., Any]]] = None,
10021010
):
10031011
"""Get route decorator with GET `method`
@@ -1036,6 +1044,7 @@ def lambda_handler(event, context):
10361044
operation_id,
10371045
include_in_schema,
10381046
security,
1047+
openapi_extensions,
10391048
middlewares,
10401049
)
10411050

@@ -1053,6 +1062,7 @@ def post(
10531062
operation_id: Optional[str] = None,
10541063
include_in_schema: bool = True,
10551064
security: Optional[List[Dict[str, List[str]]]] = None,
1065+
openapi_extensions: Optional[Dict[str, Any]] = None,
10561066
middlewares: Optional[List[Callable[..., Any]]] = None,
10571067
):
10581068
"""Post route decorator with POST `method`
@@ -1092,6 +1102,7 @@ def lambda_handler(event, context):
10921102
operation_id,
10931103
include_in_schema,
10941104
security,
1105+
openapi_extensions,
10951106
middlewares,
10961107
)
10971108

@@ -1109,6 +1120,7 @@ def put(
11091120
operation_id: Optional[str] = None,
11101121
include_in_schema: bool = True,
11111122
security: Optional[List[Dict[str, List[str]]]] = None,
1123+
openapi_extensions: Optional[Dict[str, Any]] = None,
11121124
middlewares: Optional[List[Callable[..., Any]]] = None,
11131125
):
11141126
"""Put route decorator with PUT `method`
@@ -1148,6 +1160,7 @@ def lambda_handler(event, context):
11481160
operation_id,
11491161
include_in_schema,
11501162
security,
1163+
openapi_extensions,
11511164
middlewares,
11521165
)
11531166

@@ -1165,6 +1178,7 @@ def delete(
11651178
operation_id: Optional[str] = None,
11661179
include_in_schema: bool = True,
11671180
security: Optional[List[Dict[str, List[str]]]] = None,
1181+
openapi_extensions: Optional[Dict[str, Any]] = None,
11681182
middlewares: Optional[List[Callable[..., Any]]] = None,
11691183
):
11701184
"""Delete route decorator with DELETE `method`
@@ -1203,6 +1217,7 @@ def lambda_handler(event, context):
12031217
operation_id,
12041218
include_in_schema,
12051219
security,
1220+
openapi_extensions,
12061221
middlewares,
12071222
)
12081223

@@ -1220,6 +1235,7 @@ def patch(
12201235
operation_id: Optional[str] = None,
12211236
include_in_schema: bool = True,
12221237
security: Optional[List[Dict[str, List[str]]]] = None,
1238+
openapi_extensions: Optional[Dict[str, Any]] = None,
12231239
middlewares: Optional[List[Callable]] = None,
12241240
):
12251241
"""Patch route decorator with PATCH `method`
@@ -1261,6 +1277,7 @@ def lambda_handler(event, context):
12611277
operation_id,
12621278
include_in_schema,
12631279
security,
1280+
openapi_extensions,
12641281
middlewares,
12651282
)
12661283

@@ -1278,6 +1295,7 @@ def head(
12781295
operation_id: Optional[str] = None,
12791296
include_in_schema: bool = True,
12801297
security: Optional[List[Dict[str, List[str]]]] = None,
1298+
openapi_extensions: Optional[Dict[str, Any]] = None,
12811299
middlewares: Optional[List[Callable]] = None,
12821300
):
12831301
"""Head route decorator with HEAD `method`
@@ -1318,6 +1336,7 @@ def lambda_handler(event, context):
13181336
operation_id,
13191337
include_in_schema,
13201338
security,
1339+
openapi_extensions,
13211340
middlewares,
13221341
)
13231342

@@ -1541,6 +1560,7 @@ def get_openapi_schema(
15411560
license_info: Optional["License"] = None,
15421561
security_schemes: Optional[Dict[str, "SecurityScheme"]] = None,
15431562
security: Optional[List[Dict[str, List[str]]]] = None,
1563+
openapi_extensions: Optional[Dict[str, Any]] = None,
15441564
) -> "OpenAPI":
15451565
"""
15461566
Returns the OpenAPI schema as a pydantic model.
@@ -1603,11 +1623,15 @@ def get_openapi_schema(
16031623

16041624
info.update({field: value for field, value in optional_fields.items() if value})
16051625

1626+
if not openapi_extensions:
1627+
openapi_extensions = {}
1628+
16061629
output: Dict[str, Any] = {
16071630
"openapi": openapi_version,
16081631
"info": info,
16091632
"servers": self._get_openapi_servers(servers),
16101633
"security": self._get_openapi_security(security, security_schemes),
1634+
**openapi_extensions,
16111635
}
16121636

16131637
components: Dict[str, Dict[str, Any]] = {}
@@ -1726,6 +1750,7 @@ def get_openapi_json_schema(
17261750
license_info: Optional["License"] = None,
17271751
security_schemes: Optional[Dict[str, "SecurityScheme"]] = None,
17281752
security: Optional[List[Dict[str, List[str]]]] = None,
1753+
openapi_extensions: Optional[Dict[str, Any]] = None,
17291754
) -> str:
17301755
"""
17311756
Returns the OpenAPI schema as a JSON serializable dict
@@ -1778,6 +1803,7 @@ def get_openapi_json_schema(
17781803
license_info=license_info,
17791804
security_schemes=security_schemes,
17801805
security=security,
1806+
openapi_extensions=openapi_extensions,
17811807
),
17821808
by_alias=True,
17831809
exclude_none=True,
@@ -1805,6 +1831,7 @@ def enable_swagger(
18051831
security: Optional[List[Dict[str, List[str]]]] = None,
18061832
oauth2_config: Optional["OAuth2Config"] = None,
18071833
persist_authorization: bool = False,
1834+
openapi_extensions: Optional[Dict[str, Any]] = None,
18081835
):
18091836
"""
18101837
Returns the OpenAPI schema as a JSON serializable dict
@@ -1896,6 +1923,7 @@ def swagger_handler():
18961923
license_info=license_info,
18971924
security_schemes=security_schemes,
18981925
security=security,
1926+
openapi_extensions=openapi_extensions,
18991927
)
19001928

19011929
# The .replace('</', '<\\/') part is necessary to prevent a potential issue where the JSON string contains
@@ -1949,6 +1977,7 @@ def route(
19491977
operation_id: Optional[str] = None,
19501978
include_in_schema: bool = True,
19511979
security: Optional[List[Dict[str, List[str]]]] = None,
1980+
openapi_extensions: Optional[Dict[str, Any]] = None,
19521981
middlewares: Optional[List[Callable[..., Any]]] = None,
19531982
):
19541983
"""Route decorator includes parameter `method`"""
@@ -1976,6 +2005,7 @@ def register_resolver(func: Callable):
19762005
operation_id,
19772006
include_in_schema,
19782007
security,
2008+
openapi_extensions,
19792009
middlewares,
19802010
)
19812011

@@ -2489,6 +2519,7 @@ def route(
24892519
operation_id: Optional[str] = None,
24902520
include_in_schema: bool = True,
24912521
security: Optional[List[Dict[str, List[str]]]] = None,
2522+
openapi_extensions: Optional[Dict[str, Any]] = None,
24922523
middlewares: Optional[List[Callable[..., Any]]] = None,
24932524
):
24942525
def register_route(func: Callable):
@@ -2497,6 +2528,7 @@ def register_route(func: Callable):
24972528
frozen_responses = _FrozenDict(responses) if responses else None
24982529
frozen_tags = frozenset(tags) if tags else None
24992530
frozen_security = _FrozenListDict(security) if security else None
2531+
fronzen_openapi_extensions = _FrozenDict(openapi_extensions) if openapi_extensions else None
25002532

25012533
route_key = (
25022534
rule,
@@ -2512,6 +2544,7 @@ def register_route(func: Callable):
25122544
operation_id,
25132545
include_in_schema,
25142546
frozen_security,
2547+
fronzen_openapi_extensions,
25152548
)
25162549

25172550
# Collate Middleware for routes
@@ -2592,6 +2625,7 @@ def route(
25922625
operation_id: Optional[str] = None,
25932626
include_in_schema: bool = True,
25942627
security: Optional[List[Dict[str, List[str]]]] = None,
2628+
openapi_extensions: Optional[Dict[str, Any]] = None,
25952629
middlewares: Optional[List[Callable[..., Any]]] = None,
25962630
):
25972631
# NOTE: see #1552 for more context.
@@ -2609,6 +2643,7 @@ def route(
26092643
operation_id,
26102644
include_in_schema,
26112645
security,
2646+
openapi_extensions,
26122647
middlewares,
26132648
)
26142649

aws_lambda_powertools/event_handler/bedrock_agent.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ def get( # type: ignore[override]
102102
include_in_schema: bool = True,
103103
middlewares: Optional[List[Callable[..., Any]]] = None,
104104
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
105+
106+
openapi_extensions = None
105107
security = None
106108

107109
return super(BedrockAgentResolver, self).get(
@@ -117,6 +119,7 @@ def get( # type: ignore[override]
117119
operation_id,
118120
include_in_schema,
119121
security,
122+
openapi_extensions,
120123
middlewares,
121124
)
122125

@@ -137,6 +140,7 @@ def post( # type: ignore[override]
137140
include_in_schema: bool = True,
138141
middlewares: Optional[List[Callable[..., Any]]] = None,
139142
):
143+
openapi_extensions = None
140144
security = None
141145

142146
return super().post(
@@ -152,6 +156,7 @@ def post( # type: ignore[override]
152156
operation_id,
153157
include_in_schema,
154158
security,
159+
openapi_extensions,
155160
middlewares,
156161
)
157162

@@ -172,6 +177,7 @@ def put( # type: ignore[override]
172177
include_in_schema: bool = True,
173178
middlewares: Optional[List[Callable[..., Any]]] = None,
174179
):
180+
openapi_extensions = None
175181
security = None
176182

177183
return super().put(
@@ -187,6 +193,7 @@ def put( # type: ignore[override]
187193
operation_id,
188194
include_in_schema,
189195
security,
196+
openapi_extensions,
190197
middlewares,
191198
)
192199

@@ -207,6 +214,7 @@ def patch( # type: ignore[override]
207214
include_in_schema: bool = True,
208215
middlewares: Optional[List[Callable]] = None,
209216
):
217+
openapi_extensions = None
210218
security = None
211219

212220
return super().patch(
@@ -222,6 +230,7 @@ def patch( # type: ignore[override]
222230
operation_id,
223231
include_in_schema,
224232
security,
233+
openapi_extensions,
225234
middlewares,
226235
)
227236

@@ -242,6 +251,7 @@ def delete( # type: ignore[override]
242251
include_in_schema: bool = True,
243252
middlewares: Optional[List[Callable[..., Any]]] = None,
244253
):
254+
openapi_extensions = None
245255
security = None
246256

247257
return super().delete(
@@ -257,6 +267,7 @@ def delete( # type: ignore[override]
257267
operation_id,
258268
include_in_schema,
259269
security,
270+
openapi_extensions,
260271
middlewares,
261272
)
262273

aws_lambda_powertools/event_handler/openapi/compat.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
RequestErrorModel: Type[BaseModel] = create_model("Request")
4242

4343
if PYDANTIC_V2: # pragma: no cover # false positive; dropping in v3
44-
from pydantic import TypeAdapter, ValidationError
44+
from pydantic import TypeAdapter, ValidationError, model_serializer as parser_openapi_extension
4545
from pydantic._internal._typing_extra import eval_type_lenient
4646
from pydantic.fields import FieldInfo
4747
from pydantic._internal._utils import lenient_issubclass
@@ -217,7 +217,7 @@ def model_json(model: BaseModel, **kwargs: Any) -> Any:
217217
return model.model_dump_json(**kwargs)
218218

219219
else:
220-
from pydantic import BaseModel, ValidationError
220+
from pydantic import BaseModel, ValidationError, root_validator as parser_openapi_extension
221221
from pydantic.fields import (
222222
ModelField,
223223
Required,

0 commit comments

Comments
 (0)