Skip to content

Commit 318508f

Browse files
committed
feat(event-handler): Add a more complete implementation of cors
NOTE: Some of this is based on the cors behavior of Chalice, except where we actually return the preflight response
1 parent c5709bc commit 318508f

File tree

2 files changed

+182
-25
lines changed

2 files changed

+182
-25
lines changed

aws_lambda_powertools/event_handler/api_gateway.py

Lines changed: 97 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import re
44
import zlib
55
from enum import Enum
6-
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
6+
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
77

88
from aws_lambda_powertools.shared.json_encoder import Encoder
99
from aws_lambda_powertools.utilities.data_classes import ALBEvent, APIGatewayProxyEvent, APIGatewayProxyEventV2
@@ -18,30 +18,74 @@ class ProxyEventType(Enum):
1818
api_gateway = http_api_v1
1919

2020

21+
class CORSConfig(object):
22+
_REQUIRED_HEADERS = ["Content-Type", "X-Amz-Date", "Authorization", "X-Api-Key", "X-Amz-Security-Token"]
23+
24+
def __init__(
25+
self,
26+
allow_origin: str = "*",
27+
allow_headers: List[str] = None,
28+
expose_headers: List[str] = None,
29+
max_age: int = None,
30+
allow_credentials: bool = True,
31+
):
32+
self.allow_origin = allow_origin
33+
self.allow_headers = set((allow_headers or []) + self._REQUIRED_HEADERS)
34+
self.expose_headers = expose_headers or []
35+
self.max_age = max_age
36+
self.allow_credentials = allow_credentials
37+
38+
def to_dict(self) -> Dict[str, str]:
39+
headers = {
40+
"Access-Control-Allow-Origin": self.allow_origin,
41+
"Access-Control-Allow-Headers": ",".join(sorted(self.allow_headers)),
42+
}
43+
if self.expose_headers:
44+
headers["Access-Control-Expose-Headers"] = ",".join(self.expose_headers)
45+
if self.max_age is not None:
46+
headers["Access-Control-Max-Age"] = str(self.max_age)
47+
if self.allow_credentials is True:
48+
headers["Access-Control-Allow-Credentials"] = "true"
49+
return headers
50+
51+
2152
class Route:
2253
def __init__(
23-
self, method: str, rule: Any, func: Callable, cors: bool, compress: bool, cache_control: Optional[str]
54+
self,
55+
method: str,
56+
rule: Any,
57+
func: Callable,
58+
cors: Union[bool, CORSConfig],
59+
compress: bool,
60+
cache_control: Optional[str],
2461
):
2562
self.method = method.upper()
2663
self.rule = rule
2764
self.func = func
28-
self.cors = cors
65+
self.cors: Optional[CORSConfig]
66+
if cors is True:
67+
self.cors = CORSConfig()
68+
elif isinstance(cors, CORSConfig):
69+
self.cors = cors
70+
else:
71+
self.cors = None
2972
self.compress = compress
3073
self.cache_control = cache_control
3174

3275

3376
class Response:
34-
def __init__(self, status_code: int, content_type: str, body: Union[str, bytes], headers: Dict = None):
77+
def __init__(
78+
self, status_code: int, content_type: Optional[str], body: Union[str, bytes, None], headers: Dict = None
79+
):
3580
self.status_code = status_code
3681
self.body = body
3782
self.base64_encoded = False
3883
self.headers: Dict = headers or {}
39-
self.headers.setdefault("Content-Type", content_type)
84+
if content_type:
85+
self.headers.setdefault("Content-Type", content_type)
4086

41-
def add_cors(self, method: str):
42-
self.headers["Access-Control-Allow-Origin"] = "*"
43-
self.headers["Access-Control-Allow-Methods"] = method
44-
self.headers["Access-Control-Allow-Credentials"] = "true"
87+
def add_cors(self, cors: CORSConfig):
88+
self.headers.update(cors.to_dict())
4589

4690
def add_cache_control(self, cache_control: str):
4791
self.headers["Cache-Control"] = cache_control if self.status_code == 200 else "no-cache"
@@ -54,15 +98,14 @@ def compress(self):
5498
self.body = gzip.compress(self.body) + gzip.flush()
5599

56100
def to_dict(self) -> Dict[str, Any]:
101+
result = {"statusCode": self.status_code, "headers": self.headers}
57102
if isinstance(self.body, bytes):
58103
self.base64_encoded = True
59104
self.body = base64.b64encode(self.body).decode()
60-
return {
61-
"statusCode": self.status_code,
62-
"headers": self.headers,
63-
"body": self.body,
64-
"isBase64Encoded": self.base64_encoded,
65-
}
105+
if self.body:
106+
result["isBase64Encoded"] = self.base64_encoded
107+
result["body"] = self.body
108+
return result
66109

