From f4bd418605fe6db7ddd8d1aeb3d2e8ab056be015 Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Wed, 5 May 2021 15:47:38 -0700 Subject: [PATCH] feat(event-handle): allow for cors=None setting At the route level we should allow for cors=None which means that cors are enabled for this route when configured globally --- .../event_handler/api_gateway.py | 29 +++++++++++-------- .../event_handler/test_api_gateway.py | 18 +++++++++--- 2 files changed, 31 insertions(+), 16 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index a99394b10f7..2b1e1fc0900 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -55,11 +55,11 @@ def with_cors(): ) app = ApiGatewayResolver(cors=cors_config) - @app.get("/my/path", cors=True) + @app.get("/my/path") def with_cors(): return {"message": "Foo"} - @app.get("/another-one") + @app.get("/another-one", cors=False) def without_cors(): return {"message": "Foo"} """ @@ -249,9 +249,10 @@ def __init__(self, proxy_type: Enum = ProxyEventType.APIGatewayProxyEvent, cors: self._proxy_type = proxy_type self._routes: List[Route] = [] self._cors = cors + self._cors_enabled: bool = cors is not None self._cors_methods: Set[str] = {"OPTIONS"} - def get(self, rule: str, cors: bool = True, compress: bool = False, cache_control: str = None): + def get(self, rule: str, cors: bool = None, compress: bool = False, cache_control: str = None): """Get route decorator with GET `method` Examples @@ -276,7 +277,7 @@ def lambda_handler(event, context): """ return self.route(rule, "GET", cors, compress, cache_control) - def post(self, rule: str, cors: bool = True, compress: bool = False, cache_control: str = None): + def post(self, rule: str, cors: bool = None, compress: bool = False, cache_control: str = None): """Post route decorator with POST `method` Examples @@ -302,7 +303,7 @@ def lambda_handler(event, context): """ return self.route(rule, "POST", cors, compress, cache_control) - def put(self, rule: str, cors: bool = True, compress: bool = False, cache_control: str = None): + def put(self, rule: str, cors: bool = None, compress: bool = False, cache_control: str = None): """Put route decorator with PUT `method` Examples @@ -317,7 +318,7 @@ def put(self, rule: str, cors: bool = True, compress: bool = False, cache_contro app = ApiGatewayResolver() @app.put("/put-call") - def simple_post(): + def simple_put(): put_data: dict = app.current_event.json_body return {"message": put_data["value"]} @@ -328,7 +329,7 @@ def lambda_handler(event, context): """ return self.route(rule, "PUT", cors, compress, cache_control) - def delete(self, rule: str, cors: bool = True, compress: bool = False, cache_control: str = None): + def delete(self, rule: str, cors: bool = None, compress: bool = False, cache_control: str = None): """Delete route decorator with DELETE `method` Examples @@ -353,7 +354,7 @@ def lambda_handler(event, context): """ return self.route(rule, "DELETE", cors, compress, cache_control) - def patch(self, rule: str, cors: bool = True, compress: bool = False, cache_control: str = None): + def patch(self, rule: str, cors: bool = None, compress: bool = False, cache_control: str = None): """Patch route decorator with PATCH `method` Examples @@ -381,13 +382,17 @@ def lambda_handler(event, context): """ return self.route(rule, "PATCH", cors, compress, cache_control) - def route(self, rule: str, method: str, cors: bool = True, compress: bool = False, cache_control: str = None): + def route(self, rule: str, method: str, cors: bool = None, compress: bool = False, cache_control: str = None): """Route decorator includes parameter `method`""" def register_resolver(func: Callable): logger.debug(f"Adding route using rule {rule} and method {method.upper()}") - self._routes.append(Route(method, self._compile_regex(rule), func, cors, compress, cache_control)) - if cors: + if cors is None: + cors_enabled = self._cors_enabled + else: + cors_enabled = cors + self._routes.append(Route(method, self._compile_regex(rule), func, cors_enabled, compress, cache_control)) + if cors_enabled: logger.debug(f"Registering method {method.upper()} to Allow Methods in CORS") self._cors_methods.add(method.upper()) return func @@ -454,7 +459,7 @@ def _not_found(self, method: str) -> ResponseBuilder: logger.debug("CORS is enabled, updating headers.") headers.update(self._cors.to_dict()) - if method == "OPTIONS": # Pre-flight + if method == "OPTIONS": logger.debug("Pre-flight request detected. Returning CORS with null response") headers["Access-Control-Allow-Methods"] = ",".join(sorted(self._cors_methods)) return ResponseBuilder(Response(status_code=204, content_type=None, headers=headers, body=None)) diff --git a/tests/functional/event_handler/test_api_gateway.py b/tests/functional/event_handler/test_api_gateway.py index 354a89305e1..caaaeb1b97b 100644 --- a/tests/functional/event_handler/test_api_gateway.py +++ b/tests/functional/event_handler/test_api_gateway.py @@ -182,6 +182,10 @@ def test_cors(): def with_cors() -> Response: return Response(200, TEXT_HTML, "test") + @app.get("/without-cors") + def without_cors() -> Response: + return Response(200, TEXT_HTML, "test") + def handler(event, context): return app.resolve(event, context) @@ -196,6 +200,11 @@ def handler(event, context): assert "Access-Control-Allow-Credentials" not in headers assert headers["Access-Control-Allow-Headers"] == ",".join(sorted(CORSConfig._REQUIRED_HEADERS)) + # THEN for routes without cors flag return no cors headers + mock_event = {"path": "/my/request", "httpMethod": "GET"} + result = handler(mock_event, None) + assert "Access-Control-Allow-Origin" not in result["headers"] + def test_compress(): # GIVEN a function that has compress=True @@ -359,7 +368,7 @@ def test_custom_cors_config(): app = ApiGatewayResolver(cors=cors_config) event = {"path": "/cors", "httpMethod": "GET"} - @app.get("/cors", cors=True) + @app.get("/cors") def get_with_cors(): return {} @@ -370,7 +379,7 @@ def another_one(): # WHEN calling the event handler result = app(event, None) - # THEN return the custom cors headers + # THEN routes by default return the custom cors headers assert "headers" in result headers = result["headers"] assert headers["Content-Type"] == APPLICATION_JSON @@ -385,6 +394,7 @@ def another_one(): # AND custom cors was set on the app assert isinstance(app._cors, CORSConfig) assert app._cors is cors_config + # AND routes without cors don't include "Access-Control" headers event = {"path": "/another-one", "httpMethod": "GET"} result = app(event, None) @@ -426,11 +436,11 @@ def test_cors_preflight(): # AND cors is enabled app = ApiGatewayResolver(cors=CORSConfig()) - @app.get("/foo", cors=True) + @app.get("/foo") def foo_cors(): ... - @app.route(method="delete", rule="/foo", cors=True) + @app.route(method="delete", rule="/foo") def foo_delete_cors(): ...