17
17
"""Serialized DAG and BaseOperator"""
18
18
from __future__ import annotations
19
19
20
+ import collections .abc
20
21
import datetime
21
22
import enum
22
23
import logging
23
24
import warnings
24
25
import weakref
25
26
from dataclasses import dataclass
26
27
from inspect import Parameter , signature
27
- from typing import TYPE_CHECKING , Any , Iterable , NamedTuple , Type
28
+ from typing import TYPE_CHECKING , Any , Collection , Iterable , Mapping , NamedTuple , Type , Union
28
29
29
30
import cattr
30
31
import lazy_object_proxy
@@ -207,6 +208,26 @@ def deref(self, dag: DAG) -> XComArg:
207
208
return deserialize_xcom_arg (self .data , dag )
208
209
209
210
211
+ # These two should be kept in sync. Note that these are intentionally not using
212
+ # the type declarations in expandinput.py so we always remember to update
213
+ # serialization logic when adding new ExpandInput variants. If you add things to
214
+ # the unions, be sure to update _ExpandInputRef to match.
215
+ _ExpandInputOriginalValue = Union [
216
+ # For .expand(**kwargs).
217
+ Mapping [str , Any ],
218
+ # For expand_kwargs(arg).
219
+ XComArg ,
220
+ Collection [Union [XComArg , Mapping [str , Any ]]],
221
+ ]
222
+ _ExpandInputSerializedValue = Union [
223
+ # For .expand(**kwargs).
224
+ Mapping [str , Any ],
225
+ # For expand_kwargs(arg).
226
+ _XComRef ,
227
+ Collection [Union [_XComRef , Mapping [str , Any ]]],
228
+ ]
229
+
230
+
210
231
class _ExpandInputRef (NamedTuple ):
211
232
"""Used to store info needed to create a mapped operator's expand input.
212
233
@@ -215,13 +236,29 @@ class _ExpandInputRef(NamedTuple):
215
236
"""
216
237
217
238
key : str
218
- value : _XComRef | dict [str , Any ]
239
+ value : _ExpandInputSerializedValue
240
+
241
+ @classmethod
242
+ def validate_expand_input_value (cls , value : _ExpandInputOriginalValue ) -> None :
243
+ """Validate we've covered all ``ExpandInput.value`` types.
244
+
245
+ This function does not actually do anything, but is called during
246
+ serialization so Mypy will *statically* check we have handled all
247
+ possible ExpandInput cases.
248
+ """
219
249
220
250
def deref (self , dag : DAG ) -> ExpandInput :
251
+ """De-reference into a concrete ExpandInput object.
252
+
253
+ If you add more cases here, be sure to update _ExpandInputOriginalValue
254
+ and _ExpandInputSerializedValue to match the logic.
255
+ """
221
256
if isinstance (self .value , _XComRef ):
222
257
value : Any = self .value .deref (dag )
223
- else :
258
+ elif isinstance ( self . value , collections . abc . Mapping ) :
224
259
value = {k : v .deref (dag ) if isinstance (v , _XComRef ) else v for k , v in self .value .items ()}
260
+ else :
261
+ value = [v .deref (dag ) if isinstance (v , _XComRef ) else v for v in self .value ]
225
262
return create_expand_input (self .key , value )
226
263
227
264
@@ -663,6 +700,8 @@ def serialize_mapped_operator(cls, op: MappedOperator) -> dict[str, Any]:
663
700
serialized_op = cls ._serialize_node (op , include_deps = op .deps != MappedOperator .deps_for (BaseOperator ))
664
701
# Handle expand_input and op_kwargs_expand_input.
665
702
expansion_kwargs = op ._get_specified_expand_input ()
703
+ if TYPE_CHECKING : # Let Mypy check the input type for us!
704
+ _ExpandInputRef .validate_expand_input_value (expansion_kwargs .value )
666
705
serialized_op [op ._expand_input_attr ] = {
667
706
"type" : get_map_type_key (expansion_kwargs ),
668
707
"value" : cls .serialize (expansion_kwargs .value ),
0 commit comments