Skip to content

Commit e98b0e8

Browse files
uranusjrjedcunningham
authored andcommitted
When rendering template, unmap task in context (#26702)
(cherry picked from commit 5560a46)
1 parent 131d8be commit e98b0e8

File tree

7 files changed

+61
-27
lines changed

7 files changed

+61
-27
lines changed

airflow/models/abstractoperator.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -373,15 +373,15 @@ def render_template_fields(
373373
self,
374374
context: Context,
375375
jinja_env: jinja2.Environment | None = None,
376-
) -> BaseOperator | None:
376+
) -> None:
377377
"""Template all attributes listed in template_fields.
378378
379379
If the operator is mapped, this should return the unmapped, fully
380380
rendered, and map-expanded operator. The mapped operator should not be
381-
modified.
381+
modified. However, ``context`` will be modified in-place to reference
382+
the unmapped operator for template rendering.
382383
383-
If the operator is not mapped, this should modify the operator in-place
384-
and return either *None* (for backwards compatibility) or *self*.
384+
If the operator is not mapped, this should modify the operator in-place.
385385
"""
386386
raise NotImplementedError()
387387

airflow/models/baseoperator.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1179,7 +1179,7 @@ def render_template_fields(
11791179
self,
11801180
context: Context,
11811181
jinja_env: jinja2.Environment | None = None,
1182-
) -> BaseOperator | None:
1182+
) -> None:
11831183
"""Template all attributes listed in template_fields.
11841184
11851185
This mutates the attributes in-place and is irreversible.
@@ -1190,7 +1190,6 @@ def render_template_fields(
11901190
if not jinja_env:
11911191
jinja_env = self.get_template_env()
11921192
self._do_render_template_fields(self, self.template_fields, context, jinja_env, set())
1193-
return self
11941193

11951194
@provide_session
11961195
def clear(

airflow/models/mappedoperator.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@
5858
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
5959
from airflow.ti_deps.deps.mapped_task_expanded import MappedTaskIsExpanded
6060
from airflow.typing_compat import Literal
61-
from airflow.utils.context import Context
61+
from airflow.utils.context import Context, context_update_for_unmapped
6262
from airflow.utils.helpers import is_container
6363
from airflow.utils.operator_resources import Resources
6464
from airflow.utils.state import State, TaskInstanceState
@@ -748,7 +748,7 @@ def render_template_fields(
748748
self,
749749
context: Context,
750750
jinja_env: jinja2.Environment | None = None,
751-
) -> BaseOperator | None:
751+
) -> None:
752752
if not jinja_env:
753753
jinja_env = self.get_template_env()
754754

@@ -761,6 +761,8 @@ def render_template_fields(
761761

762762
mapped_kwargs, seen_oids = self._expand_mapped_kwargs(context, session)
763763
unmapped_task = self.unmap(mapped_kwargs)
764+
context_update_for_unmapped(context, unmapped_task)
765+
764766
self._do_render_template_fields(
765767
parent=unmapped_task,
766768
template_fields=self.template_fields,
@@ -769,4 +771,3 @@ def render_template_fields(
769771
seen_oids=seen_oids,
770772
session=session,
771773
)
772-
return unmapped_task

airflow/models/taskinstance.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2190,10 +2190,14 @@ def render_templates(self, context: Context | None = None) -> Operator:
21902190
"""
21912191
if not context:
21922192
context = self.get_template_context()
2193-
rendered_task = self.task.render_template_fields(context)
2194-
if rendered_task is None: # Compatibility -- custom renderer, assume unmapped.
2195-
return self.task
2196-
original_task, self.task = self.task, rendered_task
2193+
original_task = self.task
2194+
2195+
# If self.task is mapped, this call replaces self.task to point to the
2196+
# unmapped BaseOperator created by this function! This is because the
2197+
# MappedOperator is useless for template rendering, and we need to be
2198+
# able to access the unmapped task instead.
2199+
original_task.render_template_fields(context)
2200+
21972201
return original_task
21982202

21992203
def render_k8s_pod_yaml(self) -> dict | None:

airflow/utils/context.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,26 @@
2222
import copy
2323
import functools
2424
import warnings
25-
from typing import Any, Container, ItemsView, Iterator, KeysView, Mapping, MutableMapping, ValuesView
25+
from typing import (
26+
TYPE_CHECKING,
27+
Any,
28+
Container,
29+
ItemsView,
30+
Iterator,
31+
KeysView,
32+
Mapping,
33+
MutableMapping,
34+
ValuesView,
35+
)
2636

