Skip to content

Commit 6e875e9

Browse files
committed
fix cors
1 parent ac9b20e commit 6e875e9

File tree

2 files changed

+54
-12
lines changed

2 files changed

+54
-12
lines changed

aws_lambda_powertools/event_handler/api_gateway.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -188,9 +188,12 @@ def __init__(
188188
allow_credentials: bool
189189
A boolean value that sets the value of `Access-Control-Allow-Credentials`
190190
"""
191-
self._allowed_origins = [allow_origin]
191+
192+
self.allowed_origins = [allow_origin]
193+
192194
if extra_origins:
193-
self._allowed_origins.extend(extra_origins)
195+
self.allowed_origins.extend(extra_origins)
196+
194197
self.allow_headers = set(self._REQUIRED_HEADERS + (allow_headers or []))
195198
self.expose_headers = expose_headers or []
196199
self.max_age = max_age
@@ -205,7 +208,7 @@ def to_dict(self, origin: Optional[str]) -> Dict[str, str]:
205208

206209
# If the origin doesn't match any of the allowed origins, and we don't allow all origins ("*"),
207210
# don't add any CORS headers
208-
if origin not in self._allowed_origins and "*" not in self._allowed_origins:
211+
if origin not in self.allowed_origins and "*" not in self.allowed_origins:
209212
return {}
210213

211214
# The origin matched an allowed origin, so return the CORS headers
@@ -218,7 +221,7 @@ def to_dict(self, origin: Optional[str]) -> Dict[str, str]:
218221
headers["Access-Control-Expose-Headers"] = ",".join(self.expose_headers)
219222
if self.max_age is not None:
220223
headers["Access-Control-Max-Age"] = str(self.max_age)
221-
if self.allow_credentials is True:
224+
if origin != "*" and self.allow_credentials is True:
222225
headers["Access-Control-Allow-Credentials"] = "true"
223226
return headers
224227

@@ -806,10 +809,11 @@ def __init__(
806809
def _add_cors(self, event: ResponseEventT, cors: CORSConfig):
807810
"""Update headers to include the configured Access-Control headers"""
808811
extracted_origin_header = extract_origin_header(event.resolved_headers_field)
809-
if extracted_origin_header is None:
810-
self.response.headers.update(cors.to_dict("*"))
811-
else:
812+
813+
if extracted_origin_header in cors.allowed_origins:
812814
self.response.headers.update(cors.to_dict(extracted_origin_header))
815+
if extracted_origin_header is not None and "*" in cors.allowed_origins:
816+
self.response.headers.update(cors.to_dict("*"))
813817

814818
def _add_cache_control(self, cache_control: str):
815819
"""Set the specified cache control headers for 200 http responses. For non-200 `no-cache` is used."""

tests/functional/event_handler/required_dependencies/test_api_gateway.py

+43-5
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ def handler(event, context):
324324
def test_cors():
325325
# GIVEN a function with cors=True
326326
# AND http method set to GET
327-
app = ApiGatewayResolver()
327+
app = ApiGatewayResolver(cors=CORSConfig("https://aws.amazon.com", allow_credentials=True))
328328

329329
@app.get("/my/path", cors=True)
330330
def with_cors() -> Response:
@@ -345,7 +345,7 @@ def handler(event, context):
345345
headers = result["multiValueHeaders"]
346346
assert headers["Content-Type"] == [content_types.TEXT_HTML]
347347
assert headers["Access-Control-Allow-Origin"] == ["https://aws.amazon.com"]
348-
assert "Access-Control-Allow-Credentials" not in headers
348+
assert "Access-Control-Allow-Credentials" in headers
349349
assert headers["Access-Control-Allow-Headers"] == [",".join(sorted(CORSConfig._REQUIRED_HEADERS))]
350350

351351
# THEN for routes without cors flag return no cors headers
@@ -354,7 +354,7 @@ def handler(event, context):
354354
assert "Access-Control-Allow-Origin" not in result["multiValueHeaders"]
355355

356356

357-
def test_cors_no_origin():
357+
def test_cors_no_request_origin():
358358
# GIVEN a function with cors=True
359359
# AND http method set to GET
360360
app = ApiGatewayResolver()
@@ -366,8 +366,41 @@ def with_cors() -> Response:
366366
def handler(event, context):
367367
return app.resolve(event, context)
368368

369-
# remove origin header from request
370-
del LOAD_GW_EVENT["multiValueHeaders"]["Origin"]
369+
event = LOAD_GW_EVENT.copy()
370+
del event["headers"]["Origin"]
371+
del event["multiValueHeaders"]["Origin"]
372+
373+
# WHEN calling the event handler
374+
result = handler(LOAD_GW_EVENT, None)
375+
376+
# THEN the headers should include cors headers
377+
assert "multiValueHeaders" in result
378+
headers = result["multiValueHeaders"]
379+
assert headers["Content-Type"] == [content_types.TEXT_HTML]
380+
assert "Access-Control-Allow-Credentials" not in headers
381+
assert "Access-Control-Allow-Origin" not in result["multiValueHeaders"]
382+
383+
384+
def test_cors_allow_all_request_origins():
385+
# GIVEN a function with cors=True
386+
# AND http method set to GET
387+
app = ApiGatewayResolver(
388+
cors=CORSConfig(
389+
allow_origin="*",
390+
allow_credentials=True,
391+
),
392+
)
393+
394+
@app.get("/my/path", cors=True)
395+
def with_cors() -> Response:
396+
return Response(200, content_types.TEXT_HTML, "test")
397+
398+
@app.get("/without-cors")
399+
def without_cors() -> Response:
400+
return Response(200, content_types.TEXT_HTML, "test")
401+
402+
def handler(event, context):
403+
return app.resolve(event, context)
371404

372405
# WHEN calling the event handler
373406
result = handler(LOAD_GW_EVENT, None)
@@ -380,6 +413,11 @@ def handler(event, context):
380413
assert "Access-Control-Allow-Credentials" not in headers
381414
assert headers["Access-Control-Allow-Headers"] == [",".join(sorted(CORSConfig._REQUIRED_HEADERS))]
382415

416+
# THEN for routes without cors flag return no cors headers
417+
mock_event = {"path": "/my/request", "httpMethod": "GET"}
418+
result = handler(mock_event, None)
419+
assert "Access-Control-Allow-Origin" not in result["multiValueHeaders"]
420+
383421

384422
def test_cors_preflight_body_is_empty_not_null():
385423
# GIVEN CORS is configured

0 commit comments

Comments
 (0)