Skip to content

Commit 08bbf89

Browse files
jscheffldabla
authored andcommitted
FIX: Don't raise a warning in ExecutorSafeguard when execute is called from an extended operator (#42849) (#43577)
* refactor: Don't raise a warning when execute is called from an extended operator, as this should always be allowed. * refactored: Fixed import of test_utils in test_dag_run --------- Co-authored-by: David Blain <[email protected]> (cherry picked from commit 95c46ec) Co-authored-by: David Blain <[email protected]> (cherry picked from commit 2f29c57)
1 parent 7e86bf8 commit 08bbf89

File tree

2 files changed

+33
-2
lines changed

2 files changed

+33
-2
lines changed

airflow/models/baseoperator.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import warnings
3535
from datetime import datetime, timedelta
3636
from functools import total_ordering, wraps
37+
from threading import local
3738
from types import FunctionType
3839
from typing import (
3940
TYPE_CHECKING,
@@ -391,14 +392,22 @@ class ExecutorSafeguard:
391392
"""
392393

393394
test_mode = conf.getboolean("core", "unit_test_mode")
395+
_sentinel = local()
396+
_sentinel.callers = {}
394397

395398
@classmethod
396399
def decorator(cls, func):
397400
@wraps(func)
398401
def wrapper(self, *args, **kwargs):
399402
from airflow.decorators.base import DecoratedOperator
400403

401-
sentinel = kwargs.pop(f"{self.__class__.__name__}__sentinel", None)
404+
sentinel_key = f"{self.__class__.__name__}__sentinel"
405+
sentinel = kwargs.pop(sentinel_key, None)
406+
407+
if sentinel:
408+
cls._sentinel.callers[sentinel_key] = sentinel
409+
else:
410+
sentinel = cls._sentinel.callers.pop(f"{func.__qualname__.split('.')[0]}__sentinel", None)
402411

403412
if not cls.test_mode and not sentinel == _sentinel and not isinstance(self, DecoratedOperator):
404413
message = f"{self.__class__.__name__}.{func.__name__} cannot be called outside TaskInstance!"

tests/models/test_baseoperatormeta.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,11 @@ def execute(self, context: Context) -> Any:
4040
return f"Hello {self.owner}!"
4141

4242

43+
class ExtendedHelloWorldOperator(HelloWorldOperator):
44+
def execute(self, context: Context) -> Any:
45+
return super().execute(context)
46+
47+
4348
class TestExecutorSafeguard:
4449
def setup_method(self):
4550
ExecutorSafeguard.test_mode = False
@@ -49,12 +54,29 @@ def teardown_method(self, method):
4954

5055
@pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode
5156
@pytest.mark.db_test
52-
def test_executor_when_classic_operator_called_from_dag(self, dag_maker):
57+
@patch.object(HelloWorldOperator, "log")
58+
def test_executor_when_classic_operator_called_from_dag(self, mock_log, dag_maker):
5359
with dag_maker() as dag:
5460
HelloWorldOperator(task_id="hello_operator")
5561

5662
dag_run = dag.test()
5763
assert dag_run.state == DagRunState.SUCCESS
64+
mock_log.warning.assert_not_called()
65+
66+
@pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode
67+
@pytest.mark.db_test
68+
@patch.object(HelloWorldOperator, "log")
69+
def test_executor_when_extended_classic_operator_called_from_dag(
70+
self,
71+
mock_log,
72+
dag_maker,
73+
):
74+
with dag_maker() as dag:
75+
ExtendedHelloWorldOperator(task_id="hello_operator")
76+
77+
dag_run = dag.test()
78+
assert dag_run.state == DagRunState.SUCCESS
79+
mock_log.warning.assert_not_called()
5880

5981
@pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode
6082
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)