67110

68111
class ApiGatewayResolver:
@@ -72,25 +115,43 @@ class ApiGatewayResolver:
72115
def __init__(self, proxy_type: Enum = ProxyEventType.http_api_v1):
73116
self._proxy_type = proxy_type
74117
self._routes: List[Route] = []
118+
self._cors: Optional[CORSConfig] = None
119+
self._cors_methods: Set[str] = set()
75120

76-
def get(self, rule: str, cors: bool = False, compress: bool = False, cache_control: str = None):
121+
def get(self, rule: str, cors: Union[bool, CORSConfig] = False, compress: bool = False, cache_control: str = None):
77122
return self.route(rule, "GET", cors, compress, cache_control)
78123

79-
def post(self, rule: str, cors: bool = False, compress: bool = False, cache_control: str = None):
124+
def post(self, rule: str, cors: Union[bool, CORSConfig] = False, compress: bool = False, cache_control: str = None):
80125
return self.route(rule, "POST", cors, compress, cache_control)
81126

82-
def put(self, rule: str, cors: bool = False, compress: bool = False, cache_control: str = None):
127+
def put(self, rule: str, cors: Union[bool, CORSConfig] = False, compress: bool = False, cache_control: str = None):
83128
return self.route(rule, "PUT", cors, compress, cache_control)
84129

85-
def delete(self, rule: str, cors: bool = False, compress: bool = False, cache_control: str = None):
130+
def delete(
131+
self, rule: str, cors: Union[bool, CORSConfig] = False, compress: bool = False, cache_control: str = None
132+
):
86133
return self.route(rule, "DELETE", cors, compress, cache_control)
87134

88-
def patch(self, rule: str, cors: bool = False, compress: bool = False, cache_control: str = None):
135+
def patch(
136+
self, rule: str, cors: Union[bool, CORSConfig] = False, compress: bool = False, cache_control: str = None
137+
):
89138
return self.route(rule, "PATCH", cors, compress, cache_control)
90139

91-
def route(self, rule: str, method: str, cors: bool = False, compress: bool = False, cache_control: str = None):
140+
def route(
141+
self,
142+
rule: str,
143+
method: str,
144+
cors: Union[bool, CORSConfig] = False,
145+
compress: bool = False,
146+
cache_control: str = None,
147+
):
92148
def register_resolver(func: Callable):
93-
self._routes.append(Route(method, self._compile_regex(rule), func, cors, compress, cache_control))
149+
route = Route(method, self._compile_regex(rule), func, cors, compress, cache_control)
150+
self._routes.append(route)
151+
if route.cors:
152+
if self._cors is None:
153+
self._cors = route.cors
154+
self._cors_methods.add(route.method)
94155
return func
95156

96157
return register_resolver
@@ -102,7 +163,7 @@ def resolve(self, event, context) -> Dict[str, Any]:
102163
response = self.to_response(route.func(**args))
103164

104165
if route.cors:
105-
response.add_cors(route.method)
166+
response.add_cors(route.cors)
106167
if route.cache_control:
107168
response.add_cache_control(route.cache_control)
108169
if route.compress and "gzip" in (self.current_event.get_header_value("accept-encoding") or ""):
@@ -135,6 +196,12 @@ def _to_data_class(self, event: Dict) -> BaseProxyEvent:
135196
return APIGatewayProxyEventV2(event)
136197
return ALBEvent(event)
137198

199+
@staticmethod
200+
def _preflight(allowed_methods: Set):
201+
allowed_methods.add("OPTIONS")
202+
headers = {"Access-Control-Allow-Methods": ",".join(sorted(allowed_methods))}
203+
return Response(204, None, None, headers)
204+
138205
def _find_route(self, method: str, path: str) -> Tuple[Route, Dict]:
139206
for route in self._routes:
140207
if method != route.method:
@@ -143,6 +210,13 @@ def _find_route(self, method: str, path: str) -> Tuple[Route, Dict]:
143210
if match:
144211
return route, match.groupdict()
145212

213+
if method == "OPTIONS" and self._cors is not None:
214+
# Most be the preflight options call
215+
return (
216+
Route("OPTIONS", None, self._preflight, self._cors, False, None),
217+
{"allowed_methods": self._cors_methods},
218+
)
219+
146220
raise ValueError(f"No route found for '{method}.{path}'")
147221

148222
def __call__(self, event, context) -> Any:

tests/functional/event_handler/test_api_gateway.py

