@@ -17,12 +17,15 @@ class ProxyEventType(Enum):
17
17
18
18
19
19
class RouteEntry :
20
- def __init__ (self , method : str , rule : Any , func : Callable , cors : bool , compress : bool ):
20
+ def __init__ (
21
+ self , method : str , rule : Any , func : Callable , cors : bool , compress : bool , cache_control : Optional [str ]
22
+ ):
21
23
self .method = method .upper ()
22
24
self .rule = rule
23
25
self .func = func
24
26
self .cors = cors
25
27
self .compress = compress
28
+ self .cache_control = cache_control
26
29
27
30
28
31
class ApiGatewayResolver :
@@ -33,21 +36,21 @@ def __init__(self, proxy_type: Enum = ProxyEventType.http_api_v1):
33
36
self ._proxy_type = proxy_type
34
37
self ._routes : List [RouteEntry ] = []
35
38
36
- def get (self , rule : str , cors : bool = False , compress : bool = False ):
37
- return self .route (rule , "GET" , cors , compress )
39
+ def get (self , rule : str , cors : bool = False , compress : bool = False , cache_control : str = None ):
40
+ return self .route (rule , "GET" , cors , compress , cache_control )
38
41
39
- def post (self , rule : str , cors : bool = False , compress : bool = False ):
40
- return self .route (rule , "POST" , cors , compress )
42
+ def post (self , rule : str , cors : bool = False , compress : bool = False , cache_control : str = None ):
43
+ return self .route (rule , "POST" , cors , compress , cache_control )
41
44
42
- def put (self , rule : str , cors : bool = False , compress : bool = False ):
43
- return self .route (rule , "PUT" , cors , compress )
45
+ def put (self , rule : str , cors : bool = False , compress : bool = False , cache_control : str = None ):
46
+ return self .route (rule , "PUT" , cors , compress , cache_control )
44
47
45
- def delete (self , rule : str , cors : bool = False , compress : bool = False ):
46
- return self .route (rule , "DELETE" , cors , compress )
48
+ def delete (self , rule : str , cors : bool = False , compress : bool = False , cache_control : str = None ):
49
+ return self .route (rule , "DELETE" , cors , compress , cache_control )
47
50
48
- def route (self , rule : str , method : str , cors : bool = False , compress : bool = False ):
51
+ def route (self , rule : str , method : str , cors : bool = False , compress : bool = False , cache_control : str = None ):
49
52
def register_resolver (func : Callable ):
50
- self ._register (func , rule , method , cors , compress )
53
+ self ._append (func , rule , method , cors , compress , cache_control )
51
54
return func
52
55
53
56
return register_resolver
@@ -58,31 +61,34 @@ def resolve(self, event: Dict, context: LambdaContext) -> Dict:
58
61
59
62
route , args = self ._find_route (self .current_event .http_method , self .current_event .path )
60
63
result = route .func (** args )
64
+
65
+ status : int = result [0 ]
66
+ response : Dict [str , Any ] = {"statusCode" : status }
67
+
61
68
headers = {"Content-Type" : result [1 ]}
62
69
if route .cors :
63
70
headers ["Access-Control-Allow-Origin" ] = "*"
64
71
headers ["Access-Control-Allow-Methods" ] = route .method
65
72
headers ["Access-Control-Allow-Credentials" ] = "true"
73
+ if route .cache_control :
74
+ headers ["Cache-Control" ] = route .cache_control if status == 200 else "no-cache"
75
+ response ["headers" ] = headers
66
76
67
77
body : Union [str , bytes ] = result [2 ]
68
78
if route .compress and "gzip" in (self .current_event .get_header_value ("accept-encoding" ) or "" ):
69
79
gzip_compress = zlib .compressobj (9 , zlib .DEFLATED , zlib .MAX_WBITS | 16 )
70
80
if isinstance (body , str ):
71
81
body = bytes (body , "utf-8" )
72
82
body = gzip_compress .compress (body ) + gzip_compress .flush ()
73
-
74
- response = {"statusCode" : result [0 ], "headers" : headers }
75
-
76
83
if isinstance (body , bytes ):
77
84
response ["isBase64Encoded" ] = True
78
85
body = base64 .b64encode (body ).decode ()
79
-
80
86
response ["body" ] = body
81
87
82
88
return response
83
89
84
- def _register (self , func : Callable , rule : str , method : str , cors : bool , compress : bool ):
85
- self ._routes .append (RouteEntry (method , self ._build_rule_pattern (rule ), func , cors , compress ))
90
+ def _append (self , func : Callable , rule : str , method : str , cors : bool , compress : bool , cache_control : Optional [ str ] ):
91
+ self ._routes .append (RouteEntry (method , self ._build_rule_pattern (rule ), func , cors , compress , cache_control ))
86
92
87
93
@staticmethod
88
94
def _build_rule_pattern (rule : str ):
0 commit comments