Skip to content

Commit 524d054

Browse files
author
Michal Ploski
committed
Refactor appsync resolver
1 parent d0fe867 commit 524d054

File tree

4 files changed

+110
-58
lines changed

4 files changed

+110
-58
lines changed

Diff for: aws_lambda_powertools/event_handler/appsync.py

+95-43
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,32 @@
11
import logging
22
from itertools import groupby
3-
from typing import Any, Callable, List, Optional, Type, TypeVar, Union
3+
from typing import Any, Callable, List, Optional, Type, Union
44

55
from aws_lambda_powertools.utilities.data_classes import AppSyncResolverEvent
66
from aws_lambda_powertools.utilities.typing import LambdaContext
77

88
logger = logging.getLogger(__name__)
99

10-
AppSyncResolverEventT = TypeVar("AppSyncResolverEventT", bound=AppSyncResolverEvent)
1110

11+
class RouterContext:
12+
def __init__(self):
13+
super().__init__()
14+
self.context = {}
1215

13-
class BaseRouter:
14-
current_event: Union[AppSyncResolverEventT, List[AppSyncResolverEventT]] # type: ignore[valid-type]
15-
lambda_context: LambdaContext
16-
context: dict
16+
def append_context(self, **additional_context):
17+
"""Append key=value data as routing context"""
18+
self.context.update(**additional_context)
1719

20+
def clear_context(self):
21+
"""Resets routing context"""
22+
self.context.clear()
23+
24+
25+
class ResolverRegistry:
1826
def __init__(self):
27+
super().__init__()
1928
self._resolvers: dict = {}
29+
self._batch_resolvers: dict = {}
2030

2131
def resolver(self, type_name: str = "*", field_name: Optional[str] = None):
2232
"""Registers the resolver for field_name
@@ -29,23 +39,33 @@ def resolver(self, type_name: str = "*", field_name: Optional[str] = None):
2939
Field name
3040
"""
3141

32-
def register_resolver(func):
42+
def register(func):
3343
logger.debug(f"Adding resolver `{func.__name__}` for field `{type_name}.{field_name}`")
3444
self._resolvers[f"{type_name}.{field_name}"] = {"func": func}
3545
return func
3646

37-
return register_resolver
47+
return register
3848

39-
def append_context(self, **additional_context):
40-
"""Append key=value data as routing context"""
41-
self.context.update(**additional_context)
49+
def batch_resolver(self, type_name: str = "*", field_name: Optional[str] = None):
50+
"""Registers the resolver for field_name
4251
43-
def clear_context(self):
44-
"""Resets routing context"""
45-
self.context.clear()
52+
Parameters
53+
----------
54+
type_name : str
55+
Type name
56+
field_name : str
57+
Field name
58+
"""
4659

60+
def register(func):
61+
logger.debug(f"Adding batch resolver `{func.__name__}` for field `{type_name}.{field_name}`")
62+
self._batch_resolvers[f"{type_name}.{field_name}"] = {"func": func}
63+
return func
4764