Lines changed: 85 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import pytest
99

10-
from aws_lambda_powertools.event_handler.api_gateway import ApiGatewayResolver, ProxyEventType, Response
10+
from aws_lambda_powertools.event_handler.api_gateway import ApiGatewayResolver, CORSConfig, ProxyEventType, Response
1111
from aws_lambda_powertools.shared.json_encoder import Encoder
1212
from aws_lambda_powertools.utilities.data_classes import ALBEvent, APIGatewayProxyEvent, APIGatewayProxyEventV2
1313
from tests.functional.utils import load_event
@@ -187,8 +187,10 @@ def handler(event, context):
187187
headers = result["headers"]
188188
assert headers["Content-Type"] == TEXT_HTML
189189
assert headers["Access-Control-Allow-Origin"] == "*"
190-
assert headers["Access-Control-Allow-Methods"] == "GET"
191190
assert headers["Access-Control-Allow-Credentials"] == "true"
191+
# AND "Access-Control-Allow-Methods" is only included in the preflight cors headers
192+
assert "Access-Control-Allow-Methods" not in headers
193+
assert headers["Access-Control-Allow-Headers"] == ",".join(sorted(CORSConfig._REQUIRED_HEADERS))
192194

193195

194196
def test_compress():
@@ -338,3 +340,84 @@ def rest_func() -> Response:
338340
assert result["headers"]["Content-Type"] == "header-content-type-wins"
339341
assert result["headers"]["custom"] == "value"
340342
assert result["body"] == "Not found"
343+
344+
345+
def test_preflight_cors():
346+
# GIVEN
347+
app = ApiGatewayResolver()
348+
preflight_event = {"path": "/cors", "httpMethod": "OPTIONS"}
349+
350+
@app.get("/cors", cors=True)
351+
def get_with_cors():
352+
...
353+
354+
@app.post("/cors", cors=True)
355+
def post_with_cors():
356+
...
357+
358+
@app.delete("/cors")
359+
def delete_no_cors():
360+
...
361+
362+
def handler(event, context):
363+
return app.resolve(event, context)
364+
365+
# WHEN calling the event handler
366+
# AND the httpMethod is OPTIONS
367+
result = handler(preflight_event, None)
368+
369+
# THEN return the preflight response
370+
# AND No Content it returned
371+
assert result["statusCode"] == 204
372+
assert "body" not in result
373+
assert "isBase64Encoded" not in result
374+
# AND no Content-Type is set
375+
headers = result["headers"]
376+
assert "headers" in result
377+
assert "Content-Type" not in headers
378+
# AND set the access control headers
379+
assert headers["Access-Control-Allow-Origin"] == "*"
380+
assert headers["Access-Control-Allow-Methods"] == "GET,OPTIONS,POST"
381+
assert headers["Access-Control-Allow-Credentials"] == "true"
382+
383+
384+
def test_custom_cors_config():
385+
# GIVEN a custom cors configuration
386+
app = ApiGatewayResolver()
387+
event = {"path": "/cors", "httpMethod": "GET"}
388+
allow_header = ["foo2"]
389+
cors_config = CORSConfig(
390+
allow_origin="https://foo1",
391+
expose_headers=["foo1"],
392+
allow_headers=allow_header,
393+
max_age=100,
394+
allow_credentials=False,
395+
)
396+
397+
@app.get("/cors", cors=cors_config)
398+
def get_with_cors():
399+
return {}
400+
401+
# NOTE: Currently only the first configuration is used for the OPTIONS preflight
402+
@app.get("/another-one", cors=True)
403+
def another_one():
404+
return {}
405+
406+
# WHEN calling the event handler
407+
result = app(event, None)
408+
409+
# THEN return the custom cors headers
410+
assert "headers" in result
411+
headers = result["headers"]
412+
assert headers["Content-Type"] == APPLICATION_JSON
413+
assert headers["Access-Control-Allow-Origin"] == cors_config.allow_origin
414+
expected_allows_headers = ",".join(sorted(set(allow_header + cors_config._REQUIRED_HEADERS)))
415+
assert headers["Access-Control-Allow-Headers"] == expected_allows_headers
416+
assert headers["Access-Control-Expose-Headers"] == ",".join(cors_config.expose_headers)
417+
assert headers["Access-Control-Max-Age"] == str(cors_config.max_age)
418+
assert "Access-Control-Allow-Credentials" not in headers
419+
420+
# AND custom cors was set on the app
421+
assert isinstance(app._cors, CORSConfig)
422+
assert app._cors is cors_config
423+
assert app._cors_methods == {"GET"}

0 commit comments

Comments
 (0)