Skip to content

feat(event-handler): allow for cors=None setting #421

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 17 additions & 12 deletions aws_lambda_powertools/event_handler/api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,11 @@ def with_cors():
)
app = ApiGatewayResolver(cors=cors_config)

@app.get("/my/path", cors=True)
@app.get("/my/path")
def with_cors():
return {"message": "Foo"}

@app.get("/another-one")
@app.get("/another-one", cors=False)
def without_cors():
return {"message": "Foo"}
"""
Expand Down Expand Up @@ -249,9 +249,10 @@ def __init__(self, proxy_type: Enum = ProxyEventType.APIGatewayProxyEvent, cors:
self._proxy_type = proxy_type
self._routes: List[Route] = []
self._cors = cors
self._cors_enabled: bool = cors is not None
self._cors_methods: Set[str] = {"OPTIONS"}

def get(self, rule: str, cors: bool = True, compress: bool = False, cache_control: str = None):
def get(self, rule: str, cors: bool = None, compress: bool = False, cache_control: str = None):
"""Get route decorator with GET `method`

Examples
Expand All @@ -276,7 +277,7 @@ def lambda_handler(event, context):
"""
return self.route(rule, "GET", cors, compress, cache_control)

def post(self, rule: str, cors: bool = True, compress: bool = False, cache_control: str = None):
def post(self, rule: str, cors: bool = None, compress: bool = False, cache_control: str = None):
"""Post route decorator with POST `method`

Examples
Expand All @@ -302,7 +303,7 @@ def lambda_handler(event, context):
"""
return self.route(rule, "POST", cors, compress, cache_control)

def put(self, rule: str, cors: bool = True, compress: bool = False, cache_control: str = None):
def put(self, rule: str, cors: bool = None, compress: bool = False, cache_control: str = None):
"""Put route decorator with PUT `method`

Examples
Expand All @@ -317,7 +318,7 @@ def put(self, rule: str, cors: bool = True, compress: bool = False, cache_contro
app = ApiGatewayResolver()

@app.put("/put-call")
def simple_post():
def simple_put():
put_data: dict = app.current_event.json_body
return {"message": put_data["value"]}

Expand All @@ -328,7 +329,7 @@ def lambda_handler(event, context):
"""
return self.route(rule, "PUT", cors, compress, cache_control)

def delete(self, rule: str, cors: bool = True, compress: bool = False, cache_control: str = None):
def delete(self, rule: str, cors: bool = None, compress: bool = False, cache_control: str = None):
"""Delete route decorator with DELETE `method`

Examples
Expand All @@ -353,7 +354,7 @@ def lambda_handler(event, context):
"""
return self.route(rule, "DELETE", cors, compress, cache_control)

def patch(self, rule: str, cors: bool = True, compress: bool = False, cache_control: str = None):
def patch(self, rule: str, cors: bool = None, compress: bool = False, cache_control: str = None):
"""Patch route decorator with PATCH `method`

Examples
Expand Down Expand Up @@ -381,13 +382,17 @@ def lambda_handler(event, context):
"""
return self.route(rule, "PATCH", cors, compress, cache_control)

def route(self, rule: str, method: str, cors: bool = True, compress: bool = False, cache_control: str = None):
def route(self, rule: str, method: str, cors: bool = None, compress: bool = False, cache_control: str = None):
"""Route decorator includes parameter `method`"""

def register_resolver(func: Callable):
logger.debug(f"Adding route using rule {rule} and method {method.upper()}")
self._routes.append(Route(method, self._compile_regex(rule), func, cors, compress, cache_control))
if cors:
if cors is None:
cors_enabled = self._cors_enabled
else:
cors_enabled = cors
self._routes.append(Route(method, self._compile_regex(rule), func, cors_enabled, compress, cache_control))
if cors_enabled:
logger.debug(f"Registering method {method.upper()} to Allow Methods in CORS")
self._cors_methods.add(method.upper())
return func
Expand Down Expand Up @@ -454,7 +459,7 @@ def _not_found(self, method: str) -> ResponseBuilder:
logger.debug("CORS is enabled, updating headers.")
headers.update(self._cors.to_dict())

if method == "OPTIONS": # Pre-flight
if method == "OPTIONS":
logger.debug("Pre-flight request detected. Returning CORS with null response")
headers["Access-Control-Allow-Methods"] = ",".join(sorted(self._cors_methods))
return ResponseBuilder(Response(status_code=204, content_type=None, headers=headers, body=None))
Expand Down
18 changes: 14 additions & 4 deletions tests/functional/event_handler/test_api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,10 @@ def test_cors():
def with_cors() -> Response:
return Response(200, TEXT_HTML, "test")

@app.get("/without-cors")
def without_cors() -> Response:
return Response(200, TEXT_HTML, "test")

def handler(event, context):
return app.resolve(event, context)

Expand All @@ -196,6 +200,11 @@ def handler(event, context):
assert "Access-Control-Allow-Credentials" not in headers
assert headers["Access-Control-Allow-Headers"] == ",".join(sorted(CORSConfig._REQUIRED_HEADERS))

# THEN for routes without cors flag return no cors headers
mock_event = {"path": "/my/request", "httpMethod": "GET"}
result = handler(mock_event, None)
assert "Access-Control-Allow-Origin" not in result["headers"]


def test_compress():
# GIVEN a function that has compress=True
Expand Down Expand Up @@ -359,7 +368,7 @@ def test_custom_cors_config():
app = ApiGatewayResolver(cors=cors_config)
event = {"path": "/cors", "httpMethod": "GET"}

@app.get("/cors", cors=True)
@app.get("/cors")
def get_with_cors():
return {}

Expand All @@ -370,7 +379,7 @@ def another_one():
# WHEN calling the event handler
result = app(event, None)

# THEN return the custom cors headers
# THEN routes by default return the custom cors headers
assert "headers" in result
headers = result["headers"]
assert headers["Content-Type"] == APPLICATION_JSON
Expand All @@ -385,6 +394,7 @@ def another_one():
# AND custom cors was set on the app
assert isinstance(app._cors, CORSConfig)
assert app._cors is cors_config

# AND routes without cors don't include "Access-Control" headers
event = {"path": "/another-one", "httpMethod": "GET"}
result = app(event, None)
Expand Down Expand Up @@ -426,11 +436,11 @@ def test_cors_preflight():
# AND cors is enabled
app = ApiGatewayResolver(cors=CORSConfig())

@app.get("/foo", cors=True)
@app.get("/foo")
def foo_cors():
...

@app.route(method="delete", rule="/foo", cors=True)
@app.route(method="delete", rule="/foo")
def foo_delete_cors():
...

Expand Down