Skip to content

Commit 01eb5a7

Browse files
committed
Merge branch 'develop' of https://github.com/awslabs/aws-lambda-powertools-python into feat/batch-new-processor
* 'develop' of https://github.com/awslabs/aws-lambda-powertools-python: fix(parser): kinesis sequence number is str, not int (aws-powertools#907) feat(apigateway): add exception_handler support (aws-powertools#898) fix(event-sources): Pass authorizer data to APIGatewayEventAuthorizer (aws-powertools#897) chore(deps): bump fastjsonschema from 2.15.1 to 2.15.2 (aws-powertools#891)
2 parents b0f170e + 99227ce commit 01eb5a7

File tree

9 files changed

+175
-26
lines changed

9 files changed

+175
-26
lines changed

Diff for: aws_lambda_powertools/event_handler/api_gateway.py

+49-14
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@
1010
from enum import Enum
1111
from functools import partial
1212
from http import HTTPStatus
13-
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
13+
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
1414

1515
from aws_lambda_powertools.event_handler import content_types
16-
from aws_lambda_powertools.event_handler.exceptions import ServiceError
16+
from aws_lambda_powertools.event_handler.exceptions import NotFoundError, ServiceError
1717
from aws_lambda_powertools.shared import constants
1818
from aws_lambda_powertools.shared.functions import resolve_truthy_env_var_choice
1919
from aws_lambda_powertools.shared.json_encoder import Encoder
@@ -27,7 +27,6 @@
2727
_SAFE_URI = "-._~()'!*:@,;" # https://www.ietf.org/rfc/rfc3986.txt
2828
# API GW/ALB decode non-safe URI chars; we must support them too
2929
_UNSAFE_URI = "%<>\[\]{}|^" # noqa: W605
30-
3130
_NAMED_GROUP_BOUNDARY_PATTERN = fr"(?P\1[{_SAFE_URI}{_UNSAFE_URI}\\w]+)"
3231

3332

@@ -435,6 +434,7 @@ def __init__(
435434
self._proxy_type = proxy_type
436435
self._routes: List[Route] = []
437436
self._route_keys: List[str] = []
437+
self._exception_handlers: Dict[Type, Callable] = {}
438438
self._cors = cors
439439
self._cors_enabled: bool = cors is not None
440440
self._cors_methods: Set[str] = {"OPTIONS"}
@@ -596,6 +596,10 @@ def _not_found(self, method: str) -> ResponseBuilder:
596596
headers["Access-Control-Allow-Methods"] = ",".join(sorted(self._cors_methods))
597597
return ResponseBuilder(Response(status_code=204, content_type=None, headers=headers, body=None))
598598

599+
handler = self._lookup_exception_handler(NotFoundError)
600+
if handler:
601+
return ResponseBuilder(handler(NotFoundError()))
602+
599603
return ResponseBuilder(
600604
Response(
601605
status_code=HTTPStatus.NOT_FOUND.value,
@@ -609,16 +613,11 @@ def _call_route(self, route: Route, args: Dict[str, str]) -> ResponseBuilder:
609613
"""Actually call the matching route with any provided keyword arguments."""
610614
try:
611615
return ResponseBuilder(self._to_response(route.func(**args)), route)
612-
except ServiceError as e:
613-
return ResponseBuilder(
614-
Response(
615-
status_code=e.status_code,
616-
content_type=content_types.APPLICATION_JSON,
617-
body=self._json_dump({"statusCode": e.status_code, "message": e.msg}),
618-
),
619-
route,
620-
)
621-
except Exception:
616+
except Exception as exc:
617+
response_builder = self._call_exception_handler(exc, route)
618+
if response_builder:
619+
return response_builder
620+
622621
if self._debug:
623622
# If the user has turned on debug mode,
624623
# we'll let the original exception propagate so
@@ -628,10 +627,46 @@ def _call_route(self, route: Route, args: Dict[str, str]) -> ResponseBuilder:
628627
status_code=500,
629628
content_type=content_types.TEXT_PLAIN,
630629
body="".join(traceback.format_exc()),
631-
)
630+
),
631+
route,
632632
)
633+
633634
raise
634635

