Skip to content

Commit b8b4632

Browse files
feat(event_handler): add OpenAPI extensions (#4703)
* Initial commit OpenAPI Extensions * Polishing the PR with best practicies - Comments * Polishing the PR with best practicies - Tests * Polishing the PR with best practicies - make pydanticv2 happy * Polishing the PR with best practicies - using model_validator to be more specific * Temporary mypy disabling * Make mypy happy? * Make mypy happy? * Polishing the PR with best practicies - adding e2e tests * Adding docstring * Adding documentation * Addressing Simon's feedback * Addressing Simon's feedback * Addressing Simon's feedback * Adding more tests * Adding more tests * Adding more tests
1 parent 9fc7669 commit b8b4632

File tree

17 files changed

+610
-11
lines changed

17 files changed

+610
-11
lines changed

aws_lambda_powertools/event_handler/api_gateway.py

+43
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,7 @@ def __init__(
323323
operation_id: Optional[str] = None,
324324
include_in_schema: bool = True,
325325
security: Optional[List[Dict[str, List[str]]]] = None,
326+
openapi_extensions: Optional[Dict[str, Any]] = None,
326327
middlewares: Optional[List[Callable[..., Response]]] = None,
327328
):
328329
"""
@@ -360,6 +361,8 @@ def __init__(
360361
Whether or not to include this route in the OpenAPI schema
361362
security: List[Dict[str, List[str]]], optional
362363
The OpenAPI security for this route
364+
openapi_extensions: Dict[str, Any], optional
365+
Additional OpenAPI extensions as a dictionary.
363366
middlewares: Optional[List[Callable[..., Response]]]
364367
The list of route middlewares to be called in order.
365368
"""
@@ -383,6 +386,7 @@ def __init__(
383386
self.tags = tags or []
384387
self.include_in_schema = include_in_schema
385388
self.security = security
389+
self.openapi_extensions = openapi_extensions
386390
self.middlewares = middlewares or []
387391
self.operation_id = operation_id or self._generate_operation_id()
388392

@@ -534,6 +538,10 @@ def _get_openapi_path(
534538
if self.security:
535539
operation["security"] = self.security
536540

541+
# Add OpenAPI extensions if present
542+
if self.openapi_extensions:
543+
operation.update(self.openapi_extensions)
544+
537545
# Add the parameters to the OpenAPI operation
538546
if parameters:
539547
all_parameters = {(param["in"], param["name"]): param for param in parameters}
@@ -939,6 +947,7 @@ def route(
939947
operation_id: Optional[str] = None,
940948
include_in_schema: bool = True,
941949
security: Optional[List[Dict[str, List[str]]]] = None,
950+
openapi_extensions: Optional[Dict[str, Any]] = None,
942951
middlewares: Optional[List[Callable[..., Any]]] = None,
943952
):
944953
raise NotImplementedError()
@@ -998,6 +1007,7 @@ def get(
9981007
operation_id: Optional[str] = None,
9991008
include_in_schema: bool = True,
10001009
security: Optional[List[Dict[str, List[str]]]] = None,
1010+
openapi_extensions: Optional[Dict[str, Any]] = None,
10011011
middlewares: Optional[List[Callable[..., Any]]] = None,
10021012
):
10031013
"""Get route decorator with GET `method`
@@ -1036,6 +1046,7 @@ def lambda_handler(event, context):
10361046
operation_id,
10371047
include_in_schema,
10381048
security,
1049+
openapi_extensions,
10391050
middlewares,
10401051
)
10411052

@@ -1053,6 +1064,7 @@ def post(
10531064
operation_id: Optional[str] = None,
10541065
include_in_schema: bool = True,
10551066
security: Optional[List[Dict[str, List[str]]]] = None,
1067+
openapi_extensions: Optional[Dict[str, Any]] = None,
10561068
middlewares: Optional[List[Callable[..., Any]]] = None,
10571069
):
10581070
"""Post route decorator with POST `method`
@@ -1092,6 +1104,7 @@ def lambda_handler(event, context):
10921104
operation_id,
10931105
include_in_schema,
10941106
security,
1107+
openapi_extensions,
10951108
middlewares,
10961109
)
10971110

@@ -1109,6 +1122,7 @@ def put(
11091122
operation_id: Optional[str] = None,
11101123
include_in_schema: bool = True,
11111124
security: Optional[List[Dict[str, List[str]]]] = None,
1125+
openapi_extensions: Optional[Dict[str, Any]] = None,
11121126
middlewares: Optional[List[Callable[..., Any]]] = None,
11131127
):
11141128
"""Put route decorator with PUT `method`
@@ -1148,6 +1162,7 @@ def lambda_handler(event, context):
11481162
operation_id,
11491163
include_in_schema,
11501164
security,
1165+
openapi_extensions,
11511166
middlewares,
11521167
)
11531168

@@ -1165,6 +1180,7 @@ def delete(
11651180
operation_id: Optional[str] = None,
11661181
include_in_schema: bool = True,
11671182
security: Optional[List[Dict[str, List[str]]]] = None,
1183+
openapi_extensions: Optional[Dict[str, Any]] = None,
11681184
middlewares: Optional[List[Callable[..., Any]]] = None,
11691185
):
11701186
"""Delete route decorator with DELETE `method`
@@ -1203,6 +1219,7 @@ def lambda_handler(event, context):
12031219
operation_id,
12041220
include_in_schema,
12051221
security,
1222+
openapi_extensions,
12061223
middlewares,
12071224
)
12081225

@@ -1220,6 +1237,7 @@ def patch(
12201237
operation_id: Optional[str] = None,
12211238
include_in_schema: bool = True,
12221239
security: Optional[List[Dict[str, List[str]]]] = None,
1240+
openapi_extensions: Optional[Dict[str, Any]] = None,
12231241
middlewares: Optional[List[Callable]] = None,
12241242
):
12251243
"""Patch route decorator with PATCH `method`
@@ -1261,6 +1279,7 @@ def lambda_handler(event, context):
12611279
operation_id,
12621280
include_in_schema,
12631281
security,
1282+
openapi_extensions,
12641283
middlewares,
12651284
)
12661285

@@ -1278,6 +1297,7 @@ def head(
12781297
operation_id: Optional[str] = None,
12791298
include_in_schema: bool = True,
12801299
security: Optional[List[Dict[str, List[str]]]] = None,
1300+
openapi_extensions: Optional[Dict[str, Any]] = None,
12811301
middlewares: Optional[List[Callable]] = None,
12821302
):
12831303
"""Head route decorator with HEAD `method`
@@ -1318,6 +1338,7 @@ def lambda_handler(event, context):
13181338
operation_id,
13191339
include_in_schema,
13201340
security,
1341+
openapi_extensions,
13211342
middlewares,
13221343
)
13231344

