Skip to content

Commit cbbcc4d

Browse files
author
Michael Brewer
authored
feat(event-handle): allow for cors=None setting (#421)
1 parent d044463 commit cbbcc4d

File tree

2 files changed

+31
-16
lines changed

2 files changed

+31
-16
lines changed

aws_lambda_powertools/event_handler/api_gateway.py

+17-12
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,11 @@ def with_cors():
5555
)
5656
app = ApiGatewayResolver(cors=cors_config)
5757
58-
@app.get("/my/path", cors=True)
58+
@app.get("/my/path")
5959
def with_cors():
6060
return {"message": "Foo"}
6161
62-
@app.get("/another-one")
62+
@app.get("/another-one", cors=False)
6363
def without_cors():
6464
return {"message": "Foo"}
6565
"""
@@ -249,9 +249,10 @@ def __init__(self, proxy_type: Enum = ProxyEventType.APIGatewayProxyEvent, cors:
249249
self._proxy_type = proxy_type
250250
self._routes: List[Route] = []
251251
self._cors = cors
252+
self._cors_enabled: bool = cors is not None
252253
self._cors_methods: Set[str] = {"OPTIONS"}
253254

254-
def get(self, rule: str, cors: bool = True, compress: bool = False, cache_control: str = None):
255+
def get(self, rule: str, cors: bool = None, compress: bool = False, cache_control: str = None):
255256
"""Get route decorator with GET `method`
256257
257258
Examples
@@ -276,7 +277,7 @@ def lambda_handler(event, context):
276277
"""
277278
return self.route(rule, "GET", cors, compress, cache_control)
278279

279-
def post(self, rule: str, cors: bool = True, compress: bool = False, cache_control: str = None):
280+
def post(self, rule: str, cors: bool = None, compress: bool = False, cache_control: str = None):
280281
"""Post route decorator with POST `method`
281282
282283
Examples
@@ -302,7 +303,7 @@ def lambda_handler(event, context):
302303
"""
303304
return self.route(rule, "POST", cors, compress, cache_control)
304305

305-
def put(self, rule: str, cors: bool = True, compress: bool = False, cache_control: str = None):
306+
def put(self, rule: str, cors: bool = None, compress: bool = False, cache_control: str = None):
306307
"""Put route decorator with PUT `method`
307308
308309
Examples
@@ -317,7 +318,7 @@ def put(self, rule: str, cors: bool = True, compress: bool = False, cache_contro
317318
app = ApiGatewayResolver()
318319
319320
@app.put("/put-call")
320-
def simple_post():
321+
def simple_put():
321322
put_data: dict = app.current_event.json_body
322323
return {"message": put_data["value"]}
323324
@@ -328,7 +329,7 @@ def lambda_handler(event, context):
328329
"""
329330
return self.route(rule, "PUT", cors, compress, cache_control)
330331

331-
def delete(self, rule: str, cors: bool = True, compress: bool = False, cache_control: str = None):
332+
def delete(self, rule: str, cors: bool = None, compress: bool = False, cache_control: str = None):
332333
"""Delete route decorator with DELETE `method`
333334
334335
Examples
@@ -353,7 +354,7 @@ def lambda_handler(event, context):
353354
"""
354355
return self.route(rule, "DELETE", cors, compress, cache_control)
355356

356-
def patch(self, rule: str, cors: bool = True, compress: bool = False, cache_control: str = None):
357+
def patch(self, rule: str, cors: bool = None, compress: bool = False, cache_control: str = None):
357358
"""Patch route decorator with PATCH `method`
358359
359360
Examples
@@ -381,13 +382,17 @@ def lambda_handler(event, context):
381382
"""
382383
return self.route(rule, "PATCH", cors, compress, cache_control)
383384

384-
def route(self, rule: str, method: str, cors: bool = True, compress: bool = False, cache_control: str = None):
385+
def route(self, rule: str, method: str, cors: bool = None, compress: bool = False, cache_control: str = None):
385386
"""Route decorator includes parameter `method`"""
386387

387388
def register_resolver(func: Callable):
388389
logger.debug(f"Adding route using rule {rule} and method {method.upper()}")
389-
self._routes.append(Route(method, self._compile_regex(rule), func, cors, compress, cache_control))
390-
if cors:
390+
if cors is None:
391+
cors_enabled = self._cors_enabled
392+
else:
393+
cors_enabled = cors
394+
self._routes.append(Route(method, self._compile_regex(rule), func, cors_enabled, compress, cache_control))
395+
if cors_enabled:
391396
logger.debug(f"Registering method {method.upper()} to Allow Methods in CORS")
392397
self._cors_methods.add(method.upper())
393398
return func
@@ -454,7 +459,7 @@ def _not_found(self, method: str) -> ResponseBuilder:
454459
logger.debug("CORS is enabled, updating headers.")
455460
headers.update(self._cors.to_dict())
456461

457-
if method == "OPTIONS": # Pre-flight
462+
if method == "OPTIONS":
458463
logger.debug("Pre-flight request detected. Returning CORS with null response")
459464
headers["Access-Control-Allow-Methods"] = ",".join(sorted(self._cors_methods))
460465
return ResponseBuilder(Response(status_code=204, content_type=None, headers=headers, body=None))

tests/functional/event_handler/test_api_gateway.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,10 @@ def test_cors():
182182
def with_cors() -> Response:
183183
return Response(200, TEXT_HTML, "test")
184184

185+
@app.get("/without-cors")
186+
def without_cors() -> Response:
187+
return Response(200, TEXT_HTML, "test")
188+
185189
def handler(event, context):
186190
return app.resolve(event, context)
187191

@@ -196,6 +200,11 @@ def handler(event, context):
196200
assert "Access-Control-Allow-Credentials" not in headers
197201
assert headers["Access-Control-Allow-Headers"] == ",".join(sorted(CORSConfig._REQUIRED_HEADERS))
198202

203+
# THEN for routes without cors flag return no cors headers
204+
mock_event = {"path": "/my/request", "httpMethod": "GET"}
205+
result = handler(mock_event, None)
206+
assert "Access-Control-Allow-Origin" not in result["headers"]
207+
199208

200209
def test_compress():
201210
# GIVEN a function that has compress=True
@@ -359,7 +368,7 @@ def test_custom_cors_config():
359368
app = ApiGatewayResolver(cors=cors_config)
360369
event = {"path": "/cors", "httpMethod": "GET"}
361370

362-
@app.get("/cors", cors=True)
371+
@app.get("/cors")
363372
def get_with_cors():
364373
return {}
365374

@@ -370,7 +379,7 @@ def another_one():
370379
# WHEN calling the event handler
371380
result = app(event, None)
372381

373-
# THEN return the custom cors headers
382+
# THEN routes by default return the custom cors headers
374383
assert "headers" in result
375384
headers = result["headers"]
376385
assert headers["Content-Type"] == APPLICATION_JSON
@@ -385,6 +394,7 @@ def another_one():
385394
# AND custom cors was set on the app
386395
assert isinstance(app._cors, CORSConfig)
387396
assert app._cors is cors_config
397+
388398
# AND routes without cors don't include "Access-Control" headers
389399
event = {"path": "/another-one", "httpMethod": "GET"}
390400
result = app(event, None)
@@ -426,11 +436,11 @@ def test_cors_preflight():
426436
# AND cors is enabled
427437
app = ApiGatewayResolver(cors=CORSConfig())
428438

429-
@app.get("/foo", cors=True)
439+
@app.get("/foo")
430440
def foo_cors():
431441
...
432442

433-
@app.route(method="delete", rule="/foo", cors=True)
443+
@app.route(method="delete", rule="/foo")
434444
def foo_delete_cors():
435445
...
436446

0 commit comments

Comments
 (0)