1
1
import functools
2
2
import inspect
3
3
import logging
4
+ import sys
4
5
import os
5
- from typing import Callable , Optional
6
+ from typing import Any , Callable , Dict , Optional , Union , cast , overload
7
+
8
+ if sys .version_info >= (3 , 8 ):
9
+ from typing import Protocol
10
+ else :
11
+ from typing_extensions import Protocol
6
12
7
13
from ..shared import constants
8
14
from ..shared .functions import resolve_truthy_env_var_choice
9
15
from ..tracing import Tracer
16
+ from ..utilities .typing import LambdaContext
10
17
from .exceptions import MiddlewareInvalidArgumentError
11
18
12
19
logger = logging .getLogger (__name__ )
13
20
21
+ # context: Any to avoid forcing users to type it as context: LambdaContext
22
+ _Handler = Callable [[Any , LambdaContext ], Any ]
23
+ _RawHandlerDecorator = Callable [[_Handler ], _Handler ]
24
+
25
+
26
+ class _FactoryDecorator (Protocol ):
27
+ # it'd be better for this to be using ParamSpec (available from 3.10)
28
+ def __call__ (
29
+ self , handler : _Handler , event : Dict [str , Any ], context : LambdaContext , ** kwargs : Any
30
+ ) -> _RawHandlerDecorator :
31
+ ...
32
+
33
+
34
+ class _HandlerDecorator (Protocol ):
35
+ @overload
36
+ def __call__ (self , decorator : _Handler ) -> _Handler :
37
+ ...
38
+
39
+ @overload
40
+ def __call__ (self , decorator : None = None , ** kwargs : Any ) -> _RawHandlerDecorator :
41
+ ...
42
+
43
+ def __call__ (self , decorator : Optional [_Handler ] = None , ** kwargs : Any ) -> Union [_Handler , _RawHandlerDecorator ]:
44
+ ...
45
+
46
+
47
+ @overload
48
+ def lambda_handler_decorator (decorator : _FactoryDecorator ) -> _HandlerDecorator :
49
+ ...
50
+
51
+
52
+ @overload
53
+ def lambda_handler_decorator (
54
+ decorator : None = None , trace_execution : Optional [bool ] = None
55
+ ) -> Callable [[_FactoryDecorator ], _HandlerDecorator ]:
56
+ ...
57
+
14
58
15
- def lambda_handler_decorator (decorator : Optional [Callable ] = None , trace_execution : Optional [bool ] = None ):
59
+ def lambda_handler_decorator (
60
+ decorator : Optional [_FactoryDecorator ] = None , trace_execution : Optional [bool ] = None
61
+ ) -> Union [_HandlerDecorator , Callable [[_FactoryDecorator ], _HandlerDecorator ]]:
16
62
"""Decorator factory for decorating Lambda handlers.
17
63
18
64
You can use lambda_handler_decorator to create your own middlewares,
@@ -103,19 +149,25 @@ def lambda_handler(event, context):
103
149
"""
104
150
105
151
if decorator is None :
106
- return functools .partial (lambda_handler_decorator , trace_execution = trace_execution )
152
+ return cast (
153
+ Callable [[_FactoryDecorator ], _HandlerDecorator ],
154
+ functools .partial (lambda_handler_decorator , trace_execution = trace_execution ),
155
+ )
107
156
108
157
trace_execution = resolve_truthy_env_var_choice (
109
158
env = os .getenv (constants .MIDDLEWARE_FACTORY_TRACE_ENV , "false" ), choice = trace_execution
110
159
)
111
160
112
161
@functools .wraps (decorator )
113
- def final_decorator (func : Optional [Callable ] = None , ** kwargs ):
162
+ def final_decorator (
163
+ func : Optional [_RawHandlerDecorator ] = None , ** kwargs : Any
164
+ ) -> Union [_Handler , _RawHandlerDecorator ]:
114
165
# If called with kwargs return new func with kwargs
115
166
if func is None :
116
167
return functools .partial (final_decorator , ** kwargs )
117
168
118
169
if not inspect .isfunction (func ):
170
+ assert decorator is not None
119
171
# @custom_middleware(True) vs @custom_middleware(log_event=True)
120
172
raise MiddlewareInvalidArgumentError (
121
173
f"Only keyword arguments is supported for middlewares: { decorator .__qualname__ } received { func } " # type: ignore # noqa: E501
@@ -138,4 +190,4 @@ def wrapper(event, context):
138
190
139
191
return wrapper
140
192
141
- return final_decorator
193
+ return cast ( _HandlerDecorator , final_decorator )
0 commit comments