Skip to content

Commit 32456d6

Browse files
fix(event_handler): disable allow-credentials header when origin allow_origin is * (#4638)
* bug(event_handler): fix cors no origin bug * create functional test * fix cors * fix test structure * add test event * add allowed_origins method to CORSConfig --------- Co-authored-by: Leandro Damascena <[email protected]>
1 parent 46fe028 commit 32456d6

File tree

3 files changed

+163
-3
lines changed

3 files changed

+163
-3
lines changed

Diff for: aws_lambda_powertools/event_handler/api_gateway.py

+18-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import base64
24
import json
35
import logging
@@ -190,9 +192,12 @@ def __init__(
190192
allow_credentials: bool
191193
A boolean value that sets the value of `Access-Control-Allow-Credentials`
192194
"""
195+
193196
self._allowed_origins = [allow_origin]
197+
194198
if extra_origins:
195199
self._allowed_origins.extend(extra_origins)
200+
196201
self.allow_headers = set(self._REQUIRED_HEADERS + (allow_headers or []))
197202
self.expose_headers = expose_headers or []
198203
self.max_age = max_age
@@ -220,10 +225,18 @@ def to_dict(self, origin: Optional[str]) -> Dict[str, str]:
220225
headers["Access-Control-Expose-Headers"] = ",".join(self.expose_headers)
221226
if self.max_age is not None:
222227
headers["Access-Control-Max-Age"] = str(self.max_age)
223-
if self.allow_credentials is True:
228+
if origin != "*" and self.allow_credentials is True:
224229
headers["Access-Control-Allow-Credentials"] = "true"
225230
return headers
226231

232+
def allowed_origin(self, extracted_origin: str) -> str | None:
233+
if extracted_origin in self._allowed_origins:
234+
return extracted_origin
235+
if extracted_origin is not None and "*" in self._allowed_origins:
236+
return "*"
237+
238+
return None
239+
227240
@staticmethod
228241
def build_allow_methods(methods: Set[str]) -> str:
229242
"""Build sorted comma delimited methods for Access-Control-Allow-Methods header
@@ -808,7 +821,10 @@ def __init__(
808821
def _add_cors(self, event: ResponseEventT, cors: CORSConfig):
809822
"""Update headers to include the configured Access-Control headers"""
810823
extracted_origin_header = extract_origin_header(event.resolved_headers_field)
811-
self.response.headers.update(cors.to_dict(extracted_origin_header))
824+
825+
origin = cors.allowed_origin(extracted_origin_header)
826+
if origin is not None:
827+
self.response.headers.update(cors.to_dict(origin))
812828

813829
def _add_cache_control(self, cache_control: str):
814830
"""Set the specified cache control headers for 200 http responses. For non-200 `no-cache` is used."""

Diff for: tests/events/apiGatewayProxyEventNoOrigin.json

+80
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
{
2+
"version": "1.0",
3+
"resource": "/my/path",
4+
"path": "/my/path",
5+
"httpMethod": "GET",
6+
"headers": {
7+
"Header1": "value1",
8+
"Header2": "value2"
9+
},
10+
"multiValueHeaders": {
11+
"Header1": [
12+
"value1"
13+
],
14+
"Header2": [
15+
"value1",
16+
"value2"
17+
]
18+
},
19+
"queryStringParameters": {
20+
"parameter1": "value1",
21+
"parameter2": "value"
22+
},
23+
"multiValueQueryStringParameters": {
24+
"parameter1": [
25+
"value1",
26+
"value2"
27+
],
28+
"parameter2": [
29+
"value"
30+
]
31+
},
32+
"requestContext": {
33+
"accountId": "123456789012",
34+
"apiId": "id",
35+
"authorizer": {
36+
"claims": null,
37+
"scopes": null
38+
},
39+
"domainName": "id.execute-api.us-east-1.amazonaws.com",
40+
"domainPrefix": "id",
41+
"extendedRequestId": "request-id",
42+
"httpMethod": "GET",
43+
"identity": {
44+
"accessKey": null,
45+
"accountId": null,
46+
"caller": null,
47+
"cognitoAuthenticationProvider": null,
48+
"cognitoAuthenticationType": null,
49+
"cognitoIdentityId": null,
50+
"cognitoIdentityPoolId": null,
51+
"principalOrgId": null,
52+
"sourceIp": "192.168.0.1/32",
53+
"user": null,
54+
"userAgent": "user-agent",
55+
"userArn": null,
56+
"clientCert": {
57+
"clientCertPem": "CERT_CONTENT",
58+
"subjectDN": "www.example.com",
59+
"issuerDN": "Example issuer",
60+
"serialNumber": "a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1",
61+
"validity": {
62+
"notBefore": "May 28 12:30:02 2019 GMT",
63+
"notAfter": "Aug 5 09:36:04 2021 GMT"
64+
}
65+
}
66+
},
67+
"path": "/my/path",
68+
"protocol": "HTTP/1.1",
69+
"requestId": "id=",
70+
"requestTime": "04/Mar/2020:19:15:17 +0000",
71+
"requestTimeEpoch": 1583349317135,
72+
"resourceId": null,
73+
"resourcePath": "/my/path",
74+
"stage": "$default"
75+
},
76+
"pathParameters": null,
77+
"stageVariables": null,
78+
"body": "Hello from Lambda!",
79+
"isBase64Encoded": false
80+
}

Diff for: tests/functional/event_handler/required_dependencies/test_api_gateway.py

+65-1
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def read_media(file_name: str) -> bytes:
4848

4949

5050
LOAD_GW_EVENT = load_event("apiGatewayProxyEvent.json")
51+
LOAD_GW_EVENT_NO_ORIGIN = load_event("apiGatewayProxyEventNoOrigin.json")
5152
LOAD_GW_EVENT_TRAILING_SLASH = load_event("apiGatewayProxyEventPathTrailingSlash.json")
5253

5354

@@ -324,7 +325,7 @@ def handler(event, context):
324325
def test_cors():
325326
# GIVEN a function with cors=True
326327
# AND http method set to GET
327-
app = ApiGatewayResolver()
328+
app = ApiGatewayResolver(cors=CORSConfig("https://aws.amazon.com", allow_credentials=True))
328329

329330
@app.get("/my/path", cors=True)
330331
def with_cors() -> Response:
@@ -345,6 +346,69 @@ def handler(event, context):
345346
headers = result["multiValueHeaders"]
346347
assert headers["Content-Type"] == [content_types.TEXT_HTML]
347348
assert headers["Access-Control-Allow-Origin"] == ["https://aws.amazon.com"]
349+
assert "Access-Control-Allow-Credentials" in headers
350+
assert headers["Access-Control-Allow-Headers"] == [",".join(sorted(CORSConfig._REQUIRED_HEADERS))]
351+
352+
# THEN for routes without cors flag return no cors headers
353+
mock_event = {"path": "/my/request", "httpMethod": "GET"}
354+
result = handler(mock_event, None)
355+
assert "Access-Control-Allow-Origin" not in result["multiValueHeaders"]
356+
357+
358+
def test_cors_no_request_origin():
359+
# GIVEN a function with cors=True
360+
# AND http method set to GET
361+
app = ApiGatewayResolver()
362+
363+
@app.get("/my/path", cors=True)
364+
def with_cors() -> Response:
365+
return Response(200, content_types.TEXT_HTML, "test")
366+
367+
def handler(event, context):
368+
return app.resolve(event, context)
369+
370+
event = LOAD_GW_EVENT_NO_ORIGIN
371+
372+
# WHEN calling the event handler
373+
result = handler(event, None)
374+
375+
# THEN the headers should include cors headers
376+
assert "multiValueHeaders" in result
377+
headers = result["multiValueHeaders"]
378+
assert headers["Content-Type"] == [content_types.TEXT_HTML]
379+
assert "Access-Control-Allow-Credentials" not in headers
380+
assert "Access-Control-Allow-Origin" not in result["multiValueHeaders"]
381+
382+
383+
def test_cors_allow_all_request_origins():
384+
# GIVEN a function with cors=True
385+
# AND http method set to GET
386+
app = ApiGatewayResolver(
387+
cors=CORSConfig(
388+
allow_origin="*",
389+
allow_credentials=True,
390+
),
391+
)
392+
393+
@app.get("/my/path", cors=True)
394+
def with_cors() -> Response:
395+
return Response(200, content_types.TEXT_HTML, "test")
396+
397+
@app.get("/without-cors")
398+
def without_cors() -> Response:
399+
return Response(200, content_types.TEXT_HTML, "test")
400+
401+
def handler(event, context):
402+
return app.resolve(event, context)
403+
404+
# WHEN calling the event handler
405+
result = handler(LOAD_GW_EVENT, None)
406+
407+
# THEN the headers should include cors headers
408+
assert "multiValueHeaders" in result
409+
headers = result["multiValueHeaders"]
410+
assert headers["Content-Type"] == [content_types.TEXT_HTML]
411+
assert headers["Access-Control-Allow-Origin"] == ["*"]
348412
assert "Access-Control-Allow-Credentials" not in headers
349413
assert headers["Access-Control-Allow-Headers"] == [",".join(sorted(CORSConfig._REQUIRED_HEADERS))]
350414

0 commit comments

Comments
 (0)