|
1 | 1 | import logging
|
| 2 | +from abc import ABC |
2 | 3 | from typing import Any, Callable, Optional, Type, TypeVar
|
3 | 4 |
|
4 | 5 | from aws_lambda_powertools.utilities.data_classes import AppSyncResolverEvent
|
|
9 | 10 | AppSyncResolverEventT = TypeVar("AppSyncResolverEventT", bound=AppSyncResolverEvent)
|
10 | 11 |
|
11 | 12 |
|
12 |
| -class AppSyncResolver: |
| 13 | +class BaseRouter(ABC): |
| 14 | + current_event: AppSyncResolverEventT # type: ignore[valid-type] |
| 15 | + lambda_context: LambdaContext |
| 16 | + |
| 17 | + def __init__(self): |
| 18 | + self._resolvers: dict = {} |
| 19 | + |
| 20 | + def resolver(self, type_name: str = "*", field_name: Optional[str] = None): |
| 21 | + """Registers the resolver for field_name |
| 22 | +
|
| 23 | + Parameters |
| 24 | + ---------- |
| 25 | + type_name : str |
| 26 | + Type name |
| 27 | + field_name : str |
| 28 | + Field name |
| 29 | + """ |
| 30 | + |
| 31 | + def register_resolver(func): |
| 32 | + logger.debug(f"Adding resolver `{func.__name__}` for field `{type_name}.{field_name}`") |
| 33 | + self._resolvers[f"{type_name}.{field_name}"] = {"func": func} |
| 34 | + return func |
| 35 | + |
| 36 | + return register_resolver |
| 37 | + |
| 38 | + |
| 39 | +class AppSyncResolver(BaseRouter): |
13 | 40 | """
|
14 | 41 | AppSync resolver decorator
|
15 | 42 |
|
@@ -40,29 +67,8 @@ def common_field() -> str:
|
40 | 67 | return str(uuid.uuid4())
|
41 | 68 | """
|
42 | 69 |
|
43 |
| - current_event: AppSyncResolverEventT # type: ignore[valid-type] |
44 |
| - lambda_context: LambdaContext |
45 |
| - |
46 | 70 | def __init__(self):
|
47 |
| - self._resolvers: dict = {} |
48 |
| - |
49 |
| - def resolver(self, type_name: str = "*", field_name: Optional[str] = None): |
50 |
| - """Registers the resolver for field_name |
51 |
| -
|
52 |
| - Parameters |
53 |
| - ---------- |
54 |
| - type_name : str |
55 |
| - Type name |
56 |
| - field_name : str |
57 |
| - Field name |
58 |
| - """ |
59 |
| - |
60 |
| - def register_resolver(func): |
61 |
| - logger.debug(f"Adding resolver `{func.__name__}` for field `{type_name}.{field_name}`") |
62 |
| - self._resolvers[f"{type_name}.{field_name}"] = {"func": func} |
63 |
| - return func |
64 |
| - |
65 |
| - return register_resolver |
| 71 | + super().__init__() |
66 | 72 |
|
67 | 73 | def resolve(
|
68 | 74 | self, event: dict, context: LambdaContext, data_model: Type[AppSyncResolverEvent] = AppSyncResolverEvent
|
@@ -136,10 +142,10 @@ def lambda_handler(event, context):
|
136 | 142 | ValueError
|
137 | 143 | If we could not find a field resolver
|
138 | 144 | """
|
139 |
| - self.current_event = data_model(event) |
140 |
| - self.lambda_context = context |
141 |
| - resolver = self._get_resolver(self.current_event.type_name, self.current_event.field_name) |
142 |
| - return resolver(**self.current_event.arguments) |
| 145 | + BaseRouter.current_event = data_model(event) |
| 146 | + BaseRouter.lambda_context = context |
| 147 | + resolver = self._get_resolver(BaseRouter.current_event.type_name, BaseRouter.current_event.field_name) |
| 148 | + return resolver(**BaseRouter.current_event.arguments) |
143 | 149 |
|
144 | 150 | def _get_resolver(self, type_name: str, field_name: str) -> Callable:
|
145 | 151 | """Get resolver for field_name
|
@@ -167,3 +173,18 @@ def __call__(
|
167 | 173 | ) -> Any:
|
168 | 174 | """Implicit lambda handler which internally calls `resolve`"""
|
169 | 175 | return self.resolve(event, context, data_model)
|
| 176 | + |
| 177 | + def include_router(self, router: "Router") -> None: |
| 178 | + """Adds all resolvers defined in a router |
| 179 | +
|
| 180 | + Parameters |
| 181 | + ---------- |
| 182 | + router : Router |
| 183 | + A router containing a dict of field resolvers |
| 184 | + """ |
| 185 | + self._resolvers.update(router._resolvers) |
| 186 | + |
| 187 | + |
| 188 | +class Router(BaseRouter): |
| 189 | + def __init__(self): |
| 190 | + super().__init__() |
0 commit comments