Skip to content

Commit 80c1afa

Browse files
committed
fix(apigateway): do not skip middleware and exception handlers on 404
1 parent e9d9e79 commit 80c1afa

File tree

2 files changed

+114
-37
lines changed

2 files changed

+114
-37
lines changed

aws_lambda_powertools/event_handler/api_gateway.py

+83-37
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def to_dict(self, origin: Optional[str]) -> Dict[str, str]:
211211
# The origin matched an allowed origin, so return the CORS headers
212212
headers = {
213213
"Access-Control-Allow-Origin": origin,
214-
"Access-Control-Allow-Headers": ",".join(sorted(self.allow_headers)),
214+
"Access-Control-Allow-Headers": CORSConfig.build_allow_methods(self.allow_headers),
215215
}
216216

217217
if self.expose_headers:
@@ -222,6 +222,23 @@ def to_dict(self, origin: Optional[str]) -> Dict[str, str]:
222222
headers["Access-Control-Allow-Credentials"] = "true"
223223
return headers
224224

225+
@staticmethod
226+
def build_allow_methods(methods: Set[str]) -> str:
227+
"""Build sorted comma delimited methods for Access-Control-Allow-Methods header
228+
229+
Parameters
230+
----------
231+
methods : set[str]
232+
Set of HTTP Methods
233+
234+
Returns
235+
-------
236+
set[str]
237+
Formatted string with all HTTP Methods allowed for CORS e.g., `GET, OPTIONS`
238+
239+
"""
240+
return ",".join(sorted(methods))
241+
225242

226243
class Response(Generic[ResponseT]):
227244
"""Response data class that provides greater control over what is returned from the proxy event"""
@@ -282,16 +299,16 @@ def __init__(
282299
func: Callable,
283300
cors: bool,
284301
compress: bool,
285-
cache_control: Optional[str],
286-
summary: Optional[str],
287-
description: Optional[str],
288-
responses: Optional[Dict[int, OpenAPIResponse]],
289-
response_description: Optional[str],
290-
tags: Optional[List[str]],
291-
operation_id: Optional[str],
292-
include_in_schema: bool,
293-
security: Optional[List[Dict[str, List[str]]]],
294-
middlewares: Optional[List[Callable[..., Response]]],
302+
cache_control: Optional[str] = None,
303+
summary: Optional[str] = None,
304+
description: Optional[str] = None,
305+
responses: Optional[Dict[int, OpenAPIResponse]] = None,
306+
response_description: Optional[str] = None,
307+
tags: Optional[List[str]] = None,
308+
operation_id: Optional[str] = None,
309+
include_in_schema: bool = None,
310+
security: Optional[List[Dict[str, List[str]]]] = None,
311+
middlewares: Optional[List[Callable[..., Response]]] = None,
295312
):
296313
"""
297314
@@ -1406,7 +1423,6 @@ def _registered_api_adapter(
14061423
"""
14071424
route_args: Dict = app.context.get("_route_args", {})
14081425
logger.debug(f"Calling API Route Handler: {route_args}")
1409-
14101426
return app._to_response(next_middleware(**route_args))
14111427

14121428

@@ -2090,7 +2106,9 @@ def _resolve(self) -> ResponseBuilder:
20902106
method = self.current_event.http_method.upper()
20912107
path = self._remove_prefix(self.current_event.path)
20922108

2093-
for route in self._static_routes + self._dynamic_routes:
2109+
registered_routes = self._static_routes + self._dynamic_routes
2110+
2111+
for route in registered_routes:
20942112
if method != route.method:
20952113
continue
20962114
match_results: Optional[Match] = route.rule.match(path)
@@ -2102,8 +2120,7 @@ def _resolve(self) -> ResponseBuilder:
21022120
route_keys = self._convert_matches_into_route_keys(match_results)
21032121
return self._call_route(route, route_keys) # pass fn args
21042122

2105-
logger.debug(f"No match found for path {path} and method {method}")
2106-
return self._not_found(method)
2123+
return self._handle_not_found(method=method, path=path)
21072124

21082125
def _remove_prefix(self, path: str) -> str:
21092126
"""Remove the configured prefix from the path"""
@@ -2141,36 +2158,65 @@ def _path_starts_with(path: str, prefix: str):
21412158

21422159
return path.startswith(prefix + "/")
21432160

2144-
def _not_found(self, method: str) -> ResponseBuilder:
2161+
def _handle_not_found(self, method: str, path: str) -> ResponseBuilder:
21452162
"""Called when no matching route was found and includes support for the cors preflight response"""
2146-
headers = {}
2147-
if self._cors:
2148-
logger.debug("CORS is enabled, updating headers.")
2149-
extracted_origin_header = extract_origin_header(self.current_event.resolved_headers_field)
2150-
headers.update(self._cors.to_dict(extracted_origin_header))
2151-
2152-
if method == "OPTIONS":
2153-
logger.debug("Pre-flight request detected. Returning CORS with null response")
2154-
headers["Access-Control-Allow-Methods"] = ",".join(sorted(self._cors_methods))
2155-
return ResponseBuilder(
2156-
response=Response(status_code=204, content_type=None, headers=headers, body=""),
2157-
serializer=self._serializer,
2158-
)
2163+
logger.debug(f"No match found for path {path} and method {method}")
21592164

