Skip to content

Commit daaf137

Browse files
author
Michael Brewer
committed
feat(event-handler): apigwy cache_control option
1 parent 306ee73 commit daaf137

File tree

2 files changed

+55
-17
lines changed

2 files changed

+55
-17
lines changed

aws_lambda_powertools/event_handler/api_gateway.py

+23-17
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,15 @@ class ProxyEventType(Enum):
1717

1818

1919
class RouteEntry:
20-
def __init__(self, method: str, rule: Any, func: Callable, cors: bool, compress: bool):
20+
def __init__(
21+
self, method: str, rule: Any, func: Callable, cors: bool, compress: bool, cache_control: Optional[str]
22+
):
2123
self.method = method.upper()
2224
self.rule = rule
2325
self.func = func
2426
self.cors = cors
2527
self.compress = compress
28+
self.cache_control = cache_control
2629

2730

2831
class ApiGatewayResolver:
@@ -33,21 +36,21 @@ def __init__(self, proxy_type: Enum = ProxyEventType.http_api_v1):
3336
self._proxy_type = proxy_type
3437
self._routes: List[RouteEntry] = []
3538

36-
def get(self, rule: str, cors: bool = False, compress: bool = False):
37-
return self.route(rule, "GET", cors, compress)
39+
def get(self, rule: str, cors: bool = False, compress: bool = False, cache_control: str = None):
40+
return self.route(rule, "GET", cors, compress, cache_control)
3841

39-
def post(self, rule: str, cors: bool = False, compress: bool = False):
40-
return self.route(rule, "POST", cors, compress)
42+
def post(self, rule: str, cors: bool = False, compress: bool = False, cache_control: str = None):
43+
return self.route(rule, "POST", cors, compress, cache_control)
4144

42-
def put(self, rule: str, cors: bool = False, compress: bool = False):
43-
return self.route(rule, "PUT", cors, compress)
45+
def put(self, rule: str, cors: bool = False, compress: bool = False, cache_control: str = None):
46+
return self.route(rule, "PUT", cors, compress, cache_control)
4447

45-
def delete(self, rule: str, cors: bool = False, compress: bool = False):
46-
return self.route(rule, "DELETE", cors, compress)
48+
def delete(self, rule: str, cors: bool = False, compress: bool = False, cache_control: str = None):
49+
return self.route(rule, "DELETE", cors, compress, cache_control)
4750

48-
def route(self, rule: str, method: str, cors: bool = False, compress: bool = False):
51+
def route(self, rule: str, method: str, cors: bool = False, compress: bool = False, cache_control: str = None):
4952
def register_resolver(func: Callable):
50-
self._register(func, rule, method, cors, compress)
53+
self._append(func, rule, method, cors, compress, cache_control)
5154
return func
5255

5356
return register_resolver
@@ -58,31 +61,34 @@ def resolve(self, event: Dict, context: LambdaContext) -> Dict:
5861

5962
route, args = self._find_route(self.current_event.http_method, self.current_event.path)
6063
result = route.func(**args)
64+
65+
status: int = result[0]
66+
response: Dict[str, Any] = {"statusCode": status}
67+
6168
headers = {"Content-Type": result[1]}
6269
if route.cors:
6370
headers["Access-Control-Allow-Origin"] = "*"
6471
headers["Access-Control-Allow-Methods"] = route.method
6572
headers["Access-Control-Allow-Credentials"] = "true"
73+
if route.cache_control:
74+
headers["Cache-Control"] = route.cache_control if status == 200 else "no-cache"
75+
response["headers"] = headers
6676

6777
body: Union[str, bytes] = result[2]
6878
if route.compress and "gzip" in (self.current_event.get_header_value("accept-encoding") or ""):
6979
gzip_compress = zlib.compressobj(9, zlib.DEFLATED, zlib.MAX_WBITS | 16)
7080
if isinstance(body, str):
7181
body = bytes(body, "utf-8")
7282
body = gzip_compress.compress(body) + gzip_compress.flush()
73-
74-
response = {"statusCode": result[0], "headers": headers}
75-
7683
if isinstance(body, bytes):
7784
response["isBase64Encoded"] = True
7885
body = base64.b64encode(body).decode()
79-
8086
response["body"] = body
8187

8288
return response
8389

84-
def _register(self, func: Callable, rule: str, method: str, cors: bool, compress: bool):
85-
self._routes.append(RouteEntry(method, self._build_rule_pattern(rule), func, cors, compress))
90+
def _append(self, func: Callable, rule: str, method: str, cors: bool, compress: bool, cache_control: Optional[str]):
91+
self._routes.append(RouteEntry(method, self._build_rule_pattern(rule), func, cors, compress, cache_control))
8692

8793
@staticmethod
8894
def _build_rule_pattern(rule: str):

tests/functional/event_handler/test_api_gateway.py

+32
Original file line numberDiff line numberDiff line change
@@ -154,3 +154,35 @@ def handler(event, context):
154154
assert isinstance(body, str)
155155
decompress = zlib.decompress(base64.b64decode(body), wbits=zlib.MAX_WBITS | 16).decode("UTF-8")
156156
assert decompress == expected_value
157+
158+
159+
def test_cache_control_200():
160+
app = ApiGatewayResolver()
161+
162+
@app.get("/success", cache_control="max-age=600")
163+
def with_cache_control():
164+
return 200, "text/html", "has 200 response"
165+
166+
def handler(event, context):
167+
return app.resolve(event, context)
168+
169+
result = handler({"path": "/success", "httpMethod": "GET"}, None)
170+
171+
headers = result["headers"]
172+
assert headers["Cache-Control"] == "max-age=600"
173+
174+
175+
def test_cache_control_non_200():
176+
app = ApiGatewayResolver()
177+
178+
@app.delete("/fails", cache_control="max-age=600")
179+
def with_cache_control_has_500():
180+
return 503, "text/html", "has 503 response"
181+
182+
def handler(event, context):
183+
return app.resolve(event, context)
184+
185+
result = handler({"path": "/fails", "httpMethod": "DELETE"}, None)
186+
187+
headers = result["headers"]
188+
assert headers["Cache-Control"] == "no-cache"

0 commit comments

Comments
 (0)