Skip to content

Commit 5668cba

Browse files
author
Michal Ploski
committed
Refactor code to use composition instead of inheritence
1 parent bc45703 commit 5668cba

File tree

2 files changed

+190
-91
lines changed

2 files changed

+190
-91
lines changed

Diff for: aws_lambda_powertools/event_handler/appsync.py

+81-85
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
from abc import ABC, abstractmethod
12
import logging
23
from itertools import groupby
3-
from typing import Any, Callable, List, Optional, Type, Union
4+
from typing import Any, Callable, Dict, List, Optional, Type, Union
45

56
from aws_lambda_powertools.utilities.data_classes import AppSyncResolverEvent
67
from aws_lambda_powertools.utilities.typing import LambdaContext
@@ -10,23 +11,44 @@
1011

1112
class RouterContext:
1213
def __init__(self):
13-
super().__init__()
14-
self.context = {}
14+
self._context = {}
1515

16-
def append_context(self, **additional_context):
16+
@property
17+
def context(self) -> Dict[str, Any]:
18+
return self._context
19+
20+
@context.setter
21+
def context(self, additional_context: Dict[str, Any]) -> None:
1722
"""Append key=value data as routing context"""
18-
self.context.update(**additional_context)
23+
self._context.update(**additional_context)
1924

20-
def clear_context(self):
25+
@context.deleter
26+
def context(self):
2127
"""Resets routing context"""
22-
self.context.clear()
28+
self._context.clear()
29+
30+
31+
class IResolverRegistry(ABC):
32+
@abstractmethod
33+
def resolver(self, type_name: str = "*", field_name: Optional[str] = None) -> Callable:
34+
...
2335

36+
@abstractmethod
37+
def find_resolver(self, type_name: str, field_name: str) -> Callable:
38+
...
2439

25-
class ResolverRegistry:
40+
41+
class ResolverRegistry(IResolverRegistry):
2642
def __init__(self):
27-
super().__init__()
28-
self._resolvers: dict = {}
29-
self._batch_resolvers: dict = {}
43+
self._resolvers: Dict[str, Dict[str, Any]] = {}
44+
45+
@property
46+
def resolvers(self) -> Dict[str, Dict[str, Any]]:
47+
return self._resolvers
48+
49+
@resolvers.setter
50+
def resolvers(self, resolvers: dict) -> None:
51+
self._resolvers.update(resolvers)
3052