2160-
handler = self._lookup_exception_handler(NotFoundError)
2161-
if handler:
2162-
return self._response_builder_class(response=handler(NotFoundError()), serializer=self._serializer)
2165+
def not_found_handler():
2166+
"""Route handler for 404s
2167+
2168+
It handles in the following order:
2169+
2170+
1. Pre-flight CORS requests (OPTIONS)
2171+
2. Detects and calls custom HTTP 404 handler
2172+
3. Returns standard 404 along with CORS headers
2173+
2174+
Returns
2175+
-------
2176+
Response
2177+
HTTP 404 response
2178+
"""
2179+
_headers: Dict[str, Any] = {}
21632180

2164-
return self._response_builder_class(
2165-
response=Response(
2181+
# Pre-flight request? Return immediately to avoid browser error
2182+
if self._cors and method == "OPTIONS":
2183+
logger.debug("Pre-flight request detected. Returning CORS with empty response")
2184+
_headers["Access-Control-Allow-Methods"] = CORSConfig.build_allow_methods(self._cors_methods)
2185+
2186+
return Response(status_code=204, content_type=None, headers=_headers, body="")
2187+
2188+
# Customer registered 404 route? Call it.
2189+
custom_not_found_handler = self._lookup_exception_handler(NotFoundError)
2190+
if custom_not_found_handler:
2191+
return custom_not_found_handler(NotFoundError())
2192+
2193+
# No CORS and no custom 404 fn? Default response
2194+
return Response(
21662195
status_code=HTTPStatus.NOT_FOUND.value,
21672196
content_type=content_types.APPLICATION_JSON,
2168-
headers=headers,
2197+
headers=_headers,
21692198
body={"statusCode": HTTPStatus.NOT_FOUND.value, "message": "Not found"},
2170-
),
2171-
serializer=self._serializer,
2199+
)
2200+
2201+
# We create a route to trigger entire request chain (middleware+exception handlers)
2202+
route = Route(
2203+
rule=self._compile_regex(r".*"),
2204+
method=method,
2205+
path=path,
2206+
func=not_found_handler,
2207+
cors=self._cors_enabled,
2208+
compress=False,
21722209
)
21732210

2211+
# Add matched Route reference into the Resolver context
2212+
self.append_context(_route=route, _path=path)
2213+
2214+
# Kick-off request chain:
2215+
# -> exception_handlers()
2216+
# --> middlewares()
2217+
# ---> not_found_route()
2218+
return self._call_route(route=route, route_arguments={})
2219+
21742220
def _call_route(self, route: Route, route_arguments: Dict[str, str]) -> ResponseBuilder:
21752221
"""Actually call the matching route with any provided keyword arguments."""
21762222
try:

tests/functional/event_handler/test_api_middlewares.py

+31
Original file line numberDiff line numberDiff line change
@@ -506,3 +506,34 @@ def post_lambda():
506506
result = resolver(event, {})
507507
assert result["statusCode"] == 200
508508
assert result["multiValueHeaders"]["X-Correlation-Id"][0] == resolver.current_event.request_context.request_id # type: ignore[attr-defined] # noqa: E501
509+
510+
511+
@pytest.mark.parametrize(
512+
"app, event",
513+
[
514+
(ApiGatewayResolver(proxy_type=ProxyEventType.APIGatewayProxyEvent), API_REST_EVENT),
515+
(APIGatewayRestResolver(), API_REST_EVENT),
516+
(APIGatewayHttpResolver(), API_RESTV2_EVENT),
517+
],
518+
)
519+
def test_global_middleware_not_found(app: ApiGatewayResolver, event):
520+
# GIVEN global middleware is registered
521+
522+
def middleware(app: ApiGatewayResolver, next_middleware: NextMiddleware):
523+
# add additional data to Router Context
524+
ret = next_middleware(app)
525+
ret.body = "middleware works"
526+
return ret
527+
528+
app.use(middlewares=[middleware])
529+
530+
@app.get("/this/path/does/not/exist")
531+
def nope() -> dict: ...
532+
533+
# WHEN calling the event handler for an unregistered route /my/path
534+
result = app(event, {})
535+
536+
# THEN process event correctly as HTTP 404
537+
# AND ensure middlewares are called
538+
assert result["statusCode"] == 404
539+
assert result["body"] == "middleware works"

0 commit comments

Comments
 (0)