Skip to content

Commit 937e6fe

Browse files
committed
Attempt to type lambda_handler_decorator accurately
1 parent 2a3ff9a commit 937e6fe

File tree

1 file changed

+57
-5
lines changed
  • aws_lambda_powertools/middleware_factory

1 file changed

+57
-5
lines changed

aws_lambda_powertools/middleware_factory/factory.py

+57-5
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,64 @@
11
import functools
22
import inspect
33
import logging
4+
import sys
45
import os
5-
from typing import Callable, Optional
6+
from typing import Any, Callable, Dict, Optional, Union, cast, overload
7+
8+
if sys.version_info >= (3, 8):
9+
from typing import Protocol
10+
else:
11+
from typing_extensions import Protocol
612

713
from ..shared import constants
814
from ..shared.functions import resolve_truthy_env_var_choice
915
from ..tracing import Tracer
16+
from ..utilities.typing import LambdaContext
1017
from .exceptions import MiddlewareInvalidArgumentError
1118

1219
logger = logging.getLogger(__name__)
1320

21+
# context: Any to avoid forcing users to type it as context: LambdaContext
22+
_Handler = Callable[[Any, LambdaContext], Any]
23+
_RawHandlerDecorator = Callable[[_Handler], _Handler]
24+
25+
26+
class _FactoryDecorator(Protocol):
27+
# it'd be better for this to be using ParamSpec (available from 3.10)
28+
def __call__(
29+
self, handler: _Handler, event: Dict[str, Any], context: LambdaContext, **kwargs: Any
30+
) -> _RawHandlerDecorator:
31+
...
32+
33+
34+
class _HandlerDecorator(Protocol):
35+
@overload
36+
def __call__(self, decorator: _Handler) -> _Handler:
37+
...
38+
39+
@overload
40+
def __call__(self, decorator: None = None, **kwargs: Any) -> _RawHandlerDecorator:
41+
...
42+
43+
def __call__(self, decorator: Optional[_Handler] = None, **kwargs: Any) -> Union[_Handler, _RawHandlerDecorator]:
44+
...
45+
46+
47+
@overload
48+
def lambda_handler_decorator(decorator: _FactoryDecorator) -> _HandlerDecorator:
49+
...
50+
51+
52+
@overload
53+
def lambda_handler_decorator(
54+
decorator: None = None, trace_execution: Optional[bool] = None
55+
) -> Callable[[_FactoryDecorator], _HandlerDecorator]:
56+
...
57+
1458

15-
def lambda_handler_decorator(decorator: Optional[Callable] = None, trace_execution: Optional[bool] = None):
59+
def lambda_handler_decorator(
60+
decorator: Optional[_FactoryDecorator] = None, trace_execution: Optional[bool] = None
61+
) -> Union[_HandlerDecorator, Callable[[_FactoryDecorator], _HandlerDecorator]]:
1662
"""Decorator factory for decorating Lambda handlers.
1763
1864
You can use lambda_handler_decorator to create your own middlewares,
@@ -103,19 +149,25 @@ def lambda_handler(event, context):
103149
"""
104150

105151
if decorator is None:
106-
return functools.partial(lambda_handler_decorator, trace_execution=trace_execution)
152+
return cast(
153+
Callable[[_FactoryDecorator], _HandlerDecorator],
154+
functools.partial(lambda_handler_decorator, trace_execution=trace_execution),
155+
)
107156

108157
trace_execution = resolve_truthy_env_var_choice(
109158
env=os.getenv(constants.MIDDLEWARE_FACTORY_TRACE_ENV, "false"), choice=trace_execution
110159
)
111160

112161
@functools.wraps(decorator)
113-
def final_decorator(func: Optional[Callable] = None, **kwargs):
162+
def final_decorator(
163+
func: Optional[_RawHandlerDecorator] = None, **kwargs: Any
164+
) -> Union[_Handler, _RawHandlerDecorator]:
114165
# If called with kwargs return new func with kwargs
115166
if func is None:
116167
return functools.partial(final_decorator, **kwargs)
117168

118169
if not inspect.isfunction(func):
170+
assert decorator is not None
119171
# @custom_middleware(True) vs @custom_middleware(log_event=True)
120172
raise MiddlewareInvalidArgumentError(
121173
f"Only keyword arguments is supported for middlewares: {decorator.__qualname__} received {func}" # type: ignore # noqa: E501
@@ -138,4 +190,4 @@ def wrapper(event, context):
138190

139191
return wrapper
140192

141-
return final_decorator
193+
return cast(_HandlerDecorator, final_decorator)

0 commit comments

Comments
 (0)