Skip to content

Commit 463dd47

Browse files
authored
fix(event_handler): do not skip middleware and exception handlers on 404 error (#4492)
* fix(parameters): make cache aware of single vs multiple calls Signed-off-by: heitorlessa <[email protected]> * chore: cleanup, add test for single and nested Signed-off-by: heitorlessa <[email protected]> * chore(ci): add first centralized reusable workflow * chore: remove playground * fix(apigateway): do not skip middleware and exception handlers on 404 * Delete bla.py * fix: default value for include_in_schema * chore: test preflight on notfound with middlware * docs: internal request processing --------- Signed-off-by: heitorlessa <[email protected]>
1 parent e8dfebf commit 463dd47

File tree

2 files changed

+171
-37
lines changed

2 files changed

+171
-37
lines changed

aws_lambda_powertools/event_handler/api_gateway.py

+113-37
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def to_dict(self, origin: Optional[str]) -> Dict[str, str]:
211211
# The origin matched an allowed origin, so return the CORS headers
212212
headers = {
213213
"Access-Control-Allow-Origin": origin,
214-
"Access-Control-Allow-Headers": ",".join(sorted(self.allow_headers)),
214+
"Access-Control-Allow-Headers": CORSConfig.build_allow_methods(self.allow_headers),
215215
}
216216

217217
if self.expose_headers:
@@ -222,6 +222,23 @@ def to_dict(self, origin: Optional[str]) -> Dict[str, str]:
222222
headers["Access-Control-Allow-Credentials"] = "true"
223223
return headers
224224

225+
@staticmethod
226+
def build_allow_methods(methods: Set[str]) -> str:
227+
"""Build sorted comma delimited methods for Access-Control-Allow-Methods header
228+
229+
Parameters
230+
----------
231+
methods : set[str]
232+
Set of HTTP Methods
233+
234+
Returns
235+
-------
236+
set[str]
237+
Formatted string with all HTTP Methods allowed for CORS e.g., `GET, OPTIONS`
238+
239+
"""
240+
return ",".join(sorted(methods))
241+
225242

226243
class Response(Generic[ResponseT]):
227244
"""Response data class that provides greater control over what is returned from the proxy event"""
@@ -282,16 +299,16 @@ def __init__(
282299
func: Callable,
283300
cors: bool,
284301
compress: bool,
285-
cache_control: Optional[str],
286-
summary: Optional[str],
287-
description: Optional[str],
288-
responses: Optional[Dict[int, OpenAPIResponse]],
289-
response_description: Optional[str],
290-
tags: Optional[List[str]],
291-
operation_id: Optional[str],
292-
include_in_schema: bool,
293-
security: Optional[List[Dict[str, List[str]]]],
294-
middlewares: Optional[List[Callable[..., Response]]],
302+
cache_control: Optional[str] = None,
303+
summary: Optional[str] = None,
304+
description: Optional[str] = None,
305+
responses: Optional[Dict[int, OpenAPIResponse]] = None,
306+
response_description: Optional[str] = None,
307+
tags: Optional[List[str]] = None,
308+
operation_id: Optional[str] = None,
309+
include_in_schema: bool = True,
310+
security: Optional[List[Dict[str, List[str]]]] = None,
311+
middlewares: Optional[List[Callable[..., Response]]] = None,
295312
):
296313
"""
297314
@@ -1406,7 +1423,6 @@ def _registered_api_adapter(
14061423
"""
14071424
route_args: Dict = app.context.get("_route_args", {})
14081425
logger.debug(f"Calling API Route Handler: {route_args}")
1409-
14101426
return app._to_response(next_middleware(**route_args))
14111427

14121428

@@ -1967,6 +1983,36 @@ def register_resolver(func: Callable):
19671983
def resolve(self, event, context) -> Dict[str, Any]:
19681984
"""Resolves the response based on the provide event and decorator routes
19691985
1986+
## Internals
1987+
1988+
Request processing chain is triggered by a Route object being called _(`_call_route` -> `__call__`)_:
1989+
1990+
1. **When a route is matched**
1991+
1.1. Exception handlers _(if any exception bubbled up and caught)_
1992+
1.2. Global middlewares _(before, and after on the way back)_
1993+
1.3. Path level middleware _(before, and after on the way back)_
1994+
1.4. Middleware adapter to ensure Response is homogenous (_registered_api_adapter)
1995+
1.5. Run actual route
1996+
2. **When a route is NOT matched**
1997+
2.1. Exception handlers _(if any exception bubbled up and caught)_
1998+
2.2. Global middlewares _(before, and after on the way back)_
1999+
2.3. Path level middleware _(before, and after on the way back)_
2000+
2.4. Middleware adapter to ensure Response is homogenous (_registered_api_adapter)
2001+
2.5. Run 404 route handler
2002+
3. **When a route is a pre-flight CORS (often not matched)**
2003+
3.1. Exception handlers _(if any exception bubbled up and caught)_
2004+
3.2. Global middlewares _(before, and after on the way back)_
2005+
3.3. Path level middleware _(before, and after on the way back)_
2006+
3.4. Middleware adapter to ensure Response is homogenous (_registered_api_adapter)
2007+
3.5. Return 204 with appropriate CORS headers
2008+
4. **When a route is matched with Data Validation enabled**
2009+
4.1. Exception handlers _(if any exception bubbled up and caught)_
2010+
4.2. Data Validation middleware _(before, and after on the way back)_
2011+
4.3. Global middlewares _(before, and after on the way back)_
2012+
4.4. Path level middleware _(before, and after on the way back)_
2013+
4.5. Middleware adapter to ensure Response is homogenous (_registered_api_adapter)
2014+
4.6. Run actual route
2015+
19702016
Parameters
19712017
----------
19722018
event: Dict[str, Any]
@@ -2090,7 +2136,9 @@ def _resolve(self) -> ResponseBuilder:
20902136
method = self.current_event.http_method.upper()
20912137
path = self._remove_prefix(self.current_event.path)
20922138

2093-
for route in self._static_routes + self._dynamic_routes:
2139+
registered_routes = self._static_routes + self._dynamic_routes
2140+
2141+
for route in registered_routes:
20942142
if method != route.method:
20952143
continue
20962144
match_results: Optional[Match] = route.rule.match(path)
@@ -2102,8 +2150,7 @@ def _resolve(self) -> ResponseBuilder:
21022150
route_keys = self._convert_matches_into_route_keys(match_results)
21032151
return self._call_route(route, route_keys) # pass fn args
21042152

2105-
logger.debug(f"No match found for path {path} and method {method}")
2106-
return self._not_found(method)
2153+
return self._handle_not_found(method=method, path=path)
21072154

21082155
def _remove_prefix(self, path: str) -> str:
21092156
"""Remove the configured prefix from the path"""
@@ -2141,36 +2188,65 @@ def _path_starts_with(path: str, prefix: str):
21412188

21422189
return path.startswith(prefix + "/")
21432190

2144-
def _not_found(self, method: str) -> ResponseBuilder:
2191+
def _handle_not_found(self, method: str, path: str) -> ResponseBuilder:
21452192
"""Called when no matching route was found and includes support for the cors preflight response"""
2146-
headers = {}
2147-
if self._cors:
2148-
logger.debug("CORS is enabled, updating headers.")
2149-
extracted_origin_header = extract_origin_header(self.current_event.resolved_headers_field)
2150-
headers.update(self._cors.to_dict(extracted_origin_header))
2151-
2152-
if method == "OPTIONS":
2153-
logger.debug("Pre-flight request detected. Returning CORS with null response")
2154-
headers["Access-Control-Allow-Methods"] = ",".join(sorted(self._cors_methods))
2155-
return ResponseBuilder(
2156-
response=Response(status_code=204, content_type=None, headers=headers, body=""),
2157-
serializer=self._serializer,
2158-
)
2193+
logger.debug(f"No match found for path {path} and method {method}")
21592194

2160-
handler = self._lookup_exception_handler(NotFoundError)
2161-
if handler:
2162-
return self._response_builder_class(response=handler(NotFoundError()), serializer=self._serializer)
2195+
def not_found_handler():
2196+
"""Route handler for 404s
2197+
2198+
It handles in the following order:
2199+
2200+
1. Pre-flight CORS requests (OPTIONS)
2201+
2. Detects and calls custom HTTP 404 handler
2202+
3. Returns standard 404 along with CORS headers
21632203
2164-
return self._response_builder_class(
2165-
response=Response(
2204+
Returns
2205+
-------
2206+
Response
2207+
HTTP 404 response
2208+
"""
2209+
_headers: Dict[str, Any] = {}
2210+
2211+
# Pre-flight request? Return immediately to avoid browser error
2212+
if self._cors and method == "OPTIONS":
2213+
logger.debug("Pre-flight request detected. Returning CORS with empty response")
2214+
_headers["Access-Control-Allow-Methods"] = CORSConfig.build_allow_methods(self._cors_methods)
2215+
2216+
return Response(status_code=204, content_type=None, headers=_headers, body="")
2217+
2218+
# Customer registered 404 route? Call it.
2219+
custom_not_found_handler = self._lookup_exception_handler(NotFoundError)
2220+
if custom_not_found_handler:
2221+
return custom_not_found_handler(NotFoundError())
2222+
2223+
# No CORS and no custom 404 fn? Default response
2224+
return Response(
21662225
status_code=HTTPStatus.NOT_FOUND.value,
21672226
content_type=content_types.APPLICATION_JSON,
2168-
headers=headers,
2227+
headers=_headers,
21692228
body={"statusCode": HTTPStatus.NOT_FOUND.value, "message": "Not found"},
2170-
),
2171-
serializer=self._serializer,
2229+
)
2230+
2231+
# We create a route to trigger entire request chain (middleware+exception handlers)
2232+
route = Route(
2233+
rule=self._compile_regex(r".*"),
2234+
method=method,
2235+
path=path,
2236+
func=not_found_handler,
2237+
cors=self._cors_enabled,
2238+
compress=False,
21722239
)
21732240

2241+
# Add matched Route reference into the Resolver context
2242+
self.append_context(_route=route, _path=path)
2243+
2244+
# Kick-off request chain:
2245+
# -> exception_handlers()
2246+
# --> middlewares()
2247+
# ---> not_found_route()
2248+
return self._call_route(route=route, route_arguments={})
2249+
21742250
def _call_route(self, route: Route, route_arguments: Dict[str, str]) -> ResponseBuilder:
21752251
"""Actually call the matching route with any provided keyword arguments."""
21762252
try:

tests/functional/event_handler/test_api_middlewares.py

+58
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
APIGatewayHttpResolver,
88
ApiGatewayResolver,
99
APIGatewayRestResolver,
10+
CORSConfig,
1011
ProxyEventType,
1112
Response,
1213
Router,
@@ -506,3 +507,60 @@ def post_lambda():
506507
result = resolver(event, {})
507508
assert result["statusCode"] == 200
508509
assert result["multiValueHeaders"]["X-Correlation-Id"][0] == resolver.current_event.request_context.request_id # type: ignore[attr-defined] # noqa: E501
510+
511+
512+
@pytest.mark.parametrize(
513+
"app, event",
514+
[
515+
(ApiGatewayResolver(proxy_type=ProxyEventType.APIGatewayProxyEvent), API_REST_EVENT),
516+
(APIGatewayRestResolver(), API_REST_EVENT),
517+
(APIGatewayHttpResolver(), API_RESTV2_EVENT),
518+
],
519+
)
520+
def test_global_middleware_not_found(app: ApiGatewayResolver, event):
521+
# GIVEN global middleware is registered
522+
523+
def middleware(app: ApiGatewayResolver, next_middleware: NextMiddleware):
524+
# add additional data to Router Context
525+
ret = next_middleware(app)
526+
ret.body = "middleware works"
527+
return ret
528+
529+
app.use(middlewares=[middleware])
530+
531+
@app.get("/this/path/does/not/exist")
532+
def nope() -> dict: ...
533+
534+
# WHEN calling the event handler for an unregistered route /my/path
535+
result = app(event, {})
536+
537+
# THEN process event correctly as HTTP 404
538+
# AND ensure middlewares are called
539+
assert result["statusCode"] == 404
540+
assert result["body"] == "middleware works"
541+
542+
543+
def test_global_middleware_not_found_preflight():
544+
# GIVEN global middleware is registered
545+
546+
app = ApiGatewayResolver(cors=CORSConfig(), proxy_type=ProxyEventType.APIGatewayProxyEvent)
547+
event = {**API_REST_EVENT, "httpMethod": "OPTIONS"}
548+
549+
def middleware(app: ApiGatewayResolver, next_middleware: NextMiddleware):
550+
# add additional data to Router Context
551+
ret = next_middleware(app)
552+
ret.body = "middleware works"
553+
return ret
554+
555+
app.use(middlewares=[middleware])
556+
557+
@app.get("/this/path/does/not/exist")
558+
def nope() -> dict: ...
559+
560+
# WHEN calling the event handler for an unregistered route /my/path OPTIONS
561+
result = app(event, {})
562+
563+
# THEN process event correctly as HTTP 204 (not 404)
564+
# AND ensure middlewares are called
565+
assert result["statusCode"] == 204
566+
assert result["body"] == "middleware works"

0 commit comments

Comments
 (0)