3
3
import re
4
4
import zlib
5
5
from enum import Enum
6
- from typing import Any , Callable , Dict , List , Optional , Tuple , Union
6
+ from typing import Any , Callable , Dict , List , Optional , Set , Tuple , Union
7
7
8
8
from aws_lambda_powertools .shared .json_encoder import Encoder
9
9
from aws_lambda_powertools .utilities .data_classes import ALBEvent , APIGatewayProxyEvent , APIGatewayProxyEventV2
@@ -18,30 +18,74 @@ class ProxyEventType(Enum):
18
18
api_gateway = http_api_v1
19
19
20
20
21
+ class CORSConfig (object ):
22
+ _REQUIRED_HEADERS = ["Content-Type" , "X-Amz-Date" , "Authorization" , "X-Api-Key" , "X-Amz-Security-Token" ]
23
+
24
+ def __init__ (
25
+ self ,
26
+ allow_origin : str = "*" ,
27
+ allow_headers : List [str ] = None ,
28
+ expose_headers : List [str ] = None ,
29
+ max_age : int = None ,
30
+ allow_credentials : bool = True ,
31
+ ):
32
+ self .allow_origin = allow_origin
33
+ self .allow_headers = set ((allow_headers or []) + self ._REQUIRED_HEADERS )
34
+ self .expose_headers = expose_headers or []
35
+ self .max_age = max_age
36
+ self .allow_credentials = allow_credentials
37
+
38
+ def to_dict (self ) -> Dict [str , str ]:
39
+ headers = {
40
+ "Access-Control-Allow-Origin" : self .allow_origin ,
41
+ "Access-Control-Allow-Headers" : "," .join (sorted (self .allow_headers )),
42
+ }
43
+ if self .expose_headers :
44
+ headers ["Access-Control-Expose-Headers" ] = "," .join (self .expose_headers )
45
+ if self .max_age is not None :
46
+ headers ["Access-Control-Max-Age" ] = str (self .max_age )
47
+ if self .allow_credentials is True :
48
+ headers ["Access-Control-Allow-Credentials" ] = "true"
49
+ return headers
50
+
51
+
21
52
class Route :
22
53
def __init__ (
23
- self , method : str , rule : Any , func : Callable , cors : bool , compress : bool , cache_control : Optional [str ]
54
+ self ,
55
+ method : str ,
56
+ rule : Any ,
57
+ func : Callable ,
58
+ cors : Union [bool , CORSConfig ],
59
+ compress : bool ,
60
+ cache_control : Optional [str ],
24
61
):
25
62
self .method = method .upper ()
26
63
self .rule = rule
27
64
self .func = func
28
- self .cors = cors
65
+ self .cors : Optional [CORSConfig ]
66
+ if cors is True :
67
+ self .cors = CORSConfig ()
68
+ elif isinstance (cors , CORSConfig ):
69
+ self .cors = cors
70
+ else :
71
+ self .cors = None
29
72
self .compress = compress
30
73
self .cache_control = cache_control
31
74
32
75
33
76
class Response :
34
- def __init__ (self , status_code : int , content_type : str , body : Union [str , bytes ], headers : Dict = None ):
77
+ def __init__ (
78
+ self , status_code : int , content_type : Optional [str ], body : Union [str , bytes , None ], headers : Dict = None
79
+ ):
35
80
self .status_code = status_code
36
81
self .body = body
37
82
self .base64_encoded = False
38
83
self .headers : Dict = headers or {}
39
- self .headers .setdefault ("Content-Type" , content_type )
84
+ if content_type :
85
+ self .headers .setdefault ("Content-Type" , content_type )
40
86
41
- def add_cors (self , method : str ):
42
- self .headers ["Access-Control-Allow-Origin" ] = "*"
43
- self .headers ["Access-Control-Allow-Methods" ] = method
44
- self .headers ["Access-Control-Allow-Credentials" ] = "true"
87
+ def add_cors (self , cors : CORSConfig ):
88
+ self .headers .update (cors .to_dict ())
45
89
46
90
def add_cache_control (self , cache_control : str ):
47
91
self .headers ["Cache-Control" ] = cache_control if self .status_code == 200 else "no-cache"
@@ -54,15 +98,14 @@ def compress(self):
54
98
self .body = gzip .compress (self .body ) + gzip .flush ()
55
99
56
100
def to_dict (self ) -> Dict [str , Any ]:
101
+ result = {"statusCode" : self .status_code , "headers" : self .headers }
57
102
if isinstance (self .body , bytes ):
58
103
self .base64_encoded = True
59
104
self .body = base64 .b64encode (self .body ).decode ()
60
- return {
61
- "statusCode" : self .status_code ,
62
- "headers" : self .headers ,
63
- "body" : self .body ,
64
- "isBase64Encoded" : self .base64_encoded ,
65
- }
105
+ if self .body :
106
+ result ["isBase64Encoded" ] = self .base64_encoded
107
+ result ["body" ] = self .body
108
+ return result
66
109
67
110
68
111
class ApiGatewayResolver :
@@ -72,25 +115,43 @@ class ApiGatewayResolver:
72
115
def __init__ (self , proxy_type : Enum = ProxyEventType .http_api_v1 ):
73
116
self ._proxy_type = proxy_type
74
117
self ._routes : List [Route ] = []
118
+ self ._cors : Optional [CORSConfig ] = None
119
+ self ._cors_methods : Set [str ] = set ()
75
120
76
- def get (self , rule : str , cors : bool = False , compress : bool = False , cache_control : str = None ):
121
+ def get (self , rule : str , cors : Union [ bool , CORSConfig ] = False , compress : bool = False , cache_control : str = None ):
77
122
return self .route (rule , "GET" , cors , compress , cache_control )
78
123
79
- def post (self , rule : str , cors : bool = False , compress : bool = False , cache_control : str = None ):
124
+ def post (self , rule : str , cors : Union [ bool , CORSConfig ] = False , compress : bool = False , cache_control : str = None ):
80
125
return self .route (rule , "POST" , cors , compress , cache_control )
81
126
82
- def put (self , rule : str , cors : bool = False , compress : bool = False , cache_control : str = None ):
127
+ def put (self , rule : str , cors : Union [ bool , CORSConfig ] = False , compress : bool = False , cache_control : str = None ):
83
128
return self .route (rule , "PUT" , cors , compress , cache_control )
84
129
85
- def delete (self , rule : str , cors : bool = False , compress : bool = False , cache_control : str = None ):
130
+ def delete (
131
+ self , rule : str , cors : Union [bool , CORSConfig ] = False , compress : bool = False , cache_control : str = None
132
+ ):
86
133
return self .route (rule , "DELETE" , cors , compress , cache_control )
87
134
88
- def patch (self , rule : str , cors : bool = False , compress : bool = False , cache_control : str = None ):
135
+ def patch (
136
+ self , rule : str , cors : Union [bool , CORSConfig ] = False , compress : bool = False , cache_control : str = None
137
+ ):
89
138
return self .route (rule , "PATCH" , cors , compress , cache_control )
90
139
91
- def route (self , rule : str , method : str , cors : bool = False , compress : bool = False , cache_control : str = None ):
140
+ def route (
141
+ self ,
142
+ rule : str ,
143
+ method : str ,
144
+ cors : Union [bool , CORSConfig ] = False ,
145
+ compress : bool = False ,
146
+ cache_control : str = None ,
147
+ ):
92
148
def register_resolver (func : Callable ):
93
- self ._routes .append (Route (method , self ._compile_regex (rule ), func , cors , compress , cache_control ))
149
+ route = Route (method , self ._compile_regex (rule ), func , cors , compress , cache_control )
150
+ self ._routes .append (route )
151
+ if route .cors :
152
+ if self ._cors is None :
153
+ self ._cors = route .cors
154
+ self ._cors_methods .add (route .method )
94
155
return func
95
156
96
157
return register_resolver
@@ -102,7 +163,7 @@ def resolve(self, event, context) -> Dict[str, Any]:
102
163
response = self .to_response (route .func (** args ))
103
164
104
165
if route .cors :
105
- response .add_cors (route .method )
166
+ response .add_cors (route .cors )
106
167
if route .cache_control :
107
168
response .add_cache_control (route .cache_control )
108
169
if route .compress and "gzip" in (self .current_event .get_header_value ("accept-encoding" ) or "" ):
@@ -135,6 +196,12 @@ def _to_data_class(self, event: Dict) -> BaseProxyEvent:
135
196
return APIGatewayProxyEventV2 (event )
136
197
return ALBEvent (event )
137
198
199
+ @staticmethod
200
+ def _preflight (allowed_methods : Set ):
201
+ allowed_methods .add ("OPTIONS" )
202
+ headers = {"Access-Control-Allow-Methods" : "," .join (sorted (allowed_methods ))}
203
+ return Response (204 , None , None , headers )
204
+
138
205
def _find_route (self , method : str , path : str ) -> Tuple [Route , Dict ]:
139
206
for route in self ._routes :
140
207
if method != route .method :
@@ -143,6 +210,13 @@ def _find_route(self, method: str, path: str) -> Tuple[Route, Dict]:
143
210
if match :
144
211
return route , match .groupdict ()
145
212
213
+ if method == "OPTIONS" and self ._cors is not None :
214
+ # Most be the preflight options call
215
+ return (
216
+ Route ("OPTIONS" , None , self ._preflight , self ._cors , False , None ),
217
+ {"allowed_methods" : self ._cors_methods },
218
+ )
219
+
146
220
raise ValueError (f"No route found for '{ method } .{ path } '" )
147
221
148
222
def __call__ (self , event , context ) -> Any :
0 commit comments