1
+ import base64
1
2
import re
3
+ import zlib
2
4
from enum import Enum
3
- from typing import Any , Callable , Dict , List , Optional , Tuple
5
+ from typing import Any , Callable , Dict , List , Optional , Tuple , Union
4
6
5
7
from aws_lambda_powertools .utilities .data_classes import ALBEvent , APIGatewayProxyEvent , APIGatewayProxyEventV2
6
8
from aws_lambda_powertools .utilities .data_classes .common import BaseProxyEvent
@@ -15,11 +17,12 @@ class ProxyEventType(Enum):
15
17
16
18
17
19
class RouteEntry :
18
- def __init__ (self , method : str , rule : Any , func : Callable , cors : bool ):
20
+ def __init__ (self , method : str , rule : Any , func : Callable , cors : bool , compress : bool ):
19
21
self .method = method .upper ()
20
22
self .rule = rule
21
23
self .func = func
22
24
self .cors = cors
25
+ self .compress = compress
23
26
24
27
25
28
class ApiGatewayResolver :
@@ -30,21 +33,21 @@ def __init__(self, proxy_type: Enum = ProxyEventType.http_api_v1):
30
33
self ._proxy_type = proxy_type
31
34
self ._routes : List [RouteEntry ] = []
32
35
33
- def get (self , rule : str , cors : bool = False ):
34
- return self .route (rule , "GET" , cors )
36
+ def get (self , rule : str , cors : bool = False , compress : bool = False ):
37
+ return self .route (rule , "GET" , cors , compress )
35
38
36
- def post (self , rule : str , cors : bool = False ):
37
- return self .route (rule , "POST" , cors )
39
+ def post (self , rule : str , cors : bool = False , compress : bool = False ):
40
+ return self .route (rule , "POST" , cors , compress )
38
41
39
- def put (self , rule : str , cors : bool = False ):
40
- return self .route (rule , "PUT" , cors )
42
+ def put (self , rule : str , cors : bool = False , compress : bool = False ):
43
+ return self .route (rule , "PUT" , cors , compress )
41
44
42
- def delete (self , rule : str , cors : bool = False ):
43
- return self .route (rule , "DELETE" , cors )
45
+ def delete (self , rule : str , cors : bool = False , compress : bool = False ):
46
+ return self .route (rule , "DELETE" , cors , compress )
44
47
45
- def route (self , rule : str , method : str , cors : bool = False ):
48
+ def route (self , rule : str , method : str , cors : bool = False , compress : bool = False ):
46
49
def register_resolver (func : Callable ):
47
- self ._register (func , rule , method , cors )
50
+ self ._register (func , rule , method , cors , compress )
48
51
return func
49
52
50
53
return register_resolver
@@ -54,19 +57,32 @@ def resolve(self, event: Dict, context: LambdaContext) -> Dict:
54
57
self .lambda_context = context
55
58
56
59
route , args = self ._find_route (self .current_event .http_method , self .current_event .path )
57
-
58
60
result = route .func (** args )
59
-
60
61
headers = {"Content-Type" : result [1 ]}
61
62
if route .cors :
62
63
headers ["Access-Control-Allow-Origin" ] = "*"
63
64
headers ["Access-Control-Allow-Methods" ] = route .method
64
65
headers ["Access-Control-Allow-Credentials" ] = "true"
65
66
66
- return {"statusCode" : result [0 ], "headers" : headers , "body" : result [2 ]}
67
+ body : Union [str , bytes ] = result [2 ]
68
+ if route .compress and "gzip" in (self .current_event .get_header_value ("accept-encoding" ) or "" ):
69
+ gzip_compress = zlib .compressobj (9 , zlib .DEFLATED , zlib .MAX_WBITS | 16 )
70
+ if isinstance (body , str ):
71
+ body = bytes (body , "utf-8" )
72
+ body = gzip_compress .compress (body ) + gzip_compress .flush ()
73
+
74
+ response = {"statusCode" : result [0 ], "headers" : headers }
75
+
76
+ if isinstance (body , bytes ):
77
+ response ["isBase64Encoded" ] = True
78
+ body = base64 .b64encode (body ).decode ()
79
+
80
+ response ["body" ] = body
81
+
82
+ return response
67
83
68
- def _register (self , func : Callable , rule : str , method : str , cors : bool ):
69
- self ._routes .append (RouteEntry (method , self ._build_rule_pattern (rule ), func , cors ))
84
+ def _register (self , func : Callable , rule : str , method : str , cors : bool , compress : bool ):
85
+ self ._routes .append (RouteEntry (method , self ._build_rule_pattern (rule ), func , cors , compress ))
70
86
71
87
@staticmethod
72
88
def _build_rule_pattern (rule : str ):
@@ -82,12 +98,12 @@ def _as_data_class(self, event: Dict) -> BaseProxyEvent:
82
98
83
99
def _find_route (self , method : str , path : str ) -> Tuple [RouteEntry , Dict ]:
84
100
method = method .upper ()
85
- for resolver in self ._routes :
86
- if method != resolver .method :
101
+ for route in self ._routes :
102
+ if method != route .method :
87
103
continue
88
- match : Optional [re .Match ] = resolver .rule .match (path )
104
+ match : Optional [re .Match ] = route .rule .match (path )
89
105
if match :
90
- return resolver , match .groupdict ()
106
+ return route , match .groupdict ()
91
107
92
108
raise ValueError (f"No route found for '{ method } .{ path } '" )
93
109
0 commit comments