Skip to content

Commit b8390b0

Browse files
Adding more tests
1 parent e81deb7 commit b8390b0

File tree

3 files changed

+38
-3
lines changed

3 files changed

+38
-3
lines changed

aws_lambda_powertools/event_handler/openapi/models.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from pydantic import AnyUrl, BaseModel, Field
55

66
from aws_lambda_powertools.event_handler.openapi.compat import model_rebuild, parser_openapi_extension
7+
from aws_lambda_powertools.event_handler.openapi.exceptions import SchemaValidationError
78
from aws_lambda_powertools.event_handler.openapi.pydantic_loader import PYDANTIC_V2
89
from aws_lambda_powertools.shared.types import Annotated, Literal
910

@@ -25,6 +26,7 @@ class OpenAPIExtensions(BaseModel):
2526

2627
# This rule is valid for Pydantic v1 and v2
2728
# If the 'openapi_extensions' field is present in the 'values' dictionary,
29+
# And if the extension starts with x-
2830
# update the 'values' dictionary with the contents of 'openapi_extensions',
2931
# and then remove the 'openapi_extensions' field from the 'values' dictionary
3032

@@ -34,8 +36,15 @@ class OpenAPIExtensions(BaseModel):
3436

3537
@parser_openapi_extension(mode="before")
3638
def serialize_openapi_extension_v2(self):
37-
if isinstance(self, dict) and self.get("openapi_extensions"):
38-
self.update(self.get("openapi_extensions"))
39+
openapi_extension_value = self.get("openapi_extensions")
40+
41+
if isinstance(self, dict) and openapi_extension_value:
42+
43+
for extension_key in openapi_extension_value:
44+
if not str(extension_key).startswith("x-"):
45+
raise SchemaValidationError("An OpenAPI extension key must start with x-")
46+
47+
self.update(openapi_extension_value)
3948
self.pop("openapi_extensions", None)
4049

4150
return self
@@ -44,9 +53,17 @@ def serialize_openapi_extension_v2(self):
4453

4554
@parser_openapi_extension(pre=False, allow_reuse=True)
4655
def serialize_openapi_extension_v1(cls, values):
47-
if values.get("openapi_extensions"):
56+
openapi_extension_value = values.get("openapi_extensions")
57+
58+
if openapi_extension_value:
59+
60+
for extension_key in openapi_extension_value:
61+
if not str(extension_key).startswith("x-"):
62+
raise SchemaValidationError("An OpenAPI extension key must start with x-")
63+
4864
values.update(values["openapi_extensions"])
4965
del values["openapi_extensions"]
66+
5067
return values
5168

5269
class Config:

tests/unit/event_handler/_pydantic/test_openapi_models_pydantic_v1.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import pytest
22

3+
from aws_lambda_powertools.event_handler.openapi.exceptions import SchemaValidationError
34
from aws_lambda_powertools.event_handler.openapi.models import OpenAPIExtensions
45

56

@@ -12,6 +13,14 @@ def test_openapi_extensions_with_dict():
1213
assert extensions.dict(exclude_none=True) == {"x-amazon-apigateway": {"foo": "bar"}}
1314

1415

16+
@pytest.mark.usefixtures("pydanticv1_only")
17+
def test_openapi_extensions_with_invalid_key():
18+
# GIVEN we create an OpenAPIExtensions object with an invalid value
19+
with pytest.raises(SchemaValidationError):
20+
# THEN must raise an exception
21+
OpenAPIExtensions(openapi_extensions={"amazon-apigateway-invalid": {"foo": "bar"}})
22+
23+
1524
@pytest.mark.usefixtures("pydanticv1_only")
1625
def test_openapi_extensions_with_proxy_models():
1726

tests/unit/event_handler/_pydantic/test_openapi_models_pydantic_v2.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import pytest
22

3+
from aws_lambda_powertools.event_handler.openapi.exceptions import SchemaValidationError
34
from aws_lambda_powertools.event_handler.openapi.models import OpenAPIExtensions
45

56

@@ -12,6 +13,14 @@ def test_openapi_extensions_with_dict():
1213
assert extensions.model_dump(exclude_none=True) == {"x-amazon-apigateway": {"foo": "bar"}}
1314

1415

16+
@pytest.mark.usefixtures("pydanticv2_only")
17+
def test_openapi_extensions_with_invalid_key():
18+
# GIVEN we create an OpenAPIExtensions object with an invalid value
19+
with pytest.raises(SchemaValidationError):
20+
# THEN must raise an exception
21+
OpenAPIExtensions(openapi_extensions={"amazon-apigateway-invalid": {"foo": "bar"}})
22+
23+
1524
@pytest.mark.usefixtures("pydanticv2_only")
1625
def test_openapi_extensions_with_proxy_models():
1726

0 commit comments

Comments
 (0)