-
Notifications
You must be signed in to change notification settings - Fork 421
/
Copy pathbase.py
121 lines (85 loc) · 3.64 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Generic, Protocol
from aws_lambda_powertools.event_handler.types import EventHandlerInstance
if TYPE_CHECKING:
from aws_lambda_powertools.event_handler.api_gateway import Response
class NextMiddleware(Protocol):
def __call__(self, app: EventHandlerInstance) -> Response:
"""Protocol for callback regardless of next_middleware(app), get_response(app) etc"""
...
def __name__(self) -> str: # noqa A003
"""Protocol for name of the Middleware"""
...
class BaseMiddlewareHandler(Generic[EventHandlerInstance], ABC):
"""Base implementation for Middlewares to run code before and after in a chain.
This is the middleware handler function where middleware logic is implemented.
The next middleware handler is represented by `next_middleware`, returning a Response object.
Example
--------
**Correlation ID Middleware**
```python
import requests
from aws_lambda_powertools import Logger
from aws_lambda_powertools.event_handler import APIGatewayRestResolver, Response
from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler, NextMiddleware
app = APIGatewayRestResolver()
logger = Logger()
class CorrelationIdMiddleware(BaseMiddlewareHandler):
def __init__(self, header: str):
super().__init__()
self.header = header
def handler(self, app: APIGatewayRestResolver, next_middleware: NextMiddleware) -> Response:
# BEFORE logic
request_id = app.current_event.request_context.request_id
correlation_id = app.current_event.headers.get(self.header, request_id)
# Call next middleware or route handler ('/todos')
response = next_middleware(app)
# AFTER logic
response.headers[self.header] = correlation_id
return response
@app.get("/todos", middlewares=[CorrelationIdMiddleware(header="x-correlation-id")])
def get_todos():
todos: requests.Response = requests.get("https://jsonplaceholder.typicode.com/todos")
todos.raise_for_status()
# for brevity, we'll limit to the first 10 only
return {"todos": todos.json()[:10]}
@logger.inject_lambda_context
def lambda_handler(event, context):
return app.resolve(event, context)
```
"""
@abstractmethod
def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> Response:
"""
The Middleware Handler
Parameters
----------
app: EventHandlerInstance
An instance of an Event Handler that implements ApiGatewayResolver
next_middleware: NextMiddleware
The next middleware handler in the chain
Returns
-------
Response
The response from the next middleware handler in the chain
"""
raise NotImplementedError()
@property
def __name__(self) -> str: # noqa A003
return str(self.__class__.__name__)
def __call__(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> Response:
"""
The Middleware handler function.
Parameters
----------
app: ApiGatewayResolver
An instance of an Event Handler that implements ApiGatewayResolver
next_middleware: NextMiddleware
The next middleware handler in the chain
Returns
-------
Response
The response from the next middleware handler in the chain
"""
return self.handler(app, next_middleware)