Skip to content

feat(event_handler): allow multiple CORS origins #2279

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 24 additions & 7 deletions aws_lambda_powertools/event_handler/api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def with_cors():

cors_config = CORSConfig(
allow_origin="https://wwww.example.com/",
extra_origins=["https://dev.example.com/"],
expose_headers=["x-exposed-response-header"],
allow_headers=["x-custom-request-header"],
max_age=100,
Expand All @@ -106,6 +107,7 @@ def without_cors():
def __init__(
self,
allow_origin: str = "*",
extra_origins: Optional[List[str]] = None,
allow_headers: Optional[List[str]] = None,
expose_headers: Optional[List[str]] = None,
max_age: Optional[int] = None,
Expand All @@ -117,6 +119,8 @@ def __init__(
allow_origin: str
The value of the `Access-Control-Allow-Origin` to send in the response. Defaults to "*", but should
only be used during development.
extra_origins: Optional[List[str]]
The list of additional allowed origins.
allow_headers: Optional[List[str]]
The list of additional allowed headers. This list is added to list of
built-in allowed headers: `Authorization`, `Content-Type`, `X-Amz-Date`,
Expand All @@ -128,16 +132,29 @@ def __init__(
allow_credentials: bool
A boolean value that sets the value of `Access-Control-Allow-Credentials`
"""
self.allow_origin = allow_origin
self._allowed_origins = [allow_origin]
if extra_origins:
self._allowed_origins.extend(extra_origins)
self.allow_headers = set(self._REQUIRED_HEADERS + (allow_headers or []))
self.expose_headers = expose_headers or []
self.max_age = max_age
self.allow_credentials = allow_credentials

def to_dict(self) -> Dict[str, str]:
def to_dict(self, origin: Optional[str]) -> Dict[str, str]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

origin: str = ""?

Annotation says it's optional but we are not setting a default value.

You can drop the Optional (None), and simply set to an empty str

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean I would have to compare if origin == "" later on instead of not origin? Doesn't it sound wrong?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TIL not will eventually translate to __bool__, so it works for empty strings too

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've changed my mind again and I believe it's better to keep the Optional. Reason is, the caller will try to fetch an Origin from the headers. The Origin is not always present, so mypy would complain that I can't pass an Optional to the to_dict. So I think in this case the Optional makes sense.

"""Builds the configured Access-Control http headers"""

# If there's no Origin, don't add any CORS headers
if not origin:
return {}

# If the origin doesn't match any of the allowed origins, and we don't allow all origins ("*"),
# don't add any CORS headers
if origin not in self._allowed_origins and "*" not in self._allowed_origins:
return {}

# The origin matched an allowed origin, so return the CORS headers
headers: Dict[str, str] = {
"Access-Control-Allow-Origin": self.allow_origin,
"Access-Control-Allow-Origin": origin,
"Access-Control-Allow-Headers": ",".join(sorted(self.allow_headers)),
}

Expand Down Expand Up @@ -207,9 +224,9 @@ def __init__(self, response: Response, route: Optional[Route] = None):
self.response = response
self.route = route

def _add_cors(self, cors: CORSConfig):
def _add_cors(self, event: BaseProxyEvent, cors: CORSConfig):
"""Update headers to include the configured Access-Control headers"""
self.response.headers.update(cors.to_dict())
self.response.headers.update(cors.to_dict(event.get_header_value("Origin")))

def _add_cache_control(self, cache_control: str):
"""Set the specified cache control headers for 200 http responses. For non-200 `no-cache` is used."""
Expand All @@ -230,7 +247,7 @@ def _route(self, event: BaseProxyEvent, cors: Optional[CORSConfig]):
if self.route is None:
return
if self.route.cors:
self._add_cors(cors or CORSConfig())
self._add_cors(event, cors or CORSConfig())
if self.route.cache_control:
self._add_cache_control(self.route.cache_control)
if self.route.compress and "gzip" in (event.get_header_value("accept-encoding", "") or ""):
Expand Down Expand Up @@ -644,7 +661,7 @@ def _not_found(self, method: str) -> ResponseBuilder:
headers: Dict[str, Union[str, List[str]]] = {}
if self._cors:
logger.debug("CORS is enabled, updating headers.")
headers.update(self._cors.to_dict())
headers.update(self._cors.to_dict(self.current_event.get_header_value("Origin")))

if method == "OPTIONS":
logger.debug("Pre-flight request detected. Returning CORS with null response")
Expand Down
2 changes: 1 addition & 1 deletion aws_lambda_powertools/utilities/data_classes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def get_header_value(
class BaseProxyEvent(DictWrapper):
@property
def headers(self) -> Dict[str, str]:
return self["headers"]
return self.get("headers") or {}

@property
def query_string_parameters(self) -> Optional[Dict[str, str]]:
Expand Down
23 changes: 20 additions & 3 deletions docs/core/event_handler/api_gateway.md
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,8 @@ To address this API Gateway behavior, we use `strip_prefixes` parameter to accou

You can configure CORS at the `APIGatewayRestResolver` constructor via `cors` parameter using the `CORSConfig` class.

This will ensure that CORS headers are always returned as part of the response when your functions match the path invoked.
This will ensure that CORS headers are returned as part of the response when your functions match the path invoked and the `Origin`
matches one of the allowed values.

???+ tip
Optionally disable CORS on a per path basis with `cors=False` parameter.
Expand All @@ -297,6 +298,18 @@ This will ensure that CORS headers are always returned as part of the response w
--8<-- "examples/event_handler_rest/src/setting_cors_output.json"
```

=== "setting_cors_extra_origins.py"

```python hl_lines="5 11-12 34"
--8<-- "examples/event_handler_rest/src/setting_cors_extra_origins.py"
```

=== "setting_cors_extra_origins_output.json"

```json
--8<-- "examples/event_handler_rest/src/setting_cors_extra_origins_output.json"
```

#### Pre-flight

Pre-flight (OPTIONS) calls are typically handled at the API Gateway or Lambda Function URL level as per [our sample infrastructure](#required-resources), no Lambda integration is necessary. However, ALB expects you to handle pre-flight requests.
Expand All @@ -310,9 +323,13 @@ For convenience, these are the default values when using `CORSConfig` to enable
???+ warning
Always configure `allow_origin` when using in production.

???+ tip "Multiple origins?"
If you need to allow multiple origins, pass the additional origins using the `extra_origins` key.

| Key | Value | Note |
| -------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|----------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| **[allow_origin](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin){target="_blank"}**: `str` | `*` | Only use the default value for development. **Never use `*` for production** unless your use case requires it |
| **[extra_origins](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin){target="_blank"}**: `List[str]` | `[]` | Additional origins to be allowed, in addition to the one specified in `allow_origin` |
| **[allow_headers](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers){target="_blank"}**: `List[str]` | `[Authorization, Content-Type, X-Amz-Date, X-Api-Key, X-Amz-Security-Token]` | Additional headers will be appended to the default list for your convenience |
| **[expose_headers](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Headers){target="_blank"}**: `List[str]` | `[]` | Any additional header beyond the [safe listed by CORS specification](https://developer.mozilla.org/en-US/docs/Glossary/CORS-safelisted_response_header){target="_blank"}. |
| **[max_age](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age){target="_blank"}**: `int` | `` | Only for pre-flight requests if you choose to have your function to handle it instead of API Gateway |
Expand All @@ -331,7 +348,7 @@ You can use the `Response` class to have full control over the response. For exa

=== "fine_grained_responses.py"

```python hl_lines="9 28-32"
```python hl_lines="9 29-35"
--8<-- "examples/event_handler_rest/src/fine_grained_responses.py"
```

Expand Down
3 changes: 2 additions & 1 deletion examples/event_handler_rest/src/setting_cors.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

tracer = Tracer()
logger = Logger()
cors_config = CORSConfig(allow_origin="https://example.com", max_age=300)
# CORS will match when Origin is only https://www.example.com
cors_config = CORSConfig(allow_origin="https://www.example.com", max_age=300)
app = APIGatewayRestResolver(cors=cors_config)


Expand Down
45 changes: 45 additions & 0 deletions examples/event_handler_rest/src/setting_cors_extra_origins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import requests
from requests import Response

from aws_lambda_powertools import Logger, Tracer
from aws_lambda_powertools.event_handler import APIGatewayRestResolver, CORSConfig
from aws_lambda_powertools.logging import correlation_paths
from aws_lambda_powertools.utilities.typing import LambdaContext

tracer = Tracer()
logger = Logger()
# CORS will match when Origin is https://www.example.com OR https://dev.example.com
cors_config = CORSConfig(allow_origin="https://www.example.com", extra_origins=["https://dev.example.com"], max_age=300)
app = APIGatewayRestResolver(cors=cors_config)


@app.get("/todos")
@tracer.capture_method
def get_todos():
todos: Response = requests.get("https://jsonplaceholder.typicode.com/todos")
todos.raise_for_status()

# for brevity, we'll limit to the first 10 only
return {"todos": todos.json()[:10]}


@app.get("/todos/<todo_id>")
@tracer.capture_method
def get_todo_by_id(todo_id: str): # value come as str
todos: Response = requests.get(f"https://jsonplaceholder.typicode.com/todos/{todo_id}")
todos.raise_for_status()

return {"todos": todos.json()}


@app.get("/healthcheck", cors=False) # optionally removes CORS for a given route
@tracer.capture_method
def am_i_alive():
return {"am_i_alive": "yes"}


# You can continue to use other utilities just as before
@logger.inject_lambda_context(correlation_id_path=correlation_paths.API_GATEWAY_REST)
@tracer.capture_lambda_handler
def lambda_handler(event: dict, context: LambdaContext) -> dict:
return app.resolve(event, context)
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"statusCode": 200,
"multiValueHeaders": {
"Content-Type": ["application/json"],
"Access-Control-Allow-Origin": ["https://www.example.com","https://dev.example.com"],
"Access-Control-Allow-Headers": ["Authorization,Content-Type,X-Amz-Date,X-Amz-Security-Token,X-Api-Key"]
},
"body": "{\"todos\":[{\"userId\":1,\"id\":1,\"title\":\"delectus aut autem\",\"completed\":false},{\"userId\":1,\"id\":2,\"title\":\"quis ut nam facilis et officia qui\",\"completed\":false},{\"userId\":1,\"id\":3,\"title\":\"fugiat veniam minus\",\"completed\":false},{\"userId\":1,\"id\":4,\"title\":\"et porro tempora\",\"completed\":true},{\"userId\":1,\"id\":5,\"title\":\"laboriosam mollitia et enim quasi adipisci quia provident illum\",\"completed\":false},{\"userId\":1,\"id\":6,\"title\":\"qui ullam ratione quibusdam voluptatem quia omnis\",\"completed\":false},{\"userId\":1,\"id\":7,\"title\":\"illo expedita consequatur quia in\",\"completed\":false},{\"userId\":1,\"id\":8,\"title\":\"quo adipisci enim quam ut ab\",\"completed\":true},{\"userId\":1,\"id\":9,\"title\":\"molestiae perspiciatis ipsa\",\"completed\":false},{\"userId\":1,\"id\":10,\"title\":\"illo est ratione doloremque quia maiores aut\",\"completed\":true}]}",
"isBase64Encoded": false
}
12 changes: 9 additions & 3 deletions tests/e2e/event_handler/handlers/alb_handler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
from aws_lambda_powertools.event_handler import ALBResolver, Response, content_types

app = ALBResolver()
from aws_lambda_powertools.event_handler import (
ALBResolver,
CORSConfig,
Response,
content_types,
)

cors_config = CORSConfig(allow_origin="https://www.example.org", extra_origins=["https://dev.example.org"])
app = ALBResolver(cors=cors_config)

# The reason we use post is that whoever is writing tests can easily assert on the
# content being sent (body, headers, cookies, content-type) to reduce cognitive load.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from aws_lambda_powertools.event_handler import (
APIGatewayHttpResolver,
CORSConfig,
Response,
content_types,
)

app = APIGatewayHttpResolver()
cors_config = CORSConfig(allow_origin="https://www.example.org", extra_origins=["https://dev.example.org"])
app = APIGatewayHttpResolver(cors=cors_config)

# The reason we use post is that whoever is writing tests can easily assert on the
# content being sent (body, headers, cookies, content-type) to reduce cognitive load.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from aws_lambda_powertools.event_handler import (
APIGatewayRestResolver,
CORSConfig,
Response,
content_types,
)

app = APIGatewayRestResolver()
cors_config = CORSConfig(allow_origin="https://www.example.org", extra_origins=["https://dev.example.org"])
app = APIGatewayRestResolver(cors=cors_config)

# The reason we use post is that whoever is writing tests can easily assert on the
# content being sent (body, headers, cookies, content-type) to reduce cognitive load.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from aws_lambda_powertools.event_handler import (
CORSConfig,
LambdaFunctionUrlResolver,
Response,
content_types,
)

app = LambdaFunctionUrlResolver()
cors_config = CORSConfig(allow_origin="https://www.example.org", extra_origins=["https://dev.example.org"])
app = LambdaFunctionUrlResolver(cors=cors_config)

# The reason we use post is that whoever is writing tests can easily assert on the
# content being sent (body, headers, cookies, content-type) to reduce cognitive load.
Expand Down
Loading