3153
def resolver(self, type_name: str = "*", field_name: Optional[str] = None):
3254
"""Registers the resolver for field_name
@@ -46,26 +68,15 @@ def register(func):
4668

4769
return register
4870

49-
def batch_resolver(self, type_name: str = "*", field_name: Optional[str] = None):
50-
"""Registers the resolver for field_name
51-
52-
Parameters
53-
----------
54-
type_name : str
55-
Type name
56-
field_name : str
57-
Field name
58-
"""
59-
60-
def register(func):
61-
logger.debug(f"Adding batch resolver `{func.__name__}` for field `{type_name}.{field_name}`")
62-
self._batch_resolvers[f"{type_name}.{field_name}"] = {"func": func}
63-
return func
64-
65-
return register
71+
def find_resolver(self, type_name: str, field_name: str) -> Callable:
72+
full_name = f"{type_name}.{field_name}"
73+
resolver = self._resolvers.get(full_name, self._resolvers.get(f"*.{field_name}"))
74+
if not resolver:
75+
raise ValueError(f"No resolver found for '{full_name}'")
76+
return resolver["func"]
6677

6778

68-
class AppSyncResolver(ResolverRegistry, RouterContext):
79+
class AppSyncResolver:
6980
"""
7081
AppSync resolver decorator
7182
@@ -97,17 +108,20 @@ def common_field() -> str:
97108
"""
98109

99110
def __init__(self):
100-
super().__init__()
111+
self._resolver_registry: IResolverRegistry = ResolverRegistry()
112+
self._batch_resolver_registry: IResolverRegistry = ResolverRegistry()
113+
self._router_context: RouterContext = RouterContext()
101114
self.current_batch_event: List[AppSyncResolverEvent] = []
102115
self.current_event: Optional[AppSyncResolverEvent] = None
116+
self.lambda_context: Optional[LambdaContext] = None
103117

104118
def resolve(
105119
self,
106-
event: Union[dict, List[dict]],
120+
event: Union[Dict[str, Any], List[Dict[str, Any]]],
107121
context: LambdaContext,
108122
data_model: Type[AppSyncResolverEvent] = AppSyncResolverEvent,
109123
) -> Any:
110-
"""Resolve field_name
124+
"""Resolve field_name in single event or in a batch event
111125
112126
Parameters
113127
----------
@@ -180,17 +194,17 @@ def lambda_handler(event, context):
180194
self.lambda_context = context
181195

182196
response = (
183-
self._call_batch_resolver(event, data_model)
197+
self._call_batch_resolver(event=event, data_model=data_model)
184198
if isinstance(event, list)
185-
else self._call_resolver(event, data_model)
199+
else self._call_single_resolver(event=event, data_model=data_model)
186200
)
187-
self.clear_context()
201+
del self._router_context.context
188202

189203
return response
190204

191-
def _call_resolver(self, event: dict, data_model: Type[AppSyncResolverEvent]) -> Any:
205+
def _call_single_resolver(self, event: dict, data_model: Type[AppSyncResolverEvent]) -> Any:
192206
self.current_event = data_model(event)
193-
resolver = self._get_resolver(self.current_event.type_name, self.current_event.field_name)
207+
resolver = self._resolver_registry.find_resolver(self.current_event.type_name, self.current_event.field_name)
194208
return resolver(**self.current_event.arguments)
195209

196210
def _call_batch_resolver(self, event: List[dict], data_model: Type[AppSyncResolverEvent]) -> List[Any]:
@@ -202,54 +216,12 @@ def _call_batch_resolver(self, event: List[dict], data_model: Type[AppSyncResolv
202216
ValueError("batch with different field names. It shouldn't happen!")
203217

204218
self.current_batch_event = [data_model(event) for event in event_groups[0]["events"]]
205-
resolver = self._get_batch_resolver(
219+
resolver = self._batch_resolver_registry.find_resolver(
206220
self.current_batch_event[0].type_name, self.current_batch_event[0].field_name
207221
)
208222

209223
return [resolver(event=appconfig_event) for appconfig_event in self.current_batch_event]
210224

211-
def _get_resolver(self, type_name: str, field_name: str) -> Callable:
212-
"""Get resolver for field_name
213-
214-
Parameters
215-
----------
216-
type_name : str
217-
Type name
218-
field_name : str
219-
Field name
220-
221-
Returns
222-
-------
223-
Callable
224-
callable function and configuration
225-
"""
226-
full_name = f"{type_name}.{field_name}"
227-
resolver = self._resolvers.get(full_name, self._resolvers.get(f"*.{field_name}"))
228-
if not resolver:
229-
raise ValueError(f"No resolver found for '{full_name}'")
230-
return resolver["func"]
231-
232-
def _get_batch_resolver(self, type_name: str, field_name: str) -> Callable:
233-
"""Get resolver for field_name
234-
235-
Parameters
236-
----------
237-
type_name : str
238-
Type name
239-
field_name : str
240-
Field name
241-
242-
Returns
243-
-------
244-
Callable
245-
callable function and configuration
246-
"""
247-
full_name = f"{type_name}.{field_name}"
248-
resolver = self._batch_resolvers.get(full_name, self._batch_resolvers.get(f"*.{field_name}"))
249-
if not resolver:
250-
raise ValueError(f"No batch resolver found for '{full_name}'")
251-
return resolver["func"]
252-
253225
def __call__(
254226
self,
255227
event: Union[dict, List[dict]],
@@ -267,14 +239,38 @@ def include_router(self, router: "Router") -> None:
267239
router : Router
268240
A router containing a dict of field resolvers
269241
"""
242+
270243
# Merge app and router context
271-
self.context.update(**router.context)
244+
self._router_context.context = router._router_context.context
272245
# use pointer to allow context clearance after event is processed e.g., resolve(evt, ctx)
273-
router.context = self.context
246+
router._router_context._context = self._router_context.context
274247

275-
self._resolvers.update(router._resolvers)
248+
self._resolver_registry.resolvers = router._resolver_registry.resolvers
249+
self._batch_resolver_registry.resolvers = router._batch_resolver_registry.resolvers
276250

251+
# Interfaces
252+
def resolver(self, type_name: str = "*", field_name: Optional[str] = None):
253+
return self._resolver_registry.resolver(field_name=field_name, type_name=type_name)
277254

278-
class Router(RouterContext, ResolverRegistry):
255+
def batch_resolver(self, type_name: str = "*", field_name: Optional[str] = None):
256+
return self._batch_resolver_registry.resolver(field_name=field_name, type_name=type_name)
257+
258+
def append_context(self, **additional_context) -> None:
259+
self._router_context.context = additional_context
260+
261+
262+
class Router:
279263
def __init__(self):
280-
super().__init__()
264+
self._resolver_registry = ResolverRegistry()
265+
self._batch_resolver_registry = ResolverRegistry()
266+
self._router_context = RouterContext()
267+
268+
# Interfaces
269+
def resolver(self, type_name: str = "*", field_name: Optional[str] = None):
270+
return self._resolver_registry.resolver(field_name=field_name, type_name=type_name)
271+
272+
def batch_resolver(self, type_name: str = "*", field_name: Optional[str] = None):
273+
return self._batch_resolver_registry.resolver(field_name=field_name, type_name=type_name)
274+
275+
def append_context(self, **additional_context) -> None:
276+
self._router_context.context = additional_context

Diff for: tests/functional/event_handler/test_appsync.py

