Skip to content

Commit 9625d37

Browse files
walmslesheitorlessaleandrodamascena
authored
feat(event_handler): add Middleware support for REST Event Handler (#2917)
Co-authored-by: Heitor Lessa <[email protected]> Co-authored-by: Heitor Lessa <[email protected]> Co-authored-by: Leandro Damascena <[email protected]>
1 parent 21fa25d commit 9625d37

File tree

44 files changed

+3743
-50
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+3743
-50
lines changed

aws_lambda_powertools/event_handler/api_gateway.py

+381-29
Large diffs are not rendered by default.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from aws_lambda_powertools.event_handler.middlewares.base import BaseMiddlewareHandler, NextMiddleware
2+
3+
__all__ = ["BaseMiddlewareHandler", "NextMiddleware"]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Generic
3+
4+
from typing_extensions import Protocol
5+
6+
from aws_lambda_powertools.event_handler.api_gateway import Response
7+
from aws_lambda_powertools.event_handler.types import EventHandlerInstance
8+
9+
10+
class NextMiddleware(Protocol):
11+
def __call__(self, app: EventHandlerInstance) -> Response:
12+
"""Protocol for callback regardless of next_middleware(app), get_response(app) etc"""
13+
...
14+
15+
def __name__(self) -> str: # noqa A003
16+
"""Protocol for name of the Middleware"""
17+
...
18+
19+
20+
class BaseMiddlewareHandler(Generic[EventHandlerInstance], ABC):
21+
"""Base implementation for Middlewares to run code before and after in a chain.
22+
23+
24+
This is the middleware handler function where middleware logic is implemented.
25+
The next middleware handler is represented by `next_middleware`, returning a Response object.
26+
27+
Examples
28+
--------
29+
30+
**Correlation ID Middleware**
31+
32+
```python
33+
import requests
34+
35+
from aws_lambda_powertools import Logger
36+
from aws_lambda_powertools.event_handler import APIGatewayRestResolver, Response
37+
from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler, NextMiddleware
38+
39+
app = APIGatewayRestResolver()
40+
logger = Logger()
41+
42+
43+
class CorrelationIdMiddleware(BaseMiddlewareHandler):
44+
def __init__(self, header: str):
45+
super().__init__()
46+
self.header = header
47+
48+
def handler(self, app: APIGatewayRestResolver, next_middleware: NextMiddleware) -> Response:
49+
# BEFORE logic
50+
request_id = app.current_event.request_context.request_id
51+
correlation_id = app.current_event.get_header_value(
52+
name=self.header,
53+
default_value=request_id,
54+
)
55+
56+
# Call next middleware or route handler ('/todos')
57+
response = next_middleware(app)
58+
59+
# AFTER logic
60+
response.headers[self.header] = correlation_id
61+
62+
return response
63+
64+
65+
@app.get("/todos", middlewares=[CorrelationIdMiddleware(header="x-correlation-id")])
66+
def get_todos():
67+
todos: requests.Response = requests.get("https://jsonplaceholder.typicode.com/todos")
68+
todos.raise_for_status()
69+
70+
# for brevity, we'll limit to the first 10 only
71+
return {"todos": todos.json()[:10]}
72+
73+
74+
@logger.inject_lambda_context
75+
def lambda_handler(event, context):
76+
return app.resolve(event, context)
77+
78+
```
79+
80+
"""
81+
82+
@abstractmethod
83+
def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> Response:
84+
"""
85+
The Middleware Handler
86+
87+
Parameters
88+
----------
89+
app: EventHandlerInstance
90+
An instance of an Event Handler that implements ApiGatewayResolver
91+
next_middleware: NextMiddleware
92+
The next middleware handler in the chain
93+
94+
Returns
95+
-------
96+
Response
97+
The response from the next middleware handler in the chain
98+
99+
"""
100+
raise NotImplementedError()
101+
102+
@property
103+
def __name__(self) -> str: # noqa A003
104+
return str(self.__class__.__name__)
105+
106+
def __call__(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> Response:
107+
"""
108+
The Middleware handler function.
109+
110+
Parameters
111+
----------
112+
app: ApiGatewayResolver
113+
An instance of an Event Handler that implements ApiGatewayResolver
114+
next_middleware: NextMiddleware
115+
The next middleware handler in the chain
116+
117+
Returns
118+
-------
119+
Response
120+
The response from the next middleware handler in the chain
121+
"""
122+
return self.handler(app, next_middleware)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
import logging
2+
from typing import Dict, Optional
3+
4+
from aws_lambda_powertools.event_handler.api_gateway import Response
5+
from aws_lambda_powertools.event_handler.exceptions import BadRequestError, InternalServerError
6+
from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler, NextMiddleware
7+
from aws_lambda_powertools.event_handler.types import EventHandlerInstance
8+
from aws_lambda_powertools.utilities.validation import validate
9+
from aws_lambda_powertools.utilities.validation.exceptions import InvalidSchemaFormatError, SchemaValidationError
10+
11+
logger = logging.getLogger(__name__)
12+
13+
14+
class SchemaValidationMiddleware(BaseMiddlewareHandler):
15+
"""Middleware to validate API request and response against JSON Schema using the [Validation utility](https://docs.powertools.aws.dev/lambda/python/latest/utilities/validation/).
16+
17+
Examples
18+
--------
19+
**Validating incoming event**
20+
21+
```python
22+
import requests
23+
24+
from aws_lambda_powertools import Logger
25+
from aws_lambda_powertools.event_handler import APIGatewayRestResolver, Response
26+
from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler, NextMiddleware
27+
from aws_lambda_powertools.event_handler.middlewares.schema_validation import SchemaValidationMiddleware
28+
29+
app = APIGatewayRestResolver()
30+
logger = Logger()
31+
json_schema_validation = SchemaValidationMiddleware(inbound_schema=INCOMING_JSON_SCHEMA)
32+
33+
34+
@app.get("/todos", middlewares=[json_schema_validation])
35+
def get_todos():
36+
todos: requests.Response = requests.get("https://jsonplaceholder.typicode.com/todos")
37+
todos.raise_for_status()
38+
39+
# for brevity, we'll limit to the first 10 only
40+
return {"todos": todos.json()[:10]}
41+
42+
43+
@logger.inject_lambda_context
44+
def lambda_handler(event, context):
45+
return app.resolve(event, context)
46+
```
47+
"""
48+
49+
def __init__(
50+
self,
51+
inbound_schema: Dict,
52+
inbound_formats: Optional[Dict] = None,
53+
outbound_schema: Optional[Dict] = None,
54+
outbound_formats: Optional[Dict] = None,
55+
):
56+
"""See [Validation utility](https://docs.powertools.aws.dev/lambda/python/latest/utilities/validation/) docs for examples on all parameters.
57+
58+
Parameters
59+
----------
60+
inbound_schema : Dict
61+
JSON Schema to validate incoming event
62+
inbound_formats : Optional[Dict], optional
63+
Custom formats containing a key (e.g. int64) and a value expressed as regex or callback returning bool, by default None
64+
JSON Schema to validate outbound event, by default None
65+
outbound_formats : Optional[Dict], optional
66+
Custom formats containing a key (e.g. int64) and a value expressed as regex or callback returning bool, by default None
67+
""" # noqa: E501
68+
super().__init__()
69+
self.inbound_schema = inbound_schema
70+
self.inbound_formats = inbound_formats
71+
self.outbound_schema = outbound_schema
72+
self.outbound_formats = outbound_formats
73+
74+
def bad_response(self, error: SchemaValidationError) -> Response:
75+
message: str = f"Bad Response: {error.message}"
76+
logger.debug(message)
77+
raise BadRequestError(message)
78+
79+
def bad_request(self, error: SchemaValidationError) -> Response:
80+
message: str = f"Bad Request: {error.message}"
81+
logger.debug(message)
82+
raise BadRequestError(message)
83+
84+
def bad_config(self, error: InvalidSchemaFormatError) -> Response:
85+
logger.debug(f"Invalid Schema Format: {error}")
86+
raise InternalServerError("Internal Server Error")
87+
88+
def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> Response:
89+
"""Validates incoming JSON payload (body) against JSON Schema provided.
90+
91+
Parameters
92+
----------
93+
app : EventHandlerInstance
94+
An instance of an Event Handler
95+
next_middleware : NextMiddleware
96+
Callable to get response from the next middleware or route handler in the chain
97+
98+
Returns
99+
-------
100+
Response
101+
It can return three types of response objects
102+
103+
- Original response: Propagates HTTP response returned from the next middleware if validation succeeds
104+
- HTTP 400: Payload or response failed JSON Schema validation
105+
- HTTP 500: JSON Schema provided has incorrect format
106+
"""
107+
try:
108+
validate(event=app.current_event.json_body, schema=self.inbound_schema, formats=self.inbound_formats)
109+
except SchemaValidationError as error:
110+
return self.bad_request(error)
111+
except InvalidSchemaFormatError as error:
112+
return self.bad_config(error)
113+
114+
result = next_middleware(app)
115+
116+
if self.outbound_formats is not None:
117+
try:
118+
validate(event=result.body, schema=self.inbound_schema, formats=self.inbound_formats)
119+
except SchemaValidationError as error:
120+
return self.bad_response(error)
121+
except InvalidSchemaFormatError as error:
122+
return self.bad_config(error)
123+
124+
return result
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from typing import TypeVar
2+
3+
from aws_lambda_powertools.event_handler import ApiGatewayResolver
4+
5+
EventHandlerInstance = TypeVar("EventHandlerInstance", bound=ApiGatewayResolver)

aws_lambda_powertools/shared/types.py

+9
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,14 @@
1+
import sys
12
from typing import Any, Callable, Dict, List, TypeVar, Union
23

34
AnyCallableT = TypeVar("AnyCallableT", bound=Callable[..., Any]) # noqa: VNE001
45
# JSON primitives only, mypy doesn't support recursive tho
56
JSONType = Union[str, int, float, bool, None, Dict[str, Any], List[Any]]
7+
8+
9+
if sys.version_info >= (3, 8):
10+
from typing import Protocol
11+
else:
12+
from typing_extensions import Protocol
13+
14+
__all__ = ["Protocol"]

aws_lambda_powertools/utilities/data_classes/common.py

+19-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import base64
22
import json
33
from collections.abc import Mapping
4-
from typing import Any, Callable, Dict, Iterator, List, Optional
4+
from typing import Any, Callable, Dict, Iterator, List, Optional, overload
55

66
from aws_lambda_powertools.shared.headers_serializer import BaseHeadersSerializer
77
from aws_lambda_powertools.utilities.data_classes.shared_functions import (
@@ -156,7 +156,24 @@ def get_query_string_value(self, name: str, default_value: Optional[str] = None)
156156
default_value=default_value,
157157
)
158158

159-
# Maintenance: missing @overload to ensure return type is a str when default_value is set
159+
@overload
160+
def get_header_value(
161+
self,
162+
name: str,
163+
default_value: str,
164+
case_sensitive: Optional[bool] = False,
165+
) -> str:
166+
...
167+
168+
@overload
169+
def get_header_value(
170+
self,
171+
name: str,
172+
default_value: Optional[str] = None,
173+
case_sensitive: Optional[bool] = False,
174+
) -> Optional[str]:
175+
...
176+
160177
def get_header_value(
161178
self,
162179
name: str,

aws_lambda_powertools/utilities/data_classes/vpc_lattice.py

+19-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, Optional
1+
from typing import Any, Dict, Optional, overload
22

33
from aws_lambda_powertools.shared.headers_serializer import (
44
BaseHeadersSerializer,
@@ -91,6 +91,24 @@ def get_query_string_value(self, name: str, default_value: Optional[str] = None)
9191
default_value=default_value,
9292
)
9393

94+
@overload
95+
def get_header_value(
96+
self,
97+
name: str,
98+
default_value: str,
99+
case_sensitive: Optional[bool] = False,
100+
) -> str:
101+
...
102+
103+
@overload
104+
def get_header_value(
105+
self,
106+
name: str,
107+
default_value: Optional[str] = None,
108+
case_sensitive: Optional[bool] = False,
109+
) -> Optional[str]:
110+
...
111+
94112
def get_header_value(
95113
self,
96114
name: str,

0 commit comments

Comments
 (0)