Skip to content

Commit 51c4c2d

Browse files
refactor(event-handler): Add ResponseBuilder and more docs (#412)
1 parent 48f86b3 commit 51c4c2d

File tree

4 files changed

+538
-61
lines changed

4 files changed

+538
-61
lines changed

aws_lambda_powertools/event_handler/api_gateway.py

+209-54
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,57 @@
1212

1313

1414
class ProxyEventType(Enum):
15+
"""An enumerations of the supported proxy event types.
16+
17+
**NOTE:** api_gateway is an alias of http_api_v1"""
18+
1519
http_api_v1 = "APIGatewayProxyEvent"
1620
http_api_v2 = "APIGatewayProxyEventV2"
1721
alb_event = "ALBEvent"
1822
api_gateway = http_api_v1
1923

2024

2125
class CORSConfig(object):
22-
"""CORS Config"""
26+
"""CORS Config
27+
28+
29+
Examples
30+
--------
31+
32+
Simple cors example using the default permissive cors, not this should only be used during early prototyping
33+
34+
>>> from aws_lambda_powertools.event_handler.api_gateway import ApiGatewayResolver
35+
>>>
36+
>>> app = ApiGatewayResolver()
37+
>>>
38+
>>> @app.get("/my/path", cors=True)
39+
>>> def with_cors():
40+
>>> return {"message": "Foo"}
41+
42+
Using a custom CORSConfig where `with_cors` used the custom provided CORSConfig and `without_cors`
43+
do not include any cors headers.
44+
45+
>>> from aws_lambda_powertools.event_handler.api_gateway import (
46+
>>> ApiGatewayResolver, CORSConfig
47+
>>> )
48+
>>>
49+
>>> cors_config = CORSConfig(
50+
>>> allow_origin="https://wwww.example.com/",
51+
>>> expose_headers=["x-exposed-response-header"],
52+
>>> allow_headers=["x-custom-request-header"],
53+
>>> max_age=100,
54+
>>> allow_credentials=True,
55+
>>> )
56+
>>> app = ApiGatewayResolver(cors=cors_config)
57+
>>>
58+
>>> @app.get("/my/path", cors=True)
59+
>>> def with_cors():
60+
>>> return {"message": "Foo"}
61+
>>>
62+
>>> @app.get("/another-one")
63+
>>> def without_cors():
64+
>>> return {"message": "Foo"}
65+
"""
2366

2467
_REQUIRED_HEADERS = ["Authorization", "Content-Type", "X-Amz-Date", "X-Api-Key", "X-Amz-Security-Token"]
2568

@@ -55,6 +98,7 @@ def __init__(
5598
self.allow_credentials = allow_credentials
5699

57100
def to_dict(self) -> Dict[str, str]:
101+
"""Builds the configured Access-Control http headers"""
58102
headers = {
59103
"Access-Control-Allow-Origin": self.allow_origin,
60104
"Access-Control-Allow-Headers": ",".join(sorted(self.allow_headers)),
@@ -68,7 +112,37 @@ def to_dict(self) -> Dict[str, str]:
68112
return headers
69113

70114

115+
class Response:
116+
"""Response data class that provides greater control over what is returned from the proxy event"""
117+
118+
def __init__(
119+
self, status_code: int, content_type: Optional[str], body: Union[str, bytes, None], headers: Dict = None
120+
):
121+
"""
122+
123+
Parameters
124+
----------
125+
status_code: int
126+
Http status code, example 200
127+
content_type: str
128+
Optionally set the Content-Type header, example "application/json". Note this will be merged into any
129+
provided http headers
130+
body: Union[str, bytes, None]
131+
Optionally set the response body. Note: bytes body will be automatically base64 encoded
132+
headers: dict
133+
Optionally set specific http headers. Setting "Content-Type" hear would override the `content_type` value.
134+
"""
135+
self.status_code = status_code
136+
self.body = body
137+
self.base64_encoded = False
138+
self.headers: Dict = headers or {}
139+
if content_type:
140+
self.headers.setdefault("Content-Type", content_type)
141+
142+
71143
class Route:
144+
"""Internally used Route Configuration"""
145+
72146
def __init__(
73147
self, method: str, rule: Any, func: Callable, cors: bool, compress: bool, cache_control: Optional[str]
74148
):
@@ -80,68 +154,125 @@ def __init__(
80154
self.cache_control = cache_control
81155

82156

83-
class Response:
84-
def __init__(
85-
self, status_code: int, content_type: Optional[str], body: Union[str, bytes, None], headers: Dict = None
86-
):
87-
self.status_code = status_code
88-
self.body = body
89-
self.base64_encoded = False
90-
self.headers: Dict = headers or {}
91-
if content_type:
92-
self.headers.setdefault("Content-Type", content_type)
157+
class ResponseBuilder:
158+
"""Internally used Response builder"""
93159

94-
def add_cors(self, cors: CORSConfig):
95-
self.headers.update(cors.to_dict())
160+
def __init__(self, response: Response, route: Route = None):
161+
self.response = response
162+
self.route = route
96163

97-
def add_cache_control(self, cache_control: str):
98-
self.headers["Cache-Control"] = cache_control if self.status_code == 200 else "no-cache"
164+
def _add_cors(self, cors: CORSConfig):
165+
"""Update headers to include the configured Access-Control headers"""
166+
self.response.headers.update(cors.to_dict())
99167

100-
def compress(self):
101-
self.headers["Content-Encoding"] = "gzip"
102-
if isinstance(self.body, str):
103-
self.body = bytes(self.body, "utf-8")
104-
gzip = zlib.compressobj(9, zlib.DEFLATED, zlib.MAX_WBITS | 16)
105-
self.body = gzip.compress(self.body) + gzip.flush()
168+
def _add_cache_control(self, cache_control: str):
169+
"""Set the specified cache control headers for 200 http responses. For non-200 `no-cache` is used."""
170+
self.response.headers["Cache-Control"] = cache_control if self.response.status_code == 200 else "no-cache"
106171

107-
def to_dict(self) -> Dict[str, Any]:
108-
if isinstance(self.body, bytes):
109-
self.base64_encoded = True
110-
self.body = base64.b64encode(self.body).decode()
172+
def _compress(self):
173+
"""Compress the response body, but only if `Accept-Encoding` headers includes gzip."""
174+
self.response.headers["Content-Encoding"] = "gzip"
175+
if isinstance(self.response.body, str):
176+
self.response.body = bytes(self.response.body, "utf-8")
177+
gzip = zlib.compressobj(9, zlib.DEFLATED, zlib.MAX_WBITS | 16)
178+
self.response.body = gzip.compress(self.response.body) + gzip.flush()
179+
180+
def _route(self, event: BaseProxyEvent, cors: Optional[CORSConfig]):
181+
"""Optionally handle any of the route's configure response handling"""
182+
if self.route is None:
183+
return
184+
if self.route.cors:
185+
self._add_cors(cors or CORSConfig())
186+
if self.route.cache_control:
187+
self._add_cache_control(self.route.cache_control)
188+
if self.route.compress and "gzip" in (event.get_header_value("accept-encoding", "") or ""):
189+
self._compress()
190+
191+
def build(self, event: BaseProxyEvent, cors: CORSConfig = None) -> Dict[str, Any]:
192+
"""Build the full response dict to be returned by the lambda"""
193+
self._route(event, cors)
194+
195+
if isinstance(self.response.body, bytes):
196+
self.response.base64_encoded = True
197+
self.response.body = base64.b64encode(self.response.body).decode()
111198
return {
112-
"statusCode": self.status_code,
113-
"headers": self.headers,
114-
"body": self.body,
115-
"isBase64Encoded": self.base64_encoded,
199+
"statusCode": self.response.status_code,
200+
"headers": self.response.headers,
201+
"body": self.response.body,
202+
"isBase64Encoded": self.response.base64_encoded,
116203
}
117204

118205

119206
class ApiGatewayResolver:
207+
"""API Gateway and ALB proxy resolver
208+
209+
Examples
210+
--------
211+
Simple example with a custom lambda handler using the Tracer capture_lambda_handler decorator
212+
213+
>>> from aws_lambda_powertools import Tracer
214+
>>> from aws_lambda_powertools.event_handler.api_gateway import (
215+
>>> ApiGatewayResolver
216+
>>> )
217+
>>>
218+
>>> tracer = Tracer()
219+
>>> app = ApiGatewayResolver()
220+
>>>
221+
>>> @app.get("/get-call")
222+
>>> def simple_get():
223+
>>> return {"message": "Foo"}
224+
>>>
225+
>>> @app.post("/post-call")
226+
>>> def simple_post():
227+
>>> post_data: dict = app.current_event.json_body
228+
>>> return {"message": post_data["value"]}
229+
>>>
230+
>>> @tracer.capture_lambda_handler
231+
>>> def lambda_handler(event, context):
232+
>>> return app.resolve(event, context)
233+
234+
"""
235+
120236
current_event: BaseProxyEvent
121237
lambda_context: LambdaContext
122238

123239
def __init__(self, proxy_type: Enum = ProxyEventType.http_api_v1, cors: CORSConfig = None):
240+
"""
241+
Parameters
242+
----------
243+
proxy_type: ProxyEventType
244+
Proxy request type, defaults to API Gateway V1
245+
cors: CORSConfig
246+
Optionally configure and enabled CORS. Not each route will need to have to cors=True
247+
"""
124248
self._proxy_type = proxy_type
125249
self._routes: List[Route] = []
126250
self._cors = cors
127251
self._cors_methods: Set[str] = {"OPTIONS"}
128252

129253
def get(self, rule: str, cors: bool = False, compress: bool = False, cache_control: str = None):
254+
"""Get route decorator with GET `method`"""
130255
return self.route(rule, "GET", cors, compress, cache_control)
131256

132257
def post(self, rule: str, cors: bool = False, compress: bool = False, cache_control: str = None):
258+
"""Post route decorator with POST `method`"""
133259
return self.route(rule, "POST", cors, compress, cache_control)
134260

135261
def put(self, rule: str, cors: bool = False, compress: bool = False, cache_control: str = None):
262+
"""Put route decorator with PUT `method`"""
136263
return self.route(rule, "PUT", cors, compress, cache_control)
137264

138265
def delete(self, rule: str, cors: bool = False, compress: bool = False, cache_control: str = None):
266+
"""Delete route decorator with DELETE `method`"""
139267
return self.route(rule, "DELETE", cors, compress, cache_control)
140268

141269
def patch(self, rule: str, cors: bool = False, compress: bool = False, cache_control: str = None):
270+
"""Patch route decorator with PATCH `method`"""
142271
return self.route(rule, "PATCH", cors, compress, cache_control)
143272

144273
def route(self, rule: str, method: str, cors: bool = False, compress: bool = False, cache_control: str = None):
274+
"""Route decorator includes parameter `method`"""
275+
145276
def register_resolver(func: Callable):
146277
self._routes.append(Route(method, self._compile_regex(rule), func, cors, compress, cache_control))
147278
if cors:
@@ -151,60 +282,87 @@ def register_resolver(func: Callable):
151282
return register_resolver
152283

153284
def resolve(self, event, context) -> Dict[str, Any]:
154-
self.current_event = self._to_data_class(event)
155-
self.lambda_context = context
156-
route, response = self._find_route(self.current_event.http_method.upper(), self.current_event.path)
157-
if route is None: # No matching route was found
158-
return response.to_dict()
285+
"""Resolves the response based on the provide event and decorator routes
159286
160-
if route.cors:
161-
response.add_cors(self._cors or CORSConfig())
162-
if route.cache_control:
163-
response.add_cache_control(route.cache_control)
164-
if route.compress and "gzip" in (self.current_event.get_header_value("accept-encoding") or ""):
165-
response.compress()
287+
Parameters
288+
----------
289+
event: Dict[str, Any]
290+
Event
291+
context: LambdaContext
292+
Lambda context
293+
Returns
294+
-------
295+
dict
296+
Returns the dict response
297+
"""
298+
self.current_event = self._to_proxy_event(event)
299+
self.lambda_context = context
300+
return self._resolve().build(self.current_event, self._cors)
166301

167-
return response.to_dict()
302+
def __call__(self, event, context) -> Any:
303+
return self.resolve(event, context)
168304

169305
@staticmethod
170306
def _compile_regex(rule: str):
307+
"""Precompile regex pattern"""
171308
rule_regex: str = re.sub(r"(<\w+>)", r"(?P\1.+)", rule)
172309
return re.compile("^{}$".format(rule_regex))
173310

174-
def _to_data_class(self, event: Dict) -> BaseProxyEvent:
311+
def _to_proxy_event(self, event: Dict) -> BaseProxyEvent:
312+
"""Convert the event dict to the corresponding data class"""
175313
if self._proxy_type == ProxyEventType.http_api_v1:
176314
return APIGatewayProxyEvent(event)
177315
if self._proxy_type == ProxyEventType.http_api_v2:
178316
return APIGatewayProxyEventV2(event)
179317
return ALBEvent(event)
180318

181-
def _find_route(self, method: str, path: str) -> Tuple[Optional[Route], Response]:
319+
def _resolve(self) -> ResponseBuilder:
320+
"""Resolves the response or return the not found response"""
321+
method = self.current_event.http_method.upper()
322+
path = self.current_event.path
182323
for route in self._routes:
183324
if method != route.method:
184325
continue
185326
match: Optional[re.Match] = route.rule.match(path)
186327
if match:
187328
return self._call_route(route, match.groupdict())
188329

330+
return self._not_found(method, path)
331+
332+
def _not_found(self, method: str, path: str) -> ResponseBuilder:
333+
"""Called when no matching route was found and includes support for the cors preflight response"""
189334
headers = {}
190335
if self._cors:
191336
headers.update(self._cors.to_dict())
337+
192338
if method == "OPTIONS": # Preflight
193339
headers["Access-Control-Allow-Methods"] = ",".join(sorted(self._cors_methods))
194-
return None, Response(status_code=204, content_type=None, body=None, headers=headers)
340+
return ResponseBuilder(Response(status_code=204, content_type=None, headers=headers, body=None))
195341

196-
return None, Response(
197-
status_code=404,
198-
content_type="application/json",
199-
body=json.dumps({"message": f"No route found for '{method}.{path}'"}),
200-
headers=headers,
342+
return ResponseBuilder(
343+
Response(
344+
status_code=404,
345+
content_type="application/json",
346+
headers=headers,
347+
body=json.dumps({"message": f"No route found for '{method}.{path}'"}),
348+
)
201349
)
202350

203-
def _call_route(self, route: Route, args: Dict[str, str]) -> Tuple[Route, Response]:
204-
return route, self._to_response(route.func(**args))
351+
def _call_route(self, route: Route, args: Dict[str, str]) -> ResponseBuilder:
352+
"""Actually call the matching route with any provided keyword arguments."""
353+
return ResponseBuilder(self._to_response(route.func(**args)), route)
205354

206355
@staticmethod
207356
def _to_response(result: Union[Tuple[int, str, Union[bytes, str]], Dict, Response]) -> Response:
357+
"""Convert the route's result to a Response
358+
359+
3 main result types are supported:
360+
361+
- Tuple[int, str, bytes] and Tuple[int, str, str]: status code, content-type and body (str|bytes)
362+
- Dict[str, Any]: Rest api response with just the Dict to json stringify and content-type is set to
363+
application/json
364+
- Response: returned as is, and allows for more flexibility
365+
"""
208366
if isinstance(result, Response):
209367
return result
210368
elif isinstance(result, dict):
@@ -215,6 +373,3 @@ def _to_response(result: Union[Tuple[int, str, Union[bytes, str]], Dict, Respons
215373
)
216374
else: # Tuple[int, str, Union[bytes, str]]
217375
return Response(*result)
218-
219-
def __call__(self, event, context) -> Any:
220-
return self.resolve(event, context)

0 commit comments

Comments
 (0)