@@ -1541,6 +1562,7 @@ def get_openapi_schema(
15411562
license_info: Optional["License"] = None,
15421563
security_schemes: Optional[Dict[str, "SecurityScheme"]] = None,
15431564
security: Optional[List[Dict[str, List[str]]]] = None,
1565+
openapi_extensions: Optional[Dict[str, Any]] = None,
15441566
) -> "OpenAPI":
15451567
"""
15461568
Returns the OpenAPI schema as a pydantic model.
@@ -1571,6 +1593,8 @@ def get_openapi_schema(
15711593
A declaration of the security schemes available to be used in the specification.
15721594
security: List[Dict[str, List[str]]], optional
15731595
A declaration of which security mechanisms are applied globally across the API.
1596+
openapi_extensions: Dict[str, Any], optional
1597+
Additional OpenAPI extensions as a dictionary.
15741598
15751599
Returns
15761600
-------
@@ -1603,11 +1627,15 @@ def get_openapi_schema(
16031627

16041628
info.update({field: value for field, value in optional_fields.items() if value})
16051629

1630+
if not isinstance(openapi_extensions, Dict):
1631+
openapi_extensions = {}
1632+
16061633
output: Dict[str, Any] = {
16071634
"openapi": openapi_version,
16081635
"info": info,
16091636
"servers": self._get_openapi_servers(servers),
16101637
"security": self._get_openapi_security(security, security_schemes),
1638+
**openapi_extensions,
16111639
}
16121640

16131641
components: Dict[str, Dict[str, Any]] = {}
@@ -1726,6 +1754,7 @@ def get_openapi_json_schema(
17261754
license_info: Optional["License"] = None,
17271755
security_schemes: Optional[Dict[str, "SecurityScheme"]] = None,
17281756
security: Optional[List[Dict[str, List[str]]]] = None,
1757+
openapi_extensions: Optional[Dict[str, Any]] = None,
17291758
) -> str:
17301759
"""
17311760
Returns the OpenAPI schema as a JSON serializable dict
@@ -1756,6 +1785,8 @@ def get_openapi_json_schema(
17561785
A declaration of the security schemes available to be used in the specification.
17571786
security: List[Dict[str, List[str]]], optional
17581787
A declaration of which security mechanisms are applied globally across the API.
1788+
openapi_extensions: Dict[str, Any], optional
1789+
Additional OpenAPI extensions as a dictionary.
17591790
17601791
Returns
17611792
-------
@@ -1778,6 +1809,7 @@ def get_openapi_json_schema(
17781809
license_info=license_info,
17791810
security_schemes=security_schemes,
17801811
security=security,
1812+
openapi_extensions=openapi_extensions,
17811813
),
17821814
by_alias=True,
17831815
exclude_none=True,
@@ -1805,6 +1837,7 @@ def enable_swagger(
18051837
security: Optional[List[Dict[str, List[str]]]] = None,
18061838
oauth2_config: Optional["OAuth2Config"] = None,
18071839
persist_authorization: bool = False,
1840+
openapi_extensions: Optional[Dict[str, Any]] = None,
18081841
):
18091842
"""
18101843
Returns the OpenAPI schema as a JSON serializable dict
@@ -1847,6 +1880,8 @@ def enable_swagger(
18471880
The OAuth2 configuration for the Swagger UI.
18481881
persist_authorization: bool, optional
18491882
Whether to persist authorization data on browser close/refresh.
1883+
openapi_extensions: Dict[str, Any], optional
1884+
Additional OpenAPI extensions as a dictionary.
18501885
"""
18511886
from aws_lambda_powertools.event_handler.openapi.compat import model_json
18521887
from aws_lambda_powertools.event_handler.openapi.models import Server
@@ -1896,6 +1931,7 @@ def swagger_handler():
18961931
license_info=license_info,
18971932
security_schemes=security_schemes,
18981933
security=security,
1934+
openapi_extensions=openapi_extensions,
18991935
)
19001936

