|
1 | 1 | import logging
|
2 |
| -from typing import Any, Callable, Type |
| 2 | +from typing import Any, Callable, Optional, Type, TypeVar |
3 | 3 |
|
4 | 4 | from aws_lambda_powertools.utilities.data_classes import AppSyncResolverEvent
|
5 | 5 | from aws_lambda_powertools.utilities.typing import LambdaContext
|
6 | 6 |
|
7 | 7 | logger = logging.getLogger(__name__)
|
8 | 8 |
|
| 9 | +AppSyncResolverEventT = TypeVar("AppSyncResolverEventT", bound=AppSyncResolverEvent) |
| 10 | + |
9 | 11 |
|
10 | 12 | class AppSyncResolver:
|
11 | 13 | """
|
@@ -38,13 +40,13 @@ def common_field() -> str:
|
38 | 40 | return str(uuid.uuid4())
|
39 | 41 | """
|
40 | 42 |
|
41 |
| - current_event: AppSyncResolverEvent |
| 43 | + current_event: AppSyncResolverEventT # type: ignore[valid-type] |
42 | 44 | lambda_context: LambdaContext
|
43 | 45 |
|
44 | 46 | def __init__(self):
|
45 | 47 | self._resolvers: dict = {}
|
46 | 48 |
|
47 |
| - def resolver(self, type_name: str = "*", field_name: str = None): |
| 49 | + def resolver(self, type_name: str = "*", field_name: Optional[str] = None): |
48 | 50 | """Registers the resolver for field_name
|
49 | 51 |
|
50 | 52 | Parameters
|
@@ -112,6 +114,8 @@ def _get_resolver(self, type_name: str, field_name: str) -> Callable:
|
112 | 114 | raise ValueError(f"No resolver found for '{full_name}'")
|
113 | 115 | return resolver["func"]
|
114 | 116 |
|
115 |
| - def __call__(self, event, context) -> Any: |
| 117 | + def __call__( |
| 118 | + self, event: dict, context: LambdaContext, model: Type[AppSyncResolverEvent] = AppSyncResolverEvent |
| 119 | + ) -> Any: |
116 | 120 | """Implicit lambda handler which internally calls `resolve`"""
|
117 |
| - return self.resolve(event, context) |
| 121 | + return self.resolve(event, context, model) |
0 commit comments