1
1
import logging
2
2
from itertools import groupby
3
- from typing import Any , Callable , List , Optional , Type , TypeVar , Union
3
+ from typing import Any , Callable , List , Optional , Type , Union
4
4
5
5
from aws_lambda_powertools .utilities .data_classes import AppSyncResolverEvent
6
6
from aws_lambda_powertools .utilities .typing import LambdaContext
7
7
8
8
logger = logging .getLogger (__name__ )
9
9
10
- AppSyncResolverEventT = TypeVar ("AppSyncResolverEventT" , bound = AppSyncResolverEvent )
11
10
11
+ class RouterContext :
12
+ def __init__ (self ):
13
+ super ().__init__ ()
14
+ self .context = {}
12
15
13
- class BaseRouter :
14
- current_event : Union [AppSyncResolverEventT , List [AppSyncResolverEventT ]] # type: ignore[valid-type]
15
- lambda_context : LambdaContext
16
- context : dict
16
+ def append_context (self , ** additional_context ):
17
+ """Append key=value data as routing context"""
18
+ self .context .update (** additional_context )
17
19
20
+ def clear_context (self ):
21
+ """Resets routing context"""
22
+ self .context .clear ()
23
+
24
+
25
+ class ResolverRegistry :
18
26
def __init__ (self ):
27
+ super ().__init__ ()
19
28
self ._resolvers : dict = {}
29
+ self ._batch_resolvers : dict = {}
20
30
21
31
def resolver (self , type_name : str = "*" , field_name : Optional [str ] = None ):
22
32
"""Registers the resolver for field_name
@@ -29,23 +39,33 @@ def resolver(self, type_name: str = "*", field_name: Optional[str] = None):
29
39
Field name
30
40
"""
31
41
32
- def register_resolver (func ):
42
+ def register (func ):
33
43
logger .debug (f"Adding resolver `{ func .__name__ } ` for field `{ type_name } .{ field_name } `" )
34
44
self ._resolvers [f"{ type_name } .{ field_name } " ] = {"func" : func }
35
45
return func
36
46
37
- return register_resolver
47
+ return register
38
48
39
- def append_context (self , ** additional_context ):
40
- """Append key=value data as routing context"""
41
- self .context .update (** additional_context )
49
+ def batch_resolver (self , type_name : str = "*" , field_name : Optional [str ] = None ):
50
+ """Registers the resolver for field_name
42
51
43
- def clear_context (self ):
44
- """Resets routing context"""
45
- self .context .clear ()
52
+ Parameters
53
+ ----------
54
+ type_name : str
55
+ Type name
56
+ field_name : str
57
+ Field name
58
+ """
46
59
60
+ def register (func ):
61
+ logger .debug (f"Adding batch resolver `{ func .__name__ } ` for field `{ type_name } .{ field_name } `" )
62
+ self ._batch_resolvers [f"{ type_name } .{ field_name } " ] = {"func" : func }
63
+ return func
47
64
48
- class AppSyncResolver (BaseRouter ):
65
+ return register
66
+
67
+
68
+ class AppSyncResolver (ResolverRegistry , RouterContext ):
49
69
"""
50
70
AppSync resolver decorator
51
71
@@ -78,16 +98,20 @@ def common_field() -> str:
78
98
79
99
def __init__ (self ):
80
100
super ().__init__ ()
81
- self .context = {} # early init as customers might add context before event resolution
101
+ self .current_batch_event : List [AppSyncResolverEvent ] = []
102
+ self .current_event : Optional [AppSyncResolverEvent ] = None
82
103
83
104
def resolve (
84
- self , event : dict , context : LambdaContext , data_model : Type [AppSyncResolverEvent ] = AppSyncResolverEvent
105
+ self ,
106
+ event : Union [dict , List [dict ]],
107
+ context : LambdaContext ,
108
+ data_model : Type [AppSyncResolverEvent ] = AppSyncResolverEvent ,
85
109
) -> Any :
86
110
"""Resolve field_name
87
111
88
112
Parameters
89
113
----------
90
- event : dict
114
+ event : dict | List[dict]
91
115
Lambda event
92
116
context : LambdaContext
93
117
Lambda context
@@ -152,33 +176,38 @@ def lambda_handler(event, context):
152
176
ValueError
153
177
If we could not find a field resolver
154
178
"""
155
- # Maintenance: revisit generics/overload to fix [attr-defined] in mypy usage
156
-
157
- BaseRouter .lambda_context = context
158
-
159
- # If event is a list it means that AppSync sent batch request
160
- if isinstance (event , list ):
161
- event_groups = [
162
- {"field_name" : field_name , "events" : list (events )}
163
- for field_name , events in groupby (event , key = lambda x : x ["info" ]["fieldName" ])
164
- ]
165
- if len (event_groups ) > 1 :
166
- ValueError ("batch with different field names. It shouldn't happen!" )
167
-
168
- appconfig_events = [data_model (event ) for event in event_groups [0 ]["events" ]]
169
- BaseRouter .current_event = appconfig_events
170
- resolver = self ._get_resolver (appconfig_events [0 ].type_name , event_groups [0 ]["field_name" ])
171
- response = resolver ()
172
- else :
173
- appconfig_event = data_model (event )
174
- BaseRouter .current_event = appconfig_event
175
- resolver = self ._get_resolver (appconfig_event .type_name , appconfig_event .field_name )
176
- response = resolver (** appconfig_event .arguments )
177
179
180
+ self .lambda_context = context
181
+
182
+ response = (
183
+ self ._call_batch_resolver (event , data_model )
184
+ if isinstance (event , list )
185
+ else self ._call_resolver (event , data_model )
186
+ )
178
187
self .clear_context ()
179
188
180
189
return response
181
190
191
+ def _call_resolver (self , event : dict , data_model : Type [AppSyncResolverEvent ]) -> Any :
192
+ self .current_event = data_model (event )
193
+ resolver = self ._get_resolver (self .current_event .type_name , self .current_event .field_name )
194
+ return resolver (** self .current_event .arguments )
195
+
196
+ def _call_batch_resolver (self , event : List [dict ], data_model : Type [AppSyncResolverEvent ]) -> list [Any ]:
197
+ event_groups = [
198
+ {"field_name" : field_name , "events" : list (events )}
199
+ for field_name , events in groupby (event , key = lambda x : x ["info" ]["fieldName" ])
200
+ ]
201
+ if len (event_groups ) > 1 :
202
+ ValueError ("batch with different field names. It shouldn't happen!" )
203
+
204
+ self .current_batch_event = [data_model (event ) for event in event_groups [0 ]["events" ]]
205
+ resolver = self ._get_batch_resolver (
206
+ self .current_batch_event [0 ].type_name , self .current_batch_event [0 ].field_name
207
+ )
208
+
209
+ return [resolver (event = appconfig_event ) for appconfig_event in self .current_batch_event ]
210
+
182
211
def _get_resolver (self , type_name : str , field_name : str ) -> Callable :
183
212
"""Get resolver for field_name
184
213
@@ -200,8 +229,32 @@ def _get_resolver(self, type_name: str, field_name: str) -> Callable:
200
229
raise ValueError (f"No resolver found for '{ full_name } '" )
201
230
return resolver ["func" ]
202
231
232
+ def _get_batch_resolver (self , type_name : str , field_name : str ) -> Callable :
233
+ """Get resolver for field_name
234
+
235
+ Parameters
236
+ ----------
237
+ type_name : str
238
+ Type name
239
+ field_name : str
240
+ Field name
241
+
242
+ Returns
243
+ -------
244
+ Callable
245
+ callable function and configuration
246
+ """
247
+ full_name = f"{ type_name } .{ field_name } "
248
+ resolver = self ._batch_resolvers .get (full_name , self ._batch_resolvers .get (f"*.{ field_name } " ))
249
+ if not resolver :
250
+ raise ValueError (f"No batch resolver found for '{ full_name } '" )
251
+ return resolver ["func" ]
252
+
203
253
def __call__ (
204
- self , event : dict , context : LambdaContext , data_model : Type [AppSyncResolverEvent ] = AppSyncResolverEvent
254
+ self ,
255
+ event : Union [dict , List [dict ]],
256
+ context : LambdaContext ,
257
+ data_model : Type [AppSyncResolverEvent ] = AppSyncResolverEvent ,
205
258
) -> Any :
206
259
"""Implicit lambda handler which internally calls `resolve`"""
207
260
return self .resolve (event , context , data_model )
@@ -222,7 +275,6 @@ def include_router(self, router: "Router") -> None:
222
275
self ._resolvers .update (router ._resolvers )
223
276
224
277
225
- class Router (BaseRouter ):
278
+ class Router (RouterContext , ResolverRegistry ):
226
279
def __init__ (self ):
227
280
super ().__init__ ()
228
- self .context = {} # early init as customers might add context before event resolution
0 commit comments