Skip to content

Commit e6177a0

Browse files
uranusjrephraimbuddy
authored andcommitted
Handle list when serializing expand_kwargs (#26369)
(cherry picked from commit b816a6b)
1 parent 553f7c9 commit e6177a0

File tree

2 files changed

+98
-4
lines changed

2 files changed

+98
-4
lines changed

airflow/serialization/serialized_objects.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,15 @@
1717
"""Serialized DAG and BaseOperator"""
1818
from __future__ import annotations
1919

20+
import collections.abc
2021
import datetime
2122
import enum
2223
import logging
2324
import warnings
2425
import weakref
2526
from dataclasses import dataclass
2627
from inspect import Parameter, signature
27-
from typing import TYPE_CHECKING, Any, Iterable, NamedTuple, Type
28+
from typing import TYPE_CHECKING, Any, Collection, Iterable, Mapping, NamedTuple, Type, Union
2829

2930
import cattr
3031
import lazy_object_proxy
@@ -207,6 +208,26 @@ def deref(self, dag: DAG) -> XComArg:
207208
return deserialize_xcom_arg(self.data, dag)
208209

209210

211+
# These two should be kept in sync. Note that these are intentionally not using
212+
# the type declarations in expandinput.py so we always remember to update
213+
# serialization logic when adding new ExpandInput variants. If you add things to
214+
# the unions, be sure to update _ExpandInputRef to match.
215+
_ExpandInputOriginalValue = Union[
216+
# For .expand(**kwargs).
217+
Mapping[str, Any],
218+
# For expand_kwargs(arg).
219+
XComArg,
220+
Collection[Union[XComArg, Mapping[str, Any]]],
221+
]
222+
_ExpandInputSerializedValue = Union[
223+
# For .expand(**kwargs).
224+
Mapping[str, Any],
225+
# For expand_kwargs(arg).
226+
_XComRef,
227+
Collection[Union[_XComRef, Mapping[str, Any]]],
228+
]
229+
230+
210231
class _ExpandInputRef(NamedTuple):
211232
"""Used to store info needed to create a mapped operator's expand input.
212233
@@ -215,13 +236,29 @@ class _ExpandInputRef(NamedTuple):
215236
"""
216237

217238
key: str
218-
value: _XComRef | dict[str, Any]
239+
value: _ExpandInputSerializedValue
240+
241+
@classmethod
242+
def validate_expand_input_value(cls, value: _ExpandInputOriginalValue) -> None:
243+
"""Validate we've covered all ``ExpandInput.value`` types.
244+
245+
This function does not actually do anything, but is called during
246+
serialization so Mypy will *statically* check we have handled all
247+
possible ExpandInput cases.
248+
"""
219249

220250
def deref(self, dag: DAG) -> ExpandInput:
251+
"""De-reference into a concrete ExpandInput object.
252+
253+
If you add more cases here, be sure to update _ExpandInputOriginalValue
254+
and _ExpandInputSerializedValue to match the logic.
255+
"""
221256
if isinstance(self.value, _XComRef):
222257
value: Any = self.value.deref(dag)
223-
else:
258+
elif isinstance(self.value, collections.abc.Mapping):
224259
value = {k: v.deref(dag) if isinstance(v, _XComRef) else v for k, v in self.value.items()}
260+
else:
261+
value = [v.deref(dag) if isinstance(v, _XComRef) else v for v in self.value]
225262
return create_expand_input(self.key, value)
226263

227264

@@ -663,6 +700,8 @@ def serialize_mapped_operator(cls, op: MappedOperator) -> dict[str, Any]:
663700
serialized_op = cls._serialize_node(op, include_deps=op.deps != MappedOperator.deps_for(BaseOperator))
664701
# Handle expand_input and op_kwargs_expand_input.
665702
expansion_kwargs = op._get_specified_expand_input()
703+
if TYPE_CHECKING: # Let Mypy check the input type for us!
704+
_ExpandInputRef.validate_expand_input_value(expansion_kwargs.value)
666705
serialized_op[op._expand_input_attr] = {
667706
"type": get_map_type_key(expansion_kwargs),
668707
"value": cls.serialize(expansion_kwargs.value),

tests/serialization/test_dag_serialization.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1961,7 +1961,62 @@ def test_operator_expand_xcomarg_serde():
19611961

19621962

19631963
@pytest.mark.parametrize("strict", [True, False])
1964-
def test_operator_expand_kwargs_serde(strict):
1964+
def test_operator_expand_kwargs_literal_serde(strict):
1965+
from airflow.models.xcom_arg import PlainXComArg, XComArg
1966+
from airflow.serialization.serialized_objects import _XComRef
1967+
1968+
with DAG("test-dag", start_date=datetime(2020, 1, 1)) as dag:
1969+
task1 = BaseOperator(task_id="op1")
1970+
mapped = MockOperator.partial(task_id='task_2').expand_kwargs(
1971+
[{"a": "x"}, {"a": XComArg(task1)}],
1972+
strict=strict,
1973+
)
1974+
1975+
serialized = SerializedBaseOperator.serialize(mapped)
1976+
assert serialized == {
1977+
'_is_empty': False,
1978+
'_is_mapped': True,
1979+
'_task_module': 'tests.test_utils.mock_operators',
1980+
'_task_type': 'MockOperator',
1981+
'downstream_task_ids': [],
1982+
'expand_input': {
1983+
"type": "list-of-dicts",
1984+
"value": [
1985+
{"__type": "dict", "__var": {"a": "x"}},
1986+
{
1987+
"__type": "dict",
1988+
"__var": {"a": {'__type': 'xcomref', '__var': {'task_id': 'op1', 'key': 'return_value'}}},
1989+
},
1990+
],
1991+
},
1992+
'partial_kwargs': {},
1993+
'task_id': 'task_2',
1994+
'template_fields': ['arg1', 'arg2'],
1995+
'template_ext': [],
1996+
'template_fields_renderers': {},
1997+
'operator_extra_links': [],
1998+
'ui_color': '#fff',
1999+
'ui_fgcolor': '#000',
2000+
"_disallow_kwargs_override": strict,
2001+
'_expand_input_attr': 'expand_input',
2002+
}
2003+
2004+
op = SerializedBaseOperator.deserialize_operator(serialized)
2005+
assert op.deps is MappedOperator.deps_for(BaseOperator)
2006+
assert op._disallow_kwargs_override == strict
2007+
2008+
# The XComArg can't be deserialized before the DAG is.
2009+
expand_value = op.expand_input.value
2010+
assert expand_value == [{"a": "x"}, {"a": _XComRef({"task_id": "op1", "key": XCOM_RETURN_KEY})}]
2011+
2012+
serialized_dag: DAG = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
2013+
2014+
resolved_expand_value = serialized_dag.task_dict['task_2'].expand_input.value
2015+
resolved_expand_value == [{"a": "x"}, {"a": PlainXComArg(serialized_dag.task_dict['op1'])}]
2016+
2017+
2018+
@pytest.mark.parametrize("strict", [True, False])
2019+
def test_operator_expand_kwargs_xcomarg_serde(strict):
19652020
from airflow.models.xcom_arg import PlainXComArg, XComArg
19662021
from airflow.serialization.serialized_objects import _XComRef
19672022

0 commit comments

Comments
 (0)