19011937
# The .replace('</', '<\\/') part is necessary to prevent a potential issue where the JSON string contains
@@ -1949,6 +1985,7 @@ def route(
19491985
operation_id: Optional[str] = None,
19501986
include_in_schema: bool = True,
19511987
security: Optional[List[Dict[str, List[str]]]] = None,
1988+
openapi_extensions: Optional[Dict[str, Any]] = None,
19521989
middlewares: Optional[List[Callable[..., Any]]] = None,
19531990
):
19541991
"""Route decorator includes parameter `method`"""
@@ -1976,6 +2013,7 @@ def register_resolver(func: Callable):
19762013
operation_id,
19772014
include_in_schema,
19782015
security,
2016+
openapi_extensions,
19792017
middlewares,
19802018
)
19812019

@@ -2489,6 +2527,7 @@ def route(
24892527
operation_id: Optional[str] = None,
24902528
include_in_schema: bool = True,
24912529
security: Optional[List[Dict[str, List[str]]]] = None,
2530+
openapi_extensions: Optional[Dict[str, Any]] = None,
24922531
middlewares: Optional[List[Callable[..., Any]]] = None,
24932532
):
24942533
def register_route(func: Callable):
@@ -2497,6 +2536,7 @@ def register_route(func: Callable):
24972536
frozen_responses = _FrozenDict(responses) if responses else None
24982537
frozen_tags = frozenset(tags) if tags else None
24992538
frozen_security = _FrozenListDict(security) if security else None
2539+
fronzen_openapi_extensions = _FrozenDict(openapi_extensions) if openapi_extensions else None
25002540

25012541
route_key = (
25022542
rule,
@@ -2512,6 +2552,7 @@ def register_route(func: Callable):
25122552
operation_id,
25132553
include_in_schema,
25142554
frozen_security,
2555+
fronzen_openapi_extensions,
25152556
)
25162557

