From 1695911e9095bac17de897c8bafc9414a2988dcb Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Sat, 27 Mar 2021 21:12:08 -0700 Subject: [PATCH 01/36] feat(event-handler): Add http ProxyEvent handler --- .../utilities/data_classes/alb_event.py | 8 -- .../data_classes/api_gateway_proxy_event.py | 18 +-- .../utilities/data_classes/common.py | 9 ++ .../utilities/event_handler/__init__.py | 0 .../utilities/event_handler/api_gateway.py | 103 ++++++++++++++++++ tests/functional/event_handler/__init__.py | 0 .../event_handler/test_api_gateway.py | 96 ++++++++++++++++ 7 files changed, 217 insertions(+), 17 deletions(-) create mode 100644 aws_lambda_powertools/utilities/event_handler/__init__.py create mode 100644 aws_lambda_powertools/utilities/event_handler/api_gateway.py create mode 100644 tests/functional/event_handler/__init__.py create mode 100644 tests/functional/event_handler/test_api_gateway.py diff --git a/aws_lambda_powertools/utilities/data_classes/alb_event.py b/aws_lambda_powertools/utilities/data_classes/alb_event.py index 6c7cb9e60c3..73e064d0f26 100644 --- a/aws_lambda_powertools/utilities/data_classes/alb_event.py +++ b/aws_lambda_powertools/utilities/data_classes/alb_event.py @@ -21,14 +21,6 @@ class ALBEvent(BaseProxyEvent): def request_context(self) -> ALBEventRequestContext: return ALBEventRequestContext(self._data) - @property - def http_method(self) -> str: - return self["httpMethod"] - - @property - def path(self) -> str: - return self["path"] - @property def multi_value_query_string_parameters(self) -> Optional[Dict[str, List[str]]]: return self.get("multiValueQueryStringParameters") diff --git a/aws_lambda_powertools/utilities/data_classes/api_gateway_proxy_event.py b/aws_lambda_powertools/utilities/data_classes/api_gateway_proxy_event.py index 756842ad347..1f747ac7a1f 100644 --- a/aws_lambda_powertools/utilities/data_classes/api_gateway_proxy_event.py +++ b/aws_lambda_powertools/utilities/data_classes/api_gateway_proxy_event.py @@ -212,15 +212,6 @@ def version(self) -> str: def resource(self) -> str: return self["resource"] - @property - def path(self) -> str: - return self["path"] - - @property - def http_method(self) -> str: - """The HTTP method used. Valid values include: DELETE, GET, HEAD, OPTIONS, PATCH, POST, and PUT.""" - return self["httpMethod"] - @property def multi_value_headers(self) -> Dict[str, List[str]]: return self["multiValueHeaders"] @@ -441,3 +432,12 @@ def path_parameters(self) -> Optional[Dict[str, str]]: @property def stage_variables(self) -> Optional[Dict[str, str]]: return self.get("stageVariables") + + @property + def path(self) -> str: + return self.raw_path + + @property + def http_method(self) -> str: + """The HTTP method used. Valid values include: DELETE, GET, HEAD, OPTIONS, PATCH, POST, and PUT.""" + return self.request_context.http.method diff --git a/aws_lambda_powertools/utilities/data_classes/common.py b/aws_lambda_powertools/utilities/data_classes/common.py index 6f393cccb60..59224f233c7 100644 --- a/aws_lambda_powertools/utilities/data_classes/common.py +++ b/aws_lambda_powertools/utilities/data_classes/common.py @@ -59,6 +59,15 @@ def is_base64_encoded(self) -> Optional[bool]: def body(self) -> Optional[str]: return self.get("body") + @property + def path(self) -> str: + return self["path"] + + @property + def http_method(self) -> str: + """The HTTP method used. Valid values include: DELETE, GET, HEAD, OPTIONS, PATCH, POST, and PUT.""" + return self["httpMethod"] + def get_query_string_value(self, name: str, default_value: Optional[str] = None) -> Optional[str]: """Get query string value by name diff --git a/aws_lambda_powertools/utilities/event_handler/__init__.py b/aws_lambda_powertools/utilities/event_handler/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/aws_lambda_powertools/utilities/event_handler/api_gateway.py b/aws_lambda_powertools/utilities/event_handler/api_gateway.py new file mode 100644 index 00000000000..44a9bb3ea2a --- /dev/null +++ b/aws_lambda_powertools/utilities/event_handler/api_gateway.py @@ -0,0 +1,103 @@ +from enum import Enum +from typing import Any, Callable, Dict, List, Tuple, Union + +from aws_lambda_powertools.utilities.data_classes import ALBEvent, APIGatewayProxyEvent, APIGatewayProxyEventV2 +from aws_lambda_powertools.utilities.data_classes.common import BaseProxyEvent +from aws_lambda_powertools.utilities.typing import LambdaContext + + +class ProxyEventType(Enum): + http_api_v1 = "APIGatewayProxyEvent" + http_api_v2 = "APIGatewayProxyEventV2" + alb_event = "ALBEvent" + + +class ApiGatewayResolver: + def __init__(self, proxy_type: Enum = ProxyEventType.http_api_v1): + self._proxy_type: Enum = proxy_type + self._resolvers: List[Dict] = [] + + def get(self, uri: str, include_event: bool = False, include_context: bool = False, **kwargs): + return self.route("GET", uri, include_event, include_context, **kwargs) + + def post(self, uri: str, include_event: bool = False, include_context: bool = False, **kwargs): + return self.route("POST", uri, include_event, include_context, **kwargs) + + def put(self, uri: str, include_event: bool = False, include_context: bool = False, **kwargs): + return self.route("PUT", uri, include_event, include_context, **kwargs) + + def delete(self, uri: str, include_event: bool = False, include_context: bool = False, **kwargs): + return self.route("DELETE", uri, include_event, include_context, **kwargs) + + def route( + self, + method: str, + uri: str, + include_event: bool = False, + include_context: bool = False, + **kwargs, + ): + def register_resolver(func: Callable[[Any, Any], Tuple[int, str, str]]): + self._register(func, method.upper(), uri, include_event, include_context, kwargs) + return func + + return register_resolver + + def resolve(self, event: Dict, context: LambdaContext) -> Dict: + proxy_event: BaseProxyEvent = self._as_proxy_event(event) + resolver: Callable[[Any], Tuple[int, str, str]] + config: Dict + resolver, config = self._find_resolver(proxy_event.http_method.upper(), proxy_event.path) + kwargs = self._kwargs(proxy_event, context, config) + result = resolver(**kwargs) + return {"statusCode": result[0], "headers": {"Content-Type": result[1]}, "body": result[2]} + + def _register( + self, + func: Callable[[Any, Any], Tuple[int, str, str]], + http_method: str, + uri_starts_with: str, + include_event: bool, + include_context: bool, + kwargs: Dict, + ): + kwargs["include_event"] = include_event + kwargs["include_context"] = include_context + self._resolvers.append( + { + "http_method": http_method, + "uri_starts_with": uri_starts_with, + "func": func, + "config": kwargs, + } + ) + + def _as_proxy_event(self, event: Dict) -> Union[ALBEvent, APIGatewayProxyEvent, APIGatewayProxyEventV2]: + if self._proxy_type == ProxyEventType.http_api_v1: + return APIGatewayProxyEvent(event) + if self._proxy_type == ProxyEventType.http_api_v2: + return APIGatewayProxyEventV2(event) + return ALBEvent(event) + + def _find_resolver(self, http_method: str, path: str) -> Tuple[Callable, Dict]: + for resolver in self._resolvers: + expected_method = resolver["http_method"] + if http_method != expected_method: + continue + path_starts_with = resolver["uri_starts_with"] + if path.startswith(path_starts_with): + return resolver["func"], resolver["config"] + + raise ValueError(f"No resolver found for '{http_method}.{path}'") + + @staticmethod + def _kwargs(event: BaseProxyEvent, context: LambdaContext, config: Dict) -> Dict[str, Any]: + kwargs: Dict[str, Any] = {} + if config.get("include_event", False): + kwargs["event"] = event + if config.get("include_context", False): + kwargs["context"] = context + return kwargs + + def __call__(self, event, context) -> Any: + return self.resolve(event, context) diff --git a/tests/functional/event_handler/__init__.py b/tests/functional/event_handler/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/functional/event_handler/test_api_gateway.py b/tests/functional/event_handler/test_api_gateway.py new file mode 100644 index 00000000000..517c167ce5c --- /dev/null +++ b/tests/functional/event_handler/test_api_gateway.py @@ -0,0 +1,96 @@ +import json +import os + +import pytest + +from aws_lambda_powertools.utilities.data_classes import ALBEvent, APIGatewayProxyEvent, APIGatewayProxyEventV2 +from aws_lambda_powertools.utilities.event_handler.api_gateway import ApiGatewayResolver, ProxyEventType +from aws_lambda_powertools.utilities.typing import LambdaContext + + +def load_event(file_name: str) -> dict: + full_file_name = os.path.dirname(os.path.realpath(__file__)) + "/../../events/" + file_name + with open(full_file_name) as fp: + return json.load(fp) + + +def test_alb_event(): + app = ApiGatewayResolver(proxy_type=ProxyEventType.alb_event) + + @app.get("/lambda", include_event=True, include_context=True) + def foo(event: ALBEvent, context: LambdaContext): + assert isinstance(event, ALBEvent) + assert context == {} + return 200, "text/html", "foo" + + result = app(load_event("albEvent.json"), {}) + + assert result["statusCode"] == 200 + assert result["headers"]["Content-Type"] == "text/html" + assert result["body"] == "foo" + + +def test_api_gateway_v1(): + app = ApiGatewayResolver(proxy_type=ProxyEventType.http_api_v1) + + @app.get("/my/path", include_event=True, include_context=True) + def get_lambda(event: APIGatewayProxyEvent, context: LambdaContext): + assert isinstance(event, APIGatewayProxyEvent) + assert context == {} + return 200, "application/json", json.dumps({"foo": "value"}) + + result = app(load_event("apiGatewayProxyEvent.json"), {}) + + assert result["statusCode"] == 200 + assert result["headers"]["Content-Type"] == "application/json" + + +def test_include_event_false(): + app = ApiGatewayResolver() + + @app.get("/my/path") + def get_lambda(): + return 200, "plain/html", "foo" + + result = app(load_event("apiGatewayProxyEvent.json"), {}) + + assert result["statusCode"] == 200 + assert result["body"] == "foo" + + +def test_api_gateway_v2(): + app = ApiGatewayResolver(proxy_type=ProxyEventType.http_api_v2) + + @app.post("/my/path", include_event=True) + def my_path(event: APIGatewayProxyEventV2): + assert isinstance(event, APIGatewayProxyEventV2) + post_data = json.loads(event.body or "{}") + return 200, "plain/text", post_data["username"] + + result = app(load_event("apiGatewayProxyV2Event.json"), {}) + + assert result["statusCode"] == 200 + assert result["headers"]["Content-Type"] == "plain/text" + assert result["body"] == "tom" + + +def test_no_matching(): + app = ApiGatewayResolver() + + @app.get("/not_matching_get") + def no_get_matching(): + raise RuntimeError() + + @app.put("/no_matching") + def no_put_matching(): + raise RuntimeError() + + @app.delete("/no_matching") + def no_delete_matching(): + raise RuntimeError() + + def handler(event, context): + app.resolve(event, context) + + with pytest.raises(ValueError): + handler(load_event("apiGatewayProxyEvent.json"), None) From d0f21f17512005c7ae82c78532d48855f6e2a8aa Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Sat, 27 Mar 2021 23:52:33 -0700 Subject: [PATCH 02/36] feat(event-handler): Lightweight rule matching --- .../utilities/event_handler/api_gateway.py | 52 +++++++++++-------- .../event_handler/test_api_gateway.py | 16 +++++- 2 files changed, 46 insertions(+), 22 deletions(-) diff --git a/aws_lambda_powertools/utilities/event_handler/api_gateway.py b/aws_lambda_powertools/utilities/event_handler/api_gateway.py index 44a9bb3ea2a..751f2abc5d3 100644 --- a/aws_lambda_powertools/utilities/event_handler/api_gateway.py +++ b/aws_lambda_powertools/utilities/event_handler/api_gateway.py @@ -1,5 +1,7 @@ +import re from enum import Enum -from typing import Any, Callable, Dict, List, Tuple, Union +from re import Match, Pattern +from typing import Any, Callable, Dict, List, Optional, Tuple, Union from aws_lambda_powertools.utilities.data_classes import ALBEvent, APIGatewayProxyEvent, APIGatewayProxyEventV2 from aws_lambda_powertools.utilities.data_classes.common import BaseProxyEvent @@ -17,28 +19,28 @@ def __init__(self, proxy_type: Enum = ProxyEventType.http_api_v1): self._proxy_type: Enum = proxy_type self._resolvers: List[Dict] = [] - def get(self, uri: str, include_event: bool = False, include_context: bool = False, **kwargs): - return self.route("GET", uri, include_event, include_context, **kwargs) + def get(self, rule: str, include_event: bool = False, include_context: bool = False, **kwargs): + return self.route("GET", rule, include_event, include_context, **kwargs) - def post(self, uri: str, include_event: bool = False, include_context: bool = False, **kwargs): - return self.route("POST", uri, include_event, include_context, **kwargs) + def post(self, rule: str, include_event: bool = False, include_context: bool = False, **kwargs): + return self.route("POST", rule, include_event, include_context, **kwargs) - def put(self, uri: str, include_event: bool = False, include_context: bool = False, **kwargs): - return self.route("PUT", uri, include_event, include_context, **kwargs) + def put(self, rule: str, include_event: bool = False, include_context: bool = False, **kwargs): + return self.route("PUT", rule, include_event, include_context, **kwargs) - def delete(self, uri: str, include_event: bool = False, include_context: bool = False, **kwargs): - return self.route("DELETE", uri, include_event, include_context, **kwargs) + def delete(self, rule: str, include_event: bool = False, include_context: bool = False, **kwargs): + return self.route("DELETE", rule, include_event, include_context, **kwargs) def route( self, method: str, - uri: str, + rule: str, include_event: bool = False, include_context: bool = False, **kwargs, ): def register_resolver(func: Callable[[Any, Any], Tuple[int, str, str]]): - self._register(func, method.upper(), uri, include_event, include_context, kwargs) + self._register(func, method.upper(), rule, include_event, include_context, kwargs) return func return register_resolver @@ -47,8 +49,8 @@ def resolve(self, event: Dict, context: LambdaContext) -> Dict: proxy_event: BaseProxyEvent = self._as_proxy_event(event) resolver: Callable[[Any], Tuple[int, str, str]] config: Dict - resolver, config = self._find_resolver(proxy_event.http_method.upper(), proxy_event.path) - kwargs = self._kwargs(proxy_event, context, config) + resolver, config, args = self._find_resolver(proxy_event.http_method.upper(), proxy_event.path) + kwargs = self._kwargs(proxy_event, context, config, args) result = resolver(**kwargs) return {"statusCode": result[0], "headers": {"Content-Type": result[1]}, "body": result[2]} @@ -56,22 +58,29 @@ def _register( self, func: Callable[[Any, Any], Tuple[int, str, str]], http_method: str, - uri_starts_with: str, + rule: str, include_event: bool, include_context: bool, kwargs: Dict, ): kwargs["include_event"] = include_event kwargs["include_context"] = include_context + rule_pattern = self._build_rule_pattern(rule) + self._resolvers.append( { "http_method": http_method, - "uri_starts_with": uri_starts_with, + "rule_pattern": rule_pattern, "func": func, "config": kwargs, } ) + @staticmethod + def _build_rule_pattern(rule: str) -> Pattern: + rule_regex: str = re.sub(r"(<\w+>)", r"(?P\1.+)", rule) + return re.compile("^{}$".format(rule_regex)) + def _as_proxy_event(self, event: Dict) -> Union[ALBEvent, APIGatewayProxyEvent, APIGatewayProxyEventV2]: if self._proxy_type == ProxyEventType.http_api_v1: return APIGatewayProxyEvent(event) @@ -79,20 +88,21 @@ def _as_proxy_event(self, event: Dict) -> Union[ALBEvent, APIGatewayProxyEvent, return APIGatewayProxyEventV2(event) return ALBEvent(event) - def _find_resolver(self, http_method: str, path: str) -> Tuple[Callable, Dict]: + def _find_resolver(self, http_method: str, path: str) -> Tuple[Callable, Dict, Dict]: for resolver in self._resolvers: expected_method = resolver["http_method"] if http_method != expected_method: continue - path_starts_with = resolver["uri_starts_with"] - if path.startswith(path_starts_with): - return resolver["func"], resolver["config"] + + match: Optional[Match] = resolver["rule_pattern"].match(path) + if match: + return resolver["func"], resolver["config"], match.groupdict() raise ValueError(f"No resolver found for '{http_method}.{path}'") @staticmethod - def _kwargs(event: BaseProxyEvent, context: LambdaContext, config: Dict) -> Dict[str, Any]: - kwargs: Dict[str, Any] = {} + def _kwargs(event: BaseProxyEvent, context: LambdaContext, config: Dict, args: Dict) -> Dict[str, Any]: + kwargs: Dict[str, Any] = {**args} if config.get("include_event", False): kwargs["event"] = event if config.get("include_context", False): diff --git a/tests/functional/event_handler/test_api_gateway.py b/tests/functional/event_handler/test_api_gateway.py index 517c167ce5c..e68e8b412a1 100644 --- a/tests/functional/event_handler/test_api_gateway.py +++ b/tests/functional/event_handler/test_api_gateway.py @@ -45,6 +45,20 @@ def get_lambda(event: APIGatewayProxyEvent, context: LambdaContext): assert result["headers"]["Content-Type"] == "application/json" +def test_include_rule_matching(): + app = ApiGatewayResolver() + + @app.get("//") + def get_lambda(my_id: str, name: str): + assert name == "my" + return 200, "plain/html", my_id + + result = app(load_event("apiGatewayProxyEvent.json"), {}) + + assert result["statusCode"] == 200 + assert result["body"] == "path" + + def test_include_event_false(): app = ApiGatewayResolver() @@ -74,7 +88,7 @@ def my_path(event: APIGatewayProxyEventV2): assert result["body"] == "tom" -def test_no_matching(): +def test_no_matches(): app = ApiGatewayResolver() @app.get("/not_matching_get") From 670f87263e3f64a2f19a5c94d1eeb61a57b28d43 Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Sun, 28 Mar 2021 00:01:17 -0700 Subject: [PATCH 03/36] fix(event-handler): Python 3.6 support --- .../utilities/event_handler/api_gateway.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/aws_lambda_powertools/utilities/event_handler/api_gateway.py b/aws_lambda_powertools/utilities/event_handler/api_gateway.py index 751f2abc5d3..3f2c3dfbcba 100644 --- a/aws_lambda_powertools/utilities/event_handler/api_gateway.py +++ b/aws_lambda_powertools/utilities/event_handler/api_gateway.py @@ -1,7 +1,6 @@ import re from enum import Enum -from re import Match, Pattern -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Tuple, Union from aws_lambda_powertools.utilities.data_classes import ALBEvent, APIGatewayProxyEvent, APIGatewayProxyEventV2 from aws_lambda_powertools.utilities.data_classes.common import BaseProxyEvent @@ -77,7 +76,7 @@ def _register( ) @staticmethod - def _build_rule_pattern(rule: str) -> Pattern: + def _build_rule_pattern(rule: str): rule_regex: str = re.sub(r"(<\w+>)", r"(?P\1.+)", rule) return re.compile("^{}$".format(rule_regex)) @@ -94,7 +93,7 @@ def _find_resolver(self, http_method: str, path: str) -> Tuple[Callable, Dict, D if http_method != expected_method: continue - match: Optional[Match] = resolver["rule_pattern"].match(path) + match = resolver["rule_pattern"].match(path) if match: return resolver["func"], resolver["config"], match.groupdict() From 633bff11ecc5cfd86d4a1a4b11c8b6a0ebc6c514 Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Sun, 28 Mar 2021 12:18:51 -0700 Subject: [PATCH 04/36] refactor(event-handler): Add lambda_context and current_request to app --- .../utilities/data_classes/common.py | 15 +++- .../utilities/event_handler/api_gateway.py | 68 ++++++------------- .../event_handler/test_api_gateway.py | 25 ++++--- 3 files changed, 46 insertions(+), 62 deletions(-) diff --git a/aws_lambda_powertools/utilities/data_classes/common.py b/aws_lambda_powertools/utilities/data_classes/common.py index 59224f233c7..bad13fd9093 100644 --- a/aws_lambda_powertools/utilities/data_classes/common.py +++ b/aws_lambda_powertools/utilities/data_classes/common.py @@ -1,3 +1,4 @@ +import json from typing import Any, Dict, Optional @@ -26,7 +27,10 @@ def raw_event(self) -> Dict[str, Any]: def get_header_value( - headers: Dict[str, str], name: str, default_value: Optional[str], case_sensitive: Optional[bool] + headers: Dict[str, str], + name: str, + default_value: Optional[str], + case_sensitive: Optional[bool], ) -> Optional[str]: """Get header value by name""" if case_sensitive: @@ -59,6 +63,10 @@ def is_base64_encoded(self) -> Optional[bool]: def body(self) -> Optional[str]: return self.get("body") + @property + def json_body(self) -> Dict: + return json.loads(self["body"]) + @property def path(self) -> str: return self["path"] @@ -86,7 +94,10 @@ def get_query_string_value(self, name: str, default_value: Optional[str] = None) return default_value if params is None else params.get(name, default_value) def get_header_value( - self, name: str, default_value: Optional[str] = None, case_sensitive: Optional[bool] = False + self, + name: str, + default_value: Optional[str] = None, + case_sensitive: Optional[bool] = False, ) -> Optional[str]: """Get header value by name diff --git a/aws_lambda_powertools/utilities/event_handler/api_gateway.py b/aws_lambda_powertools/utilities/event_handler/api_gateway.py index 3f2c3dfbcba..6ad0911ae65 100644 --- a/aws_lambda_powertools/utilities/event_handler/api_gateway.py +++ b/aws_lambda_powertools/utilities/event_handler/api_gateway.py @@ -15,42 +15,38 @@ class ProxyEventType(Enum): class ApiGatewayResolver: def __init__(self, proxy_type: Enum = ProxyEventType.http_api_v1): + self.current_request: BaseProxyEvent + self.lambda_context: LambdaContext self._proxy_type: Enum = proxy_type self._resolvers: List[Dict] = [] - def get(self, rule: str, include_event: bool = False, include_context: bool = False, **kwargs): - return self.route("GET", rule, include_event, include_context, **kwargs) + def get(self, rule: str): + return self.route("GET", rule) - def post(self, rule: str, include_event: bool = False, include_context: bool = False, **kwargs): - return self.route("POST", rule, include_event, include_context, **kwargs) + def post(self, rule: str): + return self.route("POST", rule) - def put(self, rule: str, include_event: bool = False, include_context: bool = False, **kwargs): - return self.route("PUT", rule, include_event, include_context, **kwargs) + def put(self, rule: str): + return self.route("PUT", rule) - def delete(self, rule: str, include_event: bool = False, include_context: bool = False, **kwargs): - return self.route("DELETE", rule, include_event, include_context, **kwargs) + def delete(self, rule: str): + return self.route("DELETE", rule) - def route( - self, - method: str, - rule: str, - include_event: bool = False, - include_context: bool = False, - **kwargs, - ): + def route(self, method: str, rule: str): def register_resolver(func: Callable[[Any, Any], Tuple[int, str, str]]): - self._register(func, method.upper(), rule, include_event, include_context, kwargs) + self._register(func, method.upper(), rule) return func return register_resolver def resolve(self, event: Dict, context: LambdaContext) -> Dict: - proxy_event: BaseProxyEvent = self._as_proxy_event(event) + # NOTE: We are doing a late initialization of current_request and lambda_context + self.current_request = self._as_proxy_event(event) + self.lambda_context = context + resolver: Callable[[Any], Tuple[int, str, str]] - config: Dict - resolver, config, args = self._find_resolver(proxy_event.http_method.upper(), proxy_event.path) - kwargs = self._kwargs(proxy_event, context, config, args) - result = resolver(**kwargs) + resolver, args = self._find_resolver(self.current_request.http_method.upper(), self.current_request.path) + result = resolver(**args) return {"statusCode": result[0], "headers": {"Content-Type": result[1]}, "body": result[2]} def _register( @@ -58,21 +54,9 @@ def _register( func: Callable[[Any, Any], Tuple[int, str, str]], http_method: str, rule: str, - include_event: bool, - include_context: bool, - kwargs: Dict, ): - kwargs["include_event"] = include_event - kwargs["include_context"] = include_context - rule_pattern = self._build_rule_pattern(rule) - self._resolvers.append( - { - "http_method": http_method, - "rule_pattern": rule_pattern, - "func": func, - "config": kwargs, - } + {"http_method": http_method, "rule_pattern": self._build_rule_pattern(rule), "func": func} ) @staticmethod @@ -87,26 +71,16 @@ def _as_proxy_event(self, event: Dict) -> Union[ALBEvent, APIGatewayProxyEvent, return APIGatewayProxyEventV2(event) return ALBEvent(event) - def _find_resolver(self, http_method: str, path: str) -> Tuple[Callable, Dict, Dict]: + def _find_resolver(self, http_method: str, path: str) -> Tuple[Callable, Dict]: for resolver in self._resolvers: expected_method = resolver["http_method"] if http_method != expected_method: continue - match = resolver["rule_pattern"].match(path) if match: - return resolver["func"], resolver["config"], match.groupdict() + return resolver["func"], match.groupdict() raise ValueError(f"No resolver found for '{http_method}.{path}'") - @staticmethod - def _kwargs(event: BaseProxyEvent, context: LambdaContext, config: Dict, args: Dict) -> Dict[str, Any]: - kwargs: Dict[str, Any] = {**args} - if config.get("include_event", False): - kwargs["event"] = event - if config.get("include_context", False): - kwargs["context"] = context - return kwargs - def __call__(self, event, context) -> Any: return self.resolve(event, context) diff --git a/tests/functional/event_handler/test_api_gateway.py b/tests/functional/event_handler/test_api_gateway.py index e68e8b412a1..48fd05f2762 100644 --- a/tests/functional/event_handler/test_api_gateway.py +++ b/tests/functional/event_handler/test_api_gateway.py @@ -5,7 +5,6 @@ from aws_lambda_powertools.utilities.data_classes import ALBEvent, APIGatewayProxyEvent, APIGatewayProxyEventV2 from aws_lambda_powertools.utilities.event_handler.api_gateway import ApiGatewayResolver, ProxyEventType -from aws_lambda_powertools.utilities.typing import LambdaContext def load_event(file_name: str) -> dict: @@ -17,10 +16,10 @@ def load_event(file_name: str) -> dict: def test_alb_event(): app = ApiGatewayResolver(proxy_type=ProxyEventType.alb_event) - @app.get("/lambda", include_event=True, include_context=True) - def foo(event: ALBEvent, context: LambdaContext): - assert isinstance(event, ALBEvent) - assert context == {} + @app.get("/lambda") + def foo(): + assert isinstance(app.current_request, ALBEvent) + assert app.lambda_context == {} return 200, "text/html", "foo" result = app(load_event("albEvent.json"), {}) @@ -33,10 +32,10 @@ def foo(event: ALBEvent, context: LambdaContext): def test_api_gateway_v1(): app = ApiGatewayResolver(proxy_type=ProxyEventType.http_api_v1) - @app.get("/my/path", include_event=True, include_context=True) - def get_lambda(event: APIGatewayProxyEvent, context: LambdaContext): - assert isinstance(event, APIGatewayProxyEvent) - assert context == {} + @app.get("/my/path") + def get_lambda(): + assert isinstance(app.current_request, APIGatewayProxyEvent) + assert app.lambda_context == {} return 200, "application/json", json.dumps({"foo": "value"}) result = app(load_event("apiGatewayProxyEvent.json"), {}) @@ -75,10 +74,10 @@ def get_lambda(): def test_api_gateway_v2(): app = ApiGatewayResolver(proxy_type=ProxyEventType.http_api_v2) - @app.post("/my/path", include_event=True) - def my_path(event: APIGatewayProxyEventV2): - assert isinstance(event, APIGatewayProxyEventV2) - post_data = json.loads(event.body or "{}") + @app.post("/my/path") + def my_path(): + assert isinstance(app.current_request, APIGatewayProxyEventV2) + post_data = app.current_request.json_body return 200, "plain/text", post_data["username"] result = app(load_event("apiGatewayProxyV2Event.json"), {}) From ad88b0bf08b6ca588070db7d7e5ed206a72dcd86 Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Sun, 28 Mar 2021 13:48:32 -0700 Subject: [PATCH 05/36] refactor(event-handler): Resolv Pycharm warnings --- .../utilities/event_handler/api_gateway.py | 20 +++++++------------ 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/aws_lambda_powertools/utilities/event_handler/api_gateway.py b/aws_lambda_powertools/utilities/event_handler/api_gateway.py index 6ad0911ae65..fd89869205e 100644 --- a/aws_lambda_powertools/utilities/event_handler/api_gateway.py +++ b/aws_lambda_powertools/utilities/event_handler/api_gateway.py @@ -1,6 +1,6 @@ import re from enum import Enum -from typing import Any, Callable, Dict, List, Tuple, Union +from typing import Any, Callable, Dict, List, Tuple from aws_lambda_powertools.utilities.data_classes import ALBEvent, APIGatewayProxyEvent, APIGatewayProxyEventV2 from aws_lambda_powertools.utilities.data_classes.common import BaseProxyEvent @@ -14,10 +14,11 @@ class ProxyEventType(Enum): class ApiGatewayResolver: + current_request: BaseProxyEvent + lambda_context: LambdaContext + def __init__(self, proxy_type: Enum = ProxyEventType.http_api_v1): - self.current_request: BaseProxyEvent - self.lambda_context: LambdaContext - self._proxy_type: Enum = proxy_type + self._proxy_type = proxy_type self._resolvers: List[Dict] = [] def get(self, rule: str): @@ -40,21 +41,14 @@ def register_resolver(func: Callable[[Any, Any], Tuple[int, str, str]]): return register_resolver def resolve(self, event: Dict, context: LambdaContext) -> Dict: - # NOTE: We are doing a late initialization of current_request and lambda_context self.current_request = self._as_proxy_event(event) self.lambda_context = context - resolver: Callable[[Any], Tuple[int, str, str]] resolver, args = self._find_resolver(self.current_request.http_method.upper(), self.current_request.path) result = resolver(**args) return {"statusCode": result[0], "headers": {"Content-Type": result[1]}, "body": result[2]} - def _register( - self, - func: Callable[[Any, Any], Tuple[int, str, str]], - http_method: str, - rule: str, - ): + def _register(self, func: Callable[[Any, Any], Tuple[int, str, str]], http_method: str, rule: str): self._resolvers.append( {"http_method": http_method, "rule_pattern": self._build_rule_pattern(rule), "func": func} ) @@ -64,7 +58,7 @@ def _build_rule_pattern(rule: str): rule_regex: str = re.sub(r"(<\w+>)", r"(?P\1.+)", rule) return re.compile("^{}$".format(rule_regex)) - def _as_proxy_event(self, event: Dict) -> Union[ALBEvent, APIGatewayProxyEvent, APIGatewayProxyEventV2]: + def _as_proxy_event(self, event: Dict) -> BaseProxyEvent: if self._proxy_type == ProxyEventType.http_api_v1: return APIGatewayProxyEvent(event) if self._proxy_type == ProxyEventType.http_api_v2: From 8b72bd262e05f263450d9af414291f751e57fe2a Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Sun, 28 Mar 2021 18:19:11 -0700 Subject: [PATCH 06/36] chore(event-handler): Refactoring --- .../utilities/event_handler/api_gateway.py | 49 ++++++++++--------- 1 file changed, 25 insertions(+), 24 deletions(-) diff --git a/aws_lambda_powertools/utilities/event_handler/api_gateway.py b/aws_lambda_powertools/utilities/event_handler/api_gateway.py index fd89869205e..5b530832de3 100644 --- a/aws_lambda_powertools/utilities/event_handler/api_gateway.py +++ b/aws_lambda_powertools/utilities/event_handler/api_gateway.py @@ -16,42 +16,32 @@ class ProxyEventType(Enum): class ApiGatewayResolver: current_request: BaseProxyEvent lambda_context: LambdaContext + _routes: List[Dict] = [] def __init__(self, proxy_type: Enum = ProxyEventType.http_api_v1): self._proxy_type = proxy_type - self._resolvers: List[Dict] = [] def get(self, rule: str): - return self.route("GET", rule) + return self.route(rule, "GET") def post(self, rule: str): - return self.route("POST", rule) + return self.route(rule, "POST") def put(self, rule: str): - return self.route("PUT", rule) + return self.route(rule, "PUT") def delete(self, rule: str): - return self.route("DELETE", rule) + return self.route(rule, "DELETE") - def route(self, method: str, rule: str): + def route(self, rule: str, method: str): def register_resolver(func: Callable[[Any, Any], Tuple[int, str, str]]): - self._register(func, method.upper(), rule) + self._register(func, method, rule) return func return register_resolver - def resolve(self, event: Dict, context: LambdaContext) -> Dict: - self.current_request = self._as_proxy_event(event) - self.lambda_context = context - resolver: Callable[[Any], Tuple[int, str, str]] - resolver, args = self._find_resolver(self.current_request.http_method.upper(), self.current_request.path) - result = resolver(**args) - return {"statusCode": result[0], "headers": {"Content-Type": result[1]}, "body": result[2]} - - def _register(self, func: Callable[[Any, Any], Tuple[int, str, str]], http_method: str, rule: str): - self._resolvers.append( - {"http_method": http_method, "rule_pattern": self._build_rule_pattern(rule), "func": func} - ) + def _register(self, func: Callable[[Any, Any], Tuple[int, str, str]], method: str, rule: str): + self._routes.append({"method": method.upper(), "rule_pattern": self._build_rule_pattern(rule), "func": func}) @staticmethod def _build_rule_pattern(rule: str): @@ -65,16 +55,27 @@ def _as_proxy_event(self, event: Dict) -> BaseProxyEvent: return APIGatewayProxyEventV2(event) return ALBEvent(event) - def _find_resolver(self, http_method: str, path: str) -> Tuple[Callable, Dict]: - for resolver in self._resolvers: - expected_method = resolver["http_method"] - if http_method != expected_method: + def _find_route(self, method: str, path: str) -> Tuple[Callable, Dict]: + method = method.upper() + for resolver in self._routes: + if method != resolver["method"]: continue match = resolver["rule_pattern"].match(path) if match: return resolver["func"], match.groupdict() - raise ValueError(f"No resolver found for '{http_method}.{path}'") + raise ValueError(f"No route found for '{method}.{path}'") + + def resolve(self, event: Dict, context: LambdaContext) -> Dict: + self.current_request = self._as_proxy_event(event) + self.lambda_context = context + + resolver: Callable[[Any], Tuple[int, str, str]] + resolver, args = self._find_route(self.current_request.http_method, self.current_request.path) + + result = resolver(**args) + + return {"statusCode": result[0], "headers": {"Content-Type": result[1]}, "body": result[2]} def __call__(self, event, context) -> Any: return self.resolve(event, context) From d66f3802589abece06ba4c484e2fb8e170b99221 Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Sun, 28 Mar 2021 18:37:37 -0700 Subject: [PATCH 07/36] feat(event-handler): Ensure we reset routes in __init__ --- .../utilities/event_handler/api_gateway.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/aws_lambda_powertools/utilities/event_handler/api_gateway.py b/aws_lambda_powertools/utilities/event_handler/api_gateway.py index 5b530832de3..47ebebe1fab 100644 --- a/aws_lambda_powertools/utilities/event_handler/api_gateway.py +++ b/aws_lambda_powertools/utilities/event_handler/api_gateway.py @@ -16,10 +16,10 @@ class ProxyEventType(Enum): class ApiGatewayResolver: current_request: BaseProxyEvent lambda_context: LambdaContext - _routes: List[Dict] = [] def __init__(self, proxy_type: Enum = ProxyEventType.http_api_v1): self._proxy_type = proxy_type + self._routes: List[Dict] = [] def get(self, rule: str): return self.route(rule, "GET") @@ -35,13 +35,13 @@ def delete(self, rule: str): def route(self, rule: str, method: str): def register_resolver(func: Callable[[Any, Any], Tuple[int, str, str]]): - self._register(func, method, rule) + self._register(func, rule, method) return func return register_resolver - def _register(self, func: Callable[[Any, Any], Tuple[int, str, str]], method: str, rule: str): - self._routes.append({"method": method.upper(), "rule_pattern": self._build_rule_pattern(rule), "func": func}) + def _register(self, func: Callable[[Any, Any], Tuple[int, str, str]], rule: str, method: str): + self._routes.append({"rule_pattern": self._build_rule_pattern(rule), "method": method.upper(), "func": func}) @staticmethod def _build_rule_pattern(rule: str): From 001e8a99a14f2788d5a5358a161e3ad110647261 Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Mon, 29 Mar 2021 16:09:13 -0700 Subject: [PATCH 08/36] refactor(event-handler): Rename to recent_event --- .../utilities/event_handler/api_gateway.py | 31 ++++++++++--------- .../event_handler/test_api_gateway.py | 8 ++--- 2 files changed, 20 insertions(+), 19 deletions(-) diff --git a/aws_lambda_powertools/utilities/event_handler/api_gateway.py b/aws_lambda_powertools/utilities/event_handler/api_gateway.py index 47ebebe1fab..ab966172f92 100644 --- a/aws_lambda_powertools/utilities/event_handler/api_gateway.py +++ b/aws_lambda_powertools/utilities/event_handler/api_gateway.py @@ -1,6 +1,6 @@ import re from enum import Enum -from typing import Any, Callable, Dict, List, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple from aws_lambda_powertools.utilities.data_classes import ALBEvent, APIGatewayProxyEvent, APIGatewayProxyEventV2 from aws_lambda_powertools.utilities.data_classes.common import BaseProxyEvent @@ -14,7 +14,7 @@ class ProxyEventType(Enum): class ApiGatewayResolver: - current_request: BaseProxyEvent + current_event: BaseProxyEvent lambda_context: LambdaContext def __init__(self, proxy_type: Enum = ProxyEventType.http_api_v1): @@ -40,6 +40,17 @@ def register_resolver(func: Callable[[Any, Any], Tuple[int, str, str]]): return register_resolver + def resolve(self, event: Dict, context: LambdaContext) -> Dict: + self.current_event = self._as_current_event(event) + self.lambda_context = context + + resolver: Callable[[Any], Tuple[int, str, str]] + resolver, args = self._find_route(self.current_event.http_method, self.current_event.path) + + result = resolver(**args) + + return {"statusCode": result[0], "headers": {"Content-Type": result[1]}, "body": result[2]} + def _register(self, func: Callable[[Any, Any], Tuple[int, str, str]], rule: str, method: str): self._routes.append({"rule_pattern": self._build_rule_pattern(rule), "method": method.upper(), "func": func}) @@ -48,7 +59,7 @@ def _build_rule_pattern(rule: str): rule_regex: str = re.sub(r"(<\w+>)", r"(?P\1.+)", rule) return re.compile("^{}$".format(rule_regex)) - def _as_proxy_event(self, event: Dict) -> BaseProxyEvent: + def _as_current_event(self, event: Dict) -> BaseProxyEvent: if self._proxy_type == ProxyEventType.http_api_v1: return APIGatewayProxyEvent(event) if self._proxy_type == ProxyEventType.http_api_v2: @@ -57,25 +68,15 @@ def _as_proxy_event(self, event: Dict) -> BaseProxyEvent: def _find_route(self, method: str, path: str) -> Tuple[Callable, Dict]: method = method.upper() + for resolver in self._routes: if method != resolver["method"]: continue - match = resolver["rule_pattern"].match(path) + match: Optional[re.Match] = resolver["rule_pattern"].match(path) if match: return resolver["func"], match.groupdict() raise ValueError(f"No route found for '{method}.{path}'") - def resolve(self, event: Dict, context: LambdaContext) -> Dict: - self.current_request = self._as_proxy_event(event) - self.lambda_context = context - - resolver: Callable[[Any], Tuple[int, str, str]] - resolver, args = self._find_route(self.current_request.http_method, self.current_request.path) - - result = resolver(**args) - - return {"statusCode": result[0], "headers": {"Content-Type": result[1]}, "body": result[2]} - def __call__(self, event, context) -> Any: return self.resolve(event, context) diff --git a/tests/functional/event_handler/test_api_gateway.py b/tests/functional/event_handler/test_api_gateway.py index 48fd05f2762..927537740e9 100644 --- a/tests/functional/event_handler/test_api_gateway.py +++ b/tests/functional/event_handler/test_api_gateway.py @@ -18,7 +18,7 @@ def test_alb_event(): @app.get("/lambda") def foo(): - assert isinstance(app.current_request, ALBEvent) + assert isinstance(app.current_event, ALBEvent) assert app.lambda_context == {} return 200, "text/html", "foo" @@ -34,7 +34,7 @@ def test_api_gateway_v1(): @app.get("/my/path") def get_lambda(): - assert isinstance(app.current_request, APIGatewayProxyEvent) + assert isinstance(app.current_event, APIGatewayProxyEvent) assert app.lambda_context == {} return 200, "application/json", json.dumps({"foo": "value"}) @@ -76,8 +76,8 @@ def test_api_gateway_v2(): @app.post("/my/path") def my_path(): - assert isinstance(app.current_request, APIGatewayProxyEventV2) - post_data = app.current_request.json_body + assert isinstance(app.current_event, APIGatewayProxyEventV2) + post_data = app.current_event.json_body return 200, "plain/text", post_data["username"] result = app(load_event("apiGatewayProxyV2Event.json"), {}) From 8c35ce7c197c057a091fa9f9691969d2d83d7364 Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Wed, 31 Mar 2021 00:27:03 -0700 Subject: [PATCH 09/36] chore(event-handler): Refactor name --- .../{utilities => }/event_handler/api_gateway.py | 0 aws_lambda_powertools/utilities/event_handler/__init__.py | 0 tests/functional/event_handler/test_api_gateway.py | 2 +- 3 files changed, 1 insertion(+), 1 deletion(-) rename aws_lambda_powertools/{utilities => }/event_handler/api_gateway.py (100%) delete mode 100644 aws_lambda_powertools/utilities/event_handler/__init__.py diff --git a/aws_lambda_powertools/utilities/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py similarity index 100% rename from aws_lambda_powertools/utilities/event_handler/api_gateway.py rename to aws_lambda_powertools/event_handler/api_gateway.py diff --git a/aws_lambda_powertools/utilities/event_handler/__init__.py b/aws_lambda_powertools/utilities/event_handler/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/tests/functional/event_handler/test_api_gateway.py b/tests/functional/event_handler/test_api_gateway.py index 927537740e9..5ee9d3b8984 100644 --- a/tests/functional/event_handler/test_api_gateway.py +++ b/tests/functional/event_handler/test_api_gateway.py @@ -3,8 +3,8 @@ import pytest +from aws_lambda_powertools.event_handler.api_gateway import ApiGatewayResolver, ProxyEventType from aws_lambda_powertools.utilities.data_classes import ALBEvent, APIGatewayProxyEvent, APIGatewayProxyEventV2 -from aws_lambda_powertools.utilities.event_handler.api_gateway import ApiGatewayResolver, ProxyEventType def load_event(file_name: str) -> dict: From 4a30ddc3283a8e1c2373479c84014d3eebe3ebd1 Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Wed, 31 Mar 2021 00:30:14 -0700 Subject: [PATCH 10/36] chore: Refactor --- aws_lambda_powertools/utilities/data_classes/common.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/aws_lambda_powertools/utilities/data_classes/common.py b/aws_lambda_powertools/utilities/data_classes/common.py index bad13fd9093..93f416f5283 100644 --- a/aws_lambda_powertools/utilities/data_classes/common.py +++ b/aws_lambda_powertools/utilities/data_classes/common.py @@ -27,10 +27,7 @@ def raw_event(self) -> Dict[str, Any]: def get_header_value( - headers: Dict[str, str], - name: str, - default_value: Optional[str], - case_sensitive: Optional[bool], + headers: Dict[str, str], name: str, default_value: Optional[str], case_sensitive: Optional[bool] ) -> Optional[str]: """Get header value by name""" if case_sensitive: @@ -94,10 +91,7 @@ def get_query_string_value(self, name: str, default_value: Optional[str] = None) return default_value if params is None else params.get(name, default_value) def get_header_value( - self, - name: str, - default_value: Optional[str] = None, - case_sensitive: Optional[bool] = False, + self, name: str, default_value: Optional[str] = None, case_sensitive: Optional[bool] = False ) -> Optional[str]: """Get header value by name From 60303402aec9826371fb16e2291020e15a9dc902 Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Sat, 3 Apr 2021 02:10:31 -0700 Subject: [PATCH 11/36] feat(event-handler): Add mapping for api_gateway --- aws_lambda_powertools/event_handler/api_gateway.py | 9 +++++---- tests/functional/event_handler/test_api_gateway.py | 7 +++++-- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index ab966172f92..9ef79bf297a 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -11,6 +11,7 @@ class ProxyEventType(Enum): http_api_v1 = "APIGatewayProxyEvent" http_api_v2 = "APIGatewayProxyEventV2" alb_event = "ALBEvent" + api_gateway = http_api_v1 class ApiGatewayResolver: @@ -41,7 +42,7 @@ def register_resolver(func: Callable[[Any, Any], Tuple[int, str, str]]): return register_resolver def resolve(self, event: Dict, context: LambdaContext) -> Dict: - self.current_event = self._as_current_event(event) + self.current_event = self._as_data_class(event) self.lambda_context = context resolver: Callable[[Any], Tuple[int, str, str]] @@ -52,14 +53,14 @@ def resolve(self, event: Dict, context: LambdaContext) -> Dict: return {"statusCode": result[0], "headers": {"Content-Type": result[1]}, "body": result[2]} def _register(self, func: Callable[[Any, Any], Tuple[int, str, str]], rule: str, method: str): - self._routes.append({"rule_pattern": self._build_rule_pattern(rule), "method": method.upper(), "func": func}) + self._routes.append({"method": method.upper(), "rule": self._build_rule_pattern(rule), "func": func}) @staticmethod def _build_rule_pattern(rule: str): rule_regex: str = re.sub(r"(<\w+>)", r"(?P\1.+)", rule) return re.compile("^{}$".format(rule_regex)) - def _as_current_event(self, event: Dict) -> BaseProxyEvent: + def _as_data_class(self, event: Dict) -> BaseProxyEvent: if self._proxy_type == ProxyEventType.http_api_v1: return APIGatewayProxyEvent(event) if self._proxy_type == ProxyEventType.http_api_v2: @@ -72,7 +73,7 @@ def _find_route(self, method: str, path: str) -> Tuple[Callable, Dict]: for resolver in self._routes: if method != resolver["method"]: continue - match: Optional[re.Match] = resolver["rule_pattern"].match(path) + match: Optional[re.Match] = resolver["rule"].match(path) if match: return resolver["func"], match.groupdict() diff --git a/tests/functional/event_handler/test_api_gateway.py b/tests/functional/event_handler/test_api_gateway.py index 5ee9d3b8984..498566667c8 100644 --- a/tests/functional/event_handler/test_api_gateway.py +++ b/tests/functional/event_handler/test_api_gateway.py @@ -55,19 +55,22 @@ def get_lambda(my_id: str, name: str): result = app(load_event("apiGatewayProxyEvent.json"), {}) assert result["statusCode"] == 200 + assert result["headers"]["Content-Type"] == "plain/html" assert result["body"] == "path" -def test_include_event_false(): - app = ApiGatewayResolver() +def test_api_gateway(): + app = ApiGatewayResolver(proxy_type=ProxyEventType.api_gateway) @app.get("/my/path") def get_lambda(): + assert isinstance(app.current_event, APIGatewayProxyEvent) return 200, "plain/html", "foo" result = app(load_event("apiGatewayProxyEvent.json"), {}) assert result["statusCode"] == 200 + assert result["headers"]["Content-Type"] == "plain/html" assert result["body"] == "foo" From 1d6ea4d8d333021bd1de524873be22de07b30757 Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Fri, 9 Apr 2021 19:20:28 -0700 Subject: [PATCH 12/36] feat(event-handler): Add cors support to apigw handler --- .../event_handler/api_gateway.py | 58 +++++++++++-------- .../event_handler/test_api_gateway.py | 20 +++++++ 2 files changed, 55 insertions(+), 23 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 9ef79bf297a..f9a4dede4c5 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -14,29 +14,37 @@ class ProxyEventType(Enum): api_gateway = http_api_v1 +class RouteEntry: + def __init__(self, method: str, rule: Any, func: Callable, cors: bool): + self.method = method.upper() + self.rule = rule + self.func = func + self.cors = cors + + class ApiGatewayResolver: current_event: BaseProxyEvent lambda_context: LambdaContext def __init__(self, proxy_type: Enum = ProxyEventType.http_api_v1): self._proxy_type = proxy_type - self._routes: List[Dict] = [] + self._routes: List[RouteEntry] = [] - def get(self, rule: str): - return self.route(rule, "GET") + def get(self, rule: str, cors: bool = False): + return self.route(rule, "GET", cors) - def post(self, rule: str): - return self.route(rule, "POST") + def post(self, rule: str, cors: bool = False): + return self.route(rule, "POST", cors) - def put(self, rule: str): - return self.route(rule, "PUT") + def put(self, rule: str, cors: bool = False): + return self.route(rule, "PUT", cors) - def delete(self, rule: str): - return self.route(rule, "DELETE") + def delete(self, rule: str, cors: bool = False): + return self.route(rule, "DELETE", cors) - def route(self, rule: str, method: str): - def register_resolver(func: Callable[[Any, Any], Tuple[int, str, str]]): - self._register(func, rule, method) + def route(self, rule: str, method: str, cors: bool = False): + def register_resolver(func: Callable): + self._register(func, rule, method, cors) return func return register_resolver @@ -45,15 +53,20 @@ def resolve(self, event: Dict, context: LambdaContext) -> Dict: self.current_event = self._as_data_class(event) self.lambda_context = context - resolver: Callable[[Any], Tuple[int, str, str]] - resolver, args = self._find_route(self.current_event.http_method, self.current_event.path) + route, args = self._find_route(self.current_event.http_method, self.current_event.path) - result = resolver(**args) + result = route.func(**args) - return {"statusCode": result[0], "headers": {"Content-Type": result[1]}, "body": result[2]} + headers = {"Content-Type": result[1]} + if route.cors: + headers["Access-Control-Allow-Origin"] = "*" + headers["Access-Control-Allow-Methods"] = route.method + headers["Access-Control-Allow-Credentials"] = "true" - def _register(self, func: Callable[[Any, Any], Tuple[int, str, str]], rule: str, method: str): - self._routes.append({"method": method.upper(), "rule": self._build_rule_pattern(rule), "func": func}) + return {"statusCode": result[0], "headers": headers, "body": result[2]} + + def _register(self, func: Callable, rule: str, method: str, cors: bool): + self._routes.append(RouteEntry(method, self._build_rule_pattern(rule), func, cors)) @staticmethod def _build_rule_pattern(rule: str): @@ -67,15 +80,14 @@ def _as_data_class(self, event: Dict) -> BaseProxyEvent: return APIGatewayProxyEventV2(event) return ALBEvent(event) - def _find_route(self, method: str, path: str) -> Tuple[Callable, Dict]: + def _find_route(self, method: str, path: str) -> Tuple[RouteEntry, Dict]: method = method.upper() - for resolver in self._routes: - if method != resolver["method"]: + if method != resolver.method: continue - match: Optional[re.Match] = resolver["rule"].match(path) + match: Optional[re.Match] = resolver.rule.match(path) if match: - return resolver["func"], match.groupdict() + return resolver, match.groupdict() raise ValueError(f"No route found for '{method}.{path}'") diff --git a/tests/functional/event_handler/test_api_gateway.py b/tests/functional/event_handler/test_api_gateway.py index 498566667c8..5a8b989a325 100644 --- a/tests/functional/event_handler/test_api_gateway.py +++ b/tests/functional/event_handler/test_api_gateway.py @@ -110,3 +110,23 @@ def handler(event, context): with pytest.raises(ValueError): handler(load_event("apiGatewayProxyEvent.json"), None) + + +def test_cors(): + app = ApiGatewayResolver() + + @app.get("/my/path", cors=True) + def with_cors(): + return 200, "text/html", "test" + + def handler(event, context): + return app.resolve(event, context) + + result = handler(load_event("apiGatewayProxyEvent.json"), None) + + assert "headers" in result + headers = result["headers"] + assert headers["Content-Type"] == "text/html" + assert headers["Access-Control-Allow-Origin"] == "*" + assert headers["Access-Control-Allow-Methods"] == "GET" + assert headers["Access-Control-Allow-Credentials"] == "true" From 306ee73b91359bd2798753ab927e3871c24fe16a Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Fri, 9 Apr 2021 23:05:21 -0700 Subject: [PATCH 13/36] feat(event-handler): apigw compress and base64encode --- .../event_handler/api_gateway.py | 58 ++++++++++++------- .../event_handler/test_api_gateway.py | 24 ++++++++ 2 files changed, 61 insertions(+), 21 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index f9a4dede4c5..9783ad3fd6c 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -1,6 +1,8 @@ +import base64 import re +import zlib from enum import Enum -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple, Union from aws_lambda_powertools.utilities.data_classes import ALBEvent, APIGatewayProxyEvent, APIGatewayProxyEventV2 from aws_lambda_powertools.utilities.data_classes.common import BaseProxyEvent @@ -15,11 +17,12 @@ class ProxyEventType(Enum): class RouteEntry: - def __init__(self, method: str, rule: Any, func: Callable, cors: bool): + def __init__(self, method: str, rule: Any, func: Callable, cors: bool, compress: bool): self.method = method.upper() self.rule = rule self.func = func self.cors = cors + self.compress = compress class ApiGatewayResolver: @@ -30,21 +33,21 @@ def __init__(self, proxy_type: Enum = ProxyEventType.http_api_v1): self._proxy_type = proxy_type self._routes: List[RouteEntry] = [] - def get(self, rule: str, cors: bool = False): - return self.route(rule, "GET", cors) + def get(self, rule: str, cors: bool = False, compress: bool = False): + return self.route(rule, "GET", cors, compress) - def post(self, rule: str, cors: bool = False): - return self.route(rule, "POST", cors) + def post(self, rule: str, cors: bool = False, compress: bool = False): + return self.route(rule, "POST", cors, compress) - def put(self, rule: str, cors: bool = False): - return self.route(rule, "PUT", cors) + def put(self, rule: str, cors: bool = False, compress: bool = False): + return self.route(rule, "PUT", cors, compress) - def delete(self, rule: str, cors: bool = False): - return self.route(rule, "DELETE", cors) + def delete(self, rule: str, cors: bool = False, compress: bool = False): + return self.route(rule, "DELETE", cors, compress) - def route(self, rule: str, method: str, cors: bool = False): + def route(self, rule: str, method: str, cors: bool = False, compress: bool = False): def register_resolver(func: Callable): - self._register(func, rule, method, cors) + self._register(func, rule, method, cors, compress) return func return register_resolver @@ -54,19 +57,32 @@ def resolve(self, event: Dict, context: LambdaContext) -> Dict: self.lambda_context = context route, args = self._find_route(self.current_event.http_method, self.current_event.path) - result = route.func(**args) - headers = {"Content-Type": result[1]} if route.cors: headers["Access-Control-Allow-Origin"] = "*" headers["Access-Control-Allow-Methods"] = route.method headers["Access-Control-Allow-Credentials"] = "true" - return {"statusCode": result[0], "headers": headers, "body": result[2]} + body: Union[str, bytes] = result[2] + if route.compress and "gzip" in (self.current_event.get_header_value("accept-encoding") or ""): + gzip_compress = zlib.compressobj(9, zlib.DEFLATED, zlib.MAX_WBITS | 16) + if isinstance(body, str): + body = bytes(body, "utf-8") + body = gzip_compress.compress(body) + gzip_compress.flush() + + response = {"statusCode": result[0], "headers": headers} + + if isinstance(body, bytes): + response["isBase64Encoded"] = True + body = base64.b64encode(body).decode() + + response["body"] = body + + return response - def _register(self, func: Callable, rule: str, method: str, cors: bool): - self._routes.append(RouteEntry(method, self._build_rule_pattern(rule), func, cors)) + def _register(self, func: Callable, rule: str, method: str, cors: bool, compress: bool): + self._routes.append(RouteEntry(method, self._build_rule_pattern(rule), func, cors, compress)) @staticmethod def _build_rule_pattern(rule: str): @@ -82,12 +98,12 @@ def _as_data_class(self, event: Dict) -> BaseProxyEvent: def _find_route(self, method: str, path: str) -> Tuple[RouteEntry, Dict]: method = method.upper() - for resolver in self._routes: - if method != resolver.method: + for route in self._routes: + if method != route.method: continue - match: Optional[re.Match] = resolver.rule.match(path) + match: Optional[re.Match] = route.rule.match(path) if match: - return resolver, match.groupdict() + return route, match.groupdict() raise ValueError(f"No route found for '{method}.{path}'") diff --git a/tests/functional/event_handler/test_api_gateway.py b/tests/functional/event_handler/test_api_gateway.py index 5a8b989a325..f4b00d71b3c 100644 --- a/tests/functional/event_handler/test_api_gateway.py +++ b/tests/functional/event_handler/test_api_gateway.py @@ -1,5 +1,7 @@ +import base64 import json import os +import zlib import pytest @@ -130,3 +132,25 @@ def handler(event, context): assert headers["Access-Control-Allow-Origin"] == "*" assert headers["Access-Control-Allow-Methods"] == "GET" assert headers["Access-Control-Allow-Credentials"] == "true" + + +def test_compress(): + mock_event = {"path": "/my/request", "httpMethod": "GET", "headers": {"Accept-Encoding": "deflate, gzip"}} + expected_value = '{"test": "value"}' + + app = ApiGatewayResolver() + + @app.get("/my/request", compress=True) + def with_compression(): + return 200, "application/json", expected_value + + def handler(event, context): + return app.resolve(event, context) + + result = handler(mock_event, None) + + assert result["isBase64Encoded"] is True + body = result["body"] + assert isinstance(body, str) + decompress = zlib.decompress(base64.b64decode(body), wbits=zlib.MAX_WBITS | 16).decode("UTF-8") + assert decompress == expected_value From daaf13789092f446808d77f5cf5da8339a822c94 Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Fri, 9 Apr 2021 23:44:23 -0700 Subject: [PATCH 14/36] feat(event-handler): apigwy cache_control option --- .../event_handler/api_gateway.py | 40 +++++++++++-------- .../event_handler/test_api_gateway.py | 32 +++++++++++++++ 2 files changed, 55 insertions(+), 17 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 9783ad3fd6c..0335f4e93db 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -17,12 +17,15 @@ class ProxyEventType(Enum): class RouteEntry: - def __init__(self, method: str, rule: Any, func: Callable, cors: bool, compress: bool): + def __init__( + self, method: str, rule: Any, func: Callable, cors: bool, compress: bool, cache_control: Optional[str] + ): self.method = method.upper() self.rule = rule self.func = func self.cors = cors self.compress = compress + self.cache_control = cache_control class ApiGatewayResolver: @@ -33,21 +36,21 @@ def __init__(self, proxy_type: Enum = ProxyEventType.http_api_v1): self._proxy_type = proxy_type self._routes: List[RouteEntry] = [] - def get(self, rule: str, cors: bool = False, compress: bool = False): - return self.route(rule, "GET", cors, compress) + def get(self, rule: str, cors: bool = False, compress: bool = False, cache_control: str = None): + return self.route(rule, "GET", cors, compress, cache_control) - def post(self, rule: str, cors: bool = False, compress: bool = False): - return self.route(rule, "POST", cors, compress) + def post(self, rule: str, cors: bool = False, compress: bool = False, cache_control: str = None): + return self.route(rule, "POST", cors, compress, cache_control) - def put(self, rule: str, cors: bool = False, compress: bool = False): - return self.route(rule, "PUT", cors, compress) + def put(self, rule: str, cors: bool = False, compress: bool = False, cache_control: str = None): + return self.route(rule, "PUT", cors, compress, cache_control) - def delete(self, rule: str, cors: bool = False, compress: bool = False): - return self.route(rule, "DELETE", cors, compress) + def delete(self, rule: str, cors: bool = False, compress: bool = False, cache_control: str = None): + return self.route(rule, "DELETE", cors, compress, cache_control) - def route(self, rule: str, method: str, cors: bool = False, compress: bool = False): + def route(self, rule: str, method: str, cors: bool = False, compress: bool = False, cache_control: str = None): def register_resolver(func: Callable): - self._register(func, rule, method, cors, compress) + self._append(func, rule, method, cors, compress, cache_control) return func return register_resolver @@ -58,11 +61,18 @@ def resolve(self, event: Dict, context: LambdaContext) -> Dict: route, args = self._find_route(self.current_event.http_method, self.current_event.path) result = route.func(**args) + + status: int = result[0] + response: Dict[str, Any] = {"statusCode": status} + headers = {"Content-Type": result[1]} if route.cors: headers["Access-Control-Allow-Origin"] = "*" headers["Access-Control-Allow-Methods"] = route.method headers["Access-Control-Allow-Credentials"] = "true" + if route.cache_control: + headers["Cache-Control"] = route.cache_control if status == 200 else "no-cache" + response["headers"] = headers body: Union[str, bytes] = result[2] if route.compress and "gzip" in (self.current_event.get_header_value("accept-encoding") or ""): @@ -70,19 +80,15 @@ def resolve(self, event: Dict, context: LambdaContext) -> Dict: if isinstance(body, str): body = bytes(body, "utf-8") body = gzip_compress.compress(body) + gzip_compress.flush() - - response = {"statusCode": result[0], "headers": headers} - if isinstance(body, bytes): response["isBase64Encoded"] = True body = base64.b64encode(body).decode() - response["body"] = body return response - def _register(self, func: Callable, rule: str, method: str, cors: bool, compress: bool): - self._routes.append(RouteEntry(method, self._build_rule_pattern(rule), func, cors, compress)) + def _append(self, func: Callable, rule: str, method: str, cors: bool, compress: bool, cache_control: Optional[str]): + self._routes.append(RouteEntry(method, self._build_rule_pattern(rule), func, cors, compress, cache_control)) @staticmethod def _build_rule_pattern(rule: str): diff --git a/tests/functional/event_handler/test_api_gateway.py b/tests/functional/event_handler/test_api_gateway.py index f4b00d71b3c..c14e643bc39 100644 --- a/tests/functional/event_handler/test_api_gateway.py +++ b/tests/functional/event_handler/test_api_gateway.py @@ -154,3 +154,35 @@ def handler(event, context): assert isinstance(body, str) decompress = zlib.decompress(base64.b64decode(body), wbits=zlib.MAX_WBITS | 16).decode("UTF-8") assert decompress == expected_value + + +def test_cache_control_200(): + app = ApiGatewayResolver() + + @app.get("/success", cache_control="max-age=600") + def with_cache_control(): + return 200, "text/html", "has 200 response" + + def handler(event, context): + return app.resolve(event, context) + + result = handler({"path": "/success", "httpMethod": "GET"}, None) + + headers = result["headers"] + assert headers["Cache-Control"] == "max-age=600" + + +def test_cache_control_non_200(): + app = ApiGatewayResolver() + + @app.delete("/fails", cache_control="max-age=600") + def with_cache_control_has_500(): + return 503, "text/html", "has 503 response" + + def handler(event, context): + return app.resolve(event, context) + + result = handler({"path": "/fails", "httpMethod": "DELETE"}, None) + + headers = result["headers"] + assert headers["Cache-Control"] == "no-cache" From 6f6a55c8999ebe7dc8e825b4ec5cfe04ecdc8f34 Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Sat, 10 Apr 2021 08:35:10 -0700 Subject: [PATCH 15/36] refactor(event-handler): Code cleanup --- .../event_handler/api_gateway.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 0335f4e93db..ba14281c90b 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -61,31 +61,31 @@ def resolve(self, event: Dict, context: LambdaContext) -> Dict: route, args = self._find_route(self.current_event.http_method, self.current_event.path) result = route.func(**args) + status_code: int = result[0] + content_type: str = result[1] + body: Union[str, bytes] = result[2] + headers = {"Content-Type": content_type} - status: int = result[0] - response: Dict[str, Any] = {"statusCode": status} - - headers = {"Content-Type": result[1]} if route.cors: headers["Access-Control-Allow-Origin"] = "*" headers["Access-Control-Allow-Methods"] = route.method headers["Access-Control-Allow-Credentials"] = "true" + if route.cache_control: - headers["Cache-Control"] = route.cache_control if status == 200 else "no-cache" - response["headers"] = headers + headers["Cache-Control"] = route.cache_control if status_code == 200 else "no-cache" - body: Union[str, bytes] = result[2] if route.compress and "gzip" in (self.current_event.get_header_value("accept-encoding") or ""): - gzip_compress = zlib.compressobj(9, zlib.DEFLATED, zlib.MAX_WBITS | 16) if isinstance(body, str): body = bytes(body, "utf-8") + gzip_compress = zlib.compressobj(9, zlib.DEFLATED, zlib.MAX_WBITS | 16) body = gzip_compress.compress(body) + gzip_compress.flush() + + base64_encoded = False if isinstance(body, bytes): - response["isBase64Encoded"] = True + base64_encoded = True body = base64.b64encode(body).decode() - response["body"] = body - return response + return {"statusCode": status_code, "headers": headers, "body": body, "isBase64Encoded": base64_encoded} def _append(self, func: Callable, rule: str, method: str, cors: bool, compress: bool, cache_control: Optional[str]): self._routes.append(RouteEntry(method, self._build_rule_pattern(rule), func, cors, compress, cache_control)) From 1484ac9fcc75901d7f39c78e618182c4ece2dbed Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Sat, 10 Apr 2021 08:35:43 -0700 Subject: [PATCH 16/36] tests(event-handler): Add missing binary handling --- .../event_handler/test_api_gateway.py | 29 ++++++++++++++++--- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/tests/functional/event_handler/test_api_gateway.py b/tests/functional/event_handler/test_api_gateway.py index c14e643bc39..6fe953be3c0 100644 --- a/tests/functional/event_handler/test_api_gateway.py +++ b/tests/functional/event_handler/test_api_gateway.py @@ -1,7 +1,7 @@ import base64 import json -import os import zlib +from pathlib import Path import pytest @@ -10,9 +10,13 @@ def load_event(file_name: str) -> dict: - full_file_name = os.path.dirname(os.path.realpath(__file__)) + "/../../events/" + file_name - with open(full_file_name) as fp: - return json.load(fp) + path = Path(str(Path(__file__).parent.parent.parent) + "/events/" + file_name) + return json.loads(path.read_text()) + + +def read_media(file_name: str) -> bytes: + path = Path(str(Path(__file__).parent.parent.parent.parent) + "/docs/media/" + file_name) + return path.read_bytes() def test_alb_event(): @@ -156,6 +160,21 @@ def handler(event, context): assert decompress == expected_value +def test_base64_encode(): + app = ApiGatewayResolver() + + @app.get("/my/path", compress=True) + def read_image(): + return 200, "image/png", read_media("idempotent_sequence_exception.png") + + mock_event = {"path": "/my/path", "httpMethod": "GET", "headers": {"Accept-Encoding": "deflate, gzip"}} + result = app(mock_event, None) + + assert result["isBase64Encoded"] is True + body = result["body"] + assert isinstance(body, str) + + def test_cache_control_200(): app = ApiGatewayResolver() @@ -169,6 +188,7 @@ def handler(event, context): result = handler({"path": "/success", "httpMethod": "GET"}, None) headers = result["headers"] + assert headers["Content-Type"] == "text/html" assert headers["Cache-Control"] == "max-age=600" @@ -185,4 +205,5 @@ def handler(event, context): result = handler({"path": "/fails", "httpMethod": "DELETE"}, None) headers = result["headers"] + assert headers["Content-Type"] == "text/html" assert headers["Cache-Control"] == "no-cache" From 9ee7702a85d24f3ea3b51655bb1febf046c9000d Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Sat, 10 Apr 2021 12:00:35 -0700 Subject: [PATCH 17/36] fix(event-handler): Set Content-Encoding header for compress --- .../event_handler/api_gateway.py | 5 +++-- .../event_handler/test_api_gateway.py | 18 ++++++++++++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index ba14281c90b..cae77a5d71c 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -75,10 +75,11 @@ def resolve(self, event: Dict, context: LambdaContext) -> Dict: headers["Cache-Control"] = route.cache_control if status_code == 200 else "no-cache" if route.compress and "gzip" in (self.current_event.get_header_value("accept-encoding") or ""): + headers["Content-Encoding"] = "gzip" if isinstance(body, str): body = bytes(body, "utf-8") - gzip_compress = zlib.compressobj(9, zlib.DEFLATED, zlib.MAX_WBITS | 16) - body = gzip_compress.compress(body) + gzip_compress.flush() + gzip = zlib.compressobj(9, zlib.DEFLATED, zlib.MAX_WBITS | 16) + body = gzip.compress(body) + gzip.flush() base64_encoded = False if isinstance(body, bytes): diff --git a/tests/functional/event_handler/test_api_gateway.py b/tests/functional/event_handler/test_api_gateway.py index 6fe953be3c0..178984a6fa6 100644 --- a/tests/functional/event_handler/test_api_gateway.py +++ b/tests/functional/event_handler/test_api_gateway.py @@ -158,6 +158,8 @@ def handler(event, context): assert isinstance(body, str) decompress = zlib.decompress(base64.b64decode(body), wbits=zlib.MAX_WBITS | 16).decode("UTF-8") assert decompress == expected_value + headers = result["headers"] + assert headers["Content-Encoding"] == "gzip" def test_base64_encode(): @@ -173,6 +175,22 @@ def read_image(): assert result["isBase64Encoded"] is True body = result["body"] assert isinstance(body, str) + headers = result["headers"] + assert headers["Content-Encoding"] == "gzip" + + +def test_compress_no_accept_encoding(): + app = ApiGatewayResolver() + expected_value = "Foo" + + @app.get("/my/path", compress=True) + def return_text(): + return 200, "text/plain", expected_value + + result = app({"path": "/my/path", "httpMethod": "GET", "headers": {}}, None) + + assert result["isBase64Encoded"] is False + assert result["body"] == expected_value def test_cache_control_200(): From 85b5ff8317e89b48632f53c1212be3f71c1fc622 Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Sat, 10 Apr 2021 16:27:16 -0700 Subject: [PATCH 18/36] feat(event-handler): Add PATCH decorator --- .../event_handler/api_gateway.py | 29 +++++----- .../event_handler/test_api_gateway.py | 56 +++++++++++++------ 2 files changed, 55 insertions(+), 30 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index cae77a5d71c..ee44f30a3f7 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -16,7 +16,7 @@ class ProxyEventType(Enum): api_gateway = http_api_v1 -class RouteEntry: +class Route: def __init__( self, method: str, rule: Any, func: Callable, cors: bool, compress: bool, cache_control: Optional[str] ): @@ -34,7 +34,7 @@ class ApiGatewayResolver: def __init__(self, proxy_type: Enum = ProxyEventType.http_api_v1): self._proxy_type = proxy_type - self._routes: List[RouteEntry] = [] + self._routes: List[Route] = [] def get(self, rule: str, cors: bool = False, compress: bool = False, cache_control: str = None): return self.route(rule, "GET", cors, compress, cache_control) @@ -48,14 +48,25 @@ def put(self, rule: str, cors: bool = False, compress: bool = False, cache_contr def delete(self, rule: str, cors: bool = False, compress: bool = False, cache_control: str = None): return self.route(rule, "DELETE", cors, compress, cache_control) + def patch(self, rule: str, cors: bool = False, compress: bool = False, cache_control: str = None): + return self.route(rule, "PATCH", cors, compress, cache_control) + def route(self, rule: str, method: str, cors: bool = False, compress: bool = False, cache_control: str = None): def register_resolver(func: Callable): - self._append(func, rule, method, cors, compress, cache_control) + self._add(func, rule, method, cors, compress, cache_control) return func return register_resolver - def resolve(self, event: Dict, context: LambdaContext) -> Dict: + def _add(self, func: Callable, rule: str, method: str, cors: bool, compress: bool, cache_control: Optional[str]): + self._routes.append(Route(method, self._build_rule_pattern(rule), func, cors, compress, cache_control)) + + @staticmethod + def _build_rule_pattern(rule: str): + rule_regex: str = re.sub(r"(<\w+>)", r"(?P\1.+)", rule) + return re.compile("^{}$".format(rule_regex)) + + def resolve(self, event, context) -> Dict[str, Any]: self.current_event = self._as_data_class(event) self.lambda_context = context @@ -88,14 +99,6 @@ def resolve(self, event: Dict, context: LambdaContext) -> Dict: return {"statusCode": status_code, "headers": headers, "body": body, "isBase64Encoded": base64_encoded} - def _append(self, func: Callable, rule: str, method: str, cors: bool, compress: bool, cache_control: Optional[str]): - self._routes.append(RouteEntry(method, self._build_rule_pattern(rule), func, cors, compress, cache_control)) - - @staticmethod - def _build_rule_pattern(rule: str): - rule_regex: str = re.sub(r"(<\w+>)", r"(?P\1.+)", rule) - return re.compile("^{}$".format(rule_regex)) - def _as_data_class(self, event: Dict) -> BaseProxyEvent: if self._proxy_type == ProxyEventType.http_api_v1: return APIGatewayProxyEvent(event) @@ -103,7 +106,7 @@ def _as_data_class(self, event: Dict) -> BaseProxyEvent: return APIGatewayProxyEventV2(event) return ALBEvent(event) - def _find_route(self, method: str, path: str) -> Tuple[RouteEntry, Dict]: + def _find_route(self, method: str, path: str) -> Tuple[Route, Dict]: method = method.upper() for route in self._routes: if method != route.method: diff --git a/tests/functional/event_handler/test_api_gateway.py b/tests/functional/event_handler/test_api_gateway.py index 178984a6fa6..dc8efca6d43 100644 --- a/tests/functional/event_handler/test_api_gateway.py +++ b/tests/functional/event_handler/test_api_gateway.py @@ -50,21 +50,6 @@ def get_lambda(): assert result["headers"]["Content-Type"] == "application/json" -def test_include_rule_matching(): - app = ApiGatewayResolver() - - @app.get("//") - def get_lambda(my_id: str, name: str): - assert name == "my" - return 200, "plain/html", my_id - - result = app(load_event("apiGatewayProxyEvent.json"), {}) - - assert result["statusCode"] == 200 - assert result["headers"]["Content-Type"] == "plain/html" - assert result["body"] == "path" - - def test_api_gateway(): app = ApiGatewayResolver(proxy_type=ProxyEventType.api_gateway) @@ -96,6 +81,21 @@ def my_path(): assert result["body"] == "tom" +def test_include_rule_matching(): + app = ApiGatewayResolver() + + @app.get("//") + def get_lambda(my_id: str, name: str): + assert name == "my" + return 200, "plain/html", my_id + + result = app(load_event("apiGatewayProxyEvent.json"), {}) + + assert result["statusCode"] == 200 + assert result["headers"]["Content-Type"] == "plain/html" + assert result["body"] == "path" + + def test_no_matches(): app = ApiGatewayResolver() @@ -103,17 +103,39 @@ def test_no_matches(): def no_get_matching(): raise RuntimeError() - @app.put("/no_matching") + @app.post("/no_matching_post") + def no_post_matching(): + raise RuntimeError() + + @app.put("/no_matching_put") def no_put_matching(): raise RuntimeError() - @app.delete("/no_matching") + @app.delete("/no_matching_delete") def no_delete_matching(): raise RuntimeError() + @app.patch("/no_matching_patch") + def no_patch_matching(): + raise RuntimeError() + def handler(event, context): app.resolve(event, context) + routes = app._routes + assert len(routes) == 5 + for route in routes: + if route.func == no_get_matching: + assert route.method == "GET" + if route.func == no_post_matching: + assert route.method == "POST" + if route.func == no_put_matching: + assert route.method == "PUT" + if route.func == no_delete_matching: + assert route.method == "DELETE" + if route.func == no_patch_matching: + assert route.method == "PATCH" + with pytest.raises(ValueError): handler(load_event("apiGatewayProxyEvent.json"), None) From 0cf5366f3c2e1b2eb008ac492ff20a2da5bb7722 Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Sun, 18 Apr 2021 10:04:11 -0700 Subject: [PATCH 19/36] docs(event-handler): Add some docs to tests --- .../event_handler/api_gateway.py | 15 +++++------- .../event_handler/test_api_gateway.py | 24 +++++++++++-------- 2 files changed, 20 insertions(+), 19 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index ee44f30a3f7..6e641771e79 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -53,19 +53,11 @@ def patch(self, rule: str, cors: bool = False, compress: bool = False, cache_con def route(self, rule: str, method: str, cors: bool = False, compress: bool = False, cache_control: str = None): def register_resolver(func: Callable): - self._add(func, rule, method, cors, compress, cache_control) + self._routes.append(Route(method, self._build_rule_pattern(rule), func, cors, compress, cache_control)) return func return register_resolver - def _add(self, func: Callable, rule: str, method: str, cors: bool, compress: bool, cache_control: Optional[str]): - self._routes.append(Route(method, self._build_rule_pattern(rule), func, cors, compress, cache_control)) - - @staticmethod - def _build_rule_pattern(rule: str): - rule_regex: str = re.sub(r"(<\w+>)", r"(?P\1.+)", rule) - return re.compile("^{}$".format(rule_regex)) - def resolve(self, event, context) -> Dict[str, Any]: self.current_event = self._as_data_class(event) self.lambda_context = context @@ -99,6 +91,11 @@ def resolve(self, event, context) -> Dict[str, Any]: return {"statusCode": status_code, "headers": headers, "body": body, "isBase64Encoded": base64_encoded} + @staticmethod + def _build_rule_pattern(rule: str): + rule_regex: str = re.sub(r"(<\w+>)", r"(?P\1.+)", rule) + return re.compile("^{}$".format(rule_regex)) + def _as_data_class(self, event: Dict) -> BaseProxyEvent: if self._proxy_type == ProxyEventType.http_api_v1: return APIGatewayProxyEvent(event) diff --git a/tests/functional/event_handler/test_api_gateway.py b/tests/functional/event_handler/test_api_gateway.py index dc8efca6d43..97ca5467aa9 100644 --- a/tests/functional/event_handler/test_api_gateway.py +++ b/tests/functional/event_handler/test_api_gateway.py @@ -97,45 +97,49 @@ def get_lambda(my_id: str, name: str): def test_no_matches(): + # GIVEN an event that does not match any of the given routes app = ApiGatewayResolver() @app.get("/not_matching_get") - def no_get_matching(): + def get_func(): raise RuntimeError() @app.post("/no_matching_post") - def no_post_matching(): + def post_func(): raise RuntimeError() @app.put("/no_matching_put") - def no_put_matching(): + def put_func(): raise RuntimeError() @app.delete("/no_matching_delete") - def no_delete_matching(): + def delete_func(): raise RuntimeError() @app.patch("/no_matching_patch") - def no_patch_matching(): + def patch_func(): raise RuntimeError() def handler(event, context): app.resolve(event, context) + # Also check check the route configurations routes = app._routes assert len(routes) == 5 for route in routes: - if route.func == no_get_matching: + if route.func == get_func: assert route.method == "GET" - if route.func == no_post_matching: + elif route.func == post_func: assert route.method == "POST" - if route.func == no_put_matching: + elif route.func == put_func: assert route.method == "PUT" - if route.func == no_delete_matching: + elif route.func == delete_func: assert route.method == "DELETE" - if route.func == no_patch_matching: + elif route.func == patch_func: assert route.method == "PATCH" + # WHEN calling the handler + # THEN raise a ValueError with pytest.raises(ValueError): handler(load_event("apiGatewayProxyEvent.json"), None) From b5a057b98e38e12b9eda563f0a963fcd02a06562 Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Sun, 18 Apr 2021 13:54:54 -0700 Subject: [PATCH 20/36] feat(event-handler): Rest API simplification with function returns a Dict --- .../event_handler/api_gateway.py | 12 ++-- .../event_handler/test_api_gateway.py | 55 +++++++++++++------ 2 files changed, 45 insertions(+), 22 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 6e641771e79..74c49054f13 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -1,4 +1,5 @@ import base64 +import json import re import zlib from enum import Enum @@ -61,12 +62,15 @@ def register_resolver(func: Callable): def resolve(self, event, context) -> Dict[str, Any]: self.current_event = self._as_data_class(event) self.lambda_context = context - route, args = self._find_route(self.current_event.http_method, self.current_event.path) result = route.func(**args) - status_code: int = result[0] - content_type: str = result[1] - body: Union[str, bytes] = result[2] + + if isinstance(result, dict): + status_code = 200 + content_type = "application/json" + body: Union[str, bytes] = json.dumps(result) + else: + status_code, content_type, body = result headers = {"Content-Type": content_type} if route.cors: diff --git a/tests/functional/event_handler/test_api_gateway.py b/tests/functional/event_handler/test_api_gateway.py index 97ca5467aa9..bff61e7f1f8 100644 --- a/tests/functional/event_handler/test_api_gateway.py +++ b/tests/functional/event_handler/test_api_gateway.py @@ -19,6 +19,11 @@ def read_media(file_name: str) -> bytes: return path.read_bytes() +LOAD_GW_EVENT = load_event("apiGatewayProxyEvent.json") +TEXT_HTML = "text/html" +APPLICATION_JSON = "application/json" + + def test_alb_event(): app = ApiGatewayResolver(proxy_type=ProxyEventType.alb_event) @@ -26,12 +31,12 @@ def test_alb_event(): def foo(): assert isinstance(app.current_event, ALBEvent) assert app.lambda_context == {} - return 200, "text/html", "foo" + return 200, TEXT_HTML, "foo" result = app(load_event("albEvent.json"), {}) assert result["statusCode"] == 200 - assert result["headers"]["Content-Type"] == "text/html" + assert result["headers"]["Content-Type"] == TEXT_HTML assert result["body"] == "foo" @@ -42,12 +47,12 @@ def test_api_gateway_v1(): def get_lambda(): assert isinstance(app.current_event, APIGatewayProxyEvent) assert app.lambda_context == {} - return 200, "application/json", json.dumps({"foo": "value"}) + return 200, APPLICATION_JSON, json.dumps({"foo": "value"}) - result = app(load_event("apiGatewayProxyEvent.json"), {}) + result = app(LOAD_GW_EVENT, {}) assert result["statusCode"] == 200 - assert result["headers"]["Content-Type"] == "application/json" + assert result["headers"]["Content-Type"] == APPLICATION_JSON def test_api_gateway(): @@ -56,12 +61,12 @@ def test_api_gateway(): @app.get("/my/path") def get_lambda(): assert isinstance(app.current_event, APIGatewayProxyEvent) - return 200, "plain/html", "foo" + return 200, TEXT_HTML, "foo" - result = app(load_event("apiGatewayProxyEvent.json"), {}) + result = app(LOAD_GW_EVENT, {}) assert result["statusCode"] == 200 - assert result["headers"]["Content-Type"] == "plain/html" + assert result["headers"]["Content-Type"] == TEXT_HTML assert result["body"] == "foo" @@ -89,7 +94,7 @@ def get_lambda(my_id: str, name: str): assert name == "my" return 200, "plain/html", my_id - result = app(load_event("apiGatewayProxyEvent.json"), {}) + result = app(LOAD_GW_EVENT, {}) assert result["statusCode"] == 200 assert result["headers"]["Content-Type"] == "plain/html" @@ -141,7 +146,7 @@ def handler(event, context): # WHEN calling the handler # THEN raise a ValueError with pytest.raises(ValueError): - handler(load_event("apiGatewayProxyEvent.json"), None) + handler(LOAD_GW_EVENT, None) def test_cors(): @@ -149,16 +154,16 @@ def test_cors(): @app.get("/my/path", cors=True) def with_cors(): - return 200, "text/html", "test" + return 200, TEXT_HTML, "test" def handler(event, context): return app.resolve(event, context) - result = handler(load_event("apiGatewayProxyEvent.json"), None) + result = handler(LOAD_GW_EVENT, None) assert "headers" in result headers = result["headers"] - assert headers["Content-Type"] == "text/html" + assert headers["Content-Type"] == TEXT_HTML assert headers["Access-Control-Allow-Origin"] == "*" assert headers["Access-Control-Allow-Methods"] == "GET" assert headers["Access-Control-Allow-Credentials"] == "true" @@ -172,7 +177,7 @@ def test_compress(): @app.get("/my/request", compress=True) def with_compression(): - return 200, "application/json", expected_value + return 200, APPLICATION_JSON, expected_value def handler(event, context): return app.resolve(event, context) @@ -224,7 +229,7 @@ def test_cache_control_200(): @app.get("/success", cache_control="max-age=600") def with_cache_control(): - return 200, "text/html", "has 200 response" + return 200, TEXT_HTML, "has 200 response" def handler(event, context): return app.resolve(event, context) @@ -232,7 +237,7 @@ def handler(event, context): result = handler({"path": "/success", "httpMethod": "GET"}, None) headers = result["headers"] - assert headers["Content-Type"] == "text/html" + assert headers["Content-Type"] == TEXT_HTML assert headers["Cache-Control"] == "max-age=600" @@ -241,7 +246,7 @@ def test_cache_control_non_200(): @app.delete("/fails", cache_control="max-age=600") def with_cache_control_has_500(): - return 503, "text/html", "has 503 response" + return 503, TEXT_HTML, "has 503 response" def handler(event, context): return app.resolve(event, context) @@ -249,5 +254,19 @@ def handler(event, context): result = handler({"path": "/fails", "httpMethod": "DELETE"}, None) headers = result["headers"] - assert headers["Content-Type"] == "text/html" + assert headers["Content-Type"] == TEXT_HTML assert headers["Cache-Control"] == "no-cache" + + +def test_rest_api(): + app = ApiGatewayResolver(proxy_type=ProxyEventType.http_api_v1) + + @app.get("/my/path") + def rest_func(): + return {"foo": "value"} + + result = app(LOAD_GW_EVENT, {}) + + assert result["statusCode"] == 200 + assert result["headers"]["Content-Type"] == APPLICATION_JSON + assert result["body"] == json.dumps({"foo": "value"}) From e7e8d5965d84f54b5d61455f3c4604e6399b811b Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Mon, 19 Apr 2021 10:20:40 -0700 Subject: [PATCH 21/36] feat(event-handler): Add Response class This will allow for fine grained control of the returning headers --- .../event_handler/api_gateway.py | 70 ++++++++++------ .../event_handler/test_api_gateway.py | 84 ++++++++++++++++--- 2 files changed, 117 insertions(+), 37 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 74c49054f13..c9f9c5c01a8 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -29,6 +29,42 @@ def __init__( self.cache_control = cache_control +class Response: + def __init__(self, status_code: int, content_type: str, body: Union[str, bytes], headers: Dict = None): + self.status_code = status_code + self.body = body + self.base64_encoded = False + self.headers: Dict = headers if headers is not None else {} + if "Content-Type" not in self.headers: + self.headers["Content-Type"] = content_type + + def add_cors(self, method: str): + self.headers["Access-Control-Allow-Origin"] = "*" + self.headers["Access-Control-Allow-Methods"] = method + self.headers["Access-Control-Allow-Credentials"] = "true" + + def add_cache_control(self, cache_control: str): + self.headers["Cache-Control"] = cache_control if self.status_code == 200 else "no-cache" + + def compress(self): + self.headers["Content-Encoding"] = "gzip" + if isinstance(self.body, str): + self.body = bytes(self.body, "utf-8") + gzip = zlib.compressobj(9, zlib.DEFLATED, zlib.MAX_WBITS | 16) + self.body = gzip.compress(self.body) + gzip.flush() + + def to_dict(self): + if isinstance(self.body, bytes): + self.base64_encoded = True + self.body = base64.b64encode(self.body).decode() + return { + "statusCode": self.status_code, + "headers": self.headers, + "body": self.body, + "isBase64Encoded": self.base64_encoded, + } + + class ApiGatewayResolver: current_event: BaseProxyEvent lambda_context: LambdaContext @@ -65,35 +101,21 @@ def resolve(self, event, context) -> Dict[str, Any]: route, args = self._find_route(self.current_event.http_method, self.current_event.path) result = route.func(**args) - if isinstance(result, dict): - status_code = 200 - content_type = "application/json" - body: Union[str, bytes] = json.dumps(result) + if isinstance(result, Response): + response = result + elif isinstance(result, dict): + response = Response(status_code=200, content_type="application/json", body=json.dumps(result)) else: - status_code, content_type, body = result - headers = {"Content-Type": content_type} + response = Response(*result) if route.cors: - headers["Access-Control-Allow-Origin"] = "*" - headers["Access-Control-Allow-Methods"] = route.method - headers["Access-Control-Allow-Credentials"] = "true" - + response.add_cors(route.method) if route.cache_control: - headers["Cache-Control"] = route.cache_control if status_code == 200 else "no-cache" - + response.add_cache_control(route.cache_control) if route.compress and "gzip" in (self.current_event.get_header_value("accept-encoding") or ""): - headers["Content-Encoding"] = "gzip" - if isinstance(body, str): - body = bytes(body, "utf-8") - gzip = zlib.compressobj(9, zlib.DEFLATED, zlib.MAX_WBITS | 16) - body = gzip.compress(body) + gzip.flush() - - base64_encoded = False - if isinstance(body, bytes): - base64_encoded = True - body = base64.b64encode(body).decode() - - return {"statusCode": status_code, "headers": headers, "body": body, "isBase64Encoded": base64_encoded} + response.compress() + + return response.to_dict() @staticmethod def _build_rule_pattern(rule: str): diff --git a/tests/functional/event_handler/test_api_gateway.py b/tests/functional/event_handler/test_api_gateway.py index bff61e7f1f8..c34d67e5165 100644 --- a/tests/functional/event_handler/test_api_gateway.py +++ b/tests/functional/event_handler/test_api_gateway.py @@ -2,10 +2,11 @@ import json import zlib from pathlib import Path +from typing import Dict, Tuple import pytest -from aws_lambda_powertools.event_handler.api_gateway import ApiGatewayResolver, ProxyEventType +from aws_lambda_powertools.event_handler.api_gateway import ApiGatewayResolver, ProxyEventType, Response from aws_lambda_powertools.utilities.data_classes import ALBEvent, APIGatewayProxyEvent, APIGatewayProxyEventV2 @@ -25,77 +26,96 @@ def read_media(file_name: str) -> bytes: def test_alb_event(): + # GIVEN a Application Load Balancer proxy type event app = ApiGatewayResolver(proxy_type=ProxyEventType.alb_event) @app.get("/lambda") - def foo(): + def foo() -> Tuple[int, str, str]: assert isinstance(app.current_event, ALBEvent) assert app.lambda_context == {} return 200, TEXT_HTML, "foo" + # WHEN result = app(load_event("albEvent.json"), {}) + # THEN process event correctly + # AND set the current_event type as ALBEvent assert result["statusCode"] == 200 assert result["headers"]["Content-Type"] == TEXT_HTML assert result["body"] == "foo" def test_api_gateway_v1(): + # GIVEN a Http API V1 proxy type event app = ApiGatewayResolver(proxy_type=ProxyEventType.http_api_v1) @app.get("/my/path") - def get_lambda(): + def get_lambda() -> Tuple[int, str, str]: assert isinstance(app.current_event, APIGatewayProxyEvent) assert app.lambda_context == {} return 200, APPLICATION_JSON, json.dumps({"foo": "value"}) + # WHEN result = app(LOAD_GW_EVENT, {}) + # THEN process event correctly + # AND set the current_event type as APIGatewayProxyEvent assert result["statusCode"] == 200 assert result["headers"]["Content-Type"] == APPLICATION_JSON def test_api_gateway(): + # GIVEN a Rest API Gateway proxy type event app = ApiGatewayResolver(proxy_type=ProxyEventType.api_gateway) @app.get("/my/path") - def get_lambda(): + def get_lambda() -> Tuple[int, str, str]: assert isinstance(app.current_event, APIGatewayProxyEvent) return 200, TEXT_HTML, "foo" + # WHEN result = app(LOAD_GW_EVENT, {}) + # THEN process event correctly + # AND set the current_event type as APIGatewayProxyEvent assert result["statusCode"] == 200 assert result["headers"]["Content-Type"] == TEXT_HTML assert result["body"] == "foo" def test_api_gateway_v2(): + # GIVEN a Http API V2 proxy type event app = ApiGatewayResolver(proxy_type=ProxyEventType.http_api_v2) @app.post("/my/path") - def my_path(): + def my_path() -> Tuple[int, str, str]: assert isinstance(app.current_event, APIGatewayProxyEventV2) post_data = app.current_event.json_body return 200, "plain/text", post_data["username"] + # WHEN result = app(load_event("apiGatewayProxyV2Event.json"), {}) + # THEN process event correctly + # AND set the current_event type as APIGatewayProxyEventV2 assert result["statusCode"] == 200 assert result["headers"]["Content-Type"] == "plain/text" assert result["body"] == "tom" def test_include_rule_matching(): + # GIVEN app = ApiGatewayResolver() @app.get("//") - def get_lambda(my_id: str, name: str): + def get_lambda(my_id: str, name: str) -> Tuple[int, str, str]: assert name == "my" return 200, "plain/html", my_id + # WHEN result = app(LOAD_GW_EVENT, {}) + # THEN assert result["statusCode"] == 200 assert result["headers"]["Content-Type"] == "plain/html" assert result["body"] == "path" @@ -153,7 +173,7 @@ def test_cors(): app = ApiGatewayResolver() @app.get("/my/path", cors=True) - def with_cors(): + def with_cors() -> Tuple[int, str, str]: return 200, TEXT_HTML, "test" def handler(event, context): @@ -176,7 +196,7 @@ def test_compress(): app = ApiGatewayResolver() @app.get("/my/request", compress=True) - def with_compression(): + def with_compression() -> Tuple[int, str, str]: return 200, APPLICATION_JSON, expected_value def handler(event, context): @@ -197,7 +217,7 @@ def test_base64_encode(): app = ApiGatewayResolver() @app.get("/my/path", compress=True) - def read_image(): + def read_image() -> Tuple[int, str, bytes]: return 200, "image/png", read_media("idempotent_sequence_exception.png") mock_event = {"path": "/my/path", "httpMethod": "GET", "headers": {"Accept-Encoding": "deflate, gzip"}} @@ -211,62 +231,100 @@ def read_image(): def test_compress_no_accept_encoding(): + # GIVEN a function with compress=True + # AND the request has no "Accept-Encoding" set to include gzip app = ApiGatewayResolver() expected_value = "Foo" @app.get("/my/path", compress=True) - def return_text(): + def return_text() -> Tuple[int, str, str]: return 200, "text/plain", expected_value + # WHEN calling the event handler result = app({"path": "/my/path", "httpMethod": "GET", "headers": {}}, None) + # THEN don't perform any gzip compression assert result["isBase64Encoded"] is False assert result["body"] == expected_value def test_cache_control_200(): + # GIVEN a function with cache_control set app = ApiGatewayResolver() @app.get("/success", cache_control="max-age=600") - def with_cache_control(): + def with_cache_control() -> Tuple[int, str, str]: return 200, TEXT_HTML, "has 200 response" def handler(event, context): return app.resolve(event, context) + # WHEN calling the event handler + # AND the function returns a 200 status code result = handler({"path": "/success", "httpMethod": "GET"}, None) + # THEN return the set Cache-Control headers = result["headers"] assert headers["Content-Type"] == TEXT_HTML assert headers["Cache-Control"] == "max-age=600" def test_cache_control_non_200(): + # GIVEN a function with cache_control set app = ApiGatewayResolver() @app.delete("/fails", cache_control="max-age=600") - def with_cache_control_has_500(): + def with_cache_control_has_500() -> Tuple[int, str, str]: return 503, TEXT_HTML, "has 503 response" def handler(event, context): return app.resolve(event, context) + # WHEN calling the event handler + # AND the function returns a 503 status code result = handler({"path": "/fails", "httpMethod": "DELETE"}, None) + # THEN return a Cache-Control of "no-cache" headers = result["headers"] assert headers["Content-Type"] == TEXT_HTML assert headers["Cache-Control"] == "no-cache" def test_rest_api(): + # GIVEN a function that returns a Dict app = ApiGatewayResolver(proxy_type=ProxyEventType.http_api_v1) @app.get("/my/path") - def rest_func(): + def rest_func() -> Dict: return {"foo": "value"} + # WHEN calling the event handler result = app(LOAD_GW_EVENT, {}) + # THEN automatically process this as a json rest api response assert result["statusCode"] == 200 assert result["headers"]["Content-Type"] == APPLICATION_JSON assert result["body"] == json.dumps({"foo": "value"}) + + +def test_handling_response_type(): + # GIVEN a function that returns Response + app = ApiGatewayResolver(proxy_type=ProxyEventType.http_api_v1) + + @app.get("/my/path") + def rest_func() -> Response: + return Response( + status_code=404, + content_type="used-if-not-set-in-header", + body="Not found", + headers={"Content-Type": "header-content-type-wins", "custom": "value"}, + ) + + # WHEN calling the event handler + result = app(LOAD_GW_EVENT, {}) + + # THEN the result can include some additional field control like overriding http headers + assert result["statusCode"] == 404 + assert result["headers"]["Content-Type"] == "header-content-type-wins" + assert result["headers"]["custom"] == "value" + assert result["body"] == "Not found" From e57b59e7ae36dfb8988cba7d0012de27e11d7e01 Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Tue, 20 Apr 2021 20:11:29 -0700 Subject: [PATCH 22/36] feat(event-handler): Use shared json Encoder Also use more concise seperators --- aws_lambda_powertools/event_handler/api_gateway.py | 7 ++++++- tests/functional/event_handler/test_api_gateway.py | 8 ++++++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index c9f9c5c01a8..1018f4267e1 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -5,6 +5,7 @@ from enum import Enum from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from aws_lambda_powertools.shared.json_encoder import Encoder from aws_lambda_powertools.utilities.data_classes import ALBEvent, APIGatewayProxyEvent, APIGatewayProxyEventV2 from aws_lambda_powertools.utilities.data_classes.common import BaseProxyEvent from aws_lambda_powertools.utilities.typing import LambdaContext @@ -104,7 +105,11 @@ def resolve(self, event, context) -> Dict[str, Any]: if isinstance(result, Response): response = result elif isinstance(result, dict): - response = Response(status_code=200, content_type="application/json", body=json.dumps(result)) + response = Response( + status_code=200, + content_type="application/json", + body=json.dumps(result, separators=(",", ":"), cls=Encoder), + ) else: response = Response(*result) diff --git a/tests/functional/event_handler/test_api_gateway.py b/tests/functional/event_handler/test_api_gateway.py index c34d67e5165..b0378ba8912 100644 --- a/tests/functional/event_handler/test_api_gateway.py +++ b/tests/functional/event_handler/test_api_gateway.py @@ -1,12 +1,14 @@ import base64 import json import zlib +from decimal import Decimal from pathlib import Path from typing import Dict, Tuple import pytest from aws_lambda_powertools.event_handler.api_gateway import ApiGatewayResolver, ProxyEventType, Response +from aws_lambda_powertools.shared.json_encoder import Encoder from aws_lambda_powertools.utilities.data_classes import ALBEvent, APIGatewayProxyEvent, APIGatewayProxyEventV2 @@ -293,10 +295,11 @@ def handler(event, context): def test_rest_api(): # GIVEN a function that returns a Dict app = ApiGatewayResolver(proxy_type=ProxyEventType.http_api_v1) + expected_dict = {"foo": "value", "second": Decimal("100.01")} @app.get("/my/path") def rest_func() -> Dict: - return {"foo": "value"} + return expected_dict # WHEN calling the event handler result = app(LOAD_GW_EVENT, {}) @@ -304,7 +307,8 @@ def rest_func() -> Dict: # THEN automatically process this as a json rest api response assert result["statusCode"] == 200 assert result["headers"]["Content-Type"] == APPLICATION_JSON - assert result["body"] == json.dumps({"foo": "value"}) + expected_str = json.dumps(expected_dict, separators=(",", ":"), indent=None, cls=Encoder) + assert result["body"] == expected_str def test_handling_response_type(): From 569cdbdd1448aa85566fb95591f0fc066d90e133 Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Tue, 20 Apr 2021 21:07:50 -0700 Subject: [PATCH 23/36] fix(data-classes): Correct typing for json_body --- aws_lambda_powertools/utilities/data_classes/common.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/aws_lambda_powertools/utilities/data_classes/common.py b/aws_lambda_powertools/utilities/data_classes/common.py index 93f416f5283..a6b975c6072 100644 --- a/aws_lambda_powertools/utilities/data_classes/common.py +++ b/aws_lambda_powertools/utilities/data_classes/common.py @@ -58,10 +58,12 @@ def is_base64_encoded(self) -> Optional[bool]: @property def body(self) -> Optional[str]: + """Submitted body of the request as a string""" return self.get("body") @property - def json_body(self) -> Dict: + def json_body(self) -> Any: + """Parses the submitted body as json""" return json.loads(self["body"]) @property From c1ea9b1068a0bbe3675182fc7d802c0b2bbe98c7 Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Tue, 20 Apr 2021 21:26:42 -0700 Subject: [PATCH 24/36] tests: Add shared test utils.load_event --- tests/functional/event_handler/test_api_gateway.py | 6 +----- tests/functional/event_handler/test_appsync.py | 8 +------- tests/functional/idempotency/conftest.py | 8 ++------ tests/functional/parser/test_alb.py | 2 +- tests/functional/parser/test_cloudwatch.py | 2 +- tests/functional/parser/test_dynamodb.py | 2 +- tests/functional/parser/test_eventbridge.py | 2 +- tests/functional/parser/test_kinesis.py | 2 +- tests/functional/parser/test_s3 object_event.py | 2 +- tests/functional/parser/test_s3.py | 2 +- tests/functional/parser/test_ses.py | 2 +- tests/functional/parser/test_sns.py | 2 +- tests/functional/parser/test_sqs.py | 2 +- tests/functional/parser/utils.py | 13 ------------- tests/functional/test_data_classes.py | 8 +------- tests/functional/utils.py | 8 ++++++++ 16 files changed, 23 insertions(+), 48 deletions(-) delete mode 100644 tests/functional/parser/utils.py create mode 100644 tests/functional/utils.py diff --git a/tests/functional/event_handler/test_api_gateway.py b/tests/functional/event_handler/test_api_gateway.py index b0378ba8912..a2811ca6124 100644 --- a/tests/functional/event_handler/test_api_gateway.py +++ b/tests/functional/event_handler/test_api_gateway.py @@ -10,11 +10,7 @@ from aws_lambda_powertools.event_handler.api_gateway import ApiGatewayResolver, ProxyEventType, Response from aws_lambda_powertools.shared.json_encoder import Encoder from aws_lambda_powertools.utilities.data_classes import ALBEvent, APIGatewayProxyEvent, APIGatewayProxyEventV2 - - -def load_event(file_name: str) -> dict: - path = Path(str(Path(__file__).parent.parent.parent) + "/events/" + file_name) - return json.loads(path.read_text()) +from tests.functional.utils import load_event def read_media(file_name: str) -> bytes: diff --git a/tests/functional/event_handler/test_appsync.py b/tests/functional/event_handler/test_appsync.py index c72331c32f1..e260fef89ab 100644 --- a/tests/functional/event_handler/test_appsync.py +++ b/tests/functional/event_handler/test_appsync.py @@ -1,18 +1,12 @@ import asyncio -import json import sys -from pathlib import Path import pytest from aws_lambda_powertools.event_handler import AppSyncResolver from aws_lambda_powertools.utilities.data_classes import AppSyncResolverEvent from aws_lambda_powertools.utilities.typing import LambdaContext - - -def load_event(file_name: str) -> dict: - path = Path(str(Path(__file__).parent.parent.parent) + "/events/" + file_name) - return json.loads(path.read_text()) +from tests.functional.utils import load_event def test_direct_resolver(): diff --git a/tests/functional/idempotency/conftest.py b/tests/functional/idempotency/conftest.py index d34d5da7d12..e100957dee7 100644 --- a/tests/functional/idempotency/conftest.py +++ b/tests/functional/idempotency/conftest.py @@ -1,7 +1,6 @@ import datetime import hashlib import json -import os from collections import namedtuple from decimal import Decimal from unittest import mock @@ -17,6 +16,7 @@ from aws_lambda_powertools.utilities.idempotency.idempotency import IdempotencyConfig from aws_lambda_powertools.utilities.validation import envelopes from aws_lambda_powertools.utilities.validation.base import unwrap_event_from_envelope +from tests.functional.utils import load_event TABLE_NAME = "TEST_TABLE" @@ -28,11 +28,7 @@ def config() -> Config: @pytest.fixture(scope="module") def lambda_apigw_event(): - full_file_name = os.path.dirname(os.path.realpath(__file__)) + "/../../events/" + "apiGatewayProxyV2Event.json" - with open(full_file_name) as fp: - event = json.load(fp) - - return event + return load_event("apiGatewayProxyV2Event.json") @pytest.fixture diff --git a/tests/functional/parser/test_alb.py b/tests/functional/parser/test_alb.py index 88631c7194c..d48e39f1bab 100644 --- a/tests/functional/parser/test_alb.py +++ b/tests/functional/parser/test_alb.py @@ -3,7 +3,7 @@ from aws_lambda_powertools.utilities.parser import ValidationError, event_parser from aws_lambda_powertools.utilities.parser.models import AlbModel from aws_lambda_powertools.utilities.typing import LambdaContext -from tests.functional.parser.utils import load_event +from tests.functional.utils import load_event @event_parser(model=AlbModel) diff --git a/tests/functional/parser/test_cloudwatch.py b/tests/functional/parser/test_cloudwatch.py index 9a61f339140..7290d0bffcb 100644 --- a/tests/functional/parser/test_cloudwatch.py +++ b/tests/functional/parser/test_cloudwatch.py @@ -9,7 +9,7 @@ from aws_lambda_powertools.utilities.parser.models import CloudWatchLogsLogEvent, CloudWatchLogsModel from aws_lambda_powertools.utilities.typing import LambdaContext from tests.functional.parser.schemas import MyCloudWatchBusiness -from tests.functional.parser.utils import load_event +from tests.functional.utils import load_event @event_parser(model=MyCloudWatchBusiness, envelope=envelopes.CloudWatchLogsEnvelope) diff --git a/tests/functional/parser/test_dynamodb.py b/tests/functional/parser/test_dynamodb.py index bd7e0795f42..9917fac234b 100644 --- a/tests/functional/parser/test_dynamodb.py +++ b/tests/functional/parser/test_dynamodb.py @@ -5,7 +5,7 @@ from aws_lambda_powertools.utilities.parser import ValidationError, envelopes, event_parser from aws_lambda_powertools.utilities.typing import LambdaContext from tests.functional.parser.schemas import MyAdvancedDynamoBusiness, MyDynamoBusiness -from tests.functional.parser.utils import load_event +from tests.functional.utils import load_event @event_parser(model=MyDynamoBusiness, envelope=envelopes.DynamoDBStreamEnvelope) diff --git a/tests/functional/parser/test_eventbridge.py b/tests/functional/parser/test_eventbridge.py index 7a3066d7b04..6242403ab35 100644 --- a/tests/functional/parser/test_eventbridge.py +++ b/tests/functional/parser/test_eventbridge.py @@ -5,7 +5,7 @@ from aws_lambda_powertools.utilities.parser import ValidationError, envelopes, event_parser from aws_lambda_powertools.utilities.typing import LambdaContext from tests.functional.parser.schemas import MyAdvancedEventbridgeBusiness, MyEventbridgeBusiness -from tests.functional.parser.utils import load_event +from tests.functional.utils import load_event @event_parser(model=MyEventbridgeBusiness, envelope=envelopes.EventBridgeEnvelope) diff --git a/tests/functional/parser/test_kinesis.py b/tests/functional/parser/test_kinesis.py index 5a7a94e0dac..632a7463805 100644 --- a/tests/functional/parser/test_kinesis.py +++ b/tests/functional/parser/test_kinesis.py @@ -6,7 +6,7 @@ from aws_lambda_powertools.utilities.parser.models import KinesisDataStreamModel, KinesisDataStreamRecordPayload from aws_lambda_powertools.utilities.typing import LambdaContext from tests.functional.parser.schemas import MyKinesisBusiness -from tests.functional.parser.utils import load_event +from tests.functional.utils import load_event @event_parser(model=MyKinesisBusiness, envelope=envelopes.KinesisDataStreamEnvelope) diff --git a/tests/functional/parser/test_s3 object_event.py b/tests/functional/parser/test_s3 object_event.py index da015338cf4..90c2555360d 100644 --- a/tests/functional/parser/test_s3 object_event.py +++ b/tests/functional/parser/test_s3 object_event.py @@ -1,7 +1,7 @@ from aws_lambda_powertools.utilities.parser import event_parser from aws_lambda_powertools.utilities.parser.models import S3ObjectLambdaEvent from aws_lambda_powertools.utilities.typing import LambdaContext -from tests.functional.parser.utils import load_event +from tests.functional.utils import load_event @event_parser(model=S3ObjectLambdaEvent) diff --git a/tests/functional/parser/test_s3.py b/tests/functional/parser/test_s3.py index a9c325f3a97..71a5dc6afe3 100644 --- a/tests/functional/parser/test_s3.py +++ b/tests/functional/parser/test_s3.py @@ -1,7 +1,7 @@ from aws_lambda_powertools.utilities.parser import event_parser, parse from aws_lambda_powertools.utilities.parser.models import S3Model, S3RecordModel from aws_lambda_powertools.utilities.typing import LambdaContext -from tests.functional.parser.utils import load_event +from tests.functional.utils import load_event @event_parser(model=S3Model) diff --git a/tests/functional/parser/test_ses.py b/tests/functional/parser/test_ses.py index f96da7bad66..d434e2350f8 100644 --- a/tests/functional/parser/test_ses.py +++ b/tests/functional/parser/test_ses.py @@ -1,7 +1,7 @@ from aws_lambda_powertools.utilities.parser import event_parser from aws_lambda_powertools.utilities.parser.models import SesModel, SesRecordModel from aws_lambda_powertools.utilities.typing import LambdaContext -from tests.functional.parser.utils import load_event +from tests.functional.utils import load_event @event_parser(model=SesModel) diff --git a/tests/functional/parser/test_sns.py b/tests/functional/parser/test_sns.py index 015af3693fa..81158a4419e 100644 --- a/tests/functional/parser/test_sns.py +++ b/tests/functional/parser/test_sns.py @@ -5,7 +5,7 @@ from aws_lambda_powertools.utilities.parser import ValidationError, envelopes, event_parser from aws_lambda_powertools.utilities.typing import LambdaContext from tests.functional.parser.schemas import MyAdvancedSnsBusiness, MySnsBusiness -from tests.functional.parser.utils import load_event +from tests.functional.utils import load_event from tests.functional.validator.conftest import sns_event # noqa: F401 diff --git a/tests/functional/parser/test_sqs.py b/tests/functional/parser/test_sqs.py index 0cea8246b50..7ca883616f2 100644 --- a/tests/functional/parser/test_sqs.py +++ b/tests/functional/parser/test_sqs.py @@ -5,7 +5,7 @@ from aws_lambda_powertools.utilities.parser import ValidationError, envelopes, event_parser from aws_lambda_powertools.utilities.typing import LambdaContext from tests.functional.parser.schemas import MyAdvancedSqsBusiness, MySqsBusiness -from tests.functional.parser.utils import load_event +from tests.functional.utils import load_event from tests.functional.validator.conftest import sqs_event # noqa: F401 diff --git a/tests/functional/parser/utils.py b/tests/functional/parser/utils.py deleted file mode 100644 index 7cb949b1289..00000000000 --- a/tests/functional/parser/utils.py +++ /dev/null @@ -1,13 +0,0 @@ -import json -import os -from typing import Any - - -def get_event_file_path(file_name: str) -> str: - return os.path.dirname(os.path.realpath(__file__)) + "/../../events/" + file_name - - -def load_event(file_name: str) -> Any: - full_file_name = get_event_file_path(file_name) - with open(full_file_name) as fp: - return json.load(fp) diff --git a/tests/functional/test_data_classes.py b/tests/functional/test_data_classes.py index 0221acc6853..d346eca480a 100644 --- a/tests/functional/test_data_classes.py +++ b/tests/functional/test_data_classes.py @@ -1,7 +1,6 @@ import base64 import datetime import json -import os from secrets import compare_digest from urllib.parse import quote_plus @@ -58,12 +57,7 @@ StreamViewType, ) from aws_lambda_powertools.utilities.data_classes.s3_object_event import S3ObjectLambdaEvent - - -def load_event(file_name: str) -> dict: - full_file_name = os.path.dirname(os.path.realpath(__file__)) + "/../events/" + file_name - with open(full_file_name) as fp: - return json.load(fp) +from tests.functional.utils import load_event def test_dict_wrapper_equals(): diff --git a/tests/functional/utils.py b/tests/functional/utils.py new file mode 100644 index 00000000000..a58d27f3526 --- /dev/null +++ b/tests/functional/utils.py @@ -0,0 +1,8 @@ +import json +from pathlib import Path +from typing import Any + + +def load_event(file_name: str) -> Any: + path = Path(str(Path(__file__).parent.parent) + "/events/" + file_name) + return json.loads(path.read_text()) From 7940c46145906ac5f1ecd0e94569dfc32e07962e Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Tue, 20 Apr 2021 22:17:33 -0700 Subject: [PATCH 25/36] docs(tests): Add more docs to tests --- .../event_handler/test_api_gateway.py | 26 +++++++++++++------ 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/tests/functional/event_handler/test_api_gateway.py b/tests/functional/event_handler/test_api_gateway.py index a2811ca6124..f2fa84c61f5 100644 --- a/tests/functional/event_handler/test_api_gateway.py +++ b/tests/functional/event_handler/test_api_gateway.py @@ -33,7 +33,7 @@ def foo() -> Tuple[int, str, str]: assert app.lambda_context == {} return 200, TEXT_HTML, "foo" - # WHEN + # WHEN calling the event handler result = app(load_event("albEvent.json"), {}) # THEN process event correctly @@ -53,7 +53,7 @@ def get_lambda() -> Tuple[int, str, str]: assert app.lambda_context == {} return 200, APPLICATION_JSON, json.dumps({"foo": "value"}) - # WHEN + # WHEN calling the event handler result = app(LOAD_GW_EVENT, {}) # THEN process event correctly @@ -71,7 +71,7 @@ def get_lambda() -> Tuple[int, str, str]: assert isinstance(app.current_event, APIGatewayProxyEvent) return 200, TEXT_HTML, "foo" - # WHEN + # WHEN calling the event handler result = app(LOAD_GW_EVENT, {}) # THEN process event correctly @@ -91,7 +91,7 @@ def my_path() -> Tuple[int, str, str]: post_data = app.current_event.json_body return 200, "plain/text", post_data["username"] - # WHEN + # WHEN calling the event handler result = app(load_event("apiGatewayProxyV2Event.json"), {}) # THEN process event correctly @@ -110,7 +110,7 @@ def get_lambda(my_id: str, name: str) -> Tuple[int, str, str]: assert name == "my" return 200, "plain/html", my_id - # WHEN + # WHEN calling the event handler result = app(LOAD_GW_EVENT, {}) # THEN @@ -168,6 +168,8 @@ def handler(event, context): def test_cors(): + # GIVEN a function with cors=True + # AND http method set to GET app = ApiGatewayResolver() @app.get("/my/path", cors=True) @@ -177,8 +179,10 @@ def with_cors() -> Tuple[int, str, str]: def handler(event, context): return app.resolve(event, context) + # WHEN calling the event handler result = handler(LOAD_GW_EVENT, None) + # THEN the headers should include cors headers assert "headers" in result headers = result["headers"] assert headers["Content-Type"] == TEXT_HTML @@ -188,11 +192,12 @@ def handler(event, context): def test_compress(): + # GIVEN a function that has compress=True + # AND an event with a "Accept-Encoding" that include gzip + app = ApiGatewayResolver() mock_event = {"path": "/my/request", "httpMethod": "GET", "headers": {"Accept-Encoding": "deflate, gzip"}} expected_value = '{"test": "value"}' - app = ApiGatewayResolver() - @app.get("/my/request", compress=True) def with_compression() -> Tuple[int, str, str]: return 200, APPLICATION_JSON, expected_value @@ -200,8 +205,10 @@ def with_compression() -> Tuple[int, str, str]: def handler(event, context): return app.resolve(event, context) + # WHEN calling the event handler result = handler(mock_event, None) + # THEN then gzip the response and base64 encode as a string assert result["isBase64Encoded"] is True body = result["body"] assert isinstance(body, str) @@ -212,15 +219,18 @@ def handler(event, context): def test_base64_encode(): + # GIVEN a function that returns bytes app = ApiGatewayResolver() + mock_event = {"path": "/my/path", "httpMethod": "GET", "headers": {"Accept-Encoding": "deflate, gzip"}} @app.get("/my/path", compress=True) def read_image() -> Tuple[int, str, bytes]: return 200, "image/png", read_media("idempotent_sequence_exception.png") - mock_event = {"path": "/my/path", "httpMethod": "GET", "headers": {"Accept-Encoding": "deflate, gzip"}} + # WHEN calling the event handler result = app(mock_event, None) + # THEN return the body and a base64 encoded string assert result["isBase64Encoded"] is True body = result["body"] assert isinstance(body, str) From f74307f0eeb15218e930f862abaeba9b810ebded Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Tue, 20 Apr 2021 22:59:05 -0700 Subject: [PATCH 26/36] refactor(event-handler): Final housekeeping --- .../event_handler/api_gateway.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 1018f4267e1..a6cb17c0b41 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -54,7 +54,7 @@ def compress(self): gzip = zlib.compressobj(9, zlib.DEFLATED, zlib.MAX_WBITS | 16) self.body = gzip.compress(self.body) + gzip.flush() - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: if isinstance(self.body, bytes): self.base64_encoded = True self.body = base64.b64encode(self.body).decode() @@ -91,15 +91,16 @@ def patch(self, rule: str, cors: bool = False, compress: bool = False, cache_con def route(self, rule: str, method: str, cors: bool = False, compress: bool = False, cache_control: str = None): def register_resolver(func: Callable): - self._routes.append(Route(method, self._build_rule_pattern(rule), func, cors, compress, cache_control)) + self._routes.append(Route(method, self._compile_regex(rule), func, cors, compress, cache_control)) return func return register_resolver def resolve(self, event, context) -> Dict[str, Any]: - self.current_event = self._as_data_class(event) + self.current_event = self._to_data_class(event) self.lambda_context = context - route, args = self._find_route(self.current_event.http_method, self.current_event.path) + + route, args = self._find_route(self.current_event.http_method.upper(), self.current_event.path) result = route.func(**args) if isinstance(result, Response): @@ -110,7 +111,7 @@ def resolve(self, event, context) -> Dict[str, Any]: content_type="application/json", body=json.dumps(result, separators=(",", ":"), cls=Encoder), ) - else: + else: # Tuple[int, str, Union[bytes, str]] response = Response(*result) if route.cors: @@ -123,11 +124,11 @@ def resolve(self, event, context) -> Dict[str, Any]: return response.to_dict() @staticmethod - def _build_rule_pattern(rule: str): + def _compile_regex(rule: str): rule_regex: str = re.sub(r"(<\w+>)", r"(?P\1.+)", rule) return re.compile("^{}$".format(rule_regex)) - def _as_data_class(self, event: Dict) -> BaseProxyEvent: + def _to_data_class(self, event: Dict) -> BaseProxyEvent: if self._proxy_type == ProxyEventType.http_api_v1: return APIGatewayProxyEvent(event) if self._proxy_type == ProxyEventType.http_api_v2: @@ -135,7 +136,6 @@ def _as_data_class(self, event: Dict) -> BaseProxyEvent: return ALBEvent(event) def _find_route(self, method: str, path: str) -> Tuple[Route, Dict]: - method = method.upper() for route in self._routes: if method != route.method: continue From 4c20ceb5c0f649ff6550c4bdc306abba7878b4ed Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Thu, 22 Apr 2021 08:47:32 -0700 Subject: [PATCH 27/36] tests(event-handler): Fix import --- tests/functional/parser/test_apigw.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/functional/parser/test_apigw.py b/tests/functional/parser/test_apigw.py index 333654f3f89..fc679d5dc37 100644 --- a/tests/functional/parser/test_apigw.py +++ b/tests/functional/parser/test_apigw.py @@ -2,7 +2,7 @@ from aws_lambda_powertools.utilities.parser.models import APIGatewayProxyEventModel from aws_lambda_powertools.utilities.typing import LambdaContext from tests.functional.parser.schemas import MyApiGatewayBusiness -from tests.functional.parser.utils import load_event +from tests.functional.utils import load_event @event_parser(model=MyApiGatewayBusiness, envelope=envelopes.ApiGatewayEnvelope) From 785877ca133077213342bca717a5a630f2fa3a54 Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Sun, 25 Apr 2021 17:15:43 -0700 Subject: [PATCH 28/36] refactor: precise handling of headers --- aws_lambda_powertools/event_handler/api_gateway.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index a6cb17c0b41..57bd80e20a1 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -35,9 +35,8 @@ def __init__(self, status_code: int, content_type: str, body: Union[str, bytes], self.status_code = status_code self.body = body self.base64_encoded = False - self.headers: Dict = headers if headers is not None else {} - if "Content-Type" not in self.headers: - self.headers["Content-Type"] = content_type + self.headers: Dict = headers or {} + self.headers.setdefault("Content-Type", content_type) def add_cors(self, method: str): self.headers["Access-Control-Allow-Origin"] = "*" From c5709bca975152b8c9b117b7014b413a538943aa Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Sun, 25 Apr 2021 18:02:33 -0700 Subject: [PATCH 29/36] refactor: add to_response to simplify logic --- .../event_handler/api_gateway.py | 27 ++++++++++--------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 57bd80e20a1..71682ecfe54 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -98,20 +98,8 @@ def register_resolver(func: Callable): def resolve(self, event, context) -> Dict[str, Any]: self.current_event = self._to_data_class(event) self.lambda_context = context - route, args = self._find_route(self.current_event.http_method.upper(), self.current_event.path) - result = route.func(**args) - - if isinstance(result, Response): - response = result - elif isinstance(result, dict): - response = Response( - status_code=200, - content_type="application/json", - body=json.dumps(result, separators=(",", ":"), cls=Encoder), - ) - else: # Tuple[int, str, Union[bytes, str]] - response = Response(*result) + response = self.to_response(route.func(**args)) if route.cors: response.add_cors(route.method) @@ -122,6 +110,19 @@ def resolve(self, event, context) -> Dict[str, Any]: return response.to_dict() + @staticmethod + def to_response(result: Union[Tuple[int, str, Union[bytes, str]], Dict, Response]) -> Response: + if isinstance(result, Response): + return result + elif isinstance(result, dict): + return Response( + status_code=200, + content_type="application/json", + body=json.dumps(result, separators=(",", ":"), cls=Encoder), + ) + else: # Tuple[int, str, Union[bytes, str]] + return Response(*result) + @staticmethod def _compile_regex(rule: str): rule_regex: str = re.sub(r"(<\w+>)", r"(?P\1.+)", rule) From 318508f19d2348d971ca536fc3f59eedbadf55c5 Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Mon, 26 Apr 2021 23:30:47 -0700 Subject: [PATCH 30/36] feat(event-handler): Add a more complete implementation of cors NOTE: Some of this is based on the cors behavior of Chalice, except where we actually return the preflight response --- .../event_handler/api_gateway.py | 120 ++++++++++++++---- .../event_handler/test_api_gateway.py | 87 ++++++++++++- 2 files changed, 182 insertions(+), 25 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 71682ecfe54..1135ddef8ed 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -3,7 +3,7 @@ import re import zlib from enum import Enum -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union from aws_lambda_powertools.shared.json_encoder import Encoder from aws_lambda_powertools.utilities.data_classes import ALBEvent, APIGatewayProxyEvent, APIGatewayProxyEventV2 @@ -18,30 +18,74 @@ class ProxyEventType(Enum): api_gateway = http_api_v1 +class CORSConfig(object): + _REQUIRED_HEADERS = ["Content-Type", "X-Amz-Date", "Authorization", "X-Api-Key", "X-Amz-Security-Token"] + + def __init__( + self, + allow_origin: str = "*", + allow_headers: List[str] = None, + expose_headers: List[str] = None, + max_age: int = None, + allow_credentials: bool = True, + ): + self.allow_origin = allow_origin + self.allow_headers = set((allow_headers or []) + self._REQUIRED_HEADERS) + self.expose_headers = expose_headers or [] + self.max_age = max_age + self.allow_credentials = allow_credentials + + def to_dict(self) -> Dict[str, str]: + headers = { + "Access-Control-Allow-Origin": self.allow_origin, + "Access-Control-Allow-Headers": ",".join(sorted(self.allow_headers)), + } + if self.expose_headers: + headers["Access-Control-Expose-Headers"] = ",".join(self.expose_headers) + if self.max_age is not None: + headers["Access-Control-Max-Age"] = str(self.max_age) + if self.allow_credentials is True: + headers["Access-Control-Allow-Credentials"] = "true" + return headers + + class Route: def __init__( - self, method: str, rule: Any, func: Callable, cors: bool, compress: bool, cache_control: Optional[str] + self, + method: str, + rule: Any, + func: Callable, + cors: Union[bool, CORSConfig], + compress: bool, + cache_control: Optional[str], ): self.method = method.upper() self.rule = rule self.func = func - self.cors = cors + self.cors: Optional[CORSConfig] + if cors is True: + self.cors = CORSConfig() + elif isinstance(cors, CORSConfig): + self.cors = cors + else: + self.cors = None self.compress = compress self.cache_control = cache_control class Response: - def __init__(self, status_code: int, content_type: str, body: Union[str, bytes], headers: Dict = None): + def __init__( + self, status_code: int, content_type: Optional[str], body: Union[str, bytes, None], headers: Dict = None + ): self.status_code = status_code self.body = body self.base64_encoded = False self.headers: Dict = headers or {} - self.headers.setdefault("Content-Type", content_type) + if content_type: + self.headers.setdefault("Content-Type", content_type) - def add_cors(self, method: str): - self.headers["Access-Control-Allow-Origin"] = "*" - self.headers["Access-Control-Allow-Methods"] = method - self.headers["Access-Control-Allow-Credentials"] = "true" + def add_cors(self, cors: CORSConfig): + self.headers.update(cors.to_dict()) def add_cache_control(self, cache_control: str): self.headers["Cache-Control"] = cache_control if self.status_code == 200 else "no-cache" @@ -54,15 +98,14 @@ def compress(self): self.body = gzip.compress(self.body) + gzip.flush() def to_dict(self) -> Dict[str, Any]: + result = {"statusCode": self.status_code, "headers": self.headers} if isinstance(self.body, bytes): self.base64_encoded = True self.body = base64.b64encode(self.body).decode() - return { - "statusCode": self.status_code, - "headers": self.headers, - "body": self.body, - "isBase64Encoded": self.base64_encoded, - } + if self.body: + result["isBase64Encoded"] = self.base64_encoded + result["body"] = self.body + return result class ApiGatewayResolver: @@ -72,25 +115,43 @@ class ApiGatewayResolver: def __init__(self, proxy_type: Enum = ProxyEventType.http_api_v1): self._proxy_type = proxy_type self._routes: List[Route] = [] + self._cors: Optional[CORSConfig] = None + self._cors_methods: Set[str] = set() - def get(self, rule: str, cors: bool = False, compress: bool = False, cache_control: str = None): + def get(self, rule: str, cors: Union[bool, CORSConfig] = False, compress: bool = False, cache_control: str = None): return self.route(rule, "GET", cors, compress, cache_control) - def post(self, rule: str, cors: bool = False, compress: bool = False, cache_control: str = None): + def post(self, rule: str, cors: Union[bool, CORSConfig] = False, compress: bool = False, cache_control: str = None): return self.route(rule, "POST", cors, compress, cache_control) - def put(self, rule: str, cors: bool = False, compress: bool = False, cache_control: str = None): + def put(self, rule: str, cors: Union[bool, CORSConfig] = False, compress: bool = False, cache_control: str = None): return self.route(rule, "PUT", cors, compress, cache_control) - def delete(self, rule: str, cors: bool = False, compress: bool = False, cache_control: str = None): + def delete( + self, rule: str, cors: Union[bool, CORSConfig] = False, compress: bool = False, cache_control: str = None + ): return self.route(rule, "DELETE", cors, compress, cache_control) - def patch(self, rule: str, cors: bool = False, compress: bool = False, cache_control: str = None): + def patch( + self, rule: str, cors: Union[bool, CORSConfig] = False, compress: bool = False, cache_control: str = None + ): return self.route(rule, "PATCH", cors, compress, cache_control) - def route(self, rule: str, method: str, cors: bool = False, compress: bool = False, cache_control: str = None): + def route( + self, + rule: str, + method: str, + cors: Union[bool, CORSConfig] = False, + compress: bool = False, + cache_control: str = None, + ): def register_resolver(func: Callable): - self._routes.append(Route(method, self._compile_regex(rule), func, cors, compress, cache_control)) + route = Route(method, self._compile_regex(rule), func, cors, compress, cache_control) + self._routes.append(route) + if route.cors: + if self._cors is None: + self._cors = route.cors + self._cors_methods.add(route.method) return func return register_resolver @@ -102,7 +163,7 @@ def resolve(self, event, context) -> Dict[str, Any]: response = self.to_response(route.func(**args)) if route.cors: - response.add_cors(route.method) + response.add_cors(route.cors) if route.cache_control: response.add_cache_control(route.cache_control) if route.compress and "gzip" in (self.current_event.get_header_value("accept-encoding") or ""): @@ -135,6 +196,12 @@ def _to_data_class(self, event: Dict) -> BaseProxyEvent: return APIGatewayProxyEventV2(event) return ALBEvent(event) + @staticmethod + def _preflight(allowed_methods: Set): + allowed_methods.add("OPTIONS") + headers = {"Access-Control-Allow-Methods": ",".join(sorted(allowed_methods))} + return Response(204, None, None, headers) + def _find_route(self, method: str, path: str) -> Tuple[Route, Dict]: for route in self._routes: if method != route.method: @@ -143,6 +210,13 @@ def _find_route(self, method: str, path: str) -> Tuple[Route, Dict]: if match: return route, match.groupdict() + if method == "OPTIONS" and self._cors is not None: + # Most be the preflight options call + return ( + Route("OPTIONS", None, self._preflight, self._cors, False, None), + {"allowed_methods": self._cors_methods}, + ) + raise ValueError(f"No route found for '{method}.{path}'") def __call__(self, event, context) -> Any: diff --git a/tests/functional/event_handler/test_api_gateway.py b/tests/functional/event_handler/test_api_gateway.py index f2fa84c61f5..7d4fb91478d 100644 --- a/tests/functional/event_handler/test_api_gateway.py +++ b/tests/functional/event_handler/test_api_gateway.py @@ -7,7 +7,7 @@ import pytest -from aws_lambda_powertools.event_handler.api_gateway import ApiGatewayResolver, ProxyEventType, Response +from aws_lambda_powertools.event_handler.api_gateway import ApiGatewayResolver, CORSConfig, ProxyEventType, Response from aws_lambda_powertools.shared.json_encoder import Encoder from aws_lambda_powertools.utilities.data_classes import ALBEvent, APIGatewayProxyEvent, APIGatewayProxyEventV2 from tests.functional.utils import load_event @@ -187,8 +187,10 @@ def handler(event, context): headers = result["headers"] assert headers["Content-Type"] == TEXT_HTML assert headers["Access-Control-Allow-Origin"] == "*" - assert headers["Access-Control-Allow-Methods"] == "GET" assert headers["Access-Control-Allow-Credentials"] == "true" + # AND "Access-Control-Allow-Methods" is only included in the preflight cors headers + assert "Access-Control-Allow-Methods" not in headers + assert headers["Access-Control-Allow-Headers"] == ",".join(sorted(CORSConfig._REQUIRED_HEADERS)) def test_compress(): @@ -338,3 +340,84 @@ def rest_func() -> Response: assert result["headers"]["Content-Type"] == "header-content-type-wins" assert result["headers"]["custom"] == "value" assert result["body"] == "Not found" + + +def test_preflight_cors(): + # GIVEN + app = ApiGatewayResolver() + preflight_event = {"path": "/cors", "httpMethod": "OPTIONS"} + + @app.get("/cors", cors=True) + def get_with_cors(): + ... + + @app.post("/cors", cors=True) + def post_with_cors(): + ... + + @app.delete("/cors") + def delete_no_cors(): + ... + + def handler(event, context): + return app.resolve(event, context) + + # WHEN calling the event handler + # AND the httpMethod is OPTIONS + result = handler(preflight_event, None) + + # THEN return the preflight response + # AND No Content it returned + assert result["statusCode"] == 204 + assert "body" not in result + assert "isBase64Encoded" not in result + # AND no Content-Type is set + headers = result["headers"] + assert "headers" in result + assert "Content-Type" not in headers + # AND set the access control headers + assert headers["Access-Control-Allow-Origin"] == "*" + assert headers["Access-Control-Allow-Methods"] == "GET,OPTIONS,POST" + assert headers["Access-Control-Allow-Credentials"] == "true" + + +def test_custom_cors_config(): + # GIVEN a custom cors configuration + app = ApiGatewayResolver() + event = {"path": "/cors", "httpMethod": "GET"} + allow_header = ["foo2"] + cors_config = CORSConfig( + allow_origin="https://foo1", + expose_headers=["foo1"], + allow_headers=allow_header, + max_age=100, + allow_credentials=False, + ) + + @app.get("/cors", cors=cors_config) + def get_with_cors(): + return {} + + # NOTE: Currently only the first configuration is used for the OPTIONS preflight + @app.get("/another-one", cors=True) + def another_one(): + return {} + + # WHEN calling the event handler + result = app(event, None) + + # THEN return the custom cors headers + assert "headers" in result + headers = result["headers"] + assert headers["Content-Type"] == APPLICATION_JSON + assert headers["Access-Control-Allow-Origin"] == cors_config.allow_origin + expected_allows_headers = ",".join(sorted(set(allow_header + cors_config._REQUIRED_HEADERS))) + assert headers["Access-Control-Allow-Headers"] == expected_allows_headers + assert headers["Access-Control-Expose-Headers"] == ",".join(cors_config.expose_headers) + assert headers["Access-Control-Max-Age"] == str(cors_config.max_age) + assert "Access-Control-Allow-Credentials" not in headers + + # AND custom cors was set on the app + assert isinstance(app._cors, CORSConfig) + assert app._cors is cors_config + assert app._cors_methods == {"GET"} From f0e4f11950a401f43a0e291e96ec28c1efb2610c Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Tue, 27 Apr 2021 11:03:26 -0700 Subject: [PATCH 31/36] fix(event-handler): Default to false --- aws_lambda_powertools/event_handler/api_gateway.py | 2 +- tests/functional/event_handler/test_api_gateway.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 1135ddef8ed..d093ba5c02e 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -27,7 +27,7 @@ def __init__( allow_headers: List[str] = None, expose_headers: List[str] = None, max_age: int = None, - allow_credentials: bool = True, + allow_credentials: bool = False, ): self.allow_origin = allow_origin self.allow_headers = set((allow_headers or []) + self._REQUIRED_HEADERS) diff --git a/tests/functional/event_handler/test_api_gateway.py b/tests/functional/event_handler/test_api_gateway.py index 7d4fb91478d..851791a9407 100644 --- a/tests/functional/event_handler/test_api_gateway.py +++ b/tests/functional/event_handler/test_api_gateway.py @@ -187,7 +187,7 @@ def handler(event, context): headers = result["headers"] assert headers["Content-Type"] == TEXT_HTML assert headers["Access-Control-Allow-Origin"] == "*" - assert headers["Access-Control-Allow-Credentials"] == "true" + assert "Access-Control-Allow-Credentials" not in headers # AND "Access-Control-Allow-Methods" is only included in the preflight cors headers assert "Access-Control-Allow-Methods" not in headers assert headers["Access-Control-Allow-Headers"] == ",".join(sorted(CORSConfig._REQUIRED_HEADERS)) @@ -378,7 +378,7 @@ def handler(event, context): # AND set the access control headers assert headers["Access-Control-Allow-Origin"] == "*" assert headers["Access-Control-Allow-Methods"] == "GET,OPTIONS,POST" - assert headers["Access-Control-Allow-Credentials"] == "true" + assert "Access-Control-Allow-Credentials" not in headers def test_custom_cors_config(): From ee52aee881924b906d9a515b9fc03b1ae028c566 Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Tue, 27 Apr 2021 13:18:44 -0700 Subject: [PATCH 32/36] refactor: make some of the code-review changes --- .../event_handler/api_gateway.py | 101 ++++++----------- .../event_handler/test_api_gateway.py | 104 ++++++++---------- 2 files changed, 82 insertions(+), 123 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index d093ba5c02e..f769fbfddc3 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -3,7 +3,7 @@ import re import zlib from enum import Enum -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union from aws_lambda_powertools.shared.json_encoder import Encoder from aws_lambda_powertools.utilities.data_classes import ALBEvent, APIGatewayProxyEvent, APIGatewayProxyEventV2 @@ -51,24 +51,12 @@ def to_dict(self) -> Dict[str, str]: class Route: def __init__( - self, - method: str, - rule: Any, - func: Callable, - cors: Union[bool, CORSConfig], - compress: bool, - cache_control: Optional[str], + self, method: str, rule: Any, func: Callable, cors: bool, compress: bool, cache_control: Optional[str] ): self.method = method.upper() self.rule = rule self.func = func - self.cors: Optional[CORSConfig] - if cors is True: - self.cors = CORSConfig() - elif isinstance(cors, CORSConfig): - self.cors = cors - else: - self.cors = None + self.cors = cors self.compress = compress self.cache_control = cache_control @@ -98,60 +86,44 @@ def compress(self): self.body = gzip.compress(self.body) + gzip.flush() def to_dict(self) -> Dict[str, Any]: - result = {"statusCode": self.status_code, "headers": self.headers} if isinstance(self.body, bytes): self.base64_encoded = True self.body = base64.b64encode(self.body).decode() - if self.body: - result["isBase64Encoded"] = self.base64_encoded - result["body"] = self.body - return result + return { + "statusCode": self.status_code, + "headers": self.headers, + "body": self.body, + "isBase64Encoded": self.base64_encoded, + } class ApiGatewayResolver: current_event: BaseProxyEvent lambda_context: LambdaContext - def __init__(self, proxy_type: Enum = ProxyEventType.http_api_v1): + def __init__(self, proxy_type: Enum = ProxyEventType.http_api_v1, cors: CORSConfig = None): self._proxy_type = proxy_type self._routes: List[Route] = [] - self._cors: Optional[CORSConfig] = None - self._cors_methods: Set[str] = set() + self.cors = cors - def get(self, rule: str, cors: Union[bool, CORSConfig] = False, compress: bool = False, cache_control: str = None): + def get(self, rule: str, cors: bool = False, compress: bool = False, cache_control: str = None): return self.route(rule, "GET", cors, compress, cache_control) - def post(self, rule: str, cors: Union[bool, CORSConfig] = False, compress: bool = False, cache_control: str = None): + def post(self, rule: str, cors: bool = False, compress: bool = False, cache_control: str = None): return self.route(rule, "POST", cors, compress, cache_control) - def put(self, rule: str, cors: Union[bool, CORSConfig] = False, compress: bool = False, cache_control: str = None): + def put(self, rule: str, cors: bool = False, compress: bool = False, cache_control: str = None): return self.route(rule, "PUT", cors, compress, cache_control) - def delete( - self, rule: str, cors: Union[bool, CORSConfig] = False, compress: bool = False, cache_control: str = None - ): + def delete(self, rule: str, cors: bool = False, compress: bool = False, cache_control: str = None): return self.route(rule, "DELETE", cors, compress, cache_control) - def patch( - self, rule: str, cors: Union[bool, CORSConfig] = False, compress: bool = False, cache_control: str = None - ): + def patch(self, rule: str, cors: bool = False, compress: bool = False, cache_control: str = None): return self.route(rule, "PATCH", cors, compress, cache_control) - def route( - self, - rule: str, - method: str, - cors: Union[bool, CORSConfig] = False, - compress: bool = False, - cache_control: str = None, - ): + def route(self, rule: str, method: str, cors: bool = False, compress: bool = False, cache_control: str = None): def register_resolver(func: Callable): - route = Route(method, self._compile_regex(rule), func, cors, compress, cache_control) - self._routes.append(route) - if route.cors: - if self._cors is None: - self._cors = route.cors - self._cors_methods.add(route.method) + self._routes.append(Route(method, self._compile_regex(rule), func, cors, compress, cache_control)) return func return register_resolver @@ -159,11 +131,12 @@ def register_resolver(func: Callable): def resolve(self, event, context) -> Dict[str, Any]: self.current_event = self._to_data_class(event) self.lambda_context = context - route, args = self._find_route(self.current_event.http_method.upper(), self.current_event.path) - response = self.to_response(route.func(**args)) + route, response = self._execute_route(self.current_event.http_method.upper(), self.current_event.path) + if route is None: + return response.to_dict() if route.cors: - response.add_cors(route.cors) + response.add_cors(self.cors or CORSConfig()) if route.cache_control: response.add_cache_control(route.cache_control) if route.compress and "gzip" in (self.current_event.get_header_value("accept-encoding") or ""): @@ -196,28 +169,24 @@ def _to_data_class(self, event: Dict) -> BaseProxyEvent: return APIGatewayProxyEventV2(event) return ALBEvent(event) - @staticmethod - def _preflight(allowed_methods: Set): - allowed_methods.add("OPTIONS") - headers = {"Access-Control-Allow-Methods": ",".join(sorted(allowed_methods))} - return Response(204, None, None, headers) - - def _find_route(self, method: str, path: str) -> Tuple[Route, Dict]: + def _execute_route(self, method: str, path: str) -> Tuple[Optional[Route], Response]: for route in self._routes: if method != route.method: continue match: Optional[re.Match] = route.rule.match(path) if match: - return route, match.groupdict() - - if method == "OPTIONS" and self._cors is not None: - # Most be the preflight options call - return ( - Route("OPTIONS", None, self._preflight, self._cors, False, None), - {"allowed_methods": self._cors_methods}, - ) - - raise ValueError(f"No route found for '{method}.{path}'") + return route, self.to_response(route.func(**match.groupdict())) + + headers = {} + if self.cors: + headers.update(self.cors.to_dict()) + + return None, Response( + status_code=404, + content_type="application/json", + body=json.dumps({"message": f"No route found for '{method}.{path}'"}), + headers=headers, + ) def __call__(self, event, context) -> Any: return self.resolve(event, context) diff --git a/tests/functional/event_handler/test_api_gateway.py b/tests/functional/event_handler/test_api_gateway.py index 851791a9407..fc3565eacb0 100644 --- a/tests/functional/event_handler/test_api_gateway.py +++ b/tests/functional/event_handler/test_api_gateway.py @@ -5,8 +5,6 @@ from pathlib import Path from typing import Dict, Tuple -import pytest - from aws_lambda_powertools.event_handler.api_gateway import ApiGatewayResolver, CORSConfig, ProxyEventType, Response from aws_lambda_powertools.shared.json_encoder import Encoder from aws_lambda_powertools.utilities.data_classes import ALBEvent, APIGatewayProxyEvent, APIGatewayProxyEventV2 @@ -144,7 +142,7 @@ def patch_func(): raise RuntimeError() def handler(event, context): - app.resolve(event, context) + return app.resolve(event, context) # Also check check the route configurations routes = app._routes @@ -162,9 +160,11 @@ def handler(event, context): assert route.method == "PATCH" # WHEN calling the handler - # THEN raise a ValueError - with pytest.raises(ValueError): - handler(LOAD_GW_EVENT, None) + # THEN return a 404 + result = handler(LOAD_GW_EVENT, None) + assert result["statusCode"] == 404 + # AND cors headers are not returned + assert "Access-Control-Allow-Origin" not in result["headers"] def test_cors(): @@ -188,8 +188,6 @@ def handler(event, context): assert headers["Content-Type"] == TEXT_HTML assert headers["Access-Control-Allow-Origin"] == "*" assert "Access-Control-Allow-Credentials" not in headers - # AND "Access-Control-Allow-Methods" is only included in the preflight cors headers - assert "Access-Control-Allow-Methods" not in headers assert headers["Access-Control-Allow-Headers"] == ",".join(sorted(CORSConfig._REQUIRED_HEADERS)) @@ -342,64 +340,24 @@ def rest_func() -> Response: assert result["body"] == "Not found" -def test_preflight_cors(): - # GIVEN - app = ApiGatewayResolver() - preflight_event = {"path": "/cors", "httpMethod": "OPTIONS"} - - @app.get("/cors", cors=True) - def get_with_cors(): - ... - - @app.post("/cors", cors=True) - def post_with_cors(): - ... - - @app.delete("/cors") - def delete_no_cors(): - ... - - def handler(event, context): - return app.resolve(event, context) - - # WHEN calling the event handler - # AND the httpMethod is OPTIONS - result = handler(preflight_event, None) - - # THEN return the preflight response - # AND No Content it returned - assert result["statusCode"] == 204 - assert "body" not in result - assert "isBase64Encoded" not in result - # AND no Content-Type is set - headers = result["headers"] - assert "headers" in result - assert "Content-Type" not in headers - # AND set the access control headers - assert headers["Access-Control-Allow-Origin"] == "*" - assert headers["Access-Control-Allow-Methods"] == "GET,OPTIONS,POST" - assert "Access-Control-Allow-Credentials" not in headers - - def test_custom_cors_config(): # GIVEN a custom cors configuration - app = ApiGatewayResolver() - event = {"path": "/cors", "httpMethod": "GET"} allow_header = ["foo2"] cors_config = CORSConfig( allow_origin="https://foo1", expose_headers=["foo1"], allow_headers=allow_header, max_age=100, - allow_credentials=False, + allow_credentials=True, ) + app = ApiGatewayResolver(cors=cors_config) + event = {"path": "/cors", "httpMethod": "GET"} - @app.get("/cors", cors=cors_config) + @app.get("/cors", cors=True) def get_with_cors(): return {} - # NOTE: Currently only the first configuration is used for the OPTIONS preflight - @app.get("/another-one", cors=True) + @app.get("/another-one") def another_one(): return {} @@ -415,9 +373,41 @@ def another_one(): assert headers["Access-Control-Allow-Headers"] == expected_allows_headers assert headers["Access-Control-Expose-Headers"] == ",".join(cors_config.expose_headers) assert headers["Access-Control-Max-Age"] == str(cors_config.max_age) - assert "Access-Control-Allow-Credentials" not in headers + assert "Access-Control-Allow-Credentials" in headers + assert headers["Access-Control-Allow-Credentials"] == "true" # AND custom cors was set on the app - assert isinstance(app._cors, CORSConfig) - assert app._cors is cors_config - assert app._cors_methods == {"GET"} + assert isinstance(app.cors, CORSConfig) + assert app.cors is cors_config + # AND routes without cors don't include "Access-Control" headers + event = {"path": "/another-one", "httpMethod": "GET"} + result = app(event, None) + headers = result["headers"] + assert "Access-Control-Allow-Origin" not in headers + + +def test_no_content_response(): + # GIVEN a response with no content-type or body + response = Response(status_code=204, content_type=None, body=None, headers=None) + + # WHEN calling to_dict + result = response.to_dict() + + # THEN return an None body and no Content-Type header + assert result["body"] is None + assert result["statusCode"] == 204 + assert "Content-Type" not in result["headers"] + + +def test_no_matches_with_cors(): + # GIVEN an event that does not match any of the given routes + # AND cors enabled + app = ApiGatewayResolver(cors=CORSConfig()) + + # WHEN calling the handler + result = app({"path": "/another-one", "httpMethod": "GET"}, None) + + # THEN return a 404 + # AND cors headers are returned + assert result["statusCode"] == 404 + assert "Access-Control-Allow-Origin" in result["headers"] From be12f3e5336b05330e7fe41fa70900baa2f1a93c Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Wed, 28 Apr 2021 04:37:50 -0700 Subject: [PATCH 33/36] feat(event-handler): Add auto generated preflight option --- .../event_handler/api_gateway.py | 50 ++++++++++---- .../event_handler/test_api_gateway.py | 65 ++++++++++++++++++- 2 files changed, 102 insertions(+), 13 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index f769fbfddc3..958fa5f1625 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -3,7 +3,7 @@ import re import zlib from enum import Enum -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union from aws_lambda_powertools.shared.json_encoder import Encoder from aws_lambda_powertools.utilities.data_classes import ALBEvent, APIGatewayProxyEvent, APIGatewayProxyEventV2 @@ -19,7 +19,9 @@ class ProxyEventType(Enum): class CORSConfig(object): - _REQUIRED_HEADERS = ["Content-Type", "X-Amz-Date", "Authorization", "X-Api-Key", "X-Amz-Security-Token"] + """CORS Config""" + + _REQUIRED_HEADERS = ["Authorization", "Content-Type", "X-Amz-Date", "X-Api-Key", "X-Amz-Security-Token"] def __init__( self, @@ -29,8 +31,25 @@ def __init__( max_age: int = None, allow_credentials: bool = False, ): + """ + Parameters + ---------- + allow_origin: str + The value of the `Access-Control-Allow-Origin` to send in the response. Defaults to "*", but should + only be used during development. + allow_headers: str + The list of additional allowed headers. This list is added to list of + built in allowed headers: `Authorization`, `Content-Type`, `X-Amz-Date`, + `X-Api-Key`, `X-Amz-Security-Token`. + expose_headers: str + A list of values to return for the Access-Control-Expose-Headers + max_age: int + The value for the `Access-Control-Max-Age` + allow_credentials: bool + A boolean value that sets the value of `Access-Control-Allow-Credentials` + """ self.allow_origin = allow_origin - self.allow_headers = set((allow_headers or []) + self._REQUIRED_HEADERS) + self.allow_headers = set(self._REQUIRED_HEADERS + (allow_headers or [])) self.expose_headers = expose_headers or [] self.max_age = max_age self.allow_credentials = allow_credentials @@ -104,7 +123,8 @@ class ApiGatewayResolver: def __init__(self, proxy_type: Enum = ProxyEventType.http_api_v1, cors: CORSConfig = None): self._proxy_type = proxy_type self._routes: List[Route] = [] - self.cors = cors + self._cors = cors + self._cors_methods: Set[str] = {"OPTIONS"} def get(self, rule: str, cors: bool = False, compress: bool = False, cache_control: str = None): return self.route(rule, "GET", cors, compress, cache_control) @@ -124,6 +144,8 @@ def patch(self, rule: str, cors: bool = False, compress: bool = False, cache_con def route(self, rule: str, method: str, cors: bool = False, compress: bool = False, cache_control: str = None): def register_resolver(func: Callable): self._routes.append(Route(method, self._compile_regex(rule), func, cors, compress, cache_control)) + if cors: + self._cors_methods.add(method.upper()) return func return register_resolver @@ -131,12 +153,12 @@ def register_resolver(func: Callable): def resolve(self, event, context) -> Dict[str, Any]: self.current_event = self._to_data_class(event) self.lambda_context = context - route, response = self._execute_route(self.current_event.http_method.upper(), self.current_event.path) - if route is None: + route, response = self._find_route(self.current_event.http_method.upper(), self.current_event.path) + if route is None: # No matching route was found return response.to_dict() if route.cors: - response.add_cors(self.cors or CORSConfig()) + response.add_cors(self._cors or CORSConfig()) if route.cache_control: response.add_cache_control(route.cache_control) if route.compress and "gzip" in (self.current_event.get_header_value("accept-encoding") or ""): @@ -169,17 +191,20 @@ def _to_data_class(self, event: Dict) -> BaseProxyEvent: return APIGatewayProxyEventV2(event) return ALBEvent(event) - def _execute_route(self, method: str, path: str) -> Tuple[Optional[Route], Response]: + def _find_route(self, method: str, path: str) -> Tuple[Optional[Route], Response]: for route in self._routes: if method != route.method: continue match: Optional[re.Match] = route.rule.match(path) if match: - return route, self.to_response(route.func(**match.groupdict())) + return self._call_route(match, route) headers = {} - if self.cors: - headers.update(self.cors.to_dict()) + if self._cors: + headers.update(self._cors.to_dict()) + if method == "OPTIONS": # Preflight + headers["Access-Control-Allow-Methods"] = ",".join(sorted(self._cors_methods)) + return None, Response(status_code=204, content_type=None, body=None, headers=headers) return None, Response( status_code=404, @@ -188,5 +213,8 @@ def _execute_route(self, method: str, path: str) -> Tuple[Optional[Route], Respo headers=headers, ) + def _call_route(self, match: re.Match, route: Route) -> Tuple[Route, Response]: + return route, self.to_response(route.func(**match.groupdict())) + def __call__(self, event, context) -> Any: return self.resolve(event, context) diff --git a/tests/functional/event_handler/test_api_gateway.py b/tests/functional/event_handler/test_api_gateway.py index fc3565eacb0..df13b047d0d 100644 --- a/tests/functional/event_handler/test_api_gateway.py +++ b/tests/functional/event_handler/test_api_gateway.py @@ -377,8 +377,8 @@ def another_one(): assert headers["Access-Control-Allow-Credentials"] == "true" # AND custom cors was set on the app - assert isinstance(app.cors, CORSConfig) - assert app.cors is cors_config + assert isinstance(app._cors, CORSConfig) + assert app._cors is cors_config # AND routes without cors don't include "Access-Control" headers event = {"path": "/another-one", "httpMethod": "GET"} result = app(event, None) @@ -411,3 +411,64 @@ def test_no_matches_with_cors(): # AND cors headers are returned assert result["statusCode"] == 404 assert "Access-Control-Allow-Origin" in result["headers"] + + +def test_preflight(): + # GIVEN an event for an OPTIONS call that does not match any of the given routes + # AND cors is enabled + app = ApiGatewayResolver(cors=CORSConfig()) + + @app.get("/foo", cors=True) + def foo_cors(): + ... + + @app.route(method="delete", rule="/foo", cors=True) + def foo_delete_cors(): + ... + + @app.post("/foo") + def post_no_cors(): + ... + + # WHEN calling the handler + result = app({"path": "/foo", "httpMethod": "OPTIONS"}, None) + + # THEN return no content + # AND include Access-Control-Allow-Methods of the cors methods used + assert result["statusCode"] == 204 + assert result["body"] is None + headers = result["headers"] + assert "Content-Type" not in headers + assert "Access-Control-Allow-Origin" in result["headers"] + assert headers["Access-Control-Allow-Methods"] == "DELETE,GET,OPTIONS" + + +def test_custom_preflight_response(): + # GIVEN cors is enabled + # AND we have a custom preflight method + # AND the request matches this custom preflight route + app = ApiGatewayResolver(cors=CORSConfig()) + + @app.route(method="OPTIONS", rule="/some-call", cors=True) + def custom_preflight(): + return Response( + status_code=200, + content_type=TEXT_HTML, + body="Foo", + headers={"Access-Control-Allow-Methods": "CUSTOM"}, + ) + + @app.route(method="CUSTOM", rule="/some-call", cors=True) + def custom_method(): + ... + + # WHEN calling the handler + result = app({"path": "/some-call", "httpMethod": "OPTIONS"}, None) + + # THEN return the custom preflight response + assert result["statusCode"] == 200 + assert result["body"] == "Foo" + headers = result["headers"] + assert headers["Content-Type"] == TEXT_HTML + assert "Access-Control-Allow-Origin" in result["headers"] + assert headers["Access-Control-Allow-Methods"] == "CUSTOM" From 2eeee5ce3cd16b331ffe1ce79e06f7c6e9508fa7 Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Wed, 28 Apr 2021 04:45:52 -0700 Subject: [PATCH 34/36] chore: bump ci From fbccaa199ead1cb9393e907426da7f1fdd48c052 Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Wed, 28 Apr 2021 04:52:37 -0700 Subject: [PATCH 35/36] fix(event-handler): make python 3.6 compatible --- aws_lambda_powertools/event_handler/api_gateway.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 958fa5f1625..3d9494606b1 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -197,7 +197,7 @@ def _find_route(self, method: str, path: str) -> Tuple[Optional[Route], Response continue match: Optional[re.Match] = route.rule.match(path) if match: - return self._call_route(match, route) + return self._call_route(route, match.groupdict()) headers = {} if self._cors: @@ -213,8 +213,8 @@ def _find_route(self, method: str, path: str) -> Tuple[Optional[Route], Response headers=headers, ) - def _call_route(self, match: re.Match, route: Route) -> Tuple[Route, Response]: - return route, self.to_response(route.func(**match.groupdict())) + def _call_route(self, route: Route, args: Dict[str, str]) -> Tuple[Route, Response]: + return route, self.to_response(route.func(**args)) def __call__(self, event, context) -> Any: return self.resolve(event, context) From 6ec444cd0cb996b0ed05ae1ce7370ef438518879 Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Wed, 28 Apr 2021 05:06:26 -0700 Subject: [PATCH 36/36] refactor: make more method as _ --- .../event_handler/api_gateway.py | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 3d9494606b1..fc744055e6c 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -166,19 +166,6 @@ def resolve(self, event, context) -> Dict[str, Any]: return response.to_dict() - @staticmethod - def to_response(result: Union[Tuple[int, str, Union[bytes, str]], Dict, Response]) -> Response: - if isinstance(result, Response): - return result - elif isinstance(result, dict): - return Response( - status_code=200, - content_type="application/json", - body=json.dumps(result, separators=(",", ":"), cls=Encoder), - ) - else: # Tuple[int, str, Union[bytes, str]] - return Response(*result) - @staticmethod def _compile_regex(rule: str): rule_regex: str = re.sub(r"(<\w+>)", r"(?P\1.+)", rule) @@ -214,7 +201,20 @@ def _find_route(self, method: str, path: str) -> Tuple[Optional[Route], Response ) def _call_route(self, route: Route, args: Dict[str, str]) -> Tuple[Route, Response]: - return route, self.to_response(route.func(**args)) + return route, self._to_response(route.func(**args)) + + @staticmethod + def _to_response(result: Union[Tuple[int, str, Union[bytes, str]], Dict, Response]) -> Response: + if isinstance(result, Response): + return result + elif isinstance(result, dict): + return Response( + status_code=200, + content_type="application/json", + body=json.dumps(result, separators=(",", ":"), cls=Encoder), + ) + else: # Tuple[int, str, Union[bytes, str]] + return Response(*result) def __call__(self, event, context) -> Any: return self.resolve(event, context)