Skip to content

Commit 270516f

Browse files
Polishing the PR with best practicies - using model_validator to be more specific
1 parent 8a111d2 commit 270516f

File tree

4 files changed

+57
-62
lines changed

4 files changed

+57
-62
lines changed

aws_lambda_powertools/event_handler/api_gateway.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1623,7 +1623,7 @@ def get_openapi_schema(
16231623

16241624
info.update({field: value for field, value in optional_fields.items() if value})
16251625

1626-
if not openapi_extensions:
1626+
if not isinstance(openapi_extensions, Dict):
16271627
openapi_extensions = {}
16281628

16291629
output: Dict[str, Any] = {

aws_lambda_powertools/event_handler/openapi/compat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
RequestErrorModel: Type[BaseModel] = create_model("Request")
4242

4343
if PYDANTIC_V2: # pragma: no cover # false positive; dropping in v3
44-
from pydantic import TypeAdapter, ValidationError, model_serializer as parser_openapi_extension
44+
from pydantic import TypeAdapter, ValidationError, model_validator as parser_openapi_extension
4545
from pydantic._internal._typing_extra import eval_type_lenient
4646
from pydantic.fields import FieldInfo
4747
from pydantic._internal._utils import lenient_issubclass

aws_lambda_powertools/event_handler/openapi/models.py

Lines changed: 36 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,38 @@
1313
"""
1414

1515

16+
class OpenapiExtensions(BaseModel):
17+
openapi_extensions: Optional[Dict[str, Any]] = None
18+
19+
# This rule is valid for Pydantic v1 and v2
20+
# If the 'openapi_extensions' field is present in the 'values' dictionary,
21+
# update the 'values' dictionary with the contents of 'openapi_extensions',
22+
# and then remove the 'openapi_extensions' field from the 'values' dictionary
23+
24+
if PYDANTIC_V2:
25+
model_config = {"extra": "allow"}
26+
27+
@parser_openapi_extension(mode="before")
28+
def serialize_openapi_extension(self):
29+
if isinstance(self, dict) and self.get("openapi_extensions"):
30+
self.update(self.get("openapi_extensions"))
31+
self.pop("openapi_extensions", None)
32+
33+
return self
34+
35+
else:
36+
37+
@parser_openapi_extension(pre=False, allow_reuse=True)
38+
def serialize_openapi_extension(cls, values):
39+
if values.get("openapi_extensions"):
40+
values.update(values["openapi_extensions"])
41+
del values["openapi_extensions"]
42+
return values
43+
44+
class Config:
45+
extra = "allow"
46+
47+
1648
# https://swagger.io/specification/#contact-object
1749
class Contact(BaseModel):
1850
name: Optional[str] = None
@@ -77,33 +109,16 @@ class Config:
77109

78110

79111
# https://swagger.io/specification/#server-object
80-
class Server(BaseModel):
112+
class Server(OpenapiExtensions):
81113
url: Union[AnyUrl, str]
82114
description: Optional[str] = None
83115
variables: Optional[Dict[str, ServerVariable]] = None
84-
openapi_extensions: Optional[Dict[str, Any]] = None
85116

86117
if PYDANTIC_V2:
87118
model_config = {"extra": "allow"}
88119

89-
@parser_openapi_extension()
90-
def serialize(self):
91-
# If the 'openapi_extensions' field is not None, return it
92-
if self.openapi_extensions:
93-
return self.openapi_extensions
94-
95120
else:
96121

97-
# If the 'openapi_extensions' field is present in the 'values' dictionary,
98-
# update the 'values' dictionary with the contents of 'openapi_extensions',
99-
# and then remove the 'openapi_extensions' field from the 'values' dictionary
100-
@parser_openapi_extension(pre=False, allow_reuse=True)
101-
def check_json(cls, values):
102-
if values.get("openapi_extensions"):
103-
values.update(values["openapi_extensions"])
104-
del values["openapi_extensions"]
105-
return values
106-
107122
class Config:
108123
extra = "allow"
109124

@@ -396,7 +411,7 @@ class Config:
396411

397412

398413
# https://swagger.io/specification/#operation-object
399-
class Operation(BaseModel):
414+
class Operation(OpenapiExtensions):
400415
tags: Optional[List[str]] = None
401416
summary: Optional[str] = None
402417
description: Optional[str] = None
@@ -410,23 +425,12 @@ class Operation(BaseModel):
410425
deprecated: Optional[bool] = None
411426
security: Optional[List[Dict[str, List[str]]]] = None
412427
servers: Optional[List[Server]] = None
413-
openapi_extensions: Optional[Dict[str, Any]] = None
414428

415429
if PYDANTIC_V2:
416430
model_config = {"extra": "allow"}
417431

418432
else:
419433

420-
# If the 'openapi_extensions' field is present in the 'values' dictionary,
421-
# update the 'values' dictionary with the contents of 'openapi_extensions',
422-
# and then remove the 'openapi_extensions' field from the 'values' dictionary
423-
@parser_openapi_extension(pre=False, allow_reuse=True)
424-
def check_json(cls, values):
425-
if values.get("openapi_extensions"):
426-
values.update(values["openapi_extensions"])
427-
del values["openapi_extensions"]
428-
return values
429-
430434
class Config:
431435
extra = "allow"
432436

@@ -464,32 +468,15 @@ class SecuritySchemeType(Enum):
464468
openIdConnect = "openIdConnect"
465469

466470

467-
class SecurityBase(BaseModel):
471+
class SecurityBase(OpenapiExtensions):
468472
type_: SecuritySchemeType = Field(alias="type")
469473
description: Optional[str] = None
470-
openapi_extensions: Optional[Dict[str, Any]] = None
471474

472475
if PYDANTIC_V2:
473476
model_config = {"extra": "allow", "populate_by_name": True}
474477

475-
@parser_openapi_extension()
476-
def serialize(self):
477-
# If the 'openapi_extensions' field is not None, return it
478-
if self.openapi_extensions:
479-
return self.openapi_extensions
480-
481478
else:
482479

483-
# If the 'openapi_extensions' field is present in the 'values' dictionary,
484-
# update the 'values' dictionary with the contents of 'openapi_extensions',
485-
# and then remove the 'openapi_extensions' field from the 'values' dictionary
486-
@parser_openapi_extension(pre=False, allow_reuse=True)
487-
def check_json(cls, values):
488-
if values.get("openapi_extensions"):
489-
values.update(values["openapi_extensions"])
490-
del values["openapi_extensions"]
491-
return values
492-
493480
class Config:
494481
extra = "allow"
495482
allow_population_by_field_name = True
@@ -602,7 +589,7 @@ class Config:
602589

603590

604591
# https://swagger.io/specification/#openapi-object
605-
class OpenAPI(BaseModel):
592+
class OpenAPI(OpenapiExtensions):
606593
openapi: str
607594
info: Info
608595
jsonSchemaDialect: Optional[str] = None
@@ -614,23 +601,12 @@ class OpenAPI(BaseModel):
614601
security: Optional[List[Dict[str, List[str]]]] = None
615602
tags: Optional[List[Tag]] = None
616603
externalDocs: Optional[ExternalDocumentation] = None
617-
openapi_extensions: Optional[Dict[str, Any]] = None
618604

619605
if PYDANTIC_V2:
620606
model_config = {"extra": "allow"}
621607

622608
else:
623609

624-
# If the 'openapi_extensions' field is present in the 'values' dictionary,
625-
# update the 'values' dictionary with the contents of 'openapi_extensions',
626-
# and then remove the 'openapi_extensions' field from the 'values' dictionary
627-
@parser_openapi_extension(pre=False, allow_reuse=True)
628-
def check_json(cls, values):
629-
if values.get("openapi_extensions"):
630-
values.update(values["openapi_extensions"])
631-
del values["openapi_extensions"]
632-
return values
633-
634610
class Config:
635611
extra = "allow"
636612

tests/functional/event_handler/_pydantic/test_openapi_extensions.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ def test_openapi_extension_server_level():
6565
# THEN the OpenAPI schema must contain the "x-amazon-apigateway-endpoint-configuration" at the server level
6666
assert "x-amazon-apigateway-endpoint-configuration" in schema["servers"][0]
6767
assert schema["servers"][0]["x-amazon-apigateway-endpoint-configuration"] == endpoint_config
68+
assert schema["servers"][0]["url"] == server_config["url"]
69+
assert schema["servers"][0]["description"] == server_config["description"]
6870

6971

7072
def test_openapi_extension_security_scheme_level_with_api_key():
@@ -102,6 +104,9 @@ def test_openapi_extension_security_scheme_level_with_api_key():
102104
assert "x-amazon-apigateway-authtype" in schema["components"]["securitySchemes"]["apiKey"]
103105
assert schema["components"]["securitySchemes"]["apiKey"]["x-amazon-apigateway-authtype"] == "custom"
104106
assert schema["components"]["securitySchemes"]["apiKey"]["x-amazon-apigateway-authorizer"] == authorizer_config
107+
assert schema["components"]["securitySchemes"]["apiKey"]["name"] == api_key_config["name"]
108+
assert schema["components"]["securitySchemes"]["apiKey"]["description"] == api_key_config["description"]
109+
assert schema["components"]["securitySchemes"]["apiKey"]["in"] == "header"
105110

106111

107112
def test_openapi_extension_security_scheme_level_with_oauth2():
@@ -142,6 +147,19 @@ def test_openapi_extension_security_scheme_level_with_oauth2():
142147
# THEN the OpenAPI schema must contain the "x-amazon-apigateway-authorizer" extension at the security scheme level
143148
assert "x-amazon-apigateway-authorizer" in schema["components"]["securitySchemes"]["oauth2"]
144149
assert schema["components"]["securitySchemes"]["oauth2"]["x-amazon-apigateway-authorizer"] == authorizer_config
150+
assert (
151+
schema["components"]["securitySchemes"]["oauth2"]["x-amazon-apigateway-authorizer"]["identitySource"]
152+
== "$request.header.Authorization"
153+
)
154+
assert schema["components"]["securitySchemes"]["oauth2"]["x-amazon-apigateway-authorizer"]["jwtConfiguration"][
155+
"audience"
156+
] == ["test"]
157+
assert (
158+
schema["components"]["securitySchemes"]["oauth2"]["x-amazon-apigateway-authorizer"]["jwtConfiguration"][
159+
"issuer"
160+
]
161+
== "https://cognito-idp.us-east-1.amazonaws.com/us-east-1_xxxxx/"
162+
)
145163

146164

147165
def test_openapi_extension_operation_level(openapi_extension_integration_detail):
@@ -159,6 +177,7 @@ def lambda_handler():
159177
# THEN the OpenAPI schema must contain the "x-amazon-apigateway-integration" extension at the operation level
160178
assert "x-amazon-apigateway-integration" in schema["paths"]["/test"]["get"]
161179
assert schema["paths"]["/test"]["get"]["x-amazon-apigateway-integration"] == openapi_extension_integration_detail
180+
assert schema["paths"]["/test"]["get"]["operationId"] == "lambda_handler_test_get"
162181

163182

164183
def test_openapi_extension_operation_level_multiple_paths(

0 commit comments

Comments
 (0)