Skip to content

Commit cb605c3

Browse files
rjmcginnessCloud Composer Team
authored and
Cloud Composer Team
committed
Fix xcom arg.py .zip bug (#26636)
(cherry picked from commit f219bfbe22e662a8747af19d688bbe843e1a953d) GitOrigin-RevId: 45a461b37dcd9c8b97952ab535a7c057f1e944bd
1 parent f555065 commit cb605c3

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

airflow/models/xcom_arg.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from airflow.utils.context import Context
3232
from airflow.utils.edgemodifier import EdgeModifier
3333
from airflow.utils.session import NEW_SESSION, provide_session
34-
from airflow.utils.types import NOTSET
34+
from airflow.utils.types import NOTSET, ArgNotSet
3535

3636
if TYPE_CHECKING:
3737
from airflow.models.dag import DAG
@@ -322,7 +322,7 @@ def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
322322
def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any:
323323
task_id = self.operator.task_id
324324
result = context["ti"].xcom_pull(task_ids=task_id, key=str(self.key), default=NOTSET, session=session)
325-
if result is not NOTSET:
325+
if not isinstance(result, ArgNotSet):
326326
return result
327327
if self.key == XCOM_RETURN_KEY:
328328
return None
@@ -437,7 +437,7 @@ def __getitem__(self, index: Any) -> Any:
437437

438438
def __len__(self) -> int:
439439
lengths = (len(v) for v in self.values)
440-
if self.fillvalue is NOTSET:
440+
if isinstance(self.fillvalue, ArgNotSet):
441441
return min(lengths)
442442
return max(lengths)
443443

@@ -460,13 +460,13 @@ def __repr__(self) -> str:
460460
args_iter = iter(self.args)
461461
first = repr(next(args_iter))
462462
rest = ", ".join(repr(arg) for arg in args_iter)
463-
if self.fillvalue is NOTSET:
463+
if isinstance(self.fillvalue, ArgNotSet):
464464
return f"{first}.zip({rest})"
465465
return f"{first}.zip({rest}, fillvalue={self.fillvalue!r})"
466466

467467
def _serialize(self) -> dict[str, Any]:
468468
args = [serialize_xcom_arg(arg) for arg in self.args]
469-
if self.fillvalue is NOTSET:
469+
if isinstance(self.fillvalue, ArgNotSet):
470470
return {"args": args}
471471
return {"args": args, "fillvalue": self.fillvalue}
472472

@@ -486,7 +486,7 @@ def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
486486
ready_lengths = [length for length in all_lengths if length is not None]
487487
if len(ready_lengths) != len(self.args):
488488
return None # If any of the referenced XComs is not ready, we are not ready either.
489-
if self.fillvalue is NOTSET:
489+
if isinstance(self.fillvalue, ArgNotSet):
490490
return min(ready_lengths)
491491
return max(ready_lengths)
492492

0 commit comments

Comments
 (0)