+109-6
Original file line numberDiff line numberDiff line change
@@ -245,16 +245,119 @@ def get_locations2(name: str):
245245
assert result2 == "get_locations2#value"
246246

247247

248+
def test_resolver_include_batch_resolver():
249+
# GIVEN
250+
app = AppSyncResolver()
251+
router = Router()
252+
253+
@router.batch_resolver(type_name="Query", field_name="listLocations")
254+
def get_locations(event: AppSyncResolverEvent) -> str:
255+
return "get_locations#" + event.arguments["name"]
256+
257+
@app.batch_resolver(field_name="listLocations2")
258+
def get_locations2(event: AppSyncResolverEvent) -> str:
259+
return "get_locations2#" + event.arguments["name"]
260+
261+
app.include_router(router)
262+
263+
# WHEN
264+
mock_event1 = [
265+
{
266+
"typeName": "Query",
267+
"info": {
268+
"fieldName": "listLocations",
269+
"parentTypeName": "Query",
270+
},
271+
"fieldName": "listLocations",
272+
"arguments": {"name": "value"},
273+
"source": {
274+
"id": "1",
275+
},
276+
}
277+
]
278+
mock_event2 = [
279+
{
280+
"typeName": "Query",
281+
"info": {
282+
"fieldName": "listLocations2",
283+
"parentTypeName": "Post",
284+
},
285+
"fieldName": "listLocations2",
286+
"arguments": {"name": "value"},
287+
"source": {
288+
"id": "2",
289+
},
290+
}
291+
]
292+
result1 = app.resolve(mock_event1, LambdaContext())
293+
result2 = app.resolve(mock_event2, LambdaContext())
294+
295+
# THEN
296+
assert result1 == ["get_locations#value"]
297+
assert result2 == ["get_locations2#value"]
298+
299+
300+
def test_resolver_include_mixed_resolver():
301+
# GIVEN
302+
app = AppSyncResolver()
303+
router = Router()
304+
305+
@router.batch_resolver(type_name="Query", field_name="listLocations")
306+
def get_locations(event: AppSyncResolverEvent) -> str:
307+
return "get_locations#" + event.arguments["name"]
308+
309+
@app.resolver(field_name="listLocations2")
310+
def get_locations2(name: str) -> str:
311+
return "get_locations2#" + name
312+
313+
app.include_router(router)
314+
315+
# WHEN
316+
mock_event1 = [
317+
{
318+
"typeName": "Query",
319+
"info": {
320+
"fieldName": "listLocations",
321+
"parentTypeName": "Query",
322+
},
323+
"fieldName": "listLocations",
324+
"arguments": {"name": "value"},
325+
"source": {
326+
"id": "1",
327+
},
328+
}
329+
]
330+
mock_event2 = {
331+
"typeName": "Query",
332+
"info": {
333+
"fieldName": "listLocations2",
334+
"parentTypeName": "Post",
335+
},
336+
"fieldName": "listLocations2",
337+
"arguments": {"name": "value"},
338+
"source": {
339+
"id": "2",
340+
},
341+
}
342+
343+
result1 = app.resolve(mock_event1, LambdaContext())
344+
result2 = app.resolve(mock_event2, LambdaContext())
345+
346+
# THEN
347+
assert result1 == ["get_locations#value"]
348+
assert result2 == "get_locations2#value"
349+
350+
248351
def test_append_context():
249352
app = AppSyncResolver()
250353
app.append_context(is_admin=True)
251-
assert app.context.get("is_admin") is True
354+
assert app._router_context.context.get("is_admin") is True
252355

253356

254357
def test_router_append_context():
255358
router = Router()
256359
router.append_context(is_admin=True)
257-
assert router.context.get("is_admin") is True
360+
assert router._router_context.context.get("is_admin") is True
258361

259362

260363
def test_route_context_is_cleared_after_resolve():
@@ -271,7 +374,7 @@ def get_locations(name: str):
271374
app.resolve(event, {})
272375

273376
# THEN context should be empty
274-
assert app.context == {}
377+
assert app._router_context.context == {}
275378

276379

277380
def test_router_has_access_to_app_context():
@@ -282,7 +385,7 @@ def test_router_has_access_to_app_context():
282385

283386
@router.resolver(type_name="Query", field_name="listLocations")
284387
def get_locations(name: str):
285-
if router.context["is_admin"]:
388+
if router._router_context.context.get("is_admin"):
286389
return f"get_locations#{name}"
287390

288391
app.include_router(router)
@@ -293,7 +396,7 @@ def get_locations(name: str):
293396

294397
# THEN
295398
assert ret == "get_locations#value"
296-
assert router.context == {}
399+
assert router._router_context.context == {}
297400

298401

299402
def test_include_router_merges_context():
@@ -307,4 +410,4 @@ def test_include_router_merges_context():
307410

308411
app.include_router(router)
309412

310-
assert app.context == router.context
413+
assert app._router_context.context == router._router_context.context

0 commit comments

Comments
 (0)