diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 78993f92c5e..05fbc1c06c1 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -84,6 +84,7 @@ def with_cors(): cors_config = CORSConfig( allow_origin="https://wwww.example.com/", + extra_origins=["https://dev.example.com/"], expose_headers=["x-exposed-response-header"], allow_headers=["x-custom-request-header"], max_age=100, @@ -106,6 +107,7 @@ def without_cors(): def __init__( self, allow_origin: str = "*", + extra_origins: Optional[List[str]] = None, allow_headers: Optional[List[str]] = None, expose_headers: Optional[List[str]] = None, max_age: Optional[int] = None, @@ -117,6 +119,8 @@ def __init__( allow_origin: str The value of the `Access-Control-Allow-Origin` to send in the response. Defaults to "*", but should only be used during development. + extra_origins: Optional[List[str]] + The list of additional allowed origins. allow_headers: Optional[List[str]] The list of additional allowed headers. This list is added to list of built-in allowed headers: `Authorization`, `Content-Type`, `X-Amz-Date`, @@ -128,16 +132,29 @@ def __init__( allow_credentials: bool A boolean value that sets the value of `Access-Control-Allow-Credentials` """ - self.allow_origin = allow_origin + self._allowed_origins = [allow_origin] + if extra_origins: + self._allowed_origins.extend(extra_origins) self.allow_headers = set(self._REQUIRED_HEADERS + (allow_headers or [])) self.expose_headers = expose_headers or [] self.max_age = max_age self.allow_credentials = allow_credentials - def to_dict(self) -> Dict[str, str]: + def to_dict(self, origin: Optional[str]) -> Dict[str, str]: """Builds the configured Access-Control http headers""" + + # If there's no Origin, don't add any CORS headers + if not origin: + return {} + + # If the origin doesn't match any of the allowed origins, and we don't allow all origins ("*"), + # don't add any CORS headers + if origin not in self._allowed_origins and "*" not in self._allowed_origins: + return {} + + # The origin matched an allowed origin, so return the CORS headers headers: Dict[str, str] = { - "Access-Control-Allow-Origin": self.allow_origin, + "Access-Control-Allow-Origin": origin, "Access-Control-Allow-Headers": ",".join(sorted(self.allow_headers)), } @@ -207,9 +224,9 @@ def __init__(self, response: Response, route: Optional[Route] = None): self.response = response self.route = route - def _add_cors(self, cors: CORSConfig): + def _add_cors(self, event: BaseProxyEvent, cors: CORSConfig): """Update headers to include the configured Access-Control headers""" - self.response.headers.update(cors.to_dict()) + self.response.headers.update(cors.to_dict(event.get_header_value("Origin"))) def _add_cache_control(self, cache_control: str): """Set the specified cache control headers for 200 http responses. For non-200 `no-cache` is used.""" @@ -230,7 +247,7 @@ def _route(self, event: BaseProxyEvent, cors: Optional[CORSConfig]): if self.route is None: return if self.route.cors: - self._add_cors(cors or CORSConfig()) + self._add_cors(event, cors or CORSConfig()) if self.route.cache_control: self._add_cache_control(self.route.cache_control) if self.route.compress and "gzip" in (event.get_header_value("accept-encoding", "") or ""): @@ -644,7 +661,7 @@ def _not_found(self, method: str) -> ResponseBuilder: headers: Dict[str, Union[str, List[str]]] = {} if self._cors: logger.debug("CORS is enabled, updating headers.") - headers.update(self._cors.to_dict()) + headers.update(self._cors.to_dict(self.current_event.get_header_value("Origin"))) if method == "OPTIONS": logger.debug("Pre-flight request detected. Returning CORS with null response") diff --git a/aws_lambda_powertools/utilities/data_classes/common.py b/aws_lambda_powertools/utilities/data_classes/common.py index d1ce8f90a07..ce02a4c11b0 100644 --- a/aws_lambda_powertools/utilities/data_classes/common.py +++ b/aws_lambda_powertools/utilities/data_classes/common.py @@ -113,7 +113,7 @@ def get_header_value( class BaseProxyEvent(DictWrapper): @property def headers(self) -> Dict[str, str]: - return self["headers"] + return self.get("headers") or {} @property def query_string_parameters(self) -> Optional[Dict[str, str]]: diff --git a/docs/core/event_handler/api_gateway.md b/docs/core/event_handler/api_gateway.md index 9348575535a..3dc6401ea8d 100644 --- a/docs/core/event_handler/api_gateway.md +++ b/docs/core/event_handler/api_gateway.md @@ -280,7 +280,8 @@ To address this API Gateway behavior, we use `strip_prefixes` parameter to accou You can configure CORS at the `APIGatewayRestResolver` constructor via `cors` parameter using the `CORSConfig` class. -This will ensure that CORS headers are always returned as part of the response when your functions match the path invoked. +This will ensure that CORS headers are returned as part of the response when your functions match the path invoked and the `Origin` +matches one of the allowed values. ???+ tip Optionally disable CORS on a per path basis with `cors=False` parameter. @@ -297,6 +298,18 @@ This will ensure that CORS headers are always returned as part of the response w --8<-- "examples/event_handler_rest/src/setting_cors_output.json" ``` +=== "setting_cors_extra_origins.py" + + ```python hl_lines="5 11-12 34" + --8<-- "examples/event_handler_rest/src/setting_cors_extra_origins.py" + ``` + +=== "setting_cors_extra_origins_output.json" + + ```json + --8<-- "examples/event_handler_rest/src/setting_cors_extra_origins_output.json" + ``` + #### Pre-flight Pre-flight (OPTIONS) calls are typically handled at the API Gateway or Lambda Function URL level as per [our sample infrastructure](#required-resources), no Lambda integration is necessary. However, ALB expects you to handle pre-flight requests. @@ -310,9 +323,13 @@ For convenience, these are the default values when using `CORSConfig` to enable ???+ warning Always configure `allow_origin` when using in production. +???+ tip "Multiple origins?" + If you need to allow multiple origins, pass the additional origins using the `extra_origins` key. + | Key | Value | Note | -| -------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +|----------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------| | **[allow_origin](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin){target="_blank"}**: `str` | `*` | Only use the default value for development. **Never use `*` for production** unless your use case requires it | +| **[extra_origins](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin){target="_blank"}**: `List[str]` | `[]` | Additional origins to be allowed, in addition to the one specified in `allow_origin` | | **[allow_headers](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers){target="_blank"}**: `List[str]` | `[Authorization, Content-Type, X-Amz-Date, X-Api-Key, X-Amz-Security-Token]` | Additional headers will be appended to the default list for your convenience | | **[expose_headers](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Headers){target="_blank"}**: `List[str]` | `[]` | Any additional header beyond the [safe listed by CORS specification](https://developer.mozilla.org/en-US/docs/Glossary/CORS-safelisted_response_header){target="_blank"}. | | **[max_age](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age){target="_blank"}**: `int` | `` | Only for pre-flight requests if you choose to have your function to handle it instead of API Gateway | @@ -331,7 +348,7 @@ You can use the `Response` class to have full control over the response. For exa === "fine_grained_responses.py" - ```python hl_lines="9 28-32" + ```python hl_lines="9 29-35" --8<-- "examples/event_handler_rest/src/fine_grained_responses.py" ``` diff --git a/examples/event_handler_rest/src/setting_cors.py b/examples/event_handler_rest/src/setting_cors.py index 101e013e552..14470cf9d1e 100644 --- a/examples/event_handler_rest/src/setting_cors.py +++ b/examples/event_handler_rest/src/setting_cors.py @@ -8,7 +8,8 @@ tracer = Tracer() logger = Logger() -cors_config = CORSConfig(allow_origin="https://example.com", max_age=300) +# CORS will match when Origin is only https://www.example.com +cors_config = CORSConfig(allow_origin="https://www.example.com", max_age=300) app = APIGatewayRestResolver(cors=cors_config) diff --git a/examples/event_handler_rest/src/setting_cors_extra_origins.py b/examples/event_handler_rest/src/setting_cors_extra_origins.py new file mode 100644 index 00000000000..3afb2794ec6 --- /dev/null +++ b/examples/event_handler_rest/src/setting_cors_extra_origins.py @@ -0,0 +1,45 @@ +import requests +from requests import Response + +from aws_lambda_powertools import Logger, Tracer +from aws_lambda_powertools.event_handler import APIGatewayRestResolver, CORSConfig +from aws_lambda_powertools.logging import correlation_paths +from aws_lambda_powertools.utilities.typing import LambdaContext + +tracer = Tracer() +logger = Logger() +# CORS will match when Origin is https://www.example.com OR https://dev.example.com +cors_config = CORSConfig(allow_origin="https://www.example.com", extra_origins=["https://dev.example.com"], max_age=300) +app = APIGatewayRestResolver(cors=cors_config) + + +@app.get("/todos") +@tracer.capture_method +def get_todos(): + todos: Response = requests.get("https://jsonplaceholder.typicode.com/todos") + todos.raise_for_status() + + # for brevity, we'll limit to the first 10 only + return {"todos": todos.json()[:10]} + + +@app.get("/todos/") +@tracer.capture_method +def get_todo_by_id(todo_id: str): # value come as str + todos: Response = requests.get(f"https://jsonplaceholder.typicode.com/todos/{todo_id}") + todos.raise_for_status() + + return {"todos": todos.json()} + + +@app.get("/healthcheck", cors=False) # optionally removes CORS for a given route +@tracer.capture_method +def am_i_alive(): + return {"am_i_alive": "yes"} + + +# You can continue to use other utilities just as before +@logger.inject_lambda_context(correlation_id_path=correlation_paths.API_GATEWAY_REST) +@tracer.capture_lambda_handler +def lambda_handler(event: dict, context: LambdaContext) -> dict: + return app.resolve(event, context) diff --git a/examples/event_handler_rest/src/setting_cors_extra_origins_output.json b/examples/event_handler_rest/src/setting_cors_extra_origins_output.json new file mode 100644 index 00000000000..c123435338c --- /dev/null +++ b/examples/event_handler_rest/src/setting_cors_extra_origins_output.json @@ -0,0 +1,10 @@ +{ + "statusCode": 200, + "multiValueHeaders": { + "Content-Type": ["application/json"], + "Access-Control-Allow-Origin": ["https://www.example.com","https://dev.example.com"], + "Access-Control-Allow-Headers": ["Authorization,Content-Type,X-Amz-Date,X-Amz-Security-Token,X-Api-Key"] + }, + "body": "{\"todos\":[{\"userId\":1,\"id\":1,\"title\":\"delectus aut autem\",\"completed\":false},{\"userId\":1,\"id\":2,\"title\":\"quis ut nam facilis et officia qui\",\"completed\":false},{\"userId\":1,\"id\":3,\"title\":\"fugiat veniam minus\",\"completed\":false},{\"userId\":1,\"id\":4,\"title\":\"et porro tempora\",\"completed\":true},{\"userId\":1,\"id\":5,\"title\":\"laboriosam mollitia et enim quasi adipisci quia provident illum\",\"completed\":false},{\"userId\":1,\"id\":6,\"title\":\"qui ullam ratione quibusdam voluptatem quia omnis\",\"completed\":false},{\"userId\":1,\"id\":7,\"title\":\"illo expedita consequatur quia in\",\"completed\":false},{\"userId\":1,\"id\":8,\"title\":\"quo adipisci enim quam ut ab\",\"completed\":true},{\"userId\":1,\"id\":9,\"title\":\"molestiae perspiciatis ipsa\",\"completed\":false},{\"userId\":1,\"id\":10,\"title\":\"illo est ratione doloremque quia maiores aut\",\"completed\":true}]}", + "isBase64Encoded": false +} diff --git a/tests/e2e/event_handler/handlers/alb_handler.py b/tests/e2e/event_handler/handlers/alb_handler.py index 26746284aee..ef1af1792ac 100644 --- a/tests/e2e/event_handler/handlers/alb_handler.py +++ b/tests/e2e/event_handler/handlers/alb_handler.py @@ -1,6 +1,12 @@ -from aws_lambda_powertools.event_handler import ALBResolver, Response, content_types - -app = ALBResolver() +from aws_lambda_powertools.event_handler import ( + ALBResolver, + CORSConfig, + Response, + content_types, +) + +cors_config = CORSConfig(allow_origin="https://www.example.org", extra_origins=["https://dev.example.org"]) +app = ALBResolver(cors=cors_config) # The reason we use post is that whoever is writing tests can easily assert on the # content being sent (body, headers, cookies, content-type) to reduce cognitive load. diff --git a/tests/e2e/event_handler/handlers/api_gateway_http_handler.py b/tests/e2e/event_handler/handlers/api_gateway_http_handler.py index 1012af7b3fb..876d78ef67b 100644 --- a/tests/e2e/event_handler/handlers/api_gateway_http_handler.py +++ b/tests/e2e/event_handler/handlers/api_gateway_http_handler.py @@ -1,10 +1,12 @@ from aws_lambda_powertools.event_handler import ( APIGatewayHttpResolver, + CORSConfig, Response, content_types, ) -app = APIGatewayHttpResolver() +cors_config = CORSConfig(allow_origin="https://www.example.org", extra_origins=["https://dev.example.org"]) +app = APIGatewayHttpResolver(cors=cors_config) # The reason we use post is that whoever is writing tests can easily assert on the # content being sent (body, headers, cookies, content-type) to reduce cognitive load. diff --git a/tests/e2e/event_handler/handlers/api_gateway_rest_handler.py b/tests/e2e/event_handler/handlers/api_gateway_rest_handler.py index d52e2728cab..d09bf6b82c9 100644 --- a/tests/e2e/event_handler/handlers/api_gateway_rest_handler.py +++ b/tests/e2e/event_handler/handlers/api_gateway_rest_handler.py @@ -1,10 +1,12 @@ from aws_lambda_powertools.event_handler import ( APIGatewayRestResolver, + CORSConfig, Response, content_types, ) -app = APIGatewayRestResolver() +cors_config = CORSConfig(allow_origin="https://www.example.org", extra_origins=["https://dev.example.org"]) +app = APIGatewayRestResolver(cors=cors_config) # The reason we use post is that whoever is writing tests can easily assert on the # content being sent (body, headers, cookies, content-type) to reduce cognitive load. diff --git a/tests/e2e/event_handler/handlers/lambda_function_url_handler.py b/tests/e2e/event_handler/handlers/lambda_function_url_handler.py index f90037afc75..e47035a971d 100644 --- a/tests/e2e/event_handler/handlers/lambda_function_url_handler.py +++ b/tests/e2e/event_handler/handlers/lambda_function_url_handler.py @@ -1,10 +1,12 @@ from aws_lambda_powertools.event_handler import ( + CORSConfig, LambdaFunctionUrlResolver, Response, content_types, ) -app = LambdaFunctionUrlResolver() +cors_config = CORSConfig(allow_origin="https://www.example.org", extra_origins=["https://dev.example.org"]) +app = LambdaFunctionUrlResolver(cors=cors_config) # The reason we use post is that whoever is writing tests can easily assert on the # content being sent (body, headers, cookies, content-type) to reduce cognitive load. diff --git a/tests/e2e/event_handler/test_cors.py b/tests/e2e/event_handler/test_cors.py new file mode 100644 index 00000000000..5d2f140715f --- /dev/null +++ b/tests/e2e/event_handler/test_cors.py @@ -0,0 +1,252 @@ +import pytest +from requests import Request + +from tests.e2e.utils import data_fetcher +from tests.e2e.utils.auth import build_iam_auth + + +@pytest.fixture +def alb_basic_listener_endpoint(infrastructure: dict) -> str: + dns_name = infrastructure.get("ALBDnsName") + port = infrastructure.get("ALBBasicListenerPort", "") + return f"http://{dns_name}:{port}" + + +@pytest.fixture +def apigw_http_endpoint(infrastructure: dict) -> str: + return infrastructure.get("APIGatewayHTTPUrl", "") + + +@pytest.fixture +def apigw_rest_endpoint(infrastructure: dict) -> str: + return infrastructure.get("APIGatewayRestUrl", "") + + +@pytest.fixture +def lambda_function_url_endpoint(infrastructure: dict) -> str: + return infrastructure.get("LambdaFunctionUrl", "") + + +@pytest.mark.xdist_group(name="event_handler") +def test_alb_cors_with_correct_origin(alb_basic_listener_endpoint): + # GIVEN + url = f"{alb_basic_listener_endpoint}/todos" + headers = {"Origin": "https://www.example.org"} + + # WHEN + response = data_fetcher.get_http_response(Request(method="POST", url=url, headers=headers, json={})) + + # THEN response has CORS headers + assert response.headers["Access-Control-Allow-Origin"] == "https://www.example.org" + + +@pytest.mark.xdist_group(name="event_handler") +def test_alb_cors_with_correct_alternative_origin(alb_basic_listener_endpoint): + # GIVEN + url = f"{alb_basic_listener_endpoint}/todos" + headers = {"Origin": "https://dev.example.org"} + + # WHEN + response = data_fetcher.get_http_response(Request(method="POST", url=url, headers=headers, json={})) + + # THEN response has CORS headers + assert response.headers["Access-Control-Allow-Origin"] == "https://dev.example.org" + + +@pytest.mark.xdist_group(name="event_handler") +def test_alb_cors_with_unknown_origin(alb_basic_listener_endpoint): + # GIVEN + url = f"{alb_basic_listener_endpoint}/todos" + headers = {"Origin": "https://www.google.com"} + + # WHEN + response = data_fetcher.get_http_response(Request(method="POST", url=url, headers=headers, json={})) + + # THEN response does NOT have CORS headers + assert "Access-Control-Allow-Origin" not in response.headers + + +@pytest.mark.xdist_group(name="event_handler") +def test_api_gateway_http_cors_with_correct_origin(apigw_http_endpoint): + # GIVEN + url = f"{apigw_http_endpoint}todos" + headers = {"Origin": "https://www.example.org"} + + # WHEN + response = data_fetcher.get_http_response( + Request( + method="POST", + url=url, + headers=headers, + json={}, + auth=build_iam_auth(url=url, aws_service="execute-api"), + ) + ) + + # THEN response has CORS headers + assert response.headers["Access-Control-Allow-Origin"] == "https://www.example.org" + + +@pytest.mark.xdist_group(name="event_handler") +def test_api_gateway_http_cors_with_correct_alternative_origin(apigw_http_endpoint): + # GIVEN + url = f"{apigw_http_endpoint}todos" + headers = {"Origin": "https://dev.example.org"} + + # WHEN + response = data_fetcher.get_http_response( + Request( + method="POST", + url=url, + headers=headers, + json={}, + auth=build_iam_auth(url=url, aws_service="execute-api"), + ) + ) + + # THEN response has CORS headers + assert response.headers["Access-Control-Allow-Origin"] == "https://dev.example.org" + + +@pytest.mark.xdist_group(name="event_handler") +def test_api_gateway_http_cors_with_unknown_origin(apigw_http_endpoint): + # GIVEN + url = f"{apigw_http_endpoint}todos" + headers = {"Origin": "https://www.google.com"} + + # WHEN + response = data_fetcher.get_http_response( + Request( + method="POST", + url=url, + headers=headers, + json={}, + auth=build_iam_auth(url=url, aws_service="execute-api"), + ) + ) + + # THEN response does NOT have CORS headers + assert "Access-Control-Allow-Origin" not in response.headers + + +@pytest.mark.xdist_group(name="event_handler") +def test_api_gateway_rest_cors_with_correct_origin(apigw_rest_endpoint): + # GIVEN + url = f"{apigw_rest_endpoint}todos" + headers = {"Origin": "https://www.example.org"} + + # WHEN + response = data_fetcher.get_http_response( + Request( + method="POST", + url=url, + headers=headers, + json={}, + ) + ) + + # THEN response has CORS headers + assert response.headers["Access-Control-Allow-Origin"] == "https://www.example.org" + + +@pytest.mark.xdist_group(name="event_handler") +def test_api_gateway_rest_cors_with_correct_alternative_origin(apigw_rest_endpoint): + # GIVEN + url = f"{apigw_rest_endpoint}todos" + headers = {"Origin": "https://dev.example.org"} + + # WHEN + response = data_fetcher.get_http_response( + Request( + method="POST", + url=url, + headers=headers, + json={}, + ) + ) + + # THEN response has CORS headers + assert response.headers["Access-Control-Allow-Origin"] == "https://dev.example.org" + + +@pytest.mark.xdist_group(name="event_handler") +def test_api_gateway_rest_cors_with_unknown_origin(apigw_rest_endpoint): + # GIVEN + url = f"{apigw_rest_endpoint}todos" + headers = {"Origin": "https://www.google.com"} + + # WHEN + response = data_fetcher.get_http_response( + Request( + method="POST", + url=url, + headers=headers, + json={}, + ) + ) + + # THEN response does NOT have CORS headers + assert "Access-Control-Allow-Origin" not in response.headers + + +@pytest.mark.xdist_group(name="event_handler") +def test_lambda_function_url_cors_with_correct_origin(lambda_function_url_endpoint): + # GIVEN + url = f"{lambda_function_url_endpoint}todos" + headers = {"Origin": "https://www.example.org"} + + # WHEN + response = data_fetcher.get_http_response( + Request( + method="POST", + url=url, + headers=headers, + json={}, + auth=build_iam_auth(url=url, aws_service="lambda"), + ) + ) + + # THEN response has CORS headers + assert response.headers["Access-Control-Allow-Origin"] == "https://www.example.org" + + +@pytest.mark.xdist_group(name="event_handler") +def test_lambda_function_url_cors_with_correct_alternative_origin(lambda_function_url_endpoint): + # GIVEN + url = f"{lambda_function_url_endpoint}todos" + headers = {"Origin": "https://dev.example.org"} + + # WHEN + response = data_fetcher.get_http_response( + Request( + method="POST", + url=url, + headers=headers, + json={}, + auth=build_iam_auth(url=url, aws_service="lambda"), + ) + ) + + # THEN response has CORS headers + assert response.headers["Access-Control-Allow-Origin"] == "https://dev.example.org" + + +@pytest.mark.xdist_group(name="event_handler") +def test_lambda_function_url_cors_with_unknown_origin(lambda_function_url_endpoint): + # GIVEN + url = f"{lambda_function_url_endpoint}todos" + headers = {"Origin": "https://www.google.com"} + + # WHEN + response = data_fetcher.get_http_response( + Request( + method="POST", + url=url, + headers=headers, + json={}, + auth=build_iam_auth(url=url, aws_service="lambda"), + ) + ) + + # THEN response does NOT have CORS headers + assert "Access-Control-Allow-Origin" not in response.headers diff --git a/tests/events/apiGatewayProxyEvent.json b/tests/events/apiGatewayProxyEvent.json index 11833d21f2c..da814c91100 100644 --- a/tests/events/apiGatewayProxyEvent.json +++ b/tests/events/apiGatewayProxyEvent.json @@ -5,7 +5,8 @@ "httpMethod": "GET", "headers": { "Header1": "value1", - "Header2": "value2" + "Header2": "value2", + "Origin": "https://aws.amazon.com" }, "multiValueHeaders": { "Header1": [ diff --git a/tests/functional/event_handler/test_api_gateway.py b/tests/functional/event_handler/test_api_gateway.py index ad9f834dbb2..6faad88d7f1 100644 --- a/tests/functional/event_handler/test_api_gateway.py +++ b/tests/functional/event_handler/test_api_gateway.py @@ -343,7 +343,7 @@ def handler(event, context): assert "multiValueHeaders" in result headers = result["multiValueHeaders"] assert headers["Content-Type"] == [content_types.TEXT_HTML] - assert headers["Access-Control-Allow-Origin"] == ["*"] + assert headers["Access-Control-Allow-Origin"] == ["https://aws.amazon.com"] assert "Access-Control-Allow-Credentials" not in headers assert headers["Access-Control-Allow-Headers"] == [",".join(sorted(CORSConfig._REQUIRED_HEADERS))] @@ -533,6 +533,34 @@ def rest_func() -> Response: assert result["body"] == "Not found" +def test_cors_multi_origin(): + # GIVEN a custom cors configuration with multiple origins + cors_config = CORSConfig(allow_origin="https://origin1", extra_origins=["https://origin2", "https://origin3"]) + app = ApiGatewayResolver(cors=cors_config) + + @app.get("/cors") + def get_with_cors(): + return {} + + # WHEN calling the event handler with the correct Origin + event = {"path": "/cors", "httpMethod": "GET", "headers": {"Origin": "https://origin3"}} + result = app(event, None) + + # THEN routes by default return the custom cors headers + headers = result["multiValueHeaders"] + assert headers["Content-Type"] == [content_types.APPLICATION_JSON] + assert headers["Access-Control-Allow-Origin"] == ["https://origin3"] + + # WHEN calling the event handler with the wrong origin + event = {"path": "/cors", "httpMethod": "GET", "headers": {"Origin": "https://wrong.origin"}} + result = app(event, None) + + # THEN routes by default return the custom cors headers + headers = result["multiValueHeaders"] + assert headers["Content-Type"] == [content_types.APPLICATION_JSON] + assert "Access-Control-Allow-Origin" not in headers + + def test_custom_cors_config(): # GIVEN a custom cors configuration allow_header = ["foo2"] @@ -544,7 +572,7 @@ def test_custom_cors_config(): allow_credentials=True, ) app = ApiGatewayResolver(cors=cors_config) - event = {"path": "/cors", "httpMethod": "GET"} + event = {"path": "/cors", "httpMethod": "GET", "headers": {"Origin": "https://foo1"}} @app.get("/cors") def get_with_cors(): @@ -561,7 +589,7 @@ def another_one(): assert "multiValueHeaders" in result headers = result["multiValueHeaders"] assert headers["Content-Type"] == [content_types.APPLICATION_JSON] - assert headers["Access-Control-Allow-Origin"] == [cors_config.allow_origin] + assert headers["Access-Control-Allow-Origin"] == ["https://foo1"] expected_allows_headers = [",".join(sorted(set(allow_header + cors_config._REQUIRED_HEADERS)))] assert headers["Access-Control-Allow-Headers"] == expected_allows_headers assert headers["Access-Control-Expose-Headers"] == [",".join(cors_config.expose_headers)] @@ -604,9 +632,9 @@ def test_no_matches_with_cors(): result = app({"path": "/another-one", "httpMethod": "GET"}, None) # THEN return a 404 - # AND cors headers are returned + # AND cors headers are NOT returned (because no Origin header was passed in) assert result["statusCode"] == 404 - assert "Access-Control-Allow-Origin" in result["multiValueHeaders"] + assert "Access-Control-Allow-Origin" not in result["multiValueHeaders"] assert "Not found" in result["body"] @@ -628,7 +656,7 @@ def post_no_cors(): ... # WHEN calling the handler - result = app({"path": "/foo", "httpMethod": "OPTIONS"}, None) + result = app({"path": "/foo", "httpMethod": "OPTIONS", "headers": {"Origin": "http://example.org"}}, None) # THEN return no content # AND include Access-Control-Allow-Methods of the cors methods used @@ -659,8 +687,11 @@ def custom_preflight(): def custom_method(): ... + # AND the request includes an origin + headers = {"Origin": "https://example.org"} + # WHEN calling the handler - result = app({"path": "/some-call", "httpMethod": "OPTIONS"}, None) + result = app({"path": "/some-call", "httpMethod": "OPTIONS", "headers": headers}, None) # THEN return the custom preflight response assert result["statusCode"] == 200 @@ -747,7 +778,8 @@ def service_error(): # AND status code equals 502 assert result["statusCode"] == 502 assert result["multiValueHeaders"]["Content-Type"] == [content_types.APPLICATION_JSON] - assert "Access-Control-Allow-Origin" in result["multiValueHeaders"] + # Because no Origin was passed in, there is not Allow-Origin on the output + assert "Access-Control-Allow-Origin" not in result["multiValueHeaders"] expected = {"statusCode": 502, "message": "Something went wrong!"} assert result["body"] == json_dump(expected)