Skip to content

Commit 8a111d2

Browse files
Polishing the PR with best practicies - make pydanticv2 happy
1 parent 84d035a commit 8a111d2

File tree

1 file changed

+61
-31
lines changed
  • aws_lambda_powertools/event_handler/openapi

1 file changed

+61
-31
lines changed

aws_lambda_powertools/event_handler/openapi/models.py

Lines changed: 61 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -13,32 +13,6 @@
1313
"""
1414

1515

16-
class OpenapiExtensions(BaseModel):
17-
"""OpenAPI extensions, see https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.1.0.md#specification-extensions"""
18-
19-
openapi_extensions: Optional[Dict[str, Any]] = None
20-
21-
if PYDANTIC_V2:
22-
23-
@parser_openapi_extension()
24-
def serialize(self):
25-
# If the 'openapi_extensions' field is not None, return it
26-
if self.openapi_extensions:
27-
return self.openapi_extensions
28-
29-
else:
30-
31-
# If the 'openapi_extensions' field is present in the 'values' dictionary,
32-
# update the 'values' dictionary with the contents of 'openapi_extensions',
33-
# and then remove the 'openapi_extensions' field from the 'values' dictionary
34-
@parser_openapi_extension(pre=False, allow_reuse=True)
35-
def check_json(cls, values):
36-
if values.get("openapi_extensions"):
37-
values.update(values["openapi_extensions"])
38-
del values["openapi_extensions"]
39-
return values
40-
41-
4216
# https://swagger.io/specification/#contact-object
4317
class Contact(BaseModel):
4418
name: Optional[str] = None
@@ -103,16 +77,33 @@ class Config:
10377

10478

10579
# https://swagger.io/specification/#server-object
106-
class Server(OpenapiExtensions):
80+
class Server(BaseModel):
10781
url: Union[AnyUrl, str]
10882
description: Optional[str] = None
10983
variables: Optional[Dict[str, ServerVariable]] = None
84+
openapi_extensions: Optional[Dict[str, Any]] = None
11085

11186
if PYDANTIC_V2:
11287
model_config = {"extra": "allow"}
11388

89+
@parser_openapi_extension()
90+
def serialize(self):
91+
# If the 'openapi_extensions' field is not None, return it
92+
if self.openapi_extensions:
93+
return self.openapi_extensions
94+
11495
else:
11596

97+
# If the 'openapi_extensions' field is present in the 'values' dictionary,
98+
# update the 'values' dictionary with the contents of 'openapi_extensions',
99+
# and then remove the 'openapi_extensions' field from the 'values' dictionary
100+
@parser_openapi_extension(pre=False, allow_reuse=True)
101+
def check_json(cls, values):
102+
if values.get("openapi_extensions"):
103+
values.update(values["openapi_extensions"])
104+
del values["openapi_extensions"]
105+
return values
106+
116107
class Config:
117108
extra = "allow"
118109

@@ -405,7 +396,7 @@ class Config:
405396

406397

407398
# https://swagger.io/specification/#operation-object
408-
class Operation(OpenapiExtensions):
399+
class Operation(BaseModel):
409400
tags: Optional[List[str]] = None
410401
summary: Optional[str] = None
411402
description: Optional[str] = None
@@ -419,12 +410,23 @@ class Operation(OpenapiExtensions):
419410
deprecated: Optional[bool] = None
420411
security: Optional[List[Dict[str, List[str]]]] = None
421412
servers: Optional[List[Server]] = None
413+
openapi_extensions: Optional[Dict[str, Any]] = None
422414

423415
if PYDANTIC_V2:
424416
model_config = {"extra": "allow"}
425417

426418
else:
427419

420+
# If the 'openapi_extensions' field is present in the 'values' dictionary,
421+
# update the 'values' dictionary with the contents of 'openapi_extensions',
422+
# and then remove the 'openapi_extensions' field from the 'values' dictionary
423+
@parser_openapi_extension(pre=False, allow_reuse=True)
424+
def check_json(cls, values):
425+
if values.get("openapi_extensions"):
426+
values.update(values["openapi_extensions"])
427+
del values["openapi_extensions"]
428+
return values
429+
428430
class Config:
429431
extra = "allow"
430432

@@ -462,15 +464,32 @@ class SecuritySchemeType(Enum):
462464
openIdConnect = "openIdConnect"
463465

464466

465-
class SecurityBase(OpenapiExtensions):
467+
class SecurityBase(BaseModel):
466468
type_: SecuritySchemeType = Field(alias="type")
467469
description: Optional[str] = None
470+
openapi_extensions: Optional[Dict[str, Any]] = None
468471

469472
if PYDANTIC_V2:
470473
model_config = {"extra": "allow", "populate_by_name": True}
471474

475+
@parser_openapi_extension()
476+
def serialize(self):
477+
# If the 'openapi_extensions' field is not None, return it
478+
if self.openapi_extensions:
479+
return self.openapi_extensions
480+
472481
else:
473482

483+
# If the 'openapi_extensions' field is present in the 'values' dictionary,
484+
# update the 'values' dictionary with the contents of 'openapi_extensions',
485+
# and then remove the 'openapi_extensions' field from the 'values' dictionary
486+
@parser_openapi_extension(pre=False, allow_reuse=True)
487+
def check_json(cls, values):
488+
if values.get("openapi_extensions"):
489+
values.update(values["openapi_extensions"])
490+
del values["openapi_extensions"]
491+
return values
492+
474493
class Config:
475494
extra = "allow"
476495
allow_population_by_field_name = True
@@ -560,7 +579,7 @@ class OpenIdConnect(SecurityBase):
560579

561580

562581
# https://swagger.io/specification/#components-object
563-
class Components(OpenapiExtensions):
582+
class Components(BaseModel):
564583
schemas: Optional[Dict[str, Union[Schema, Reference]]] = None
565584
responses: Optional[Dict[str, Union[Response, Reference]]] = None
566585
parameters: Optional[Dict[str, Union[Parameter, Reference]]] = None
@@ -583,7 +602,7 @@ class Config:
583602

584603

585604
# https://swagger.io/specification/#openapi-object
586-
class OpenAPI(OpenapiExtensions):
605+
class OpenAPI(BaseModel):
587606
openapi: str
588607
info: Info
589608
jsonSchemaDialect: Optional[str] = None
@@ -595,12 +614,23 @@ class OpenAPI(OpenapiExtensions):
595614
security: Optional[List[Dict[str, List[str]]]] = None
596615
tags: Optional[List[Tag]] = None
597616
externalDocs: Optional[ExternalDocumentation] = None
617+
openapi_extensions: Optional[Dict[str, Any]] = None
598618

599619
if PYDANTIC_V2:
600620
model_config = {"extra": "allow"}
601621

602622
else:
603623

624+
# If the 'openapi_extensions' field is present in the 'values' dictionary,
625+
# update the 'values' dictionary with the contents of 'openapi_extensions',
626+
# and then remove the 'openapi_extensions' field from the 'values' dictionary
627+
@parser_openapi_extension(pre=False, allow_reuse=True)
628+
def check_json(cls, values):
629+
if values.get("openapi_extensions"):
630+
values.update(values["openapi_extensions"])
631+
del values["openapi_extensions"]
632+
return values
633+
604634
class Config:
605635
extra = "allow"
606636

0 commit comments

Comments
 (0)