Skip to content

feat(event_handler): allow multiple CORS origins #2279

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 18, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 24 additions & 7 deletions aws_lambda_powertools/event_handler/api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def with_cors():

cors_config = CORSConfig(
allow_origin="https://wwww.example.com/",
extra_origins=["https://www1.example.com/"],
expose_headers=["x-exposed-response-header"],
allow_headers=["x-custom-request-header"],
max_age=100,
Expand All @@ -106,6 +107,7 @@ def without_cors():
def __init__(
self,
allow_origin: str = "*",
extra_origins: Optional[List[str]] = None,
allow_headers: Optional[List[str]] = None,
expose_headers: Optional[List[str]] = None,
max_age: Optional[int] = None,
Expand All @@ -117,6 +119,8 @@ def __init__(
allow_origin: str
The value of the `Access-Control-Allow-Origin` to send in the response. Defaults to "*", but should
only be used during development.
extra_origins: Optional[List[str]]
The list of additional allowed origins.
allow_headers: Optional[List[str]]
The list of additional allowed headers. This list is added to list of
built-in allowed headers: `Authorization`, `Content-Type`, `X-Amz-Date`,
Expand All @@ -128,16 +132,29 @@ def __init__(
allow_credentials: bool
A boolean value that sets the value of `Access-Control-Allow-Credentials`
"""
self.allow_origin = allow_origin
self.allowed_origins = [allow_origin]
if extra_origins:
self.allowed_origins.extend(extra_origins)
self.allow_headers = set(self._REQUIRED_HEADERS + (allow_headers or []))
self.expose_headers = expose_headers or []
self.max_age = max_age
self.allow_credentials = allow_credentials

def to_dict(self) -> Dict[str, str]:
def to_dict(self, origin: Optional[str]) -> Dict[str, str]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

origin: str = ""?

Annotation says it's optional but we are not setting a default value.

You can drop the Optional (None), and simply set to an empty str

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean I would have to compare if origin == "" later on instead of not origin? Doesn't it sound wrong?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TIL not will eventually translate to __bool__, so it works for empty strings too

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've changed my mind again and I believe it's better to keep the Optional. Reason is, the caller will try to fetch an Origin from the headers. The Origin is not always present, so mypy would complain that I can't pass an Optional to the to_dict. So I think in this case the Optional makes sense.

"""Builds the configured Access-Control http headers"""

# If there's no Origin, don't add any CORS headers
if not origin:
return {}

# If the origin doesn't match any of the allowed origins, and we don't allow all origins ("*"),
# don't add any CORS headers
if origin not in self.allowed_origins and "*" not in self.allowed_origins:
return {}

# The origin matched an allowed origin, so return the CORS headers
headers: Dict[str, str] = {
"Access-Control-Allow-Origin": self.allow_origin,
"Access-Control-Allow-Origin": origin,
"Access-Control-Allow-Headers": ",".join(sorted(self.allow_headers)),
}

Expand Down Expand Up @@ -207,9 +224,9 @@ def __init__(self, response: Response, route: Optional[Route] = None):
self.response = response
self.route = route

def _add_cors(self, cors: CORSConfig):
def _add_cors(self, event: BaseProxyEvent, cors: CORSConfig):
"""Update headers to include the configured Access-Control headers"""
self.response.headers.update(cors.to_dict())
self.response.headers.update(cors.to_dict(event.get_header_value("Origin")))

def _add_cache_control(self, cache_control: str):
"""Set the specified cache control headers for 200 http responses. For non-200 `no-cache` is used."""
Expand All @@ -230,7 +247,7 @@ def _route(self, event: BaseProxyEvent, cors: Optional[CORSConfig]):
if self.route is None:
return
if self.route.cors:
self._add_cors(cors or CORSConfig())
self._add_cors(event, cors or CORSConfig())
if self.route.cache_control:
self._add_cache_control(self.route.cache_control)
if self.route.compress and "gzip" in (event.get_header_value("accept-encoding", "") or ""):
Expand Down Expand Up @@ -644,7 +661,7 @@ def _not_found(self, method: str) -> ResponseBuilder:
headers: Dict[str, Union[str, List[str]]] = {}
if self._cors:
logger.debug("CORS is enabled, updating headers.")
headers.update(self._cors.to_dict())
headers.update(self._cors.to_dict(self.current_event.get_header_value("Origin")))

if method == "OPTIONS":
logger.debug("Pre-flight request detected. Returning CORS with null response")
Expand Down
2 changes: 1 addition & 1 deletion aws_lambda_powertools/utilities/data_classes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def get_header_value(
class BaseProxyEvent(DictWrapper):
@property
def headers(self) -> Dict[str, str]:
return self["headers"]
return self.get("headers") or {}

@property
def query_string_parameters(self) -> Optional[Dict[str, str]]:
Expand Down
8 changes: 6 additions & 2 deletions docs/core/event_handler/api_gateway.md
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,8 @@ To address this API Gateway behavior, we use `strip_prefixes` parameter to accou

You can configure CORS at the `APIGatewayRestResolver` constructor via `cors` parameter using the `CORSConfig` class.

This will ensure that CORS headers are always returned as part of the response when your functions match the path invoked.
This will ensure that CORS headers are returned as part of the response when your functions match the path invoked and the `Origin`
matches one of the allowed values.

???+ tip
Optionally disable CORS on a per path basis with `cors=False` parameter.
Expand Down Expand Up @@ -310,6 +311,9 @@ For convenience, these are the default values when using `CORSConfig` to enable
???+ warning
Always configure `allow_origin` when using in production.

???+ tip "Multiple allowed origins?"
If you require multiple allowed origins, pass the additional origins using the `extra_origins` key.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ruben, please add the extra_origins field in the table below.

How about adding a new tab with an example to show how to use this new field? What do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I'm dizzy in the taxi but this reads odd, perhaps what you meant was:

Multiple origins?

If you need to allow multiple origins ...

As in, Allow-Origins is an explicit CORS terminology, but here we can be more flexible in wording

| Key | Value | Note |
| -------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| **[allow_origin](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin){target="_blank"}**: `str` | `*` | Only use the default value for development. **Never use `*` for production** unless your use case requires it |
Expand All @@ -331,7 +335,7 @@ You can use the `Response` class to have full control over the response. For exa

=== "fine_grained_responses.py"

```python hl_lines="9 28-32"
```python hl_lines="9 29-35"
--8<-- "examples/event_handler_rest/src/fine_grained_responses.py"
```

Expand Down
3 changes: 2 additions & 1 deletion tests/events/apiGatewayProxyEvent.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
"httpMethod": "GET",
"headers": {
"Header1": "value1",
"Header2": "value2"
"Header2": "value2",
"Origin": "https://aws.amazon.com"
},
"multiValueHeaders": {
"Header1": [
Expand Down
47 changes: 39 additions & 8 deletions tests/functional/event_handler/test_api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def handler(event, context):
assert "multiValueHeaders" in result
headers = result["multiValueHeaders"]
assert headers["Content-Type"] == [content_types.TEXT_HTML]
assert headers["Access-Control-Allow-Origin"] == ["*"]
assert headers["Access-Control-Allow-Origin"] == ["https://aws.amazon.com"]
assert "Access-Control-Allow-Credentials" not in headers
assert headers["Access-Control-Allow-Headers"] == [",".join(sorted(CORSConfig._REQUIRED_HEADERS))]

Expand Down Expand Up @@ -533,6 +533,36 @@ def rest_func() -> Response:
assert result["body"] == "Not found"


def test_cors_multi_origin():
# GIVEN a custom cors configuration with multiple origins
cors_config = CORSConfig(allow_origin="https://origin1", extra_origins=["https://origin2", "https://origin3"])
app = ApiGatewayResolver(cors=cors_config)

@app.get("/cors")
def get_with_cors():
return {}

# WHEN calling the event handler with the correct Origin
event = {"path": "/cors", "httpMethod": "GET", "headers": {"Origin": "https://origin3"}}
result = app(event, None)

# THEN routes by default return the custom cors headers
assert "multiValueHeaders" in result
headers = result["multiValueHeaders"]
assert headers["Content-Type"] == [content_types.APPLICATION_JSON]
assert headers["Access-Control-Allow-Origin"] == ["https://origin3"]

# WHEN calling the event handler with the wrong origin
event = {"path": "/cors", "httpMethod": "GET", "headers": {"Origin": "https://wrong.origin"}}
result = app(event, None)

# THEN routes by default return the custom cors headers
assert "multiValueHeaders" in result
headers = result["multiValueHeaders"]
assert headers["Content-Type"] == [content_types.APPLICATION_JSON]
assert "Access-Control-Allow-Origin" not in headers


def test_custom_cors_config():
# GIVEN a custom cors configuration
allow_header = ["foo2"]
Expand All @@ -544,7 +574,7 @@ def test_custom_cors_config():
allow_credentials=True,
)
app = ApiGatewayResolver(cors=cors_config)
event = {"path": "/cors", "httpMethod": "GET"}
event = {"path": "/cors", "httpMethod": "GET", "headers": {"Origin": "https://foo1"}}

@app.get("/cors")
def get_with_cors():
Expand All @@ -561,7 +591,7 @@ def another_one():
assert "multiValueHeaders" in result
headers = result["multiValueHeaders"]
assert headers["Content-Type"] == [content_types.APPLICATION_JSON]
assert headers["Access-Control-Allow-Origin"] == [cors_config.allow_origin]
assert headers["Access-Control-Allow-Origin"] == [cors_config.allowed_origins[0]]
expected_allows_headers = [",".join(sorted(set(allow_header + cors_config._REQUIRED_HEADERS)))]
assert headers["Access-Control-Allow-Headers"] == expected_allows_headers
assert headers["Access-Control-Expose-Headers"] == [",".join(cors_config.expose_headers)]
Expand Down Expand Up @@ -604,9 +634,9 @@ def test_no_matches_with_cors():
result = app({"path": "/another-one", "httpMethod": "GET"}, None)

# THEN return a 404
# AND cors headers are returned
# AND cors headers are NOT returned (because no Origin header was passed in)
assert result["statusCode"] == 404
assert "Access-Control-Allow-Origin" in result["multiValueHeaders"]
assert "Access-Control-Allow-Origin" not in result["multiValueHeaders"]
assert "Not found" in result["body"]


Expand All @@ -628,7 +658,7 @@ def post_no_cors():
...

# WHEN calling the handler
result = app({"path": "/foo", "httpMethod": "OPTIONS"}, None)
result = app({"path": "/foo", "httpMethod": "OPTIONS", "headers": {"Origin": "http://example.org"}}, None)

# THEN return no content
# AND include Access-Control-Allow-Methods of the cors methods used
Expand Down Expand Up @@ -660,7 +690,7 @@ def custom_method():
...

# WHEN calling the handler
result = app({"path": "/some-call", "httpMethod": "OPTIONS"}, None)
result = app({"path": "/some-call", "httpMethod": "OPTIONS", "headers": {"Origin": "https://example.org"}}, None)

# THEN return the custom preflight response
assert result["statusCode"] == 200
Expand Down Expand Up @@ -747,7 +777,8 @@ def service_error():
# AND status code equals 502
assert result["statusCode"] == 502
assert result["multiValueHeaders"]["Content-Type"] == [content_types.APPLICATION_JSON]
assert "Access-Control-Allow-Origin" in result["multiValueHeaders"]
# Because no Origin was passed in, there is not Allow-Origin on the output
assert "Access-Control-Allow-Origin" not in result["multiValueHeaders"]
expected = {"statusCode": 502, "message": "Something went wrong!"}
assert result["body"] == json_dump(expected)

Expand Down