25172558
# Collate Middleware for routes
@@ -2592,6 +2633,7 @@ def route(
25922633
operation_id: Optional[str] = None,
25932634
include_in_schema: bool = True,
25942635
security: Optional[List[Dict[str, List[str]]]] = None,
2636+
openapi_extensions: Optional[Dict[str, Any]] = None,
25952637
middlewares: Optional[List[Callable[..., Any]]] = None,
25962638
):
25972639
# NOTE: see #1552 for more context.
@@ -2609,6 +2651,7 @@ def route(
26092651
operation_id,
26102652
include_in_schema,
26112653
security,
2654+
openapi_extensions,
26122655
middlewares,
26132656
)
26142657

aws_lambda_powertools/event_handler/bedrock_agent.py

+11
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ def get( # type: ignore[override]
102102
include_in_schema: bool = True,
103103
middlewares: Optional[List[Callable[..., Any]]] = None,
104104
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
105+
106+
openapi_extensions = None
105107
security = None
106108

107109
return super(BedrockAgentResolver, self).get(
@@ -117,6 +119,7 @@ def get( # type: ignore[override]
117119
operation_id,
118120
include_in_schema,
119121
security,
122+
openapi_extensions,
120123
middlewares,
121124
)
122125

@@ -137,6 +140,7 @@ def post( # type: ignore[override]
137140
include_in_schema: bool = True,
138141
middlewares: Optional[List[Callable[..., Any]]] = None,
139142
):
143+
openapi_extensions = None
140144
security = None
141145

142146
return super().post(
@@ -152,6 +156,7 @@ def post( # type: ignore[override]
152156
operation_id,
153157
include_in_schema,
154158
security,
159+
openapi_extensions,
155160
middlewares,
156161
)
157162

@@ -172,6 +177,7 @@ def put( # type: ignore[override]
172177
include_in_schema: bool = True,
173178
middlewares: Optional[List[Callable[..., Any]]] = None,
174179
):
180+
openapi_extensions = None
175181
security = None
176182

177183
return super().put(
@@ -187,6 +193,7 @@ def put( # type: ignore[override]
187193
operation_id,
188194
include_in_schema,
189195
security,
196+
openapi_extensions,
190197
middlewares,
191198
)
192199

@@ -207,6 +214,7 @@ def patch( # type: ignore[override]
207214
include_in_schema: bool = True,
208215
middlewares: Optional[List[Callable]] = None,
209216
):
217+
openapi_extensions = None
210218
security = None
211219

212220
return super().patch(
@@ -222,6 +230,7 @@ def patch( # type: ignore[override]
222230
operation_id,
223231
include_in_schema,
224232
security,
233+
openapi_extensions,
225234
middlewares,
226235
)
227236

@@ -242,6 +251,7 @@ def delete( # type: ignore[override]
242251
include_in_schema: bool = True,
243252
middlewares: Optional[List[Callable[..., Any]]] = None,
244253
):
254+
openapi_extensions = None
245255
security = None
246256

247257
return super().delete(
@@ -257,6 +267,7 @@ def delete( # type: ignore[override]
257267
operation_id,
258268
include_in_schema,
259269
security,
270+
openapi_extensions,
260271
middlewares,
261272
)
262273

aws_lambda_powertools/event_handler/openapi/compat.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
RequestErrorModel: Type[BaseModel] = create_model("Request")
4242

4343
if PYDANTIC_V2: # pragma: no cover # false positive; dropping in v3
44-
from pydantic import TypeAdapter, ValidationError
44+
from pydantic import TypeAdapter, ValidationError, model_validator as parser_openapi_extension
4545
from pydantic._internal._typing_extra import eval_type_lenient
4646
from pydantic.fields import FieldInfo
4747
from pydantic._internal._utils import lenient_issubclass
@@ -217,7 +217,7 @@ def model_json(model: BaseModel, **kwargs: Any) -> Any:
217217
return model.model_dump_json(**kwargs)
218218

219219
else:
220-
from pydantic import BaseModel, ValidationError
220+
from pydantic import BaseModel, ValidationError, root_validator as parser_openapi_extension
221221
from pydantic.fields import (
222222
ModelField,
223223
Required,

0 commit comments

Comments
 (0)