Skip to content

refactor(test): make CORS test consistent with expected behavior #4882

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions aws_lambda_powertools/event_handler/api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,20 +122,22 @@ class CORSConfig:
Examples
--------

Simple cors example using the default permissive cors, not this should only be used during early prototyping
Simple CORS example using the default permissive CORS, note that this should only be used during early prototyping.

```python
from aws_lambda_powertools.event_handler import APIGatewayRestResolver
from aws_lambda_powertools.event_handler.api_gateway import (
APIGatewayRestResolver, CORSConfig
)

app = APIGatewayRestResolver()
app = APIGatewayRestResolver(cors=CORSConfig())

@app.get("/my/path", cors=True)
@app.get("/my/path")
def with_cors():
return {"message": "Foo"}
```

Using a custom CORSConfig where `with_cors` used the custom provided CORSConfig and `without_cors`
do not include any cors headers.
do not include any CORS headers.

```python
from aws_lambda_powertools.event_handler.api_gateway import (
Expand Down
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

low hanging fruit to test the UserWarnings and reduce noise in the pytest output

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great attention to details @wurstnase! Thank you so much!! ❤️

Original file line number Diff line number Diff line change
Expand Up @@ -323,15 +323,15 @@ def handler(event, context):


def test_cors():
# GIVEN a function with cors=True
# GIVEN a function
# AND http method set to GET
app = ApiGatewayResolver(cors=CORSConfig("https://aws.amazon.com", allow_credentials=True))

@app.get("/my/path", cors=True)
@app.get("/my/path")
def with_cors() -> Response:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if we can remove cors=True here. It's the default behavior when cors is set in ApiGatewayResolver. What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I will check and test some other places to.

return Response(200, content_types.TEXT_HTML, "test")

@app.get("/without-cors")
@app.get("/without-cors", cors=False)
def without_cors() -> Response:
return Response(200, content_types.TEXT_HTML, "test")

Expand All @@ -350,17 +350,17 @@ def handler(event, context):
assert headers["Access-Control-Allow-Headers"] == [",".join(sorted(CORSConfig._REQUIRED_HEADERS))]

# THEN for routes without cors flag return no cors headers
mock_event = {"path": "/my/request", "httpMethod": "GET"}
mock_event = {"path": "/without-cors", "httpMethod": "GET"}
result = handler(mock_event, None)
assert "Access-Control-Allow-Origin" not in result["multiValueHeaders"]


def test_cors_no_request_origin():
# GIVEN a function with cors=True
# GIVEN a function
# AND http method set to GET
app = ApiGatewayResolver()
app = ApiGatewayResolver(cors=CORSConfig())

@app.get("/my/path", cors=True)
@app.get("/my/path")
def with_cors() -> Response:
return Response(200, content_types.TEXT_HTML, "test")

Expand All @@ -381,7 +381,7 @@ def handler(event, context):


def test_cors_allow_all_request_origins():
# GIVEN a function with cors=True
# GIVEN a function
# AND http method set to GET
app = ApiGatewayResolver(
cors=CORSConfig(
Expand All @@ -390,11 +390,11 @@ def test_cors_allow_all_request_origins():
),
)

@app.get("/my/path", cors=True)
@app.get("/my/path")
def with_cors() -> Response:
return Response(200, content_types.TEXT_HTML, "test")

@app.get("/without-cors")
@app.get("/without-cors", cors=False)
def without_cors() -> Response:
return Response(200, content_types.TEXT_HTML, "test")

Expand All @@ -413,7 +413,7 @@ def handler(event, context):
assert headers["Access-Control-Allow-Headers"] == [",".join(sorted(CORSConfig._REQUIRED_HEADERS))]

# THEN for routes without cors flag return no cors headers
mock_event = {"path": "/my/request", "httpMethod": "GET"}
mock_event = {"path": "/without-cors", "httpMethod": "GET"}
result = handler(mock_event, None)
assert "Access-Control-Allow-Origin" not in result["multiValueHeaders"]

Expand Down Expand Up @@ -811,7 +811,7 @@ def test_custom_preflight_response():
# AND the request matches this custom preflight route
app = ApiGatewayResolver(cors=CORSConfig())

@app.route(method="OPTIONS", rule="/some-call", cors=True)
@app.route(method="OPTIONS", rule="/some-call")
def custom_preflight():
return Response(
status_code=200,
Expand All @@ -820,7 +820,7 @@ def custom_preflight():
headers={"Access-Control-Allow-Methods": ["CUSTOM"]},
)

@app.route(method="CUSTOM", rule="/some-call", cors=True)
@app.route(method="CUSTOM", rule="/some-call")
def custom_method(): ...

# AND the request includes an origin
Expand Down Expand Up @@ -903,7 +903,7 @@ def internal_server_error():
assert result["body"] == json_dump(expected)

# GIVEN an ServiceError with a custom status code
@app.get(rule="/service-error", cors=True)
@app.get(rule="/service-error")
def service_error():
raise ServiceError(502, "Something went wrong!")

Expand Down Expand Up @@ -964,7 +964,8 @@ def raises_error():
def test_powertools_dev_sets_debug_mode(monkeypatch):
# GIVEN a debug mode environment variable is set
monkeypatch.setenv(constants.POWERTOOLS_DEV_ENV, "true")
app = ApiGatewayResolver()
with pytest.warns(UserWarning, match="POWERTOOLS_DEV environment variable is enabled."):
app = ApiGatewayResolver()

# WHEN calling app._debug
# THEN the debug mode is enabled
Expand Down Expand Up @@ -1428,7 +1429,8 @@ def get_func():
def get_func_another_duplicate():
raise RuntimeError()

app.include_router(router)
with pytest.warns(UserWarning, match="A route like this was already registered"):
app.include_router(router)

# WHEN calling the handler
result = app(LOAD_GW_EVENT, None)
Expand Down Expand Up @@ -1707,7 +1709,12 @@ def my_path():
@event_source(data_class=APIGatewayProxyEventV2)
def handler(event: APIGatewayProxyEventV2, context):
assert isinstance(event, APIGatewayProxyEventV2)
return app.resolve(event, context)

with pytest.warns(
UserWarning,
match="You don't need to serialize event to Event Source Data Class when using Event Handler",
):
return app.resolve(event, context)

# THEN
result = handler(load_event("apiGatewayProxyV2Event.json"), None)
Expand Down