Skip to content

Commit 97bd70f

Browse files
committed
feat(event_handler): add Cookies first class citizen
1 parent 3b444ae commit 97bd70f

File tree

6 files changed

+135
-29
lines changed

6 files changed

+135
-29
lines changed

aws_lambda_powertools/event_handler/api_gateway.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from typing import Any, Callable, Dict, List, Match, Optional, Pattern, Set, Tuple, Type, Union
1414

1515
from aws_lambda_powertools.event_handler import content_types
16+
from aws_lambda_powertools.event_handler.cookies import Cookie
1617
from aws_lambda_powertools.event_handler.exceptions import NotFoundError, ServiceError
1718
from aws_lambda_powertools.shared import constants
1819
from aws_lambda_powertools.shared.functions import resolve_truthy_env_var_choice
@@ -147,7 +148,7 @@ def __init__(
147148
content_type: Optional[str],
148149
body: Union[str, bytes, None],
149150
headers: Optional[Dict[str, Union[str, List[str]]]] = None,
150-
cookies: Optional[List[str]] = None,
151+
cookies: Optional[List[Cookie]] = None,
151152
):
152153
"""
153154
@@ -162,7 +163,7 @@ def __init__(
162163
Optionally set the response body. Note: bytes body will be automatically base64 encoded
163164
headers: dict[str, Union[str, List[str]]]
164165
Optionally set specific http headers. Setting "Content-Type" here would override the `content_type` value.
165-
cookies: list[str]
166+
cookies: list[Cookie]
166167
Optionally set cookies.
167168
"""
168169
self.status_code = status_code
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
from datetime import datetime
2+
from enum import Enum
3+
from io import StringIO
4+
from typing import List, Optional
5+
6+
7+
class SameSite(Enum):
8+
DEFAULT_MODE = ""
9+
LAX_MODE = "Lax"
10+
STRICT_MODE = "Strict"
11+
NONE_MODE = "None"
12+
13+
14+
class Cookie:
15+
def __init__(
16+
self,
17+
name: str,
18+
value: str,
19+
path: Optional[str] = None,
20+
domain: Optional[str] = None,
21+
expires: Optional[datetime] = None,
22+
max_age: Optional[int] = None,
23+
secure: Optional[bool] = None,
24+
http_only: Optional[bool] = None,
25+
same_site: Optional[SameSite] = None,
26+
custom_attributes: Optional[List[str]] = None,
27+
):
28+
self.name = name
29+
self.value = value
30+
self.path = path
31+
self.domain = domain
32+
self.expires = expires
33+
self.max_age = max_age
34+
self.secure = secure
35+
self.http_only = http_only
36+
self.same_site = same_site
37+
self.custom_attributes = custom_attributes
38+
39+
def __str__(self) -> str:
40+
payload = StringIO()
41+
payload.write(f"{self.name}=")
42+
43+
# Maintenance(rf): the value needs to be sanitized
44+
payload.write(self.value)
45+
46+
if self.path and len(self.path) > 0:
47+
# Maintenance(rf): the value of path needs to be sanitized
48+
payload.write(f"; Path={self.path}")
49+
50+
if self.domain and len(self.domain) > 0:
51+
payload.write(f"; Domain={self.domain}")
52+
53+
if self.expires:
54+
# Maintenance(rf) this format is wrong
55+
payload.write(f"; Expires={self.expires.strftime('YYYY-MM-dd')}")
56+
57+
if self.max_age:
58+
if self.max_age > 0:
59+
payload.write(f"; MaxAge={self.max_age}")
60+
if self.max_age < 0:
61+
payload.write("; MaxAge=0")
62+
63+
if self.http_only:
64+
payload.write("; HttpOnly")
65+
66+
if self.secure:
67+
payload.write("; Secure")
68+
69+
if self.same_site:
70+
payload.write(f"; SameSite={self.same_site.value}")
71+
72+
if self.custom_attributes:
73+
for attr in self.custom_attributes:
74+
# Maintenance(rf): the value needs to be sanitized
75+
payload.write(f"; {attr}")
76+
77+
return payload.getvalue()

