Skip to content

Commit 74cd99d

Browse files
feat(event_handler): support to enable or disable compression in custom responses (#2544)
* feature: adding Response compress parameter * feature: addressing Heitor's feedback * feature: addressing Heitor's feedback * refactor(event_handler): make _has_compression_enabled standalone --------- Co-authored-by: Heitor Lessa <[email protected]>
1 parent a9df77c commit 74cd99d

File tree

6 files changed

+133
-4
lines changed

6 files changed

+133
-4
lines changed

aws_lambda_powertools/event_handler/api_gateway.py

+37-1
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ def __init__(
177177
body: Union[str, bytes, None] = None,
178178
headers: Optional[Dict[str, Union[str, List[str]]]] = None,
179179
cookies: Optional[List[Cookie]] = None,
180+
compress: Optional[bool] = None,
180181
):
181182
"""
182183
@@ -199,6 +200,7 @@ def __init__(
199200
self.base64_encoded = False
200201
self.headers: Dict[str, Union[str, List[str]]] = headers if headers else {}
201202
self.cookies = cookies or []
203+
self.compress = compress
202204
if content_type:
203205
self.headers.setdefault("Content-Type", content_type)
204206

@@ -233,6 +235,38 @@ def _add_cache_control(self, cache_control: str):
233235
cache_control = cache_control if self.response.status_code == 200 else "no-cache"
234236
self.response.headers["Cache-Control"] = cache_control
235237

238+
@staticmethod
239+
def _has_compression_enabled(
240+
route_compression: bool, response_compression: Optional[bool], event: BaseProxyEvent
241+
) -> bool:
242+
"""
243+
Checks if compression is enabled.
244+
245+
NOTE: Response compression takes precedence.
246+
247+
Parameters
248+
----------
249+
route_compression: bool, optional
250+
A boolean indicating whether compression is enabled or not in the route setting.
251+
response_compression: bool, optional
252+
A boolean indicating whether compression is enabled or not in the response setting.
253+
event: BaseProxyEvent
254+
The event object containing the request details.
255+
256+
Returns
257+
-------
258+
bool
259+
True if compression is enabled and the "gzip" encoding is accepted, False otherwise.
260+
"""
261+
encoding: str = event.get_header_value(name="accept-encoding", default_value="", case_sensitive=False) # type: ignore[assignment] # noqa: E501
262+
if "gzip" in encoding:
263+
if response_compression is not None:
264+
return response_compression # e.g., Response(compress=False/True))
265+
if route_compression:
266+
return True # e.g., @app.get(compress=True)
267+
268+
return False
269+
236270
def _compress(self):
237271
"""Compress the response body, but only if `Accept-Encoding` headers includes gzip."""
238272
self.response.headers["Content-Encoding"] = "gzip"
@@ -250,7 +284,9 @@ def _route(self, event: BaseProxyEvent, cors: Optional[CORSConfig]):
250284
self._add_cors(event, cors or CORSConfig())
251285
if self.route.cache_control:
252286
self._add_cache_control(self.route.cache_control)
253-
if self.route.compress and "gzip" in (event.get_header_value("accept-encoding", "") or ""):
287+
if self._has_compression_enabled(
288+
route_compression=self.route.compress, response_compression=self.response.compress, event=event
289+
):
254290
self._compress()
255291

256292
def build(self, event: BaseProxyEvent, cors: Optional[CORSConfig] = None) -> Dict[str, Any]:

aws_lambda_powertools/utilities/data_classes/common.py

+1
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ def get_query_string_value(self, name: str, default_value: Optional[str] = None)
154154
query_string_parameters=self.query_string_parameters, name=name, default_value=default_value
155155
)
156156

157+
# Maintenance: missing @overload to ensure return type is a str when default_value is set
157158
def get_header_value(
158159
self, name: str, default_value: Optional[str] = None, case_sensitive: Optional[bool] = False
159160
) -> Optional[str]:

docs/core/event_handler/api_gateway.md

+12-3
Original file line numberDiff line numberDiff line change
@@ -360,15 +360,24 @@ You can use the `Response` class to have full control over the response. For exa
360360

361361
### Compress
362362

363-
You can compress with gzip and base64 encode your responses via `compress` parameter.
363+
You can compress with gzip and base64 encode your responses via `compress` parameter. You have the option to pass the `compress` parameter when working with a specific route or using the Response object.
364+
365+
???+ info
366+
The `compress` parameter used in the Response object takes precedence over the one used in the route.
364367

365368
???+ warning
366369
The client must send the `Accept-Encoding` header, otherwise a normal response will be sent.
367370

368-
=== "compressing_responses.py"
371+
=== "compressing_responses_using_route.py"
369372

370373
```python hl_lines="17 27"
371-
--8<-- "examples/event_handler_rest/src/compressing_responses.py"
374+
--8<-- "examples/event_handler_rest/src/compressing_responses_using_route.py"
375+
```
376+
377+
=== "compressing_responses_using_response.py"
378+
379+
```python hl_lines="24"
380+
--8<-- "examples/event_handler_rest/src/compressing_responses_using_response.py"
372381
```
373382

374383
=== "compressing_responses.json"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import requests
2+
3+
from aws_lambda_powertools import Logger, Tracer
4+
from aws_lambda_powertools.event_handler import (
5+
APIGatewayRestResolver,
6+
Response,
7+
content_types,
8+
)
9+
from aws_lambda_powertools.logging import correlation_paths
10+
from aws_lambda_powertools.utilities.typing import LambdaContext
11+
12+
tracer = Tracer()
13+
logger = Logger()
14+
app = APIGatewayRestResolver()
15+
16+
17+
@app.get("/todos")
18+
@tracer.capture_method
19+
def get_todos():
20+
todos: requests.Response = requests.get("https://jsonplaceholder.typicode.com/todos")
21+
todos.raise_for_status()
22+
23+
# for brevity, we'll limit to the first 10 only
24+
return Response(status_code=200, content_type=content_types.APPLICATION_JSON, body=todos.json()[:10], compress=True)
25+
26+
27+
# You can continue to use other utilities just as before
28+
@logger.inject_lambda_context(correlation_id_path=correlation_paths.API_GATEWAY_REST)
29+
@tracer.capture_lambda_handler
30+
def lambda_handler(event: dict, context: LambdaContext) -> dict:
31+
return app.resolve(event, context)

tests/functional/event_handler/test_api_gateway.py

+52
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,58 @@ def test_cors_preflight_body_is_empty_not_null():
366366
assert result["body"] == ""
367367

368368

369+
def test_override_route_compress_parameter():
370+
# GIVEN a function that has compress=True
371+
# AND an event with a "Accept-Encoding" that include gzip
372+
# AND the Response object with compress=False
373+
app = ApiGatewayResolver()
374+
mock_event = {"path": "/my/request", "httpMethod": "GET", "headers": {"Accept-Encoding": "deflate, gzip"}}
375+
expected_value = '{"test": "value"}'
376+
377+
@app.get("/my/request", compress=True)
378+
def with_compression() -> Response:
379+
return Response(200, content_types.APPLICATION_JSON, expected_value, compress=False)
380+
381+
def handler(event, context):
382+
return app.resolve(event, context)
383+
384+
# WHEN calling the event handler
385+
result = handler(mock_event, None)
386+
387+
# THEN then the response is not compressed
388+
assert result["isBase64Encoded"] is False
389+
assert result["body"] == expected_value
390+
assert result["multiValueHeaders"].get("Content-Encoding") is None
391+
392+
393+
def test_response_with_compress_enabled():
394+
# GIVEN a function
395+
# AND an event with a "Accept-Encoding" that include gzip
396+
# AND the Response object with compress=True
397+
app = ApiGatewayResolver()
398+
mock_event = {"path": "/my/request", "httpMethod": "GET", "headers": {"Accept-Encoding": "deflate, gzip"}}
399+
expected_value = '{"test": "value"}'
400+
401+
@app.get("/my/request")
402+
def route_without_compression() -> Response:
403+
return Response(200, content_types.APPLICATION_JSON, expected_value, compress=True)
404+
405+
def handler(event, context):
406+
return app.resolve(event, context)
407+
408+
# WHEN calling the event handler
409+
result = handler(mock_event, None)
410+
411+
# THEN then gzip the response and base64 encode as a string
412+
assert result["isBase64Encoded"] is True
413+
body = result["body"]
414+
assert isinstance(body, str)
415+
decompress = zlib.decompress(base64.b64decode(body), wbits=zlib.MAX_WBITS | 16).decode("UTF-8")
416+
assert decompress == expected_value
417+
headers = result["multiValueHeaders"]
418+
assert headers["Content-Encoding"] == ["gzip"]
419+
420+
369421
def test_compress():
370422
# GIVEN a function that has compress=True
371423
# AND an event with a "Accept-Encoding" that include gzip

0 commit comments

Comments
 (0)