Skip to content

Commit 306ee73

Browse files
author
Michael Brewer
committed
feat(event-handler): apigw compress and base64encode
1 parent 1d6ea4d commit 306ee73

File tree

2 files changed

+61
-21
lines changed

2 files changed

+61
-21
lines changed

aws_lambda_powertools/event_handler/api_gateway.py

+37-21
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
import base64
12
import re
3+
import zlib
24
from enum import Enum
3-
from typing import Any, Callable, Dict, List, Optional, Tuple
5+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
46

57
from aws_lambda_powertools.utilities.data_classes import ALBEvent, APIGatewayProxyEvent, APIGatewayProxyEventV2
68
from aws_lambda_powertools.utilities.data_classes.common import BaseProxyEvent
@@ -15,11 +17,12 @@ class ProxyEventType(Enum):
1517

1618

1719
class RouteEntry:
18-
def __init__(self, method: str, rule: Any, func: Callable, cors: bool):
20+
def __init__(self, method: str, rule: Any, func: Callable, cors: bool, compress: bool):
1921
self.method = method.upper()
2022
self.rule = rule
2123
self.func = func
2224
self.cors = cors
25+
self.compress = compress
2326

2427

2528
class ApiGatewayResolver:
@@ -30,21 +33,21 @@ def __init__(self, proxy_type: Enum = ProxyEventType.http_api_v1):
3033
self._proxy_type = proxy_type
3134
self._routes: List[RouteEntry] = []
3235

33-
def get(self, rule: str, cors: bool = False):
34-
return self.route(rule, "GET", cors)
36+
def get(self, rule: str, cors: bool = False, compress: bool = False):
37+
return self.route(rule, "GET", cors, compress)
3538

36-
def post(self, rule: str, cors: bool = False):
37-
return self.route(rule, "POST", cors)
39+
def post(self, rule: str, cors: bool = False, compress: bool = False):
40+
return self.route(rule, "POST", cors, compress)
3841

39-
def put(self, rule: str, cors: bool = False):
40-
return self.route(rule, "PUT", cors)
42+
def put(self, rule: str, cors: bool = False, compress: bool = False):
43+
return self.route(rule, "PUT", cors, compress)
4144

42-
def delete(self, rule: str, cors: bool = False):
43-
return self.route(rule, "DELETE", cors)
45+
def delete(self, rule: str, cors: bool = False, compress: bool = False):
46+
return self.route(rule, "DELETE", cors, compress)
4447

45-
def route(self, rule: str, method: str, cors: bool = False):
48+
def route(self, rule: str, method: str, cors: bool = False, compress: bool = False):
4649
def register_resolver(func: Callable):
47-
self._register(func, rule, method, cors)
50+
self._register(func, rule, method, cors, compress)
4851
return func
4952

5053
return register_resolver
@@ -54,19 +57,32 @@ def resolve(self, event: Dict, context: LambdaContext) -> Dict:
5457
self.lambda_context = context
5558

5659
route, args = self._find_route(self.current_event.http_method, self.current_event.path)
57-
5860
result = route.func(**args)
59-
6061
headers = {"Content-Type": result[1]}
6162
if route.cors:
6263
headers["Access-Control-Allow-Origin"] = "*"
6364
headers["Access-Control-Allow-Methods"] = route.method
6465
headers["Access-Control-Allow-Credentials"] = "true"
6566

66-
return {"statusCode": result[0], "headers": headers, "body": result[2]}
67+
body: Union[str, bytes] = result[2]
68+
if route.compress and "gzip" in (self.current_event.get_header_value("accept-encoding") or ""):
69+
gzip_compress = zlib.compressobj(9, zlib.DEFLATED, zlib.MAX_WBITS | 16)
70+
if isinstance(body, str):
71+
body = bytes(body, "utf-8")
72+
body = gzip_compress.compress(body) + gzip_compress.flush()
73+
74+
response = {"statusCode": result[0], "headers": headers}
75+
76+
if isinstance(body, bytes):
77+
response["isBase64Encoded"] = True
78+
body = base64.b64encode(body).decode()
79+
80+
response["body"] = body
81+
82+
return response
6783

68-
def _register(self, func: Callable, rule: str, method: str, cors: bool):
69-
self._routes.append(RouteEntry(method, self._build_rule_pattern(rule), func, cors))
84+
def _register(self, func: Callable, rule: str, method: str, cors: bool, compress: bool):
85+
self._routes.append(RouteEntry(method, self._build_rule_pattern(rule), func, cors, compress))
7086

7187
@staticmethod
7288
def _build_rule_pattern(rule: str):
@@ -82,12 +98,12 @@ def _as_data_class(self, event: Dict) -> BaseProxyEvent:
8298

8399
def _find_route(self, method: str, path: str) -> Tuple[RouteEntry, Dict]:
84100
method = method.upper()
85-
for resolver in self._routes:
86-
if method != resolver.method:
101+
for route in self._routes:
102+
if method != route.method:
87103
continue
88-
match: Optional[re.Match] = resolver.rule.match(path)
104+
match: Optional[re.Match] = route.rule.match(path)
89105
if match:
90-
return resolver, match.groupdict()
106+
return route, match.groupdict()
91107

92108
raise ValueError(f"No route found for '{method}.{path}'")
93109

tests/functional/event_handler/test_api_gateway.py

+24
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import base64
12
import json
23
import os
4+
import zlib
35

46
import pytest
57

@@ -130,3 +132,25 @@ def handler(event, context):
130132
assert headers["Access-Control-Allow-Origin"] == "*"
131133
assert headers["Access-Control-Allow-Methods"] == "GET"
132134
assert headers["Access-Control-Allow-Credentials"] == "true"
135+
136+
137+
def test_compress():
138+
mock_event = {"path": "/my/request", "httpMethod": "GET", "headers": {"Accept-Encoding": "deflate, gzip"}}
139+
expected_value = '{"test": "value"}'
140+
141+
app = ApiGatewayResolver()
142+
143+
@app.get("/my/request", compress=True)
144+
def with_compression():
145+
return 200, "application/json", expected_value
146+
147+
def handler(event, context):
148+
return app.resolve(event, context)
149+
150+
result = handler(mock_event, None)
151+
152+
assert result["isBase64Encoded"] is True
153+
body = result["body"]
154+
assert isinstance(body, str)
155+
decompress = zlib.decompress(base64.b64decode(body), wbits=zlib.MAX_WBITS | 16).decode("UTF-8")
156+
assert decompress == expected_value

0 commit comments

Comments
 (0)