aws_lambda_powertools/shared/headers_serializer.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,16 @@
22
from collections import defaultdict
33
from typing import Any, Dict, List, Union
44

5+
from aws_lambda_powertools.event_handler.cookies import Cookie
6+
57

68
class BaseHeadersSerializer:
79
"""
810
Helper class to correctly serialize headers and cookies for Amazon API Gateway,
911
ALB and Lambda Function URL response payload.
1012
"""
1113

12-
def serialize(self, headers: Dict[str, Union[str, List[str]]], cookies: List[str]) -> Dict[str, Any]:
14+
def serialize(self, headers: Dict[str, Union[str, List[str]]], cookies: List[Cookie]) -> Dict[str, Any]:
1315
"""
1416
Serializes headers and cookies according to the request type.
1517
Returns a dict that can be merged with the response payload.
@@ -25,7 +27,7 @@ def serialize(self, headers: Dict[str, Union[str, List[str]]], cookies: List[str
2527

2628

2729
class HttpApiHeadersSerializer(BaseHeadersSerializer):
28-
def serialize(self, headers: Dict[str, Union[str, List[str]]], cookies: List[str]) -> Dict[str, Any]:
30+
def serialize(self, headers: Dict[str, Union[str, List[str]]], cookies: List[Cookie]) -> Dict[str, Any]:
2931
"""
3032
When using HTTP APIs or LambdaFunctionURLs, everything is taken care automatically for us.
3133
We can directly assign a list of cookies and a dict of headers to the response payload, and the
@@ -44,11 +46,11 @@ def serialize(self, headers: Dict[str, Union[str, List[str]]], cookies: List[str
4446
else:
4547
combined_headers[key] = ", ".join(values)
4648

47-
return {"headers": combined_headers, "cookies": cookies}
49+
return {"headers": combined_headers, "cookies": list(map(str, cookies))}
4850

4951

5052
class MultiValueHeadersSerializer(BaseHeadersSerializer):
51-
def serialize(self, headers: Dict[str, Union[str, List[str]]], cookies: List[str]) -> Dict[str, Any]:
53+
def serialize(self, headers: Dict[str, Union[str, List[str]]], cookies: List[Cookie]) -> Dict[str, Any]:
5254
"""
5355
When using REST APIs, headers can be encoded using the `multiValueHeaders` key on the response.
5456
This is also the case when using an ALB integration with the `multiValueHeaders` option enabled.
@@ -69,13 +71,13 @@ def serialize(self, headers: Dict[str, Union[str, List[str]]], cookies: List[str
6971
if cookies:
7072
payload.setdefault("Set-Cookie", [])
7173
for cookie in cookies:
72-
payload["Set-Cookie"].append(cookie)
74+
payload["Set-Cookie"].append(str(cookie))
7375

7476
return {"multiValueHeaders": payload}
7577

7678

7779
class SingleValueHeadersSerializer(BaseHeadersSerializer):
78-
def serialize(self, headers: Dict[str, Union[str, List[str]]], cookies: List[str]) -> Dict[str, Any]:
80+
def serialize(self, headers: Dict[str, Union[str, List[str]]], cookies: List[Cookie]) -> Dict[str, Any]:
7981
"""
8082
The ALB integration has `multiValueHeaders` disabled by default.
8183
If we try to set multiple headers with the same key, or more than one cookie, print a warning.
@@ -93,7 +95,7 @@ def serialize(self, headers: Dict[str, Union[str, List[str]]], cookies: List[str
9395
)
9496

9597
# We can only send one cookie, send the last one
96-
payload["headers"]["Set-Cookie"] = cookies[-1]
98+
payload["headers"]["Set-Cookie"] = str(cookies[-1])
9799

98100
for key, values in headers.items():
99101
if isinstance(values, str):

tests/e2e/event_handler/handlers/alb_handler.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
1+
from typing import Dict
2+
13
from aws_lambda_powertools.event_handler import ALBResolver, Response, content_types
24

35
app = ALBResolver()
46

57

6-
@app.get("/todos")
8+
@app.post("/todos")
79
def hello():
10+
payload: Dict = app.current_event.json_body
11+
812
return Response(
913
status_code=200,
1014
content_type=content_types.TEXT_PLAIN,
1115
body="Hello world",
12-
cookies=["CookieMonster", "MonsterCookie"],
13-
headers={"Foo": ["bar", "zbr"]},
16+
cookies=payload["cookies"],
17+
headers=payload["headers"],
1418
)
1519

1620

tests/e2e/event_handler/handlers/lambda_function_url_handler.py

+13-6
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,21 @@
33
app = LambdaFunctionUrlResolver()
44

55

6-
@app.get("/todos")
6+
@app.post("/todos")
77
def hello():
8+
payload = app.current_event.json_body
9+
10+
body = payload.get("body", "Hello World")
11+
status_code = payload.get("status_code", 200)
12+
headers = payload.get("headers", {})
13+
cookies = payload.get("cookies", [])
14+
815
return Response(
9-
status_code=200,
10-
content_type=content_types.TEXT_PLAIN,
11-
body="Hello world",
12-
cookies=["CookieMonster", "MonsterCookie"],
13-
headers={"Foo": ["bar", "zbr"]},
16+
status_code=status_code,
17+
content_type=headers.get("Content-Type", content_types.TEXT_PLAIN),
18+
body=body,
19+
cookies=cookies,
20+
headers=headers,
1421
)
1522

1623

tests/e2e/event_handler/test_header_serializer.py

+26-11
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
from uuid import uuid4
2+
13
import pytest
24
from requests import Request
35

6+
from aws_lambda_powertools.event_handler.cookies import Cookie
47
from tests.e2e.utils import data_fetcher
58

69

@@ -122,20 +125,32 @@ def test_api_gateway_http_headers_serializer(apigw_http_endpoint):
122125
def test_lambda_function_url_headers_serializer(lambda_function_url_endpoint):
123126
# GIVEN
124127
url = f"{lambda_function_url_endpoint}todos" # the function url endpoint already has the trailing /
128+
body = "Hello World"
129+
status_code = 200
130+
headers = {"Content-Type": "text/plain", "Vary": ["Accept-Encoding", "User-Agent"]}
131+
cookies = [
132+
Cookie(name="session_id", value=str(uuid4()), secure=True, http_only=True),
133+
Cookie(name="ab_experiment", value="3"),
134+
]
125135

126136
# WHEN
127-
response = data_fetcher.get_http_response(Request(method="GET", url=url))
137+
response = data_fetcher.get_http_response(
138+
Request(
139+
method="POST",
140+
url=url,
141+
json={"body": body, "status_code": status_code, "headers": headers, "cookies": list(map(str, cookies))},
142+
)
143+
)
128144

129145
# THEN
130-
assert response.status_code == 200
131-
assert response.content == b"Hello world"
132-
assert response.headers["content-type"] == "text/plain"
146+
assert response.status_code == status_code
147+
assert response.content.decode("ascii") == body
133148

134-
# Only the last header for key "Foo" should be set
135-
assert "Foo" in response.headers
136-
foo_headers = [x.strip() for x in response.headers["Foo"].split(",")]
137-
assert sorted(foo_headers) == ["bar", "zbr"]
149+
for key, value in headers.items():
150+
assert key in response.headers
151+
value = value if isinstance(value, str) else ", ".join(value)
152+
assert response.headers[key] == value
138153

139-
# Only the last cookie should be set
140-
assert "MonsterCookie" in response.cookies.keys()
141-
assert "CookieMonster" in response.cookies.keys()
154+
for cookie in cookies:
155+
assert cookie.name in response.cookies
156+
assert response.cookies.get(cookie.name) == cookie.value

0 commit comments

Comments
 (0)