Skip to content

Commit 33f80fd

Browse files
author
Michael Brewer
authored
feat(api-gateway): add common service errors (#506)
1 parent 2473480 commit 33f80fd

File tree

5 files changed

+175
-18
lines changed

5 files changed

+175
-18
lines changed

Diff for: aws_lambda_powertools/event_handler/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Event handler decorators for common Lambda events
33
"""
44

5+
from .api_gateway import ApiGatewayResolver
56
from .appsync import AppSyncResolver
67

7-
__all__ = ["AppSyncResolver"]
8+
__all__ = ["AppSyncResolver", "ApiGatewayResolver"]

Diff for: aws_lambda_powertools/event_handler/api_gateway.py

+25-8
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,11 @@
44
import re
55
import zlib
66
from enum import Enum
7+
from http import HTTPStatus
78
from typing import Any, Callable, Dict, List, Optional, Set, Union
89

10+
from aws_lambda_powertools.event_handler import content_types
11+
from aws_lambda_powertools.event_handler.exceptions import ServiceError
912
from aws_lambda_powertools.shared.json_encoder import Encoder
1013
from aws_lambda_powertools.utilities.data_classes import ALBEvent, APIGatewayProxyEvent, APIGatewayProxyEventV2
1114
from aws_lambda_powertools.utilities.data_classes.common import BaseProxyEvent
@@ -466,19 +469,28 @@ def _not_found(self, method: str) -> ResponseBuilder:
466469

467470
return ResponseBuilder(
468471
Response(
469-
status_code=404,
470-
content_type="application/json",
472+
status_code=HTTPStatus.NOT_FOUND.value,
473+
content_type=content_types.APPLICATION_JSON,
471474
headers=headers,
472-
body=json.dumps({"message": "Not found"}),
475+
body=self._json_dump({"statusCode": HTTPStatus.NOT_FOUND.value, "message": "Not found"}),
473476
)
474477
)
475478

476479
def _call_route(self, route: Route, args: Dict[str, str]) -> ResponseBuilder:
477480
"""Actually call the matching route with any provided keyword arguments."""
478-
return ResponseBuilder(self._to_response(route.func(**args)), route)
481+
try:
482+
return ResponseBuilder(self._to_response(route.func(**args)), route)
483+
except ServiceError as e:
484+
return ResponseBuilder(
485+
Response(
486+
status_code=e.status_code,
487+
content_type=content_types.APPLICATION_JSON,
488+
body=self._json_dump({"statusCode": e.status_code, "message": e.msg}),
489+
),
490+
route,
491+
)
479492

480-
@staticmethod
481-
def _to_response(result: Union[Dict, Response]) -> Response:
493+
def _to_response(self, result: Union[Dict, Response]) -> Response:
482494
"""Convert the route's result to a Response
483495
484496
2 main result types are supported:
@@ -493,6 +505,11 @@ def _to_response(result: Union[Dict, Response]) -> Response:
493505
logger.debug("Simple response detected, serializing return before constructing final response")
494506
return Response(
495507
status_code=200,
496-
content_type="application/json",
497-
body=json.dumps(result, separators=(",", ":"), cls=Encoder),
508+
content_type=content_types.APPLICATION_JSON,
509+
body=self._json_dump(result),
498510
)
511+
512+
@staticmethod
513+
def _json_dump(obj: Any) -> str:
514+
"""Does a concise json serialization"""
515+
return json.dumps(obj, separators=(",", ":"), cls=Encoder)

Diff for: aws_lambda_powertools/event_handler/content_types.py

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
APPLICATION_JSON = "application/json"
2+
PLAIN_TEXT = "plain/text"

Diff for: aws_lambda_powertools/event_handler/exceptions.py

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from http import HTTPStatus
2+
3+
4+
class ServiceError(Exception):
5+
"""Service Error"""
6+
7+
def __init__(self, status_code: int, msg: str):
8+
"""
9+
Parameters
10+
----------
11+
status_code: int
12+
Http status code
13+
msg: str
14+
Error message
15+
"""
16+
self.status_code = status_code
17+
self.msg = msg
18+
19+
20+
class BadRequestError(ServiceError):
21+
"""Bad Request Error"""
22+
23+
def __init__(self, msg: str):
24+
super().__init__(HTTPStatus.BAD_REQUEST, msg)
25+
26+
27+
class UnauthorizedError(ServiceError):
28+
"""Unauthorized Error"""
29+
30+
def __init__(self, msg: str):
31+
super().__init__(HTTPStatus.UNAUTHORIZED, msg)
32+
33+
34+
class NotFoundError(ServiceError):
35+
"""Not Found Error"""
36+
37+
def __init__(self, msg: str = "Not found"):
38+
super().__init__(HTTPStatus.NOT_FOUND, msg)
39+
40+
41+
class InternalServerError(ServiceError):
42+
"""Internal Server Error"""
43+
44+
def __init__(self, message: str):
45+
super().__init__(HTTPStatus.INTERNAL_SERVER_ERROR, message)

Diff for: tests/functional/event_handler/test_api_gateway.py

+101-9
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,21 @@
55
from pathlib import Path
66
from typing import Dict
77

8+
from aws_lambda_powertools.event_handler import content_types
89
from aws_lambda_powertools.event_handler.api_gateway import (
910
ApiGatewayResolver,
1011
CORSConfig,
1112
ProxyEventType,
1213
Response,
1314
ResponseBuilder,
1415
)
16+
from aws_lambda_powertools.event_handler.exceptions import (
17+
BadRequestError,
18+
InternalServerError,
19+
NotFoundError,
20+
ServiceError,
21+
UnauthorizedError,
22+
)
1523
from aws_lambda_powertools.shared.json_encoder import Encoder
1624
from aws_lambda_powertools.utilities.data_classes import ALBEvent, APIGatewayProxyEvent, APIGatewayProxyEventV2
1725
from tests.functional.utils import load_event
@@ -24,7 +32,6 @@ def read_media(file_name: str) -> bytes:
2432

2533
LOAD_GW_EVENT = load_event("apiGatewayProxyEvent.json")
2634
TEXT_HTML = "text/html"
27-
APPLICATION_JSON = "application/json"
2835

2936

3037
def test_alb_event():
@@ -55,15 +62,15 @@ def test_api_gateway_v1():
5562
def get_lambda() -> Response:
5663
assert isinstance(app.current_event, APIGatewayProxyEvent)
5764
assert app.lambda_context == {}
58-
return Response(200, APPLICATION_JSON, json.dumps({"foo": "value"}))
65+
return Response(200, content_types.APPLICATION_JSON, json.dumps({"foo": "value"}))
5966

6067
# WHEN calling the event handler
6168
result = app(LOAD_GW_EVENT, {})
6269

6370
# THEN process event correctly
6471
# AND set the current_event type as APIGatewayProxyEvent
6572
assert result["statusCode"] == 200
66-
assert result["headers"]["Content-Type"] == APPLICATION_JSON
73+
assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON
6774

6875

6976
def test_api_gateway():
@@ -93,15 +100,15 @@ def test_api_gateway_v2():
93100
def my_path() -> Response:
94101
assert isinstance(app.current_event, APIGatewayProxyEventV2)
95102
post_data = app.current_event.json_body
96-
return Response(200, "plain/text", post_data["username"])
103+
return Response(200, content_types.PLAIN_TEXT, post_data["username"])
97104

98105
# WHEN calling the event handler
99106
result = app(load_event("apiGatewayProxyV2Event.json"), {})
100107

101108
# THEN process event correctly
102109
# AND set the current_event type as APIGatewayProxyEventV2
103110
assert result["statusCode"] == 200
104-
assert result["headers"]["Content-Type"] == "plain/text"
111+
assert result["headers"]["Content-Type"] == content_types.PLAIN_TEXT
105112
assert result["body"] == "tom"
106113

107114

@@ -215,7 +222,7 @@ def test_compress():
215222

216223
@app.get("/my/request", compress=True)
217224
def with_compression() -> Response:
218-
return Response(200, APPLICATION_JSON, expected_value)
225+
return Response(200, content_types.APPLICATION_JSON, expected_value)
219226

220227
def handler(event, context):
221228
return app.resolve(event, context)
@@ -261,7 +268,7 @@ def test_compress_no_accept_encoding():
261268

262269
@app.get("/my/path", compress=True)
263270
def return_text() -> Response:
264-
return Response(200, "text/plain", expected_value)
271+
return Response(200, content_types.PLAIN_TEXT, expected_value)
265272

266273
# WHEN calling the event handler
267274
result = app({"path": "/my/path", "httpMethod": "GET", "headers": {}}, None)
@@ -327,7 +334,7 @@ def rest_func() -> Dict:
327334

328335
# THEN automatically process this as a json rest api response
329336
assert result["statusCode"] == 200
330-
assert result["headers"]["Content-Type"] == APPLICATION_JSON
337+
assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON
331338
expected_str = json.dumps(expected_dict, separators=(",", ":"), indent=None, cls=Encoder)
332339
assert result["body"] == expected_str
333340

@@ -382,7 +389,7 @@ def another_one():
382389
# THEN routes by default return the custom cors headers
383390
assert "headers" in result
384391
headers = result["headers"]
385-
assert headers["Content-Type"] == APPLICATION_JSON
392+
assert headers["Content-Type"] == content_types.APPLICATION_JSON
386393
assert headers["Access-Control-Allow-Origin"] == cors_config.allow_origin
387394
expected_allows_headers = ",".join(sorted(set(allow_header + cors_config._REQUIRED_HEADERS)))
388395
assert headers["Access-Control-Allow-Headers"] == expected_allows_headers
@@ -429,6 +436,7 @@ def test_no_matches_with_cors():
429436
# AND cors headers are returned
430437
assert result["statusCode"] == 404
431438
assert "Access-Control-Allow-Origin" in result["headers"]
439+
assert "Not found" in result["body"]
432440

433441

434442
def test_cors_preflight():
@@ -490,3 +498,87 @@ def custom_method():
490498
assert headers["Content-Type"] == TEXT_HTML
491499
assert "Access-Control-Allow-Origin" in result["headers"]
492500
assert headers["Access-Control-Allow-Methods"] == "CUSTOM"
501+
502+
503+
def test_service_error_responses():
504+
# SCENARIO handling different kind of service errors being raised
505+
app = ApiGatewayResolver(cors=CORSConfig())
506+
507+
def json_dump(obj):
508+
return json.dumps(obj, separators=(",", ":"))
509+
510+
# GIVEN an BadRequestError
511+
@app.get(rule="/bad-request-error", cors=False)
512+
def bad_request_error():
513+
raise BadRequestError("Missing required parameter")
514+
515+
# WHEN calling the handler
516+
# AND path is /bad-request-error
517+
result = app({"path": "/bad-request-error", "httpMethod": "GET"}, None)
518+
# THEN return the bad request error response
519+
# AND status code equals 400
520+
assert result["statusCode"] == 400
521+
assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON
522+
expected = {"statusCode": 400, "message": "Missing required parameter"}
523+
assert result["body"] == json_dump(expected)
524+
525+
# GIVEN an UnauthorizedError
526+
@app.get(rule="/unauthorized-error", cors=False)
527+
def unauthorized_error():
528+
raise UnauthorizedError("Unauthorized")
529+
530+
# WHEN calling the handler
531+
# AND path is /unauthorized-error
532+
result = app({"path": "/unauthorized-error", "httpMethod": "GET"}, None)
533+
# THEN return the unauthorized error response
534+
# AND status code equals 401
535+
assert result["statusCode"] == 401
536+
assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON
537+
expected = {"statusCode": 401, "message": "Unauthorized"}
538+
assert result["body"] == json_dump(expected)
539+
540+
# GIVEN an NotFoundError
541+
@app.get(rule="/not-found-error", cors=False)
542+
def not_found_error():
543+
raise NotFoundError
544+
545+
# WHEN calling the handler
546+
# AND path is /not-found-error
547+
result = app({"path": "/not-found-error", "httpMethod": "GET"}, None)
548+
# THEN return the not found error response
549+
# AND status code equals 404
550+
assert result["statusCode"] == 404
551+
assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON
552+
expected = {"statusCode": 404, "message": "Not found"}
553+
assert result["body"] == json_dump(expected)
554+
555+
# GIVEN an InternalServerError
556+
@app.get(rule="/internal-server-error", cors=False)
557+
def internal_server_error():
558+
raise InternalServerError("Internal server error")
559+
560+
# WHEN calling the handler
561+
# AND path is /internal-server-error
562+
result = app({"path": "/internal-server-error", "httpMethod": "GET"}, None)
563+
# THEN return the internal server error response
564+
# AND status code equals 500
565+
assert result["statusCode"] == 500
566+
assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON
567+
expected = {"statusCode": 500, "message": "Internal server error"}
568+
assert result["body"] == json_dump(expected)
569+
570+
# GIVEN an ServiceError with a custom status code
571+
@app.get(rule="/service-error", cors=True)
572+
def service_error():
573+
raise ServiceError(502, "Something went wrong!")
574+
575+
# WHEN calling the handler
576+
# AND path is /service-error
577+
result = app({"path": "/service-error", "httpMethod": "GET"}, None)
578+
# THEN return the service error response
579+
# AND status code equals 502
580+
assert result["statusCode"] == 502
581+
assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON
582+
assert "Access-Control-Allow-Origin" in result["headers"]
583+
expected = {"statusCode": 502, "message": "Something went wrong!"}
584+
assert result["body"] == json_dump(expected)

0 commit comments

Comments
 (0)