48-
class AppSyncResolver(BaseRouter):
65+
return register
66+
67+
68+
class AppSyncResolver(ResolverRegistry, RouterContext):
4969
"""
5070
AppSync resolver decorator
5171
@@ -78,16 +98,20 @@ def common_field() -> str:
7898

7999
def __init__(self):
80100
super().__init__()
81-
self.context = {} # early init as customers might add context before event resolution
101+
self.current_batch_event: List[AppSyncResolverEvent] = []
102+
self.current_event: Optional[AppSyncResolverEvent] = None
82103

83104
def resolve(
84-
self, event: dict, context: LambdaContext, data_model: Type[AppSyncResolverEvent] = AppSyncResolverEvent
105+
self,
106+
event: Union[dict, List[dict]],
107+
context: LambdaContext,
108+
data_model: Type[AppSyncResolverEvent] = AppSyncResolverEvent,
85109
) -> Any:
86110
"""Resolve field_name
87111
88112
Parameters
89113
----------
90-
event : dict
114+
event : dict | List[dict]
91115
Lambda event
92116
context : LambdaContext
93117
Lambda context
@@ -152,33 +176,38 @@ def lambda_handler(event, context):
152176
ValueError
153177
If we could not find a field resolver
154178
"""
155-
# Maintenance: revisit generics/overload to fix [attr-defined] in mypy usage
156-
157-
BaseRouter.lambda_context = context
158-
159-
# If event is a list it means that AppSync sent batch request
160-
if isinstance(event, list):
161-
event_groups = [
162-
{"field_name": field_name, "events": list(events)}
163-
for field_name, events in groupby(event, key=lambda x: x["info"]["fieldName"])
164-
]
165-
if len(event_groups) > 1:
166-
ValueError("batch with different field names. It shouldn't happen!")
167-
168-
appconfig_events = [data_model(event) for event in event_groups[0]["events"]]
169-
BaseRouter.current_event = appconfig_events
170-
resolver = self._get_resolver(appconfig_events[0].type_name, event_groups[0]["field_name"])
171-
response = resolver()
172-
else:
173-
appconfig_event = data_model(event)
174-
BaseRouter.current_event = appconfig_event
175-
resolver = self._get_resolver(appconfig_event.type_name, appconfig_event.field_name)
176-
response = resolver(**appconfig_event.arguments)
177179

180+
self.lambda_context = context
181+
182+
response = (
183+
self._call_batch_resolver(event, data_model)
184+
if isinstance(event, list)
185+
else self._call_resolver(event, data_model)
186+
)
178187
self.clear_context()
179188

180189
return response
181190

191+
def _call_resolver(self, event: dict, data_model: Type[AppSyncResolverEvent]) -> Any:
192+
self.current_event = data_model(event)
193+
resolver = self._get_resolver(self.current_event.type_name, self.current_event.field_name)
194+
return resolver(**self.current_event.arguments)
195+
196+
def _call_batch_resolver(self, event: List[dict], data_model: Type[AppSyncResolverEvent]) -> list[Any]:
197+
event_groups = [
198+
{"field_name": field_name, "events": list(events)}
199+
for field_name, events in groupby(event, key=lambda x: x["info"]["fieldName"])
200+
]
201+
if len(event_groups) > 1:
202+
ValueError("batch with different field names. It shouldn't happen!")
203+
204+
self.current_batch_event = [data_model(event) for event in event_groups[0]["events"]]
205+
resolver = self._get_batch_resolver(
206+
self.current_batch_event[0].type_name, self.current_batch_event[0].field_name
207+
)
208+
209+
return [resolver(event=appconfig_event) for appconfig_event in self.current_batch_event]
210+
182211
def _get_resolver(self, type_name: str, field_name: str) -> Callable:
183212
"""Get resolver for field_name
184213
@@ -200,8 +229,32 @@ def _get_resolver(self, type_name: str, field_name: str) -> Callable:
200229
raise ValueError(f"No resolver found for '{full_name}'")
201230
return resolver["func"]
202231

232+
def _get_batch_resolver(self, type_name: str, field_name: str) -> Callable:
233+
"""Get resolver for field_name
234+
235+
Parameters
236+
----------
237+
type_name : str
238+
Type name
239+
field_name : str
240+
Field name
241+
242+
Returns
243+
-------
244+
Callable
245+
callable function and configuration
246+
"""
247+
full_name = f"{type_name}.{field_name}"
248+
resolver = self._batch_resolvers.get(full_name, self._batch_resolvers.get(f"*.{field_name}"))
249+
if not resolver:
250+
raise ValueError(f"No batch resolver found for '{full_name}'")
251+
return resolver["func"]
252+
203253
def __call__(
204-
self, event: dict, context: LambdaContext, data_model: Type[AppSyncResolverEvent] = AppSyncResolverEvent
254+
self,
255+
event: Union[dict, List[dict]],
256+
context: LambdaContext,
257+
data_model: Type[AppSyncResolverEvent] = AppSyncResolverEvent,
205258
) -> Any:
206259
"""Implicit lambda handler which internally calls `resolve`"""
207260
return self.resolve(event, context, data_model)
@@ -222,7 +275,6 @@ def include_router(self, router: "Router") -> None:
222275
self._resolvers.update(router._resolvers)
223276

