Skip to content

fix(event_handler): do not skip middleware and exception handlers on 404 error #4492

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
150 changes: 113 additions & 37 deletions aws_lambda_powertools/event_handler/api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def to_dict(self, origin: Optional[str]) -> Dict[str, str]:
# The origin matched an allowed origin, so return the CORS headers
headers = {
"Access-Control-Allow-Origin": origin,
"Access-Control-Allow-Headers": ",".join(sorted(self.allow_headers)),
"Access-Control-Allow-Headers": CORSConfig.build_allow_methods(self.allow_headers),
}

if self.expose_headers:
Expand All @@ -222,6 +222,23 @@ def to_dict(self, origin: Optional[str]) -> Dict[str, str]:
headers["Access-Control-Allow-Credentials"] = "true"
return headers

@staticmethod
def build_allow_methods(methods: Set[str]) -> str:
"""Build sorted comma delimited methods for Access-Control-Allow-Methods header

Parameters
----------
methods : set[str]
Set of HTTP Methods

Returns
-------
set[str]
Formatted string with all HTTP Methods allowed for CORS e.g., `GET, OPTIONS`

"""
return ",".join(sorted(methods))


class Response(Generic[ResponseT]):
"""Response data class that provides greater control over what is returned from the proxy event"""
Expand Down Expand Up @@ -282,16 +299,16 @@ def __init__(
func: Callable,
cors: bool,
compress: bool,
cache_control: Optional[str],
summary: Optional[str],
description: Optional[str],
responses: Optional[Dict[int, OpenAPIResponse]],
response_description: Optional[str],
tags: Optional[List[str]],
operation_id: Optional[str],
include_in_schema: bool,
security: Optional[List[Dict[str, List[str]]]],
middlewares: Optional[List[Callable[..., Response]]],
cache_control: Optional[str] = None,
summary: Optional[str] = None,
description: Optional[str] = None,
responses: Optional[Dict[int, OpenAPIResponse]] = None,
response_description: Optional[str] = None,
tags: Optional[List[str]] = None,
operation_id: Optional[str] = None,
include_in_schema: bool = True,
security: Optional[List[Dict[str, List[str]]]] = None,
middlewares: Optional[List[Callable[..., Response]]] = None,
):
"""

Expand Down Expand Up @@ -1406,7 +1423,6 @@ def _registered_api_adapter(
"""
route_args: Dict = app.context.get("_route_args", {})
logger.debug(f"Calling API Route Handler: {route_args}")

return app._to_response(next_middleware(**route_args))


Expand Down Expand Up @@ -1967,6 +1983,36 @@ def register_resolver(func: Callable):
def resolve(self, event, context) -> Dict[str, Any]:
"""Resolves the response based on the provide event and decorator routes

## Internals

Request processing chain is triggered by a Route object being called _(`_call_route` -> `__call__`)_:

1. **When a route is matched**
1.1. Exception handlers _(if any exception bubbled up and caught)_
1.2. Global middlewares _(before, and after on the way back)_
1.3. Path level middleware _(before, and after on the way back)_
1.4. Middleware adapter to ensure Response is homogenous (_registered_api_adapter)
1.5. Run actual route
2. **When a route is NOT matched**
2.1. Exception handlers _(if any exception bubbled up and caught)_
2.2. Global middlewares _(before, and after on the way back)_
2.3. Path level middleware _(before, and after on the way back)_
2.4. Middleware adapter to ensure Response is homogenous (_registered_api_adapter)
2.5. Run 404 route handler
3. **When a route is a pre-flight CORS (often not matched)**
3.1. Exception handlers _(if any exception bubbled up and caught)_
3.2. Global middlewares _(before, and after on the way back)_
3.3. Path level middleware _(before, and after on the way back)_
3.4. Middleware adapter to ensure Response is homogenous (_registered_api_adapter)
3.5. Return 204 with appropriate CORS headers
4. **When a route is matched with Data Validation enabled**
4.1. Exception handlers _(if any exception bubbled up and caught)_
4.2. Data Validation middleware _(before, and after on the way back)_
4.3. Global middlewares _(before, and after on the way back)_
4.4. Path level middleware _(before, and after on the way back)_
4.5. Middleware adapter to ensure Response is homogenous (_registered_api_adapter)
4.6. Run actual route

Parameters
----------
event: Dict[str, Any]
Expand Down Expand Up @@ -2090,7 +2136,9 @@ def _resolve(self) -> ResponseBuilder:
method = self.current_event.http_method.upper()
path = self._remove_prefix(self.current_event.path)

for route in self._static_routes + self._dynamic_routes:
registered_routes = self._static_routes + self._dynamic_routes

for route in registered_routes:
if method != route.method:
continue
match_results: Optional[Match] = route.rule.match(path)
Expand All @@ -2102,8 +2150,7 @@ def _resolve(self) -> ResponseBuilder:
route_keys = self._convert_matches_into_route_keys(match_results)
return self._call_route(route, route_keys) # pass fn args

logger.debug(f"No match found for path {path} and method {method}")
return self._not_found(method)
return self._handle_not_found(method=method, path=path)

def _remove_prefix(self, path: str) -> str:
"""Remove the configured prefix from the path"""
Expand Down Expand Up @@ -2141,36 +2188,65 @@ def _path_starts_with(path: str, prefix: str):

return path.startswith(prefix + "/")

def _not_found(self, method: str) -> ResponseBuilder:
def _handle_not_found(self, method: str, path: str) -> ResponseBuilder:
"""Called when no matching route was found and includes support for the cors preflight response"""
headers = {}
if self._cors:
logger.debug("CORS is enabled, updating headers.")
extracted_origin_header = extract_origin_header(self.current_event.resolved_headers_field)
headers.update(self._cors.to_dict(extracted_origin_header))

if method == "OPTIONS":
logger.debug("Pre-flight request detected. Returning CORS with null response")
headers["Access-Control-Allow-Methods"] = ",".join(sorted(self._cors_methods))
return ResponseBuilder(
response=Response(status_code=204, content_type=None, headers=headers, body=""),
serializer=self._serializer,
)
logger.debug(f"No match found for path {path} and method {method}")

handler = self._lookup_exception_handler(NotFoundError)
if handler:
return self._response_builder_class(response=handler(NotFoundError()), serializer=self._serializer)
def not_found_handler():
"""Route handler for 404s

