Skip to content

refactor(event-handler): Add ResponseBuilder and more docs #412

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 1, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 62 additions & 42 deletions aws_lambda_powertools/event_handler/api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,28 +91,47 @@ def __init__(
if content_type:
self.headers.setdefault("Content-Type", content_type)

def add_cors(self, cors: CORSConfig):
self.headers.update(cors.to_dict())

def add_cache_control(self, cache_control: str):
self.headers["Cache-Control"] = cache_control if self.status_code == 200 else "no-cache"
class ResponseBuilder:
def __init__(self, response: Response, route: Route = None):
self.response = response
self.route = route

def compress(self):
self.headers["Content-Encoding"] = "gzip"
if isinstance(self.body, str):
self.body = bytes(self.body, "utf-8")
gzip = zlib.compressobj(9, zlib.DEFLATED, zlib.MAX_WBITS | 16)
self.body = gzip.compress(self.body) + gzip.flush()
def _add_cors(self, cors: CORSConfig):
self.response.headers.update(cors.to_dict())

def _add_cache_control(self, cache_control: str):
self.response.headers["Cache-Control"] = cache_control if self.response.status_code == 200 else "no-cache"

def to_dict(self) -> Dict[str, Any]:
if isinstance(self.body, bytes):
self.base64_encoded = True
self.body = base64.b64encode(self.body).decode()
def _compress(self):
self.response.headers["Content-Encoding"] = "gzip"
if isinstance(self.response.body, str):
self.response.body = bytes(self.response.body, "utf-8")
gzip = zlib.compressobj(9, zlib.DEFLATED, zlib.MAX_WBITS | 16)
self.response.body = gzip.compress(self.response.body) + gzip.flush()

def _route(self, event: BaseProxyEvent, cors: Optional[CORSConfig]):
"""Optionally handle any of the route's configure response handling"""
if self.route is None:
return
if self.route.cors:
self._add_cors(cors or CORSConfig())
if self.route.cache_control:
self._add_cache_control(self.route.cache_control)
if self.route.compress and "gzip" in (event.get_header_value("accept-encoding", "") or ""):
self._compress()

def build(self, event: BaseProxyEvent, cors: CORSConfig = None) -> Dict[str, Any]:
self._route(event, cors)

if isinstance(self.response.body, bytes):
self.response.base64_encoded = True
self.response.body = base64.b64encode(self.response.body).decode()
return {
"statusCode": self.status_code,
"headers": self.headers,
"body": self.body,
"isBase64Encoded": self.base64_encoded,
"statusCode": self.response.status_code,
"headers": self.response.headers,
"body": self.response.body,
"isBase64Encoded": self.response.base64_encoded,
}


Expand Down Expand Up @@ -153,58 +172,62 @@ def register_resolver(func: Callable):
def resolve(self, event, context) -> Dict[str, Any]:
self.current_event = self._to_data_class(event)
self.lambda_context = context
route, response = self._find_route(self.current_event.http_method.upper(), self.current_event.path)
if route is None: # No matching route was found
return response.to_dict()

if route.cors:
response.add_cors(self._cors or CORSConfig())
if route.cache_control:
response.add_cache_control(route.cache_control)
if route.compress and "gzip" in (self.current_event.get_header_value("accept-encoding") or ""):
response.compress()
return self._resolve_response().build(self.current_event, self._cors)

return response.to_dict()
def __call__(self, event, context) -> Any:
return self.resolve(event, context)

@staticmethod
def _compile_regex(rule: str):
"""Precompile regex pattern"""
rule_regex: str = re.sub(r"(<\w+>)", r"(?P\1.+)", rule)
return re.compile("^{}$".format(rule_regex))

def _to_data_class(self, event: Dict) -> BaseProxyEvent:
"""Convert the event dict to the corresponding data class"""
if self._proxy_type == ProxyEventType.http_api_v1:
return APIGatewayProxyEvent(event)
if self._proxy_type == ProxyEventType.http_api_v2:
return APIGatewayProxyEventV2(event)
return ALBEvent(event)

def _find_route(self, method: str, path: str) -> Tuple[Optional[Route], Response]:
def _resolve_response(self) -> ResponseBuilder:
"""Resolve the response or return the not found response"""
method = self.current_event.http_method.upper()
path = self.current_event.path
for route in self._routes:
if method != route.method:
continue
match: Optional[re.Match] = route.rule.match(path)
if match:
return self._call_route(route, match.groupdict())

return self._not_found(method, path)

def _not_found(self, method: str, path: str) -> ResponseBuilder:
"""No matching route was found, includes support for the cors preflight response"""
headers = {}
if self._cors:
headers.update(self._cors.to_dict())
if method == "OPTIONS": # Preflight
headers["Access-Control-Allow-Methods"] = ",".join(sorted(self._cors_methods))
return None, Response(status_code=204, content_type=None, body=None, headers=headers)

return None, Response(
status_code=404,
content_type="application/json",
body=json.dumps({"message": f"No route found for '{method}.{path}'"}),
headers=headers,
return ResponseBuilder(Response(status_code=204, content_type=None, body=None, headers=headers))
return ResponseBuilder(
Response(
status_code=404,
content_type="application/json",
body=json.dumps({"message": f"No route found for '{method}.{path}'"}),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self: Pull the entire code to check whether we're not making customers vulnerable to data enumeration attacks via HTTP status code

headers=headers,
)
)

def _call_route(self, route: Route, args: Dict[str, str]) -> Tuple[Route, Response]:
return route, self._to_response(route.func(**args))
def _call_route(self, route: Route, args: Dict[str, str]) -> ResponseBuilder:
"""Actually call the matching route with any provided keyword arguments."""
return ResponseBuilder(self._to_response(route.func(**args)), route)

@staticmethod
def _to_response(result: Union[Tuple[int, str, Union[bytes, str]], Dict, Response]) -> Response:
"""Convert the route result to a Response"""
if isinstance(result, Response):
return result
elif isinstance(result, dict):
Expand All @@ -215,6 +238,3 @@ def _to_response(result: Union[Tuple[int, str, Union[bytes, str]], Dict, Respons
)
else: # Tuple[int, str, Union[bytes, str]]
return Response(*result)

def __call__(self, event, context) -> Any:
return self.resolve(event, context)
22 changes: 15 additions & 7 deletions tests/functional/event_handler/test_api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@
from pathlib import Path
from typing import Dict, Tuple

from aws_lambda_powertools.event_handler.api_gateway import ApiGatewayResolver, CORSConfig, ProxyEventType, Response
from aws_lambda_powertools.event_handler.api_gateway import (
ApiGatewayResolver,
CORSConfig,
ProxyEventType,
Response,
ResponseBuilder,
)
from aws_lambda_powertools.shared.json_encoder import Encoder
from aws_lambda_powertools.utilities.data_classes import ALBEvent, APIGatewayProxyEvent, APIGatewayProxyEventV2
from tests.functional.utils import load_event
Expand Down Expand Up @@ -106,14 +112,14 @@ def test_include_rule_matching():
@app.get("/<name>/<my_id>")
def get_lambda(my_id: str, name: str) -> Tuple[int, str, str]:
assert name == "my"
return 200, "plain/html", my_id
return 200, TEXT_HTML, my_id

# WHEN calling the event handler
result = app(LOAD_GW_EVENT, {})

# THEN
assert result["statusCode"] == 200
assert result["headers"]["Content-Type"] == "plain/html"
assert result["headers"]["Content-Type"] == TEXT_HTML
assert result["body"] == "path"


Expand Down Expand Up @@ -389,14 +395,16 @@ def another_one():
def test_no_content_response():
# GIVEN a response with no content-type or body
response = Response(status_code=204, content_type=None, body=None, headers=None)
response_builder = ResponseBuilder(response)

# WHEN calling to_dict
result = response.to_dict()
result = response_builder.build(APIGatewayProxyEvent(LOAD_GW_EVENT))

# THEN return an None body and no Content-Type header
assert result["statusCode"] == response.status_code
assert result["body"] is None
assert result["statusCode"] == 204
assert "Content-Type" not in result["headers"]
headers = result["headers"]
assert "Content-Type" not in headers


def test_no_matches_with_cors():
Expand All @@ -413,7 +421,7 @@ def test_no_matches_with_cors():
assert "Access-Control-Allow-Origin" in result["headers"]


def test_preflight():
def test_cors_preflight():
# GIVEN an event for an OPTIONS call that does not match any of the given routes
# AND cors is enabled
app = ApiGatewayResolver(cors=CORSConfig())
Expand Down