2737
import lazy_object_proxy
2838

2939
from airflow.exceptions import RemovedInAirflow3Warning
3040
from airflow.utils.types import NOTSET
3141

42+
if TYPE_CHECKING:
43+
from airflow.models.baseoperator import BaseOperator
44+
3245
# NOTE: Please keep this in sync with Context in airflow/utils/context.pyi.
3346
KNOWN_CONTEXT_KEYS = {
3447
"conf",
@@ -291,3 +304,15 @@ def _create_value(k: str, v: Any) -> Any:
291304
return lazy_object_proxy.Proxy(factory)
292305

293306
return {k: _create_value(k, v) for k, v in source._context.items()}
307+
308+
309+
def context_update_for_unmapped(context: Context, task: BaseOperator) -> None:
310+
"""Update context after task unmapping.
311+
312+
Since ``get_template_context()`` is called before unmapping, the context
313+
contains information about the mapped task. We need to do some in-place
314+
updates to ensure the template context reflects the unmapped task instead.
315+
316+
:meta private:
317+
"""
318+
context["task"] = context["ti"].task = task

tests/decorators/test_python.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -756,11 +756,13 @@ def fn(arg1, arg2):
756756

757757
mapped_ti: TaskInstance = dr.get_task_instance(mapped.operator.task_id, session=session)
758758
mapped_ti.map_index = 0
759-
op = mapped.operator.render_template_fields(context=mapped_ti.get_template_context(session=session))
760-
assert op
761759

762-
assert op.op_kwargs['arg1'] == "{{ ds }}"
763-
assert op.op_kwargs['arg2'] == "fn"
760+
assert mapped_ti.task.is_mapped
761+
mapped.operator.render_template_fields(context=mapped_ti.get_template_context(session=session))
762+
assert not mapped_ti.task.is_mapped
763+
764+
assert mapped_ti.task.op_kwargs['arg1'] == "{{ ds }}"
765+
assert mapped_ti.task.op_kwargs['arg2'] == "fn"
764766

765767

766768
def test_task_decorator_has_wrapped_attr():

tests/models/test_mappedoperator.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -305,12 +305,14 @@ def __init__(self, value, arg1, **kwargs):
305305

306306
mapped_ti: TaskInstance = dr.get_task_instance(mapped.task_id, session=session)
307307
mapped_ti.map_index = 0
308-
op = mapped.render_template_fields(context=mapped_ti.get_template_context(session=session))
309-
assert isinstance(op, MyOperator)
310308

311-
assert op.value == "{{ ds }}", "Should not be templated!"
312-
assert op.arg1 == "{{ ds }}", "Should not be templated!"
313-
assert op.arg2 == "a"
309+
assert isinstance(mapped_ti.task, MappedOperator)
310+
mapped.render_template_fields(context=mapped_ti.get_template_context(session=session))
311+
assert isinstance(mapped_ti.task, MyOperator)
312+
313+
assert mapped_ti.task.value == "{{ ds }}", "Should not be templated!"
314+
assert mapped_ti.task.arg1 == "{{ ds }}", "Should not be templated!"
315+
assert mapped_ti.task.arg2 == "a"
314316

315317

316318
def test_mapped_render_nested_template_fields(dag_maker, session):
@@ -430,10 +432,11 @@ def test_expand_kwargs_render_template_fields_validating_operator(dag_maker, ses
430432
ti: TaskInstance = dr.get_task_instance(mapped.task_id, session=session)
431433
ti.refresh_from_task(mapped)
432434
ti.map_index = map_index
433-
op = mapped.render_template_fields(context=ti.get_template_context(session=session))
434-
assert isinstance(op, MockOperator)
435-
assert op.arg1 == expected
436-
assert op.arg2 == "a"
435+
assert isinstance(ti.task, MappedOperator)
436+
mapped.render_template_fields(context=ti.get_template_context(session=session))
437+
assert isinstance(ti.task, MockOperator)
438+
assert ti.task.arg1 == expected
439+
assert ti.task.arg2 == "a"
437440

438441

439442
def test_xcomarg_property_of_mapped_operator(dag_maker):

0 commit comments

Comments
 (0)