Skip to content

Commit 84c3590

Browse files
Refactoring to support aggregate events
1 parent d039211 commit 84c3590

File tree

12 files changed

+619
-333
lines changed

12 files changed

+619
-333
lines changed

aws_lambda_powertools/event_handler/appsync.py

+57-5
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import warnings
44
from typing import Any, Callable, Dict, List, Optional, Type, Union
55

6-
from aws_lambda_powertools.event_handler.graphql_appsync.exceptions import ResolverNotFoundError
6+
from aws_lambda_powertools.event_handler.graphql_appsync.exceptions import InvalidBatchResponse, ResolverNotFoundError
77
from aws_lambda_powertools.event_handler.graphql_appsync.router import Router
88
from aws_lambda_powertools.utilities.data_classes import AppSyncResolverEvent
99
from aws_lambda_powertools.utilities.typing import LambdaContext
@@ -168,7 +168,12 @@ def _call_single_resolver(self, event: dict, data_model: Type[AppSyncResolverEve
168168
raise ValueError(f"No resolver found for '{self.current_event.type_name}.{self.current_event.field_name}'")
169169
return resolver["func"](**self.current_event.arguments)
170170

171-
def _call_sync_batch_resolver(self, resolver: Callable, raise_on_error: bool = False) -> List[Any]:
171+
def _call_sync_batch_resolver(
172+
self,
173+
resolver: Callable,
174+
raise_on_error: bool = False,
175+
aggregate: bool = True,
176+
) -> List[Any]:
172177
"""
173178
Calls a synchronous batch resolver function for each event in the current batch.
174179
@@ -179,6 +184,10 @@ def _call_sync_batch_resolver(self, resolver: Callable, raise_on_error: bool = F
179184
raise_on_error: bool
180185
A flag indicating whether to raise an error when processing batches
181186
with failed items. Defaults to False, which means errors are handled without raising exceptions.
187+
aggregate: bool
188+
A flag indicating whether the batch items should be processed at once or individually.
189+
If True (default), the batch resolver will process all items in the batch as a single event.
190+
If False, the batch resolver will process each item in the batch individually.
182191
183192
Returns
184193
-------
@@ -188,6 +197,17 @@ def _call_sync_batch_resolver(self, resolver: Callable, raise_on_error: bool = F
188197

189198
logger.debug(f"Graceful error handling flag {raise_on_error=}")
190199

200+
# Checks whether the entire batch should be processed at once
201+
if aggregate:
202+
# Process the entire batch
203+
response = resolver(event=self.current_batch_event)
204+
205+
if not isinstance(response, List):
206+
raise InvalidBatchResponse("The response must be a List when using batch resolvers")
207+
208+
return response
209+
210+
# Non aggregated events, so we call this event list x times
191211
# Stop on first exception we encounter
192212
if raise_on_error:
193213
return [
@@ -206,7 +226,12 @@ def _call_sync_batch_resolver(self, resolver: Callable, raise_on_error: bool = F
206226

207227
return results
208228

209-
async def _call_async_batch_resolver(self, resolver: Callable, raise_on_error: bool = False) -> List[Any]:
229+
async def _call_async_batch_resolver(
230+
self,
231+
resolver: Callable,
232+
raise_on_error: bool = False,
233+
aggregate: bool = True,
234+
) -> List[Any]:
210235
"""
211236
Asynchronously call a batch resolver for each event in the current batch.
212237
@@ -217,6 +242,10 @@ async def _call_async_batch_resolver(self, resolver: Callable, raise_on_error: b
217242
raise_on_error: bool
218243
A flag indicating whether to raise an error when processing batches
219244
with failed items. Defaults to False, which means errors are handled without raising exceptions.
245+
aggregate: bool
246+
A flag indicating whether the batch items should be processed at once or individually.
247+
If True (default), the batch resolver will process all items in the batch as a single event.
248+
If False, the batch resolver will process each item in the batch individually.
220249
221250
Returns
222251
-------
@@ -225,7 +254,17 @@ async def _call_async_batch_resolver(self, resolver: Callable, raise_on_error: b
225254
"""
226255

227256
logger.debug(f"Graceful error handling flag {raise_on_error=}")
228-
response = []
257+
258+
response: List = []
259+
260+
# Checks whether the entire batch should be processed at once
261+
if aggregate:
262+
# Process the entire batch
263+
response.extend(await asyncio.gather(resolver(event=self.current_batch_event)))
264+
if not isinstance(response[0], List):
265+
raise InvalidBatchResponse("The response must be a List when using batch resolvers")
266+
267+
return response[0]
229268

230269
# Prime coroutines
231270
tasks = [resolver(event=e, **e.arguments) for e in self.current_batch_event]
@@ -286,14 +325,19 @@ def _call_batch_resolver(self, event: List[dict], data_model: Type[AppSyncResolv
286325

287326
if resolver:
288327
logger.debug(f"Found sync resolver. {resolver=}, {field_name=}")
289-
return self._call_sync_batch_resolver(resolver=resolver["func"], raise_on_error=resolver["raise_on_error"])
328+
return self._call_sync_batch_resolver(
329+
resolver=resolver["func"],
330+
raise_on_error=resolver["raise_on_error"],
331+
aggregate=resolver["aggregate"],
332+
)
290333

291334
if async_resolver:
292335
logger.debug(f"Found async resolver. {resolver=}, {field_name=}")
293336
return asyncio.run(
294337
self._call_async_batch_resolver(
295338
resolver=async_resolver["func"],
296339
raise_on_error=async_resolver["raise_on_error"],
340+
aggregate=async_resolver["aggregate"],
297341
),
298342
)
299343

@@ -371,6 +415,7 @@ def batch_resolver(
371415
type_name: str = "*",
372416
field_name: Optional[str] = None,
373417
raise_on_error: bool = False,
418+
aggregate: bool = True,
374419
) -> Callable:
375420
"""Registers batch resolver function for GraphQL type and field name.
376421
@@ -385,6 +430,10 @@ def batch_resolver(
385430
GraphQL field e.g., getTodo, createTodo, by default None
386431
raise_on_error : bool, optional
387432
Whether to fail entire batch upon error, or handle errors gracefully (None), by default False
433+
aggregate: bool
434+
A flag indicating whether the batch items should be processed at once or individually.
435+
If True (default), the batch resolver will process all items in the batch as a single event.
436+
If False, the batch resolver will process each item in the batch individually.
388437
389438
Returns
390439
-------
@@ -395,16 +444,19 @@ def batch_resolver(
395444
field_name=field_name,
396445
type_name=type_name,
397446
raise_on_error=raise_on_error,
447+
aggregate=aggregate,
398448
)
399449

400450
def async_batch_resolver(
401451
self,
402452
type_name: str = "*",
403453
field_name: Optional[str] = None,
404454
raise_on_error: bool = False,
455+
aggregate: bool = True,
405456
) -> Callable:
406457
return self._async_batch_resolver_registry.register(
407458
field_name=field_name,
408459
type_name=type_name,
409460
raise_on_error=raise_on_error,
461+
aggregate=aggregate,
410462
)

aws_lambda_powertools/event_handler/graphql_appsync/_registry.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ def register(
1313
type_name: str = "*",
1414
field_name: Optional[str] = None,
1515
raise_on_error: bool = False,
16+
aggregate: bool = True,
1617
) -> Callable:
1718
"""Registers the resolver for field_name
1819
@@ -25,6 +26,10 @@ def register(
2526
raise_on_error: bool
2627
A flag indicating whether to raise an error when processing batches
2728
with failed items. Defaults to False, which means errors are handled without raising exceptions.
29+
aggregate: bool
30+
A flag indicating whether the batch items should be processed at once or individually.
31+
If True (default), the batch resolver will process all items in the batch as a single event.
32+
If False, the batch resolver will process each item in the batch individually.
2833
2934
Return
3035
----------
@@ -34,7 +39,11 @@ def register(
3439

3540
def _register(func) -> Callable:
3641
logger.debug(f"Adding resolver `{func.__name__}` for field `{type_name}.{field_name}`")
37-
self.resolvers[f"{type_name}.{field_name}"] = {"func": func, "raise_on_error": raise_on_error}
42+
self.resolvers[f"{type_name}.{field_name}"] = {
43+
"func": func,
44+
"raise_on_error": raise_on_error,
45+
"aggregate": aggregate,
46+
}
3847
return func
3948

4049
return _register

aws_lambda_powertools/event_handler/graphql_appsync/base.py

+10
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def batch_resolver(
4949
type_name: str = "*",
5050
field_name: Optional[str] = None,
5151
raise_on_error: bool = False,
52+
aggregate: bool = True,
5253
) -> Callable:
5354
"""
5455
Retrieve a batch resolver function for a specific type and field.
@@ -62,6 +63,10 @@ def batch_resolver(
6263
raise_on_error: bool
6364
A flag indicating whether to raise an error when processing batches
6465
with failed items. Defaults to False, which means errors are handled without raising exceptions.
66+
aggregate: bool
67+
A flag indicating whether the batch items should be processed at once or individually.
68+
If True (default), the batch resolver will process all items in the batch as a single event.
69+
If False, the batch resolver will process each item in the batch individually.
6570
6671
Examples
6772
--------
@@ -95,6 +100,7 @@ def async_batch_resolver(
95100
type_name: str = "*",
96101
field_name: Optional[str] = None,
97102
raise_on_error: bool = False,
103+
aggregate: bool = True,
98104
) -> Callable:
99105
"""
100106
Retrieve a batch resolver function for a specific type and field and runs async.
@@ -108,6 +114,10 @@ def async_batch_resolver(
108114
raise_on_error: bool
109115
A flag indicating whether to raise an error when processing batches
110116
with failed items. Defaults to False, which means errors are handled without raising exceptions.
117+
aggregate: bool
118+
A flag indicating whether the batch items should be processed at once or individually.
119+
If True (default), the batch resolver will process all items in the batch as a single event.
120+
If False, the batch resolver will process each item in the batch individually.
111121
112122
Examples
113123
--------

aws_lambda_powertools/event_handler/graphql_appsync/exceptions.py

+6
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,9 @@ class ResolverNotFoundError(Exception):
22
"""
33
When a resolver is not found during a lookup.
44
"""
5+
6+
7+
class InvalidBatchResponse(Exception):
8+
"""
9+
When a batch response something different from a List
10+
"""

aws_lambda_powertools/event_handler/graphql_appsync/router.py

+4
Original file line numberDiff line numberDiff line change
@@ -21,23 +21,27 @@ def batch_resolver(
2121
type_name: str = "*",
2222
field_name: Optional[str] = None,
2323
raise_on_error: bool = False,
24+
aggregate: bool = True,
2425
) -> Callable:
2526
return self._batch_resolver_registry.register(
2627
field_name=field_name,
2728
type_name=type_name,
2829
raise_on_error=raise_on_error,
30+
aggregate=aggregate,
2931
)
3032

3133
def async_batch_resolver(
3234
self,
3335
type_name: str = "*",
3436
field_name: Optional[str] = None,
3537
raise_on_error: bool = False,
38+
aggregate: bool = True,
3639
) -> Callable:
3740
return self._async_batch_resolver_registry.register(
3841
field_name=field_name,
3942
type_name=type_name,
4043
raise_on_error=raise_on_error,
44+
aggregate=aggregate,
4145
)
4246

4347
def append_context(self, **additional_context):

tests/e2e/event_handler_appsync/files/schema.graphql

+3-1
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,6 @@ type Post {
1717
downs: Int
1818
relatedPosts: [Post]
1919
relatedPostsAsync: [Post]
20-
}
20+
relatedPostsAggregate: [Post]
21+
relatedPostsAsyncAggregate: [Post]
22+
}

tests/e2e/event_handler_appsync/handlers/appsync_resolver_handler.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ class Post(BaseModel):
7676
downs: str
7777

7878

79+
# PROCESSING SINGLE RESOLVERS
7980
@app.resolver(type_name="Query", field_name="getPost")
8081
def get_post(post_id: str = "") -> dict:
8182
post = Post(**posts[post_id]).dict()
@@ -87,15 +88,27 @@ def all_posts() -> List[dict]:
8788
return list(posts.values())
8889

8990

90-
@app.batch_resolver(type_name="Post", field_name="relatedPosts")
91+
# PROCESSING BATCH WITHOUT AGGREGATION
92+
@app.batch_resolver(type_name="Post", field_name="relatedPosts", aggregate=False)
9193
def related_posts(event: AppSyncResolverEvent) -> Optional[list]:
9294
return posts_related[event.source["post_id"]] if event.source else None
9395

9496

95-
@app.async_batch_resolver(type_name="Post", field_name="relatedPostsAsync")
97+
@app.async_batch_resolver(type_name="Post", field_name="relatedPostsAsync", aggregate=False)
9698
async def related_posts_async(event: AppSyncResolverEvent) -> Optional[list]:
9799
return posts_related[event.source["post_id"]] if event.source else None
98100

99101

102+
# PROCESSING BATCH WITH AGGREGATION
103+
@app.batch_resolver(type_name="Post", field_name="relatedPostsAggregate")
104+
def related_posts_aggregate(event: List[AppSyncResolverEvent]) -> Optional[list]:
105+
return [posts_related[record.source.get("post_id")] for record in event]
106+
107+
108+
@app.async_batch_resolver(type_name="Post", field_name="relatedPostsAsyncAggregate")
109+
async def related_posts_async_aggregate(event: List[AppSyncResolverEvent]) -> Optional[list]:
110+
return [posts_related[record.source.get("post_id")] for record in event]
111+
112+
100113
def lambda_handler(event, context: LambdaContext) -> dict:
101114
return app.resolve(event, context)

tests/e2e/event_handler_appsync/infrastructure.py

+14
Original file line numberDiff line numberDiff line change
@@ -58,5 +58,19 @@ def _create_appsync_endpoint(self, function: Function):
5858
max_batch_size=10,
5959
)
6060

61+
lambda_datasource.create_resolver(
62+
"QueryGetPostRelatedResolverAggregate",
63+
type_name="Post",
64+
field_name="relatedPostsAggregate",
65+
max_batch_size=10,
66+
)
67+
68+
lambda_datasource.create_resolver(
69+
"QueryGetPostRelatedAsyncResolverAggregate",
70+
type_name="Post",
71+
field_name="relatedPostsAsyncAggregate",
72+
max_batch_size=10,
73+
)
74+
6175
CfnOutput(self.stack, "GraphQLHTTPUrl", value=api.graphql_url)
6276
CfnOutput(self.stack, "GraphQLAPIKey", value=api.api_key)

tests/e2e/event_handler_appsync/test_appsync_resolvers.py

+39-1
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def test_appsync_get_post(appsync_endpoint, appsync_access_key):
7575

7676

7777
@pytest.mark.xdist_group(name="event_handler")
78-
def test_appsync_get_related_posts_batch(appsync_endpoint, appsync_access_key):
78+
def test_appsync_get_related_posts_batch_without_aggregate(appsync_endpoint, appsync_access_key):
7979
# GIVEN
8080
post_id = "2"
8181
related_posts_ids = ["3", "5"]
@@ -110,3 +110,41 @@ def test_appsync_get_related_posts_batch(appsync_endpoint, appsync_access_key):
110110
assert post["post_id"] in related_posts_ids
111111
for post in data["getPost"]["relatedPostsAsync"]:
112112
assert post["post_id"] in related_posts_ids
113+
114+
115+
@pytest.mark.xdist_group(name="event_handler")
116+
def test_appsync_get_related_posts_batch_with_aggregate(appsync_endpoint, appsync_access_key):
117+
# GIVEN
118+
post_id = "2"
119+
related_posts_ids = ["3", "5"]
120+
121+
body = {
122+
"query": f'query MyQuery {{ getPost(post_id: "{post_id}") \
123+
{{ post_id relatedPostsAggregate {{ post_id }} relatedPostsAsyncAggregate {{ post_id }} }} }}',
124+
"variables": None,
125+
"operationName": "MyQuery",
126+
}
127+
128+
# WHEN
129+
response = data_fetcher.get_http_response(
130+
Request(
131+
method="POST",
132+
url=appsync_endpoint,
133+
json=body,
134+
headers={"x-api-key": appsync_access_key, "Content-Type": "application/json"},
135+
),
136+
)
137+
138+
# THEN expect a HTTP 200 response and content return Post id with dependent Posts id's
139+
assert response.status_code == 200
140+
assert response.content is not None
141+
142+
data = json.loads(response.content.decode("ascii"))["data"]
143+
144+
assert data["getPost"]["post_id"] == post_id
145+
assert len(data["getPost"]["relatedPostsAggregate"]) == len(related_posts_ids)
146+
assert len(data["getPost"]["relatedPostsAsyncAggregate"]) == len(related_posts_ids)
147+
for post in data["getPost"]["relatedPostsAggregate"]:
148+
assert post["post_id"] in related_posts_ids
149+
for post in data["getPost"]["relatedPostsAsyncAggregate"]:
150+
assert post["post_id"] in related_posts_ids

tests/functional/event_handler/required_dependencies/appsync/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)