diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index f82532f1f71..7c4d676931e 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -211,7 +211,7 @@ def to_dict(self, origin: Optional[str]) -> Dict[str, str]: # The origin matched an allowed origin, so return the CORS headers headers = { "Access-Control-Allow-Origin": origin, - "Access-Control-Allow-Headers": ",".join(sorted(self.allow_headers)), + "Access-Control-Allow-Headers": CORSConfig.build_allow_methods(self.allow_headers), } if self.expose_headers: @@ -222,6 +222,23 @@ def to_dict(self, origin: Optional[str]) -> Dict[str, str]: headers["Access-Control-Allow-Credentials"] = "true" return headers + @staticmethod + def build_allow_methods(methods: Set[str]) -> str: + """Build sorted comma delimited methods for Access-Control-Allow-Methods header + + Parameters + ---------- + methods : set[str] + Set of HTTP Methods + + Returns + ------- + set[str] + Formatted string with all HTTP Methods allowed for CORS e.g., `GET, OPTIONS` + + """ + return ",".join(sorted(methods)) + class Response(Generic[ResponseT]): """Response data class that provides greater control over what is returned from the proxy event""" @@ -282,16 +299,16 @@ def __init__( func: Callable, cors: bool, compress: bool, - cache_control: Optional[str], - summary: Optional[str], - description: Optional[str], - responses: Optional[Dict[int, OpenAPIResponse]], - response_description: Optional[str], - tags: Optional[List[str]], - operation_id: Optional[str], - include_in_schema: bool, - security: Optional[List[Dict[str, List[str]]]], - middlewares: Optional[List[Callable[..., Response]]], + cache_control: Optional[str] = None, + summary: Optional[str] = None, + description: Optional[str] = None, + responses: Optional[Dict[int, OpenAPIResponse]] = None, + response_description: Optional[str] = None, + tags: Optional[List[str]] = None, + operation_id: Optional[str] = None, + include_in_schema: bool = True, + security: Optional[List[Dict[str, List[str]]]] = None, + middlewares: Optional[List[Callable[..., Response]]] = None, ): """ @@ -1406,7 +1423,6 @@ def _registered_api_adapter( """ route_args: Dict = app.context.get("_route_args", {}) logger.debug(f"Calling API Route Handler: {route_args}") - return app._to_response(next_middleware(**route_args)) @@ -1967,6 +1983,36 @@ def register_resolver(func: Callable): def resolve(self, event, context) -> Dict[str, Any]: """Resolves the response based on the provide event and decorator routes + ## Internals + + Request processing chain is triggered by a Route object being called _(`_call_route` -> `__call__`)_: + + 1. **When a route is matched** + 1.1. Exception handlers _(if any exception bubbled up and caught)_ + 1.2. Global middlewares _(before, and after on the way back)_ + 1.3. Path level middleware _(before, and after on the way back)_ + 1.4. Middleware adapter to ensure Response is homogenous (_registered_api_adapter) + 1.5. Run actual route + 2. **When a route is NOT matched** + 2.1. Exception handlers _(if any exception bubbled up and caught)_ + 2.2. Global middlewares _(before, and after on the way back)_ + 2.3. Path level middleware _(before, and after on the way back)_ + 2.4. Middleware adapter to ensure Response is homogenous (_registered_api_adapter) + 2.5. Run 404 route handler + 3. **When a route is a pre-flight CORS (often not matched)** + 3.1. Exception handlers _(if any exception bubbled up and caught)_ + 3.2. Global middlewares _(before, and after on the way back)_ + 3.3. Path level middleware _(before, and after on the way back)_ + 3.4. Middleware adapter to ensure Response is homogenous (_registered_api_adapter) + 3.5. Return 204 with appropriate CORS headers + 4. **When a route is matched with Data Validation enabled** + 4.1. Exception handlers _(if any exception bubbled up and caught)_ + 4.2. Data Validation middleware _(before, and after on the way back)_ + 4.3. Global middlewares _(before, and after on the way back)_ + 4.4. Path level middleware _(before, and after on the way back)_ + 4.5. Middleware adapter to ensure Response is homogenous (_registered_api_adapter) + 4.6. Run actual route + Parameters ---------- event: Dict[str, Any] @@ -2090,7 +2136,9 @@ def _resolve(self) -> ResponseBuilder: method = self.current_event.http_method.upper() path = self._remove_prefix(self.current_event.path) - for route in self._static_routes + self._dynamic_routes: + registered_routes = self._static_routes + self._dynamic_routes + + for route in registered_routes: if method != route.method: continue match_results: Optional[Match] = route.rule.match(path) @@ -2102,8 +2150,7 @@ def _resolve(self) -> ResponseBuilder: route_keys = self._convert_matches_into_route_keys(match_results) return self._call_route(route, route_keys) # pass fn args - logger.debug(f"No match found for path {path} and method {method}") - return self._not_found(method) + return self._handle_not_found(method=method, path=path) def _remove_prefix(self, path: str) -> str: """Remove the configured prefix from the path""" @@ -2141,36 +2188,65 @@ def _path_starts_with(path: str, prefix: str): return path.startswith(prefix + "/") - def _not_found(self, method: str) -> ResponseBuilder: + def _handle_not_found(self, method: str, path: str) -> ResponseBuilder: """Called when no matching route was found and includes support for the cors preflight response""" - headers = {} - if self._cors: - logger.debug("CORS is enabled, updating headers.") - extracted_origin_header = extract_origin_header(self.current_event.resolved_headers_field) - headers.update(self._cors.to_dict(extracted_origin_header)) - - if method == "OPTIONS": - logger.debug("Pre-flight request detected. Returning CORS with null response") - headers["Access-Control-Allow-Methods"] = ",".join(sorted(self._cors_methods)) - return ResponseBuilder( - response=Response(status_code=204, content_type=None, headers=headers, body=""), - serializer=self._serializer, - ) + logger.debug(f"No match found for path {path} and method {method}") - handler = self._lookup_exception_handler(NotFoundError) - if handler: - return self._response_builder_class(response=handler(NotFoundError()), serializer=self._serializer) + def not_found_handler(): + """Route handler for 404s + + It handles in the following order: + + 1. Pre-flight CORS requests (OPTIONS) + 2. Detects and calls custom HTTP 404 handler + 3. Returns standard 404 along with CORS headers - return self._response_builder_class( - response=Response( + Returns + ------- + Response + HTTP 404 response + """ + _headers: Dict[str, Any] = {} + + # Pre-flight request? Return immediately to avoid browser error + if self._cors and method == "OPTIONS": + logger.debug("Pre-flight request detected. Returning CORS with empty response") + _headers["Access-Control-Allow-Methods"] = CORSConfig.build_allow_methods(self._cors_methods) + + return Response(status_code=204, content_type=None, headers=_headers, body="") + + # Customer registered 404 route? Call it. + custom_not_found_handler = self._lookup_exception_handler(NotFoundError) + if custom_not_found_handler: + return custom_not_found_handler(NotFoundError()) + + # No CORS and no custom 404 fn? Default response + return Response( status_code=HTTPStatus.NOT_FOUND.value, content_type=content_types.APPLICATION_JSON, - headers=headers, + headers=_headers, body={"statusCode": HTTPStatus.NOT_FOUND.value, "message": "Not found"}, - ), - serializer=self._serializer, + ) + + # We create a route to trigger entire request chain (middleware+exception handlers) + route = Route( + rule=self._compile_regex(r".*"), + method=method, + path=path, + func=not_found_handler, + cors=self._cors_enabled, + compress=False, ) + # Add matched Route reference into the Resolver context + self.append_context(_route=route, _path=path) + + # Kick-off request chain: + # -> exception_handlers() + # --> middlewares() + # ---> not_found_route() + return self._call_route(route=route, route_arguments={}) + def _call_route(self, route: Route, route_arguments: Dict[str, str]) -> ResponseBuilder: """Actually call the matching route with any provided keyword arguments.""" try: diff --git a/tests/functional/event_handler/test_api_middlewares.py b/tests/functional/event_handler/test_api_middlewares.py index 58bec259072..8d2cf5972bc 100644 --- a/tests/functional/event_handler/test_api_middlewares.py +++ b/tests/functional/event_handler/test_api_middlewares.py @@ -7,6 +7,7 @@ APIGatewayHttpResolver, ApiGatewayResolver, APIGatewayRestResolver, + CORSConfig, ProxyEventType, Response, Router, @@ -506,3 +507,60 @@ def post_lambda(): result = resolver(event, {}) assert result["statusCode"] == 200 assert result["multiValueHeaders"]["X-Correlation-Id"][0] == resolver.current_event.request_context.request_id # type: ignore[attr-defined] # noqa: E501 + + +@pytest.mark.parametrize( + "app, event", + [ + (ApiGatewayResolver(proxy_type=ProxyEventType.APIGatewayProxyEvent), API_REST_EVENT), + (APIGatewayRestResolver(), API_REST_EVENT), + (APIGatewayHttpResolver(), API_RESTV2_EVENT), + ], +) +def test_global_middleware_not_found(app: ApiGatewayResolver, event): + # GIVEN global middleware is registered + + def middleware(app: ApiGatewayResolver, next_middleware: NextMiddleware): + # add additional data to Router Context + ret = next_middleware(app) + ret.body = "middleware works" + return ret + + app.use(middlewares=[middleware]) + + @app.get("/this/path/does/not/exist") + def nope() -> dict: ... + + # WHEN calling the event handler for an unregistered route /my/path + result = app(event, {}) + + # THEN process event correctly as HTTP 404 + # AND ensure middlewares are called + assert result["statusCode"] == 404 + assert result["body"] == "middleware works" + + +def test_global_middleware_not_found_preflight(): + # GIVEN global middleware is registered + + app = ApiGatewayResolver(cors=CORSConfig(), proxy_type=ProxyEventType.APIGatewayProxyEvent) + event = {**API_REST_EVENT, "httpMethod": "OPTIONS"} + + def middleware(app: ApiGatewayResolver, next_middleware: NextMiddleware): + # add additional data to Router Context + ret = next_middleware(app) + ret.body = "middleware works" + return ret + + app.use(middlewares=[middleware]) + + @app.get("/this/path/does/not/exist") + def nope() -> dict: ... + + # WHEN calling the event handler for an unregistered route /my/path OPTIONS + result = app(event, {}) + + # THEN process event correctly as HTTP 204 (not 404) + # AND ensure middlewares are called + assert result["statusCode"] == 204 + assert result["body"] == "middleware works"