636+
def not_found(self, func: Callable):
637+
return self.exception_handler(NotFoundError)(func)
638+
639+
def exception_handler(self, exc_class: Type[Exception]):
640+
def register_exception_handler(func: Callable):
641+
self._exception_handlers[exc_class] = func
642+
643+
return register_exception_handler
644+
645+
def _lookup_exception_handler(self, exp_type: Type) -> Optional[Callable]:
646+
# Use "Method Resolution Order" to allow for matching against a base class
647+
# of an exception
648+
for cls in exp_type.__mro__:
649+
if cls in self._exception_handlers:
650+
return self._exception_handlers[cls]
651+
return None
652+
653+
def _call_exception_handler(self, exp: Exception, route: Route) -> Optional[ResponseBuilder]:
654+
handler = self._lookup_exception_handler(type(exp))
655+
if handler:
656+
return ResponseBuilder(handler(exp), route)
657+
658+
if isinstance(exp, ServiceError):
659+
return ResponseBuilder(
660+
Response(
661+
status_code=exp.status_code,
662+
content_type=content_types.APPLICATION_JSON,
663+
body=self._json_dump({"statusCode": exp.status_code, "message": exp.msg}),
664+
),
665+
route,
666+
)
667+
668+
return None
669+
635670
def _to_response(self, result: Union[Dict, Response]) -> Response:
636671
"""Convert the route's result to a Response
637672

Diff for: aws_lambda_powertools/utilities/data_classes/api_gateway_proxy_event.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,22 @@
1111
class APIGatewayEventAuthorizer(DictWrapper):
1212
@property
1313
def claims(self) -> Optional[Dict[str, Any]]:
14-
return self["requestContext"]["authorizer"].get("claims")
14+
return self.get("claims")
1515

1616
@property
1717
def scopes(self) -> Optional[List[str]]:
18-
return self["requestContext"]["authorizer"].get("scopes")
18+
return self.get("scopes")
19+
20+
@property
21+
def principal_id(self) -> Optional[str]:
22+
"""The principal user identification associated with the token sent by the client and returned from an
23+
API Gateway Lambda authorizer (formerly known as a custom authorizer)"""
24+
return self.get("principalId")
25+
26+
@property
27+
def integration_latency(self) -> Optional[int]:
28+
"""The authorizer latency in ms."""
29+
return self.get("integrationLatency")
1930

2031

2132
class APIGatewayEventRequestContext(BaseRequestContext):
@@ -56,7 +67,7 @@ def route_key(self) -> Optional[str]:
5667

5768
@property
5869
def authorizer(self) -> APIGatewayEventAuthorizer:
59-
return APIGatewayEventAuthorizer(self._data)
70+
return APIGatewayEventAuthorizer(self._data["requestContext"]["authorizer"])
6071

6172

6273
class APIGatewayProxyEvent(BaseProxyEvent):

Diff for: aws_lambda_powertools/utilities/data_classes/common.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ def __eq__(self, other: Any) -> bool:
1818

1919
return self._data == other._data
2020

21-
def get(self, key: str) -> Optional[Any]:
22-
return self._data.get(key)
21+
def get(self, key: str, default: Optional[Any] = None) -> Optional[Any]:
22+
return self._data.get(key, default)
2323

2424
@property
2525
def raw_event(self) -> Dict[str, Any]:

Diff for: aws_lambda_powertools/utilities/parser/models/kinesis.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from typing import List, Union
55

66
from pydantic import BaseModel, validator
7-
from pydantic.types import PositiveInt
87

98
from aws_lambda_powertools.utilities.parser.types import Literal, Model
109

@@ -14,7 +13,7 @@
1413
class KinesisDataStreamRecordPayload(BaseModel):
1514
kinesisSchemaVersion: str
1615
partitionKey: str
17-
sequenceNumber: PositiveInt
16+
sequenceNumber: str
1817
data: Union[bytes, Model] # base64 encoded str is parsed into bytes
1918
approximateArrivalTimestamp: float
2019

Diff for: poetry.lock

+7-3
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Diff for: tests/events/apiGatewayProxyEventPrincipalId.json

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
{
2+
"resource": "/trip",
3+
"path": "/trip",
4+
"httpMethod": "POST",
5+
"requestContext": {
6+
"requestId": "34972478-2843-4ced-a657-253108738274",
7+
"authorizer": {
8+
"user_id": "fake_username",
9+
"principalId": "fake",
10+
"integrationLatency": 451
11+
}
12+
}
13+
}

Diff for: tests/functional/event_handler/test_api_gateway.py

+74-1
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def patch_func():
163163
def handler(event, context):
164164
return app.resolve(event, context)
165165

166-
# Also check check the route configurations
166+
# Also check the route configurations
167167
routes = app._routes
168168
assert len(routes) == 5
169169
for route in routes:
@@ -1076,3 +1076,76 @@ def foo():
10761076

10771077
assert result["statusCode"] == 200
10781078
assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON
1079+
1080+
1081+
def test_exception_handler():
1082+
# GIVEN a resolver with an exception handler defined for ValueError
1083+
app = ApiGatewayResolver()
1084+
1085+
@app.exception_handler(ValueError)
1086+
def handle_value_error(ex: ValueError):
1087+
print(f"request path is '{app.current_event.path}'")
1088+
return Response(
1089+
status_code=418,
1090+
content_type=content_types.TEXT_HTML,
1091+
body=str(ex),
1092+
)
1093+
1094+
@app.get("/my/path")
1095+
def get_lambda() -> Response:
1096+
raise ValueError("Foo!")
1097+
1098+
# WHEN calling the event handler
1099+
# AND a ValueError is raised
1100+
result = app(LOAD_GW_EVENT, {})
1101+
1102+
# THEN call the exception_handler
1103+
assert result["statusCode"] == 418
1104+
assert result["headers"]["Content-Type"] == content_types.TEXT_HTML
1105+
assert result["body"] == "Foo!"
1106+
1107+
1108+
def test_exception_handler_service_error():
1109+
# GIVEN
1110+
app = ApiGatewayResolver()
1111+
1112+
@app.exception_handler(ServiceError)
1113+
def service_error(ex: ServiceError):
1114+
print(ex.msg)
1115+
return Response(
1116+
status_code=ex.status_code,
1117+
content_type=content_types.APPLICATION_JSON,
1118+
body="CUSTOM ERROR FORMAT",
1119+
)
1120+
1121+
@app.get("/my/path")
1122+
def get_lambda() -> Response:
1123+
raise InternalServerError("Something sensitive")
1124+
1125+
# WHEN calling the event handler
1126+
# AND a ServiceError is raised
1127+
result = app(LOAD_GW_EVENT, {})
1128+
1129+
# THEN call the exception_handler
1130+
assert result["statusCode"] == 500
1131+
assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON
1132+
assert result["body"] == "CUSTOM ERROR FORMAT"
1133+
1134+
1135+
def test_exception_handler_not_found():
1136+
# GIVEN a resolver with an exception handler defined for a 404 not found
1137+
app = ApiGatewayResolver()
1138+
1139+
@app.not_found
1140+
def handle_not_found(exc: NotFoundError) -> Response:
1141+
assert isinstance(exc, NotFoundError)
1142+
return Response(status_code=404, content_type=content_types.TEXT_PLAIN, body="I am a teapot!")
1143+
1144+
# WHEN calling the event handler
1145+
# AND not route is found
1146+
result = app(LOAD_GW_EVENT, {})
1147+
1148+
# THEN call the exception_handler
1149+
assert result["statusCode"] == 404
1150+
assert result["headers"]["Content-Type"] == content_types.TEXT_PLAIN
1151+
assert result["body"] == "I am a teapot!"

Diff for: tests/functional/parser/test_kinesis.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def handle_kinesis_no_envelope(event: KinesisDataStreamModel, _: LambdaContext):
3535
assert kinesis.approximateArrivalTimestamp == 1545084650.987
3636
assert kinesis.kinesisSchemaVersion == "1.0"
3737
assert kinesis.partitionKey == "1"
38-
assert kinesis.sequenceNumber == 49590338271490256608559692538361571095921575989136588898
38+
assert kinesis.sequenceNumber == "49590338271490256608559692538361571095921575989136588898"
3939
assert kinesis.data == b"Hello, this is a test."
4040

4141

Diff for: tests/functional/test_data_classes.py

+14
Original file line numberDiff line numberDiff line change
@@ -897,6 +897,20 @@ def test_api_gateway_proxy_event():
897897
assert request_context.identity.client_cert.subject_dn == "www.example.com"
898898

899899

900+
def test_api_gateway_proxy_event_with_principal_id():
901+
event = APIGatewayProxyEvent(load_event("apiGatewayProxyEventPrincipalId.json"))
902+
903+
request_context = event.request_context
904+
authorizer = request_context.authorizer
905+
assert authorizer.claims is None
906+
assert authorizer.scopes is None
907+
assert authorizer["principalId"] == "fake"
908+
assert authorizer.get("principalId") == "fake"
909+
assert authorizer.principal_id == "fake"
910+
assert authorizer.integration_latency == 451
911+
assert authorizer.get("integrationStatus", "failed") == "failed"
912+
913+
900914
def test_api_gateway_proxy_v2_event():
901915
event = APIGatewayProxyEventV2(load_event("apiGatewayProxyV2Event.json"))
902916

0 commit comments

Comments
 (0)