Skip to content

Commit e4c236b

Browse files
fix(event_handler): security scheme unhashable list when working with router (#4421)
1 parent 88c8e91 commit e4c236b

File tree

6 files changed

+191
-38
lines changed

6 files changed

+191
-38
lines changed

aws_lambda_powertools/event_handler/api_gateway.py

+26-9
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from aws_lambda_powertools.event_handler import content_types
3434
from aws_lambda_powertools.event_handler.exceptions import NotFoundError, ServiceError
3535
from aws_lambda_powertools.event_handler.openapi.constants import DEFAULT_API_VERSION, DEFAULT_OPENAPI_VERSION
36-
from aws_lambda_powertools.event_handler.openapi.exceptions import RequestValidationError
36+
from aws_lambda_powertools.event_handler.openapi.exceptions import RequestValidationError, SchemaValidationError
3737
from aws_lambda_powertools.event_handler.openapi.types import (
3838
COMPONENT_REF_PREFIX,
3939
METHODS_WITH_BODY,
@@ -43,7 +43,12 @@
4343
validation_error_definition,
4444
validation_error_response_definition,
4545
)
46-
from aws_lambda_powertools.event_handler.util import _FrozenDict, extract_origin_header
46+
from aws_lambda_powertools.event_handler.util import (
47+
_FrozenDict,
48+
_FrozenListDict,
49+
_validate_openapi_security_parameters,
50+
extract_origin_header,
51+
)
4752
from aws_lambda_powertools.shared.cookies import Cookie
4853
from aws_lambda_powertools.shared.functions import powertools_dev_is_set
4954
from aws_lambda_powertools.shared.json_encoder import Encoder
@@ -703,6 +708,7 @@ def _openapi_operation_parameters(
703708
from aws_lambda_powertools.event_handler.openapi.params import Param
704709

705710
parameters = []
711+
parameter: Dict[str, Any]
706712
for param in all_route_params:
707713
field_info = param.field_info
708714
field_info = cast(Param, field_info)
@@ -1588,6 +1594,16 @@ def get_openapi_schema(
15881594

15891595
# Add routes to the OpenAPI schema
15901596
for route in all_routes:
1597+
1598+
if route.security and not _validate_openapi_security_parameters(
1599+
security=route.security,
1600+
security_schemes=security_schemes,
1601+
):
1602+
raise SchemaValidationError(
1603+
"Security configuration was not found in security_schemas or security_schema was not defined. "
1604+
"See: https://docs.powertools.aws.dev/lambda/python/latest/core/event_handler/api_gateway/#security-schemes",
1605+
)
1606+
15911607
if not route.include_in_schema:
15921608
continue
15931609

@@ -1630,15 +1646,15 @@ def _get_openapi_security(
16301646
security: Optional[List[Dict[str, List[str]]]],
16311647
security_schemes: Optional[Dict[str, "SecurityScheme"]],
16321648
) -> Optional[List[Dict[str, List[str]]]]:
1649+
16331650
if not security:
16341651
return None
16351652

1636-
if not security_schemes:
1637-
raise ValueError("security_schemes must be provided if security is provided")
1638-
1639-
# Check if all keys in security are present in the security_schemes
1640-
if any(key not in security_schemes for sec in security for key in sec):
1641-
raise ValueError("Some security schemes not found in security_schemes")
1653+
if not _validate_openapi_security_parameters(security=security, security_schemes=security_schemes):
1654+
raise SchemaValidationError(
1655+
"Security configuration was not found in security_schemas or security_schema was not defined. "
1656+
"See: https://docs.powertools.aws.dev/lambda/python/latest/core/event_handler/api_gateway/#security-schemes",
1657+
)
16421658

16431659
return security
16441660

@@ -2386,6 +2402,7 @@ def register_route(func: Callable):
23862402
methods = (method,) if isinstance(method, str) else tuple(method)
23872403
frozen_responses = _FrozenDict(responses) if responses else None
23882404
frozen_tags = frozenset(tags) if tags else None
2405+
frozen_security = _FrozenListDict(security) if security else None
23892406

23902407
route_key = (
23912408
rule,
@@ -2400,7 +2417,7 @@ def register_route(func: Callable):
24002417
frozen_tags,
24012418
operation_id,
24022419
include_in_schema,
2403-
security,
2420+
frozen_security,
24042421
)
24052422

24062423
# Collate Middleware for routes

aws_lambda_powertools/event_handler/openapi/exceptions.py

+6
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,9 @@ class RequestValidationError(ValidationException):
2121
def __init__(self, errors: Sequence[Any], *, body: Any = None) -> None:
2222
super().__init__(errors)
2323
self.body = body
24+
25+
26+
class SchemaValidationError(ValidationException):
27+
"""
28+
Raised when the OpenAPI schema validation fails
29+
"""

aws_lambda_powertools/event_handler/util.py

+63-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
from typing import Any, Dict
1+
from typing import Any, Dict, List, Optional
22

3+
from aws_lambda_powertools.event_handler.openapi.models import SecurityScheme
34
from aws_lambda_powertools.utilities.data_classes.shared_functions import get_header_value
45

56

@@ -18,17 +19,45 @@ def __hash__(self):
1819
return hash(frozenset(self.keys()))
1920

2021

22+
class _FrozenListDict(List[Dict[str, List[str]]]):
23+
"""
24+
Freezes a list of dictionaries containing lists of strings.
25+
26+
This function takes a list of dictionaries where the values are lists of strings and converts it into
27+
a frozen set of frozen sets of frozen dictionaries. This is done by iterating over the input list,
28+
converting each dictionary's values (lists of strings) into frozen sets of strings, and then
29+
converting the resulting dictionary into a frozen dictionary. Finally, all these frozen dictionaries
30+
are collected into a frozen set of frozen sets.
31+
32+
This operation is useful when you want to ensure the immutability of the data structure and make it
33+
hashable, which is required for certain operations like using it as a key in a dictionary or as an
34+
element in a set.
35+
36+
Example: [{"TestAuth": ["test", "test1"]}]
37+
"""
38+
39+
def __hash__(self):
40+
hashable_items = []
41+
for item in self:
42+
hashable_items.extend((key, frozenset(value)) for key, value in item.items())
43+
return hash(frozenset(hashable_items))
44+
45+
2146
def extract_origin_header(resolver_headers: Dict[str, Any]):
2247
"""
2348
Extracts the 'origin' or 'Origin' header from the provided resolver headers.
2449
2550
The 'origin' or 'Origin' header can be either a single header or a multi-header.
2651
27-
Args:
28-
resolver_headers (Dict): A dictionary containing the headers.
52+
Parameters
53+
----------
54+
resolver_headers: Dict
55+
A dictionary containing the headers.
2956
30-
Returns:
31-
Optional[str]: The value(s) of the origin header or None.
57+
Returns
58+
-------
59+
Optional[str]
60+
The value(s) of the origin header or None.
3261
"""
3362
resolved_header = get_header_value(
3463
headers=resolver_headers,
@@ -40,3 +69,32 @@ def extract_origin_header(resolver_headers: Dict[str, Any]):
4069
return resolved_header[0]
4170

4271
return resolved_header
72+
73+
74+
def _validate_openapi_security_parameters(
75+
security: List[Dict[str, List[str]]],
76+
security_schemes: Optional[Dict[str, "SecurityScheme"]] = None,
77+
) -> bool:
78+
"""
79+
This function checks if all security requirements listed in the 'security'
80+
parameter are defined in the 'security_schemes' dictionary, as specified
81+
in the OpenAPI schema.
82+
83+
Parameters
84+
----------
85+
security: List[Dict[str, List[str]]]
86+
A list of security requirements
87+
security_schemes: Optional[Dict[str, "SecurityScheme"]]
88+
A dictionary mapping security scheme names to their corresponding security scheme objects.
89+
90+
Returns
91+
-------
92+
bool
93+
Whether list of security schemes match allowed security_schemes.
94+
"""
95+
96+
security_schemes = security_schemes or {}
97+
98+
security_schema_match = all(key in security_schemes for sec in security for key in sec)
99+
100+
return bool(security_schema_match and security_schemes)

docs/core/event_handler/api_gateway.md

+1-2
Original file line numberDiff line numberDiff line change
@@ -1032,8 +1032,7 @@ Below is an example configuration for serving Swagger UI from a custom path or C
10321032
???-info "Does Powertools implement any of the security schemes?"
10331033
No. Powertools adds support for generating OpenAPI documentation with [security schemes](https://swagger.io/docs/specification/authentication/), but it doesn't implement any of the security schemes itself, so you must implement the security mechanisms separately.
10341034

1035-
OpenAPI uses the term security scheme for [authentication and authorization schemes](https://swagger.io/docs/specification/authentication/){target="_blank"}.
1036-
When you're describing your API, declare security schemes at the top level, and reference them globally or per operation.
1035+
Security schemes are declared at the top-level first. You can reference them globally or on a per path _(operation)_ level. **However**, if you reference security schemes that are not defined at the top-level it will lead to a `SchemaValidationError` _(invalid OpenAPI spec)_.
10371036

10381037
=== "Global OpenAPI security schemes"
10391038

tests/functional/event_handler/conftest.py

+6
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import fastjsonschema
44
import pytest
55

6+
from aws_lambda_powertools.event_handler.openapi.models import APIKey, APIKeyIn
67
from tests.functional.utils import load_event
78

89

@@ -114,3 +115,8 @@ def openapi31_schema():
114115
data,
115116
use_formats=False,
116117
)
118+
119+
120+
@pytest.fixture
121+
def security_scheme():
122+
return {"apiKey": APIKey(name="X-API-KEY", description="API Key", in_=APIKeyIn.header)}
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,22 @@
11
import pytest
22

33
from aws_lambda_powertools.event_handler import APIGatewayRestResolver
4-
from aws_lambda_powertools.event_handler.openapi.models import APIKey, APIKeyIn
4+
from aws_lambda_powertools.event_handler.api_gateway import Router
5+
from aws_lambda_powertools.event_handler.openapi.exceptions import SchemaValidationError
56

67

7-
def test_openapi_top_level_security():
8+
def test_openapi_top_level_security(security_scheme):
9+
# GIVEN an APIGatewayRestResolver instance
810
app = APIGatewayRestResolver()
911

1012
@app.get("/")
1113
def handler():
1214
raise NotImplementedError()
1315

14-
schema = app.get_openapi_schema(
15-
security_schemes={
16-
"apiKey": APIKey(name="X-API-KEY", description="API Key", in_=APIKeyIn.header),
17-
},
18-
security=[{"apiKey": []}],
19-
)
16+
# WHEN the get_openapi_schema method is called with a security scheme
17+
schema = app.get_openapi_schema(security_schemes=security_scheme, security=[{"apiKey": []}])
2018

19+
# THEN the resulting schema should have security defined at the top level
2120
security = schema.security
2221
assert security is not None
2322

@@ -26,37 +25,105 @@ def handler():
2625

2726

2827
def test_openapi_top_level_security_missing():
28+
# GIVEN an APIGatewayRestResolver instance
2929
app = APIGatewayRestResolver()
3030

3131
@app.get("/")
3232
def handler():
3333
raise NotImplementedError()
3434

35-
with pytest.raises(ValueError):
35+
# WHEN the get_openapi_schema method is called with security defined without security schemes
36+
# THEN a SchemaValidationError should be raised
37+
with pytest.raises(SchemaValidationError):
3638
app.get_openapi_schema(
3739
security=[{"apiKey": []}],
3840
)
3941

4042

41-
def test_openapi_operation_security():
43+
def test_openapi_top_level_security_mismatch(security_scheme):
44+
# GIVEN an APIGatewayRestResolver instance
45+
app = APIGatewayRestResolver()
46+
47+
@app.get("/")
48+
def handler():
49+
raise NotImplementedError()
50+
51+
# WHEN the get_openapi_schema method is called with security defined security schemes as APIKey
52+
# AND top level security is defined as HTTPBearer
53+
# THEN a SchemaValidationError should be raised
54+
with pytest.raises(SchemaValidationError):
55+
app.get_openapi_schema(
56+
security_schemes=security_scheme,
57+
security=[{"HTTPBearer": []}],
58+
)
59+
60+
61+
def test_openapi_operation_level_security(security_scheme):
62+
# GIVEN an APIGatewayRestResolver instance
4263
app = APIGatewayRestResolver()
4364

4465
@app.get("/", security=[{"apiKey": []}])
4566
def handler():
4667
raise NotImplementedError()
4768

48-
schema = app.get_openapi_schema(
49-
security_schemes={
50-
"apiKey": APIKey(name="X-API-KEY", description="API Key", in_=APIKeyIn.header),
51-
},
52-
)
69+
# WHEN the get_openapi_schema method is called with security defined at the operation level
70+
schema = app.get_openapi_schema(security_schemes=security_scheme)
5371

54-
security = schema.security
55-
assert security is None
72+
# THEN the resulting schema should have security defined at the operation level, not the top level
73+
top_level_security = schema.security
74+
path_level_security = schema.paths["/"].get.security
75+
assert top_level_security is None
76+
assert path_level_security[0] == {"apiKey": []}
5677

57-
operation = schema.paths["/"].get
58-
security = operation.security
59-
assert security is not None
6078

61-
assert len(security) == 1
62-
assert security[0] == {"apiKey": []}
79+
def test_openapi_operation_level_security_missing():
80+
# GIVEN an APIGatewayRestResolver instance
81+
app = APIGatewayRestResolver()
82+
83+
# AND a route with a security scheme defined
84+
@app.get("/", security=[{"apiKey": []}])
85+
def handler():
86+
raise NotImplementedError()
87+
88+
# WHEN the get_openapi_schema method is called without security schemes defined
89+
# THEN a SchemaValidationError should be raised
90+
with pytest.raises(SchemaValidationError):
91+
app.get_openapi_schema()
92+
93+
94+
def test_openapi_operation_level_security_mismatch(security_scheme):
95+
# GIVEN an APIGatewayRestResolver instance
96+
app = APIGatewayRestResolver()
97+
98+
# AND a route with a security scheme using HTTPBearer
99+
@app.get("/", security=[{"HTTPBearer": []}])
100+
def handler():
101+
raise NotImplementedError()
102+
103+
# WHEN the get_openapi_schema method is called with security defined security schemes as APIKey
104+
# THEN a SchemaValidationError should be raised
105+
with pytest.raises(SchemaValidationError):
106+
app.get_openapi_schema(
107+
security_schemes=security_scheme,
108+
)
109+
110+
111+
def test_openapi_operation_level_security_with_router(security_scheme):
112+
# GIVEN an APIGatewayRestResolver instance with a Router
113+
app = APIGatewayRestResolver()
114+
router = Router()
115+
116+
@router.get("/", security=[{"apiKey": []}])
117+
def handler():
118+
raise NotImplementedError()
119+
120+
app.include_router(router)
121+
122+
# WHEN the get_openapi_schema method is called with security defined at the operation level in the Router
123+
schema = app.get_openapi_schema(security_schemes=security_scheme)
124+
125+
# THEN the resulting schema should have security defined at the operation level
126+
top_level_security = schema.security
127+
path_level_security = schema.paths["/"].get.security
128+
assert top_level_security is None
129+
assert path_level_security[0] == {"apiKey": []}

0 commit comments

Comments
 (0)