It handles in the following order:

1. Pre-flight CORS requests (OPTIONS)
2. Detects and calls custom HTTP 404 handler
3. Returns standard 404 along with CORS headers

return self._response_builder_class(
response=Response(
Returns
-------
Response
HTTP 404 response
"""
_headers: Dict[str, Any] = {}

# Pre-flight request? Return immediately to avoid browser error
if self._cors and method == "OPTIONS":
logger.debug("Pre-flight request detected. Returning CORS with empty response")
_headers["Access-Control-Allow-Methods"] = CORSConfig.build_allow_methods(self._cors_methods)

return Response(status_code=204, content_type=None, headers=_headers, body="")

# Customer registered 404 route? Call it.
custom_not_found_handler = self._lookup_exception_handler(NotFoundError)
if custom_not_found_handler:
return custom_not_found_handler(NotFoundError())

# No CORS and no custom 404 fn? Default response
return Response(
status_code=HTTPStatus.NOT_FOUND.value,
content_type=content_types.APPLICATION_JSON,
headers=headers,
headers=_headers,
body={"statusCode": HTTPStatus.NOT_FOUND.value, "message": "Not found"},
),
serializer=self._serializer,
)

# We create a route to trigger entire request chain (middleware+exception handlers)
route = Route(
rule=self._compile_regex(r".*"),
method=method,
path=path,
func=not_found_handler,
cors=self._cors_enabled,
compress=False,
)

# Add matched Route reference into the Resolver context
self.append_context(_route=route, _path=path)

# Kick-off request chain:
# -> exception_handlers()
# --> middlewares()
# ---> not_found_route()
return self._call_route(route=route, route_arguments={})

def _call_route(self, route: Route, route_arguments: Dict[str, str]) -> ResponseBuilder:
"""Actually call the matching route with any provided keyword arguments."""
try:
Expand Down
58 changes: 58 additions & 0 deletions tests/functional/event_handler/test_api_middlewares.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
APIGatewayHttpResolver,
ApiGatewayResolver,
APIGatewayRestResolver,
CORSConfig,
ProxyEventType,
Response,
Router,
Expand Down Expand Up @@ -506,3 +507,60 @@ def post_lambda():
result = resolver(event, {})
assert result["statusCode"] == 200
assert result["multiValueHeaders"]["X-Correlation-Id"][0] == resolver.current_event.request_context.request_id # type: ignore[attr-defined] # noqa: E501


@pytest.mark.parametrize(
"app, event",
[
(ApiGatewayResolver(proxy_type=ProxyEventType.APIGatewayProxyEvent), API_REST_EVENT),
(APIGatewayRestResolver(), API_REST_EVENT),
(APIGatewayHttpResolver(), API_RESTV2_EVENT),
],
)
def test_global_middleware_not_found(app: ApiGatewayResolver, event):
# GIVEN global middleware is registered

def middleware(app: ApiGatewayResolver, next_middleware: NextMiddleware):
# add additional data to Router Context
ret = next_middleware(app)
ret.body = "middleware works"
return ret

app.use(middlewares=[middleware])

@app.get("/this/path/does/not/exist")
def nope() -> dict: ...

# WHEN calling the event handler for an unregistered route /my/path
result = app(event, {})

# THEN process event correctly as HTTP 404
# AND ensure middlewares are called
assert result["statusCode"] == 404
assert result["body"] == "middleware works"


def test_global_middleware_not_found_preflight():
# GIVEN global middleware is registered

app = ApiGatewayResolver(cors=CORSConfig(), proxy_type=ProxyEventType.APIGatewayProxyEvent)
event = {**API_REST_EVENT, "httpMethod": "OPTIONS"}

def middleware(app: ApiGatewayResolver, next_middleware: NextMiddleware):
# add additional data to Router Context
ret = next_middleware(app)
ret.body = "middleware works"
return ret

app.use(middlewares=[middleware])

@app.get("/this/path/does/not/exist")
def nope() -> dict: ...

# WHEN calling the event handler for an unregistered route /my/path OPTIONS
result = app(event, {})

# THEN process event correctly as HTTP 204 (not 404)
# AND ensure middlewares are called
assert result["statusCode"] == 204
assert result["body"] == "middleware works"