Skip to content

Commit e7e8d59

Browse files
committed
feat(event-handler): Add Response class
This will allow for fine grained control of the returning headers
1 parent b5a057b commit e7e8d59

File tree

2 files changed

+117
-37
lines changed

2 files changed

+117
-37
lines changed

aws_lambda_powertools/event_handler/api_gateway.py

+46-24
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,42 @@ def __init__(
2929
self.cache_control = cache_control
3030

3131

32+
class Response:
33+
def __init__(self, status_code: int, content_type: str, body: Union[str, bytes], headers: Dict = None):
34+
self.status_code = status_code
35+
self.body = body
36+
self.base64_encoded = False
37+
self.headers: Dict = headers if headers is not None else {}
38+
if "Content-Type" not in self.headers:
39+
self.headers["Content-Type"] = content_type
40+
41+
def add_cors(self, method: str):
42+
self.headers["Access-Control-Allow-Origin"] = "*"
43+
self.headers["Access-Control-Allow-Methods"] = method
44+
self.headers["Access-Control-Allow-Credentials"] = "true"
45+
46+
def add_cache_control(self, cache_control: str):
47+
self.headers["Cache-Control"] = cache_control if self.status_code == 200 else "no-cache"
48+
49+
def compress(self):
50+
self.headers["Content-Encoding"] = "gzip"
51+
if isinstance(self.body, str):
52+
self.body = bytes(self.body, "utf-8")
53+
gzip = zlib.compressobj(9, zlib.DEFLATED, zlib.MAX_WBITS | 16)
54+
self.body = gzip.compress(self.body) + gzip.flush()
55+
56+
def to_dict(self):
57+
if isinstance(self.body, bytes):
58+
self.base64_encoded = True
59+
self.body = base64.b64encode(self.body).decode()
60+
return {
61+
"statusCode": self.status_code,
62+
"headers": self.headers,
63+
"body": self.body,
64+
"isBase64Encoded": self.base64_encoded,
65+
}
66+
67+
3268
class ApiGatewayResolver:
3369
current_event: BaseProxyEvent
3470
lambda_context: LambdaContext
@@ -65,35 +101,21 @@ def resolve(self, event, context) -> Dict[str, Any]:
65101
route, args = self._find_route(self.current_event.http_method, self.current_event.path)
66102
result = route.func(**args)
67103

68-
if isinstance(result, dict):
69-
status_code = 200
70-
content_type = "application/json"
71-
body: Union[str, bytes] = json.dumps(result)
104+
if isinstance(result, Response):
105+
response = result
106+
elif isinstance(result, dict):
107+
response = Response(status_code=200, content_type="application/json", body=json.dumps(result))
72108
else:
73-
status_code, content_type, body = result
74-
headers = {"Content-Type": content_type}
109+
response = Response(*result)
75110

76111
if route.cors:
77-
headers["Access-Control-Allow-Origin"] = "*"
78-
headers["Access-Control-Allow-Methods"] = route.method
79-
headers["Access-Control-Allow-Credentials"] = "true"
80-
112+
response.add_cors(route.method)
81113
if route.cache_control:
82-
headers["Cache-Control"] = route.cache_control if status_code == 200 else "no-cache"
83-
114+
response.add_cache_control(route.cache_control)
84115
if route.compress and "gzip" in (self.current_event.get_header_value("accept-encoding") or ""):
85-
headers["Content-Encoding"] = "gzip"
86-
if isinstance(body, str):
87-
body = bytes(body, "utf-8")
88-
gzip = zlib.compressobj(9, zlib.DEFLATED, zlib.MAX_WBITS | 16)
89-
body = gzip.compress(body) + gzip.flush()
90-
91-
base64_encoded = False
92-
if isinstance(body, bytes):
93-
base64_encoded = True
94-
body = base64.b64encode(body).decode()
95-
96-
return {"statusCode": status_code, "headers": headers, "body": body, "isBase64Encoded": base64_encoded}
116+
response.compress()
117+
118+
return response.to_dict()
97119

98120
@staticmethod
99121
def _build_rule_pattern(rule: str):

tests/functional/event_handler/test_api_gateway.py

+71-13
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
import json
33
import zlib
44
from pathlib import Path
5+
from typing import Dict, Tuple
56

67
import pytest
78

8-
from aws_lambda_powertools.event_handler.api_gateway import ApiGatewayResolver, ProxyEventType
9+
from aws_lambda_powertools.event_handler.api_gateway import ApiGatewayResolver, ProxyEventType, Response
910
from aws_lambda_powertools.utilities.data_classes import ALBEvent, APIGatewayProxyEvent, APIGatewayProxyEventV2
1011

1112

@@ -25,77 +26,96 @@ def read_media(file_name: str) -> bytes:
2526

2627

2728
def test_alb_event():
29+
# GIVEN a Application Load Balancer proxy type event
2830
app = ApiGatewayResolver(proxy_type=ProxyEventType.alb_event)
2931

3032
@app.get("/lambda")
31-
def foo():
33+
def foo() -> Tuple[int, str, str]:
3234
assert isinstance(app.current_event, ALBEvent)
3335
assert app.lambda_context == {}
3436
return 200, TEXT_HTML, "foo"
3537

38+
# WHEN
3639
result = app(load_event("albEvent.json"), {})
3740

41+
# THEN process event correctly
42+
# AND set the current_event type as ALBEvent
3843
assert result["statusCode"] == 200
3944
assert result["headers"]["Content-Type"] == TEXT_HTML
4045
assert result["body"] == "foo"
4146

4247

4348
def test_api_gateway_v1():
49+
# GIVEN a Http API V1 proxy type event
4450
app = ApiGatewayResolver(proxy_type=ProxyEventType.http_api_v1)
4551

4652
@app.get("/my/path")
47-
def get_lambda():
53+
def get_lambda() -> Tuple[int, str, str]:
4854
assert isinstance(app.current_event, APIGatewayProxyEvent)
4955
assert app.lambda_context == {}
5056
return 200, APPLICATION_JSON, json.dumps({"foo": "value"})
5157

58+
# WHEN
5259
result = app(LOAD_GW_EVENT, {})
5360

61+
# THEN process event correctly
62+
# AND set the current_event type as APIGatewayProxyEvent
5463
assert result["statusCode"] == 200
5564
assert result["headers"]["Content-Type"] == APPLICATION_JSON
5665

5766

5867
def test_api_gateway():
68+
# GIVEN a Rest API Gateway proxy type event
5969
app = ApiGatewayResolver(proxy_type=ProxyEventType.api_gateway)
6070

6171
@app.get("/my/path")
62-
def get_lambda():
72+
def get_lambda() -> Tuple[int, str, str]:
6373
assert isinstance(app.current_event, APIGatewayProxyEvent)
6474
return 200, TEXT_HTML, "foo"
6575

76+
# WHEN
6677
result = app(LOAD_GW_EVENT, {})
6778

79+
# THEN process event correctly
80+
# AND set the current_event type as APIGatewayProxyEvent
6881
assert result["statusCode"] == 200
6982
assert result["headers"]["Content-Type"] == TEXT_HTML
7083
assert result["body"] == "foo"
7184

7285

7386
def test_api_gateway_v2():
87+
# GIVEN a Http API V2 proxy type event
7488
app = ApiGatewayResolver(proxy_type=ProxyEventType.http_api_v2)
7589

7690
@app.post("/my/path")
77-
def my_path():
91+
def my_path() -> Tuple[int, str, str]:
7892
assert isinstance(app.current_event, APIGatewayProxyEventV2)
7993
post_data = app.current_event.json_body
8094
return 200, "plain/text", post_data["username"]
8195

96+
# WHEN
8297
result = app(load_event("apiGatewayProxyV2Event.json"), {})
8398

99+
# THEN process event correctly
100+
# AND set the current_event type as APIGatewayProxyEventV2
84101
assert result["statusCode"] == 200
85102
assert result["headers"]["Content-Type"] == "plain/text"
86103
assert result["body"] == "tom"
87104

88105

89106
def test_include_rule_matching():
107+
# GIVEN
90108
app = ApiGatewayResolver()
91109

92110
@app.get("/<name>/<my_id>")
93-
def get_lambda(my_id: str, name: str):
111+
def get_lambda(my_id: str, name: str) -> Tuple[int, str, str]:
94112
assert name == "my"
95113
return 200, "plain/html", my_id
96114

115+
# WHEN
97116
result = app(LOAD_GW_EVENT, {})
98117

118+
# THEN
99119
assert result["statusCode"] == 200
100120
assert result["headers"]["Content-Type"] == "plain/html"
101121
assert result["body"] == "path"
@@ -153,7 +173,7 @@ def test_cors():
153173
app = ApiGatewayResolver()
154174

155175
@app.get("/my/path", cors=True)
156-
def with_cors():
176+
def with_cors() -> Tuple[int, str, str]:
157177
return 200, TEXT_HTML, "test"
158178

159179
def handler(event, context):
@@ -176,7 +196,7 @@ def test_compress():
176196
app = ApiGatewayResolver()
177197

178198
@app.get("/my/request", compress=True)
179-
def with_compression():
199+
def with_compression() -> Tuple[int, str, str]:
180200
return 200, APPLICATION_JSON, expected_value
181201

182202
def handler(event, context):
@@ -197,7 +217,7 @@ def test_base64_encode():
197217
app = ApiGatewayResolver()
198218

199219
@app.get("/my/path", compress=True)
200-
def read_image():
220+
def read_image() -> Tuple[int, str, bytes]:
201221
return 200, "image/png", read_media("idempotent_sequence_exception.png")
202222

203223
mock_event = {"path": "/my/path", "httpMethod": "GET", "headers": {"Accept-Encoding": "deflate, gzip"}}
@@ -211,62 +231,100 @@ def read_image():
211231

212232

213233
def test_compress_no_accept_encoding():
234+
# GIVEN a function with compress=True
235+
# AND the request has no "Accept-Encoding" set to include gzip
214236
app = ApiGatewayResolver()
215237
expected_value = "Foo"
216238

217239
@app.get("/my/path", compress=True)
218-
def return_text():
240+
def return_text() -> Tuple[int, str, str]:
219241
return 200, "text/plain", expected_value
220242

243+
# WHEN calling the event handler
221244
result = app({"path": "/my/path", "httpMethod": "GET", "headers": {}}, None)
222245

246+
# THEN don't perform any gzip compression
223247
assert result["isBase64Encoded"] is False
224248
assert result["body"] == expected_value
225249

226250

227251
def test_cache_control_200():
252+
# GIVEN a function with cache_control set
228253
app = ApiGatewayResolver()
229254

230255
@app.get("/success", cache_control="max-age=600")
231-
def with_cache_control():
256+
def with_cache_control() -> Tuple[int, str, str]:
232257
return 200, TEXT_HTML, "has 200 response"
233258

234259
def handler(event, context):
235260
return app.resolve(event, context)
236261

262+
# WHEN calling the event handler
263+
# AND the function returns a 200 status code
237264
result = handler({"path": "/success", "httpMethod": "GET"}, None)
238265

266+
# THEN return the set Cache-Control
239267
headers = result["headers"]
240268
assert headers["Content-Type"] == TEXT_HTML
241269
assert headers["Cache-Control"] == "max-age=600"
242270

243271

244272
def test_cache_control_non_200():
273+
# GIVEN a function with cache_control set
245274
app = ApiGatewayResolver()
246275

247276
@app.delete("/fails", cache_control="max-age=600")
248-
def with_cache_control_has_500():
277+
def with_cache_control_has_500() -> Tuple[int, str, str]:
249278
return 503, TEXT_HTML, "has 503 response"
250279

251280
def handler(event, context):
252281
return app.resolve(event, context)
253282

283+
# WHEN calling the event handler
284+
# AND the function returns a 503 status code
254285
result = handler({"path": "/fails", "httpMethod": "DELETE"}, None)
255286

287+
# THEN return a Cache-Control of "no-cache"
256288
headers = result["headers"]
257289
assert headers["Content-Type"] == TEXT_HTML
258290
assert headers["Cache-Control"] == "no-cache"
259291

260292

261293
def test_rest_api():
294+
# GIVEN a function that returns a Dict
262295
app = ApiGatewayResolver(proxy_type=ProxyEventType.http_api_v1)
263296

264297
@app.get("/my/path")
265-
def rest_func():
298+
def rest_func() -> Dict:
266299
return {"foo": "value"}
267300

301+
# WHEN calling the event handler
268302
result = app(LOAD_GW_EVENT, {})
269303

304+
# THEN automatically process this as a json rest api response
270305
assert result["statusCode"] == 200
271306
assert result["headers"]["Content-Type"] == APPLICATION_JSON
272307
assert result["body"] == json.dumps({"foo": "value"})
308+
309+
310+
def test_handling_response_type():
311+
# GIVEN a function that returns Response
312+
app = ApiGatewayResolver(proxy_type=ProxyEventType.http_api_v1)
313+
314+
@app.get("/my/path")
315+
def rest_func() -> Response:
316+
return Response(
317+
status_code=404,
318+
content_type="used-if-not-set-in-header",
319+
body="Not found",
320+
headers={"Content-Type": "header-content-type-wins", "custom": "value"},
321+
)
322+
323+
# WHEN calling the event handler
324+
result = app(LOAD_GW_EVENT, {})
325+
326+
# THEN the result can include some additional field control like overriding http headers
327+
assert result["statusCode"] == 404
328+
assert result["headers"]["Content-Type"] == "header-content-type-wins"
329+
assert result["headers"]["custom"] == "value"
330+
assert result["body"] == "Not found"

0 commit comments

Comments
 (0)