Skip to content

Commit e24a985

Browse files
author
Michael Brewer
authored
feat(event-handler): add http ProxyEvent handler (#369)
1 parent 776569a commit e24a985

21 files changed

+742
-61
lines changed

Diff for: aws_lambda_powertools/event_handler/api_gateway.py

+220
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
import base64
2+
import json
3+
import re
4+
import zlib
5+
from enum import Enum
6+
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
7+
8+
from aws_lambda_powertools.shared.json_encoder import Encoder
9+
from aws_lambda_powertools.utilities.data_classes import ALBEvent, APIGatewayProxyEvent, APIGatewayProxyEventV2
10+
from aws_lambda_powertools.utilities.data_classes.common import BaseProxyEvent
11+
from aws_lambda_powertools.utilities.typing import LambdaContext
12+
13+
14+
class ProxyEventType(Enum):
15+
http_api_v1 = "APIGatewayProxyEvent"
16+
http_api_v2 = "APIGatewayProxyEventV2"
17+
alb_event = "ALBEvent"
18+
api_gateway = http_api_v1
19+
20+
21+
class CORSConfig(object):
22+
"""CORS Config"""
23+
24+
_REQUIRED_HEADERS = ["Authorization", "Content-Type", "X-Amz-Date", "X-Api-Key", "X-Amz-Security-Token"]
25+
26+
def __init__(
27+
self,
28+
allow_origin: str = "*",
29+
allow_headers: List[str] = None,
30+
expose_headers: List[str] = None,
31+
max_age: int = None,
32+
allow_credentials: bool = False,
33+
):
34+
"""
35+
Parameters
36+
----------
37+
allow_origin: str
38+
The value of the `Access-Control-Allow-Origin` to send in the response. Defaults to "*", but should
39+
only be used during development.
40+
allow_headers: str
41+
The list of additional allowed headers. This list is added to list of
42+
built in allowed headers: `Authorization`, `Content-Type`, `X-Amz-Date`,
43+
`X-Api-Key`, `X-Amz-Security-Token`.
44+
expose_headers: str
45+
A list of values to return for the Access-Control-Expose-Headers
46+
max_age: int
47+
The value for the `Access-Control-Max-Age`
48+
allow_credentials: bool
49+
A boolean value that sets the value of `Access-Control-Allow-Credentials`
50+
"""
51+
self.allow_origin = allow_origin
52+
self.allow_headers = set(self._REQUIRED_HEADERS + (allow_headers or []))
53+
self.expose_headers = expose_headers or []
54+
self.max_age = max_age
55+
self.allow_credentials = allow_credentials
56+
57+
def to_dict(self) -> Dict[str, str]:
58+
headers = {
59+
"Access-Control-Allow-Origin": self.allow_origin,
60+
"Access-Control-Allow-Headers": ",".join(sorted(self.allow_headers)),
61+
}
62+
if self.expose_headers:
63+
headers["Access-Control-Expose-Headers"] = ",".join(self.expose_headers)
64+
if self.max_age is not None:
65+
headers["Access-Control-Max-Age"] = str(self.max_age)
66+
if self.allow_credentials is True:
67+
headers["Access-Control-Allow-Credentials"] = "true"
68+
return headers
69+
70+
71+
class Route:
72+
def __init__(
73+
self, method: str, rule: Any, func: Callable, cors: bool, compress: bool, cache_control: Optional[str]
74+
):
75+
self.method = method.upper()
76+
self.rule = rule
77+
self.func = func
78+
self.cors = cors
79+
self.compress = compress
80+
self.cache_control = cache_control
81+
82+
83+
class Response:
84+
def __init__(
85+
self, status_code: int, content_type: Optional[str], body: Union[str, bytes, None], headers: Dict = None
86+
):
87+
self.status_code = status_code
88+
self.body = body
89+
self.base64_encoded = False
90+
self.headers: Dict = headers or {}
91+
if content_type:
92+
self.headers.setdefault("Content-Type", content_type)
93+
94+
def add_cors(self, cors: CORSConfig):
95+
self.headers.update(cors.to_dict())
96+
97+
def add_cache_control(self, cache_control: str):
98+
self.headers["Cache-Control"] = cache_control if self.status_code == 200 else "no-cache"
99+
100+
def compress(self):
101+
self.headers["Content-Encoding"] = "gzip"
102+
if isinstance(self.body, str):
103+
self.body = bytes(self.body, "utf-8")
104+
gzip = zlib.compressobj(9, zlib.DEFLATED, zlib.MAX_WBITS | 16)
105+
self.body = gzip.compress(self.body) + gzip.flush()
106+
107+
def to_dict(self) -> Dict[str, Any]:
108+
if isinstance(self.body, bytes):
109+
self.base64_encoded = True
110+
self.body = base64.b64encode(self.body).decode()
111+
return {
112+
"statusCode": self.status_code,
113+
"headers": self.headers,
114+
"body": self.body,
115+
"isBase64Encoded": self.base64_encoded,
116+
}
117+
118+
119+
class ApiGatewayResolver:
120+
current_event: BaseProxyEvent
121+
lambda_context: LambdaContext
122+
123+
def __init__(self, proxy_type: Enum = ProxyEventType.http_api_v1, cors: CORSConfig = None):
124+
self._proxy_type = proxy_type
125+
self._routes: List[Route] = []
126+
self._cors = cors
127+
self._cors_methods: Set[str] = {"OPTIONS"}
128+
129+
def get(self, rule: str, cors: bool = False, compress: bool = False, cache_control: str = None):
130+
return self.route(rule, "GET", cors, compress, cache_control)
131+
132+
def post(self, rule: str, cors: bool = False, compress: bool = False, cache_control: str = None):
133+
return self.route(rule, "POST", cors, compress, cache_control)
134+
135+
def put(self, rule: str, cors: bool = False, compress: bool = False, cache_control: str = None):
136+
return self.route(rule, "PUT", cors, compress, cache_control)
137+
138+
def delete(self, rule: str, cors: bool = False, compress: bool = False, cache_control: str = None):
139+
return self.route(rule, "DELETE", cors, compress, cache_control)
140+
141+
def patch(self, rule: str, cors: bool = False, compress: bool = False, cache_control: str = None):
142+
return self.route(rule, "PATCH", cors, compress, cache_control)
143+
144+
def route(self, rule: str, method: str, cors: bool = False, compress: bool = False, cache_control: str = None):
145+
def register_resolver(func: Callable):
146+
self._routes.append(Route(method, self._compile_regex(rule), func, cors, compress, cache_control))
147+
if cors:
148+
self._cors_methods.add(method.upper())
149+
return func
150+
151+
return register_resolver
152+
153+
def resolve(self, event, context) -> Dict[str, Any]:
154+
self.current_event = self._to_data_class(event)
155+
self.lambda_context = context
156+
route, response = self._find_route(self.current_event.http_method.upper(), self.current_event.path)
157+
if route is None: # No matching route was found
158+
return response.to_dict()
159+
160+
if route.cors:
161+
response.add_cors(self._cors or CORSConfig())
162+
if route.cache_control:
163+
response.add_cache_control(route.cache_control)
164+
if route.compress and "gzip" in (self.current_event.get_header_value("accept-encoding") or ""):
165+
response.compress()
166+
167+
return response.to_dict()
168+
169+
@staticmethod
170+
def _compile_regex(rule: str):
171+
rule_regex: str = re.sub(r"(<\w+>)", r"(?P\1.+)", rule)
172+
return re.compile("^{}$".format(rule_regex))
173+
174+
def _to_data_class(self, event: Dict) -> BaseProxyEvent:
175+
if self._proxy_type == ProxyEventType.http_api_v1:
176+
return APIGatewayProxyEvent(event)
177+
if self._proxy_type == ProxyEventType.http_api_v2:
178+
return APIGatewayProxyEventV2(event)
179+
return ALBEvent(event)
180+
181+
def _find_route(self, method: str, path: str) -> Tuple[Optional[Route], Response]:
182+
for route in self._routes:
183+
if method != route.method:
184+
continue
185+
match: Optional[re.Match] = route.rule.match(path)
186+
if match:
187+
return self._call_route(route, match.groupdict())
188+
189+
headers = {}
190+
if self._cors:
191+
headers.update(self._cors.to_dict())
192+
if method == "OPTIONS": # Preflight
193+
headers["Access-Control-Allow-Methods"] = ",".join(sorted(self._cors_methods))
194+
return None, Response(status_code=204, content_type=None, body=None, headers=headers)
195+
196+
return None, Response(
197+
status_code=404,
198+
content_type="application/json",
199+
body=json.dumps({"message": f"No route found for '{method}.{path}'"}),
200+
headers=headers,
201+
)
202+
203+
def _call_route(self, route: Route, args: Dict[str, str]) -> Tuple[Route, Response]:
204+
return route, self._to_response(route.func(**args))
205+
206+
@staticmethod
207+
def _to_response(result: Union[Tuple[int, str, Union[bytes, str]], Dict, Response]) -> Response:
208+
if isinstance(result, Response):
209+
return result
210+
elif isinstance(result, dict):
211+
return Response(
212+
status_code=200,
213+
content_type="application/json",
214+
body=json.dumps(result, separators=(",", ":"), cls=Encoder),
215+
)
216+
else: # Tuple[int, str, Union[bytes, str]]
217+
return Response(*result)
218+
219+
def __call__(self, event, context) -> Any:
220+
return self.resolve(event, context)

Diff for: aws_lambda_powertools/utilities/data_classes/alb_event.py

-8
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,6 @@ class ALBEvent(BaseProxyEvent):
2121
def request_context(self) -> ALBEventRequestContext:
2222
return ALBEventRequestContext(self._data)
2323

24-
@property
25-
def http_method(self) -> str:
26-
return self["httpMethod"]
27-
28-
@property
29-
def path(self) -> str:
30-
return self["path"]
31-
3224
@property
3325
def multi_value_query_string_parameters(self) -> Optional[Dict[str, List[str]]]:
3426
return self.get("multiValueQueryStringParameters")

Diff for: aws_lambda_powertools/utilities/data_classes/api_gateway_proxy_event.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -217,15 +217,6 @@ def version(self) -> str:
217217
def resource(self) -> str:
218218
return self["resource"]
219219

220-
@property
221-
def path(self) -> str:
222-
return self["path"]
223-
224-
@property
225-
def http_method(self) -> str:
226-
"""The HTTP method used. Valid values include: DELETE, GET, HEAD, OPTIONS, PATCH, POST, and PUT."""
227-
return self["httpMethod"]
228-
229220
@property
230221
def multi_value_headers(self) -> Dict[str, List[str]]:
231222
return self["multiValueHeaders"]
@@ -446,3 +437,12 @@ def path_parameters(self) -> Optional[Dict[str, str]]:
446437
@property
447438
def stage_variables(self) -> Optional[Dict[str, str]]:
448439
return self.get("stageVariables")
440+
441+
@property
442+
def path(self) -> str:
443+
return self.raw_path
444+
445+
@property
446+
def http_method(self) -> str:
447+
"""The HTTP method used. Valid values include: DELETE, GET, HEAD, OPTIONS, PATCH, POST, and PUT."""
448+
return self.request_context.http.method

Diff for: aws_lambda_powertools/utilities/data_classes/common.py

+16
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
from typing import Any, Dict, Optional
23

34

@@ -57,8 +58,23 @@ def is_base64_encoded(self) -> Optional[bool]:
5758

5859
@property
5960
def body(self) -> Optional[str]:
61+
"""Submitted body of the request as a string"""
6062
return self.get("body")
6163

64+
@property
65+
def json_body(self) -> Any:
66+
"""Parses the submitted body as json"""
67+
return json.loads(self["body"])
68+
69+
@property
70+
def path(self) -> str:
71+
return self["path"]
72+
73+
@property
74+
def http_method(self) -> str:
75+
"""The HTTP method used. Valid values include: DELETE, GET, HEAD, OPTIONS, PATCH, POST, and PUT."""
76+
return self["httpMethod"]
77+
6278
def get_query_string_value(self, name: str, default_value: Optional[str] = None) -> Optional[str]:
6379
"""Get query string value by name
6480

0 commit comments

Comments
 (0)