From d74068f4fc107ce7c62752ca0cdd7d37e1a07da3 Mon Sep 17 00:00:00 2001 From: Simon Thulbourn Date: Thu, 27 Jun 2024 10:16:15 +0000 Subject: [PATCH 1/6] bug(event_handler): fix cors no origin bug --- aws_lambda_powertools/event_handler/api_gateway.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 7c4d676931e..3b9e36179bb 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -806,7 +806,10 @@ def __init__( def _add_cors(self, event: ResponseEventT, cors: CORSConfig): """Update headers to include the configured Access-Control headers""" extracted_origin_header = extract_origin_header(event.resolved_headers_field) - self.response.headers.update(cors.to_dict(extracted_origin_header)) + if extracted_origin_header is None: + self.response.headers.update(cors.to_dict("*")) + else: + self.response.headers.update(cors.to_dict(extracted_origin_header)) def _add_cache_control(self, cache_control: str): """Set the specified cache control headers for 200 http responses. For non-200 `no-cache` is used.""" From ac9b20e92bf1236072c27c15f49e9eef1c1c7523 Mon Sep 17 00:00:00 2001 From: Simon Thulbourn Date: Thu, 27 Jun 2024 10:24:58 +0000 Subject: [PATCH 2/6] create functional test --- .../required_dependencies/test_api_gateway.py | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/tests/functional/event_handler/required_dependencies/test_api_gateway.py b/tests/functional/event_handler/required_dependencies/test_api_gateway.py index efd7edf0e5e..a1dff4bd218 100644 --- a/tests/functional/event_handler/required_dependencies/test_api_gateway.py +++ b/tests/functional/event_handler/required_dependencies/test_api_gateway.py @@ -354,6 +354,33 @@ def handler(event, context): assert "Access-Control-Allow-Origin" not in result["multiValueHeaders"] +def test_cors_no_origin(): + # GIVEN a function with cors=True + # AND http method set to GET + app = ApiGatewayResolver() + + @app.get("/my/path", cors=True) + def with_cors() -> Response: + return Response(200, content_types.TEXT_HTML, "test") + + def handler(event, context): + return app.resolve(event, context) + + # remove origin header from request + del LOAD_GW_EVENT["multiValueHeaders"]["Origin"] + + # WHEN calling the event handler + result = handler(LOAD_GW_EVENT, None) + + # THEN the headers should include cors headers + assert "multiValueHeaders" in result + headers = result["multiValueHeaders"] + assert headers["Content-Type"] == [content_types.TEXT_HTML] + assert headers["Access-Control-Allow-Origin"] == ["*"] + assert "Access-Control-Allow-Credentials" not in headers + assert headers["Access-Control-Allow-Headers"] == [",".join(sorted(CORSConfig._REQUIRED_HEADERS))] + + def test_cors_preflight_body_is_empty_not_null(): # GIVEN CORS is configured app = ALBResolver(cors=CORSConfig()) From 6e875e9a1ce4fd4b15f0fcd213a0d79225b837b8 Mon Sep 17 00:00:00 2001 From: Simon Thulbourn Date: Fri, 28 Jun 2024 14:59:50 +0000 Subject: [PATCH 3/6] fix cors --- .../event_handler/api_gateway.py | 18 ++++--- .../required_dependencies/test_api_gateway.py | 48 +++++++++++++++++-- 2 files changed, 54 insertions(+), 12 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 3b9e36179bb..d5d7b71fd95 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -188,9 +188,12 @@ def __init__( allow_credentials: bool A boolean value that sets the value of `Access-Control-Allow-Credentials` """ - self._allowed_origins = [allow_origin] + + self.allowed_origins = [allow_origin] + if extra_origins: - self._allowed_origins.extend(extra_origins) + self.allowed_origins.extend(extra_origins) + self.allow_headers = set(self._REQUIRED_HEADERS + (allow_headers or [])) self.expose_headers = expose_headers or [] self.max_age = max_age @@ -205,7 +208,7 @@ def to_dict(self, origin: Optional[str]) -> Dict[str, str]: # If the origin doesn't match any of the allowed origins, and we don't allow all origins ("*"), # don't add any CORS headers - if origin not in self._allowed_origins and "*" not in self._allowed_origins: + if origin not in self.allowed_origins and "*" not in self.allowed_origins: return {} # The origin matched an allowed origin, so return the CORS headers @@ -218,7 +221,7 @@ def to_dict(self, origin: Optional[str]) -> Dict[str, str]: headers["Access-Control-Expose-Headers"] = ",".join(self.expose_headers) if self.max_age is not None: headers["Access-Control-Max-Age"] = str(self.max_age) - if self.allow_credentials is True: + if origin != "*" and self.allow_credentials is True: headers["Access-Control-Allow-Credentials"] = "true" return headers @@ -806,10 +809,11 @@ def __init__( def _add_cors(self, event: ResponseEventT, cors: CORSConfig): """Update headers to include the configured Access-Control headers""" extracted_origin_header = extract_origin_header(event.resolved_headers_field) - if extracted_origin_header is None: - self.response.headers.update(cors.to_dict("*")) - else: + + if extracted_origin_header in cors.allowed_origins: self.response.headers.update(cors.to_dict(extracted_origin_header)) + if extracted_origin_header is not None and "*" in cors.allowed_origins: + self.response.headers.update(cors.to_dict("*")) def _add_cache_control(self, cache_control: str): """Set the specified cache control headers for 200 http responses. For non-200 `no-cache` is used.""" diff --git a/tests/functional/event_handler/required_dependencies/test_api_gateway.py b/tests/functional/event_handler/required_dependencies/test_api_gateway.py index a1dff4bd218..61d82bc999c 100644 --- a/tests/functional/event_handler/required_dependencies/test_api_gateway.py +++ b/tests/functional/event_handler/required_dependencies/test_api_gateway.py @@ -324,7 +324,7 @@ def handler(event, context): def test_cors(): # GIVEN a function with cors=True # AND http method set to GET - app = ApiGatewayResolver() + app = ApiGatewayResolver(cors=CORSConfig("https://aws.amazon.com", allow_credentials=True)) @app.get("/my/path", cors=True) def with_cors() -> Response: @@ -345,7 +345,7 @@ def handler(event, context): headers = result["multiValueHeaders"] assert headers["Content-Type"] == [content_types.TEXT_HTML] assert headers["Access-Control-Allow-Origin"] == ["https://aws.amazon.com"] - assert "Access-Control-Allow-Credentials" not in headers + assert "Access-Control-Allow-Credentials" in headers assert headers["Access-Control-Allow-Headers"] == [",".join(sorted(CORSConfig._REQUIRED_HEADERS))] # THEN for routes without cors flag return no cors headers @@ -354,7 +354,7 @@ def handler(event, context): assert "Access-Control-Allow-Origin" not in result["multiValueHeaders"] -def test_cors_no_origin(): +def test_cors_no_request_origin(): # GIVEN a function with cors=True # AND http method set to GET app = ApiGatewayResolver() @@ -366,8 +366,41 @@ def with_cors() -> Response: def handler(event, context): return app.resolve(event, context) - # remove origin header from request - del LOAD_GW_EVENT["multiValueHeaders"]["Origin"] + event = LOAD_GW_EVENT.copy() + del event["headers"]["Origin"] + del event["multiValueHeaders"]["Origin"] + + # WHEN calling the event handler + result = handler(LOAD_GW_EVENT, None) + + # THEN the headers should include cors headers + assert "multiValueHeaders" in result + headers = result["multiValueHeaders"] + assert headers["Content-Type"] == [content_types.TEXT_HTML] + assert "Access-Control-Allow-Credentials" not in headers + assert "Access-Control-Allow-Origin" not in result["multiValueHeaders"] + + +def test_cors_allow_all_request_origins(): + # GIVEN a function with cors=True + # AND http method set to GET + app = ApiGatewayResolver( + cors=CORSConfig( + allow_origin="*", + allow_credentials=True, + ), + ) + + @app.get("/my/path", cors=True) + def with_cors() -> Response: + return Response(200, content_types.TEXT_HTML, "test") + + @app.get("/without-cors") + def without_cors() -> Response: + return Response(200, content_types.TEXT_HTML, "test") + + def handler(event, context): + return app.resolve(event, context) # WHEN calling the event handler result = handler(LOAD_GW_EVENT, None) @@ -380,6 +413,11 @@ def handler(event, context): assert "Access-Control-Allow-Credentials" not in headers assert headers["Access-Control-Allow-Headers"] == [",".join(sorted(CORSConfig._REQUIRED_HEADERS))] + # THEN for routes without cors flag return no cors headers + mock_event = {"path": "/my/request", "httpMethod": "GET"} + result = handler(mock_event, None) + assert "Access-Control-Allow-Origin" not in result["multiValueHeaders"] + def test_cors_preflight_body_is_empty_not_null(): # GIVEN CORS is configured From 7599b05732dc86f74e880e56e44ec85888e4d0cc Mon Sep 17 00:00:00 2001 From: Simon Thulbourn Date: Fri, 28 Jun 2024 15:06:39 +0000 Subject: [PATCH 4/6] fix test structure --- .../required_dependencies/test_api_gateway.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/functional/event_handler/required_dependencies/test_api_gateway.py b/tests/functional/event_handler/required_dependencies/test_api_gateway.py index 61d82bc999c..2a7a5fae0e6 100644 --- a/tests/functional/event_handler/required_dependencies/test_api_gateway.py +++ b/tests/functional/event_handler/required_dependencies/test_api_gateway.py @@ -48,6 +48,7 @@ def read_media(file_name: str) -> bytes: LOAD_GW_EVENT = load_event("apiGatewayProxyEvent.json") +LOAD_GW_EVENT_NO_ORIGIN = load_event("apiGatewayProxyEventNoOrigin.json") LOAD_GW_EVENT_TRAILING_SLASH = load_event("apiGatewayProxyEventPathTrailingSlash.json") @@ -366,12 +367,10 @@ def with_cors() -> Response: def handler(event, context): return app.resolve(event, context) - event = LOAD_GW_EVENT.copy() - del event["headers"]["Origin"] - del event["multiValueHeaders"]["Origin"] + event = LOAD_GW_EVENT_NO_ORIGIN # WHEN calling the event handler - result = handler(LOAD_GW_EVENT, None) + result = handler(event, None) # THEN the headers should include cors headers assert "multiValueHeaders" in result From d71f9adf66917cafbe6a5a529c3f7d3ea348fa9e Mon Sep 17 00:00:00 2001 From: Simon Thulbourn Date: Fri, 28 Jun 2024 15:06:53 +0000 Subject: [PATCH 5/6] add test event --- .../events/apiGatewayProxyEventNoOrigin.json | 80 +++++++++++++++++++ 1 file changed, 80 insertions(+) create mode 100644 tests/events/apiGatewayProxyEventNoOrigin.json diff --git a/tests/events/apiGatewayProxyEventNoOrigin.json b/tests/events/apiGatewayProxyEventNoOrigin.json new file mode 100644 index 00000000000..666022723ad --- /dev/null +++ b/tests/events/apiGatewayProxyEventNoOrigin.json @@ -0,0 +1,80 @@ +{ + "version": "1.0", + "resource": "/my/path", + "path": "/my/path", + "httpMethod": "GET", + "headers": { + "Header1": "value1", + "Header2": "value2" + }, + "multiValueHeaders": { + "Header1": [ + "value1" + ], + "Header2": [ + "value1", + "value2" + ] + }, + "queryStringParameters": { + "parameter1": "value1", + "parameter2": "value" + }, + "multiValueQueryStringParameters": { + "parameter1": [ + "value1", + "value2" + ], + "parameter2": [ + "value" + ] + }, + "requestContext": { + "accountId": "123456789012", + "apiId": "id", + "authorizer": { + "claims": null, + "scopes": null + }, + "domainName": "id.execute-api.us-east-1.amazonaws.com", + "domainPrefix": "id", + "extendedRequestId": "request-id", + "httpMethod": "GET", + "identity": { + "accessKey": null, + "accountId": null, + "caller": null, + "cognitoAuthenticationProvider": null, + "cognitoAuthenticationType": null, + "cognitoIdentityId": null, + "cognitoIdentityPoolId": null, + "principalOrgId": null, + "sourceIp": "192.168.0.1/32", + "user": null, + "userAgent": "user-agent", + "userArn": null, + "clientCert": { + "clientCertPem": "CERT_CONTENT", + "subjectDN": "www.example.com", + "issuerDN": "Example issuer", + "serialNumber": "a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1", + "validity": { + "notBefore": "May 28 12:30:02 2019 GMT", + "notAfter": "Aug 5 09:36:04 2021 GMT" + } + } + }, + "path": "/my/path", + "protocol": "HTTP/1.1", + "requestId": "id=", + "requestTime": "04/Mar/2020:19:15:17 +0000", + "requestTimeEpoch": 1583349317135, + "resourceId": null, + "resourcePath": "/my/path", + "stage": "$default" + }, + "pathParameters": null, + "stageVariables": null, + "body": "Hello from Lambda!", + "isBase64Encoded": false +} \ No newline at end of file From cdbd01e14201d4f7ecfd3163877562f9aff6664b Mon Sep 17 00:00:00 2001 From: Simon Thulbourn Date: Fri, 5 Jul 2024 09:36:55 +0000 Subject: [PATCH 6/6] add allowed_origins method to CORSConfig --- .../event_handler/api_gateway.py | 23 +++++++++++++------ 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 7b95f35cd45..2c829789e8c 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import base64 import json import logging @@ -191,10 +193,10 @@ def __init__( A boolean value that sets the value of `Access-Control-Allow-Credentials` """ - self.allowed_origins = [allow_origin] + self._allowed_origins = [allow_origin] if extra_origins: - self.allowed_origins.extend(extra_origins) + self._allowed_origins.extend(extra_origins) self.allow_headers = set(self._REQUIRED_HEADERS + (allow_headers or [])) self.expose_headers = expose_headers or [] @@ -210,7 +212,7 @@ def to_dict(self, origin: Optional[str]) -> Dict[str, str]: # If the origin doesn't match any of the allowed origins, and we don't allow all origins ("*"), # don't add any CORS headers - if origin not in self.allowed_origins and "*" not in self.allowed_origins: + if origin not in self._allowed_origins and "*" not in self._allowed_origins: return {} # The origin matched an allowed origin, so return the CORS headers @@ -227,6 +229,14 @@ def to_dict(self, origin: Optional[str]) -> Dict[str, str]: headers["Access-Control-Allow-Credentials"] = "true" return headers + def allowed_origin(self, extracted_origin: str) -> str | None: + if extracted_origin in self._allowed_origins: + return extracted_origin + if extracted_origin is not None and "*" in self._allowed_origins: + return "*" + + return None + @staticmethod def build_allow_methods(methods: Set[str]) -> str: """Build sorted comma delimited methods for Access-Control-Allow-Methods header @@ -812,10 +822,9 @@ def _add_cors(self, event: ResponseEventT, cors: CORSConfig): """Update headers to include the configured Access-Control headers""" extracted_origin_header = extract_origin_header(event.resolved_headers_field) - if extracted_origin_header in cors.allowed_origins: - self.response.headers.update(cors.to_dict(extracted_origin_header)) - if extracted_origin_header is not None and "*" in cors.allowed_origins: - self.response.headers.update(cors.to_dict("*")) + origin = cors.allowed_origin(extracted_origin_header) + if origin is not None: + self.response.headers.update(cors.to_dict(origin)) def _add_cache_control(self, cache_control: str): """Set the specified cache control headers for 200 http responses. For non-200 `no-cache` is used."""