224277

225-
class Router(BaseRouter):
278+
class Router(RouterContext, ResolverRegistry):
226279
def __init__(self):
227280
super().__init__()
228-
self.context = {} # early init as customers might add context before event resolution

Diff for: examples/event_handler_graphql/src/custom_models.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ def api_key(self) -> str:
4242
@app.resolver(type_name="Query", field_name="listLocations")
4343
def list_locations(page: int = 0, size: int = 10) -> List[Location]:
4444
# additional properties/methods will now be available under current_event
45-
logger.debug(f"Request country origin: {app.current_event.country_viewer}") # type: ignore[attr-defined]
45+
if app.current_event:
46+
logger.debug(f"Request country origin: {app.current_event.country_viewer}") # type: ignore[attr-defined]
4647
return [{"id": scalar_types_utils.make_id(), "name": "Perry, James and Carroll"}]
4748

4849

Diff for: tests/e2e/event_handler/handlers/appsync_resolver_handler.py

+5-8
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
from typing import List
1+
from typing import List, Optional
22

33
from pydantic import BaseModel
44

55
from aws_lambda_powertools.event_handler import AppSyncResolver
6+
from aws_lambda_powertools.utilities.data_classes import AppSyncResolverEvent
67
from aws_lambda_powertools.utilities.typing import LambdaContext
78

89
app = AppSyncResolver()
@@ -86,13 +87,9 @@ def all_posts() -> List[dict]:
8687
return list(posts.values())
8788

8889

89-
@app.resolver(type_name="Post", field_name="relatedPosts")
90-
def related_posts() -> List[dict]:
91-
posts = []
92-
for resolver_event in app.current_event:
93-
if resolver_event.source:
94-
posts.append(posts_related[resolver_event.source["post_id"]])
95-
return posts
90+
@app.batch_resolver(type_name="Post", field_name="relatedPosts")
91+
def related_posts(event: AppSyncResolverEvent) -> Optional[list]:
92+
return posts_related[event.source["post_id"]] if event.source else None
9693

9794

9895
def lambda_handler(event, context: LambdaContext) -> dict:

Diff for: tests/functional/event_handler/test_appsync.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import sys
3+
from typing import Optional
34

45
import pytest
56

@@ -199,22 +200,23 @@ def test_resolve_batch_processing():
199200
"fieldName": "listLocations",
200201
"arguments": {},
201202
"source": {
202-
"id": "3",
203+
"id": [3, 4],
203204
},
204205
},
205206
]
206207

207208
app = AppSyncResolver()
208209

209-
@app.resolver(field_name="listLocations")
210-
def create_something(): # noqa AA03 VNE003
211-
return [event.source["id"] for event in app.current_event]
210+
@app.batch_resolver(field_name="listLocations")
211+
def create_something(event: AppSyncResolverEvent) -> Optional[list]: # noqa AA03 VNE003
212+
return event.source["id"] if event.source else None
212213

213214
# Call the implicit handler
214215
result = app.resolve(event, LambdaContext())
215-
assert result == ["1", "2", "3"]
216+
assert result == [appsync_event["source"]["id"] for appsync_event in event]
216217

217-
assert len(app.current_event) == len(event)
218+
assert app.current_batch_event and len(app.current_batch_event) == len(event)
219+
assert not app.current_event
218220

219221

220222
def test_resolver_include_resolver():

0 commit comments

Comments
 (0)