Skip to content

Commit 8b01fc5

Browse files
author
Michael Brewer
authored
feat(appsync): add Router to allow large resolver composition (#776)
1 parent bb8e3b6 commit 8b01fc5

File tree

2 files changed

+75
-27
lines changed

2 files changed

+75
-27
lines changed

aws_lambda_powertools/event_handler/appsync.py

+48-27
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
from abc import ABC
23
from typing import Any, Callable, Optional, Type, TypeVar
34

45
from aws_lambda_powertools.utilities.data_classes import AppSyncResolverEvent
@@ -9,7 +10,33 @@
910
AppSyncResolverEventT = TypeVar("AppSyncResolverEventT", bound=AppSyncResolverEvent)
1011

1112

12-
class AppSyncResolver:
13+
class BaseRouter(ABC):
14+
current_event: AppSyncResolverEventT # type: ignore[valid-type]
15+
lambda_context: LambdaContext
16+
17+
def __init__(self):
18+
self._resolvers: dict = {}
19+
20+
def resolver(self, type_name: str = "*", field_name: Optional[str] = None):
21+
"""Registers the resolver for field_name
22+
23+
Parameters
24+
----------
25+
type_name : str
26+
Type name
27+
field_name : str
28+
Field name
29+
"""
30+
31+
def register_resolver(func):
32+
logger.debug(f"Adding resolver `{func.__name__}` for field `{type_name}.{field_name}`")
33+
self._resolvers[f"{type_name}.{field_name}"] = {"func": func}
34+
return func
35+
36+
return register_resolver
37+
38+
39+
class AppSyncResolver(BaseRouter):
1340
"""
1441
AppSync resolver decorator
1542
@@ -40,29 +67,8 @@ def common_field() -> str:
4067
return str(uuid.uuid4())
4168
"""
4269

43-
current_event: AppSyncResolverEventT # type: ignore[valid-type]
44-
lambda_context: LambdaContext
45-
4670
def __init__(self):
47-
self._resolvers: dict = {}
48-
49-
def resolver(self, type_name: str = "*", field_name: Optional[str] = None):
50-
"""Registers the resolver for field_name
51-
52-
Parameters
53-
----------
54-
type_name : str
55-
Type name
56-
field_name : str
57-
Field name
58-
"""
59-
60-
def register_resolver(func):
61-
logger.debug(f"Adding resolver `{func.__name__}` for field `{type_name}.{field_name}`")
62-
self._resolvers[f"{type_name}.{field_name}"] = {"func": func}
63-
return func
64-
65-
return register_resolver
71+
super().__init__()
6672

6773
def resolve(
6874
self, event: dict, context: LambdaContext, data_model: Type[AppSyncResolverEvent] = AppSyncResolverEvent
@@ -136,10 +142,10 @@ def lambda_handler(event, context):
136142
ValueError
137143
If we could not find a field resolver
138144
"""
139-
self.current_event = data_model(event)
140-
self.lambda_context = context
141-
resolver = self._get_resolver(self.current_event.type_name, self.current_event.field_name)
142-
return resolver(**self.current_event.arguments)
145+
BaseRouter.current_event = data_model(event)
146+
BaseRouter.lambda_context = context
147+
resolver = self._get_resolver(BaseRouter.current_event.type_name, BaseRouter.current_event.field_name)
148+
return resolver(**BaseRouter.current_event.arguments)
143149

144150
def _get_resolver(self, type_name: str, field_name: str) -> Callable:
145151
"""Get resolver for field_name
@@ -167,3 +173,18 @@ def __call__(
167173
) -> Any:
168174
"""Implicit lambda handler which internally calls `resolve`"""
169175
return self.resolve(event, context, data_model)
176+
177+
def include_router(self, router: "Router") -> None:
178+
"""Adds all resolvers defined in a router
179+
180+
Parameters
181+
----------
182+
router : Router
183+
A router containing a dict of field resolvers
184+
"""
185+
self._resolvers.update(router._resolvers)
186+
187+
188+
class Router(BaseRouter):
189+
def __init__(self):
190+
super().__init__()

tests/functional/event_handler/test_appsync.py

+27
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pytest
55

66
from aws_lambda_powertools.event_handler import AppSyncResolver
7+
from aws_lambda_powertools.event_handler.appsync import Router
78
from aws_lambda_powertools.utilities.data_classes import AppSyncResolverEvent
89
from aws_lambda_powertools.utilities.typing import LambdaContext
910
from tests.functional.utils import load_event
@@ -161,3 +162,29 @@ def create_something(id: str): # noqa AA03 VNE003
161162
assert result == "my identifier"
162163

163164
assert app.current_event.country_viewer == "US"
165+
166+
167+
def test_resolver_include_resolver():
168+
# GIVEN
169+
app = AppSyncResolver()
170+
router = Router()
171+
172+
@router.resolver(type_name="Query", field_name="listLocations")
173+
def get_locations(name: str):
174+
return "get_locations#" + name
175+
176+
@app.resolver(field_name="listLocations2")
177+
def get_locations2(name: str):
178+
return "get_locations2#" + name
179+
180+
app.include_router(router)
181+
182+
# WHEN
183+
mock_event1 = {"typeName": "Query", "fieldName": "listLocations", "arguments": {"name": "value"}}
184+
mock_event2 = {"typeName": "Query", "fieldName": "listLocations2", "arguments": {"name": "value"}}
185+
result1 = app.resolve(mock_event1, LambdaContext())
186+
result2 = app.resolve(mock_event2, LambdaContext())
187+
188+
# THEN
189+
assert result1 == "get_locations#value"
190+
assert result2 == "get_locations2#value"

0 commit comments

Comments
 (0)