Skip to content

Commit 7e86bf8

Browse files
Mark all tasks as skipped when failing a dag_run manually including t… (#43572)
* Mark all tasks as skipped when failing a dag_run manually including tasks with None state (#43482) (cherry picked from commit eda6a8f) * Fix tests for 2.10.x --------- Co-authored-by: Abhishek <[email protected]> (cherry picked from commit 72eef0f)
1 parent 8e79c7a commit 7e86bf8

File tree

2 files changed

+82
-3
lines changed

2 files changed

+82
-3
lines changed

airflow/api/common/mark_tasks.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from typing import TYPE_CHECKING, Collection, Iterable, Iterator, NamedTuple
2323

24-
from sqlalchemy import or_, select
24+
from sqlalchemy import and_, or_, select
2525
from sqlalchemy.orm import lazyload
2626

2727
from airflow.models.dagrun import DagRun
@@ -500,8 +500,13 @@ def set_dag_run_state_to_failed(
500500
select(TaskInstance).filter(
501501
TaskInstance.dag_id == dag.dag_id,
502502
TaskInstance.run_id == run_id,
503-
TaskInstance.state.not_in(State.finished),
504-
TaskInstance.state.not_in(running_states),
503+
or_(
504+
TaskInstance.state.is_(None),
505+
and_(
506+
TaskInstance.state.not_in(State.finished),
507+
TaskInstance.state.not_in(running_states),
508+
),
509+
),
505510
)
506511
).all()
507512

tests/www/views/test_views_dagrun.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,3 +290,77 @@ def test_dag_runs_queue_new_tasks_action(session, admin_client, completed_dag_ru
290290
check_content_in_response("runme_2", resp)
291291
check_content_not_in_response("runme_1", resp)
292292
assert resp.status_code == 200
293+
294+
295+
@pytest.fixture
296+
def dag_run_with_all_done_task(session):
297+
"""Creates a DAG run for example_bash_decorator with tasks in various states and an ALL_DONE task not yet run."""
298+
dag = DagBag().get_dag("example_bash_decorator")
299+
300+
# Re-sync the DAG to the DB
301+
dag.sync_to_db()
302+
303+
execution_date = timezone.datetime(2016, 1, 9)
304+
dr = dag.create_dagrun(
305+
state="running",
306+
execution_date=execution_date,
307+
data_interval=(execution_date, execution_date),
308+
run_id="test_dagrun_failed",
309+
session=session,
310+
)
311+
312+
# Create task instances in various states to test the ALL_DONE trigger rule
313+
tis = [
314+
# runme_loop tasks
315+
TaskInstance(dag.get_task("runme_0"), run_id=dr.run_id, state="success"),
316+
TaskInstance(dag.get_task("runme_1"), run_id=dr.run_id, state="failed"),
317+
TaskInstance(dag.get_task("runme_2"), run_id=dr.run_id, state="running"),
318+
# Other tasks before run_this_last
319+
TaskInstance(dag.get_task("run_after_loop"), run_id=dr.run_id, state="success"),
320+
TaskInstance(dag.get_task("also_run_this"), run_id=dr.run_id, state="success"),
321+
TaskInstance(dag.get_task("also_run_this_again"), run_id=dr.run_id, state="skipped"),
322+
TaskInstance(dag.get_task("this_will_skip"), run_id=dr.run_id, state="running"),
323+
# The task with trigger_rule=ALL_DONE
324+
TaskInstance(dag.get_task("run_this_last"), run_id=dr.run_id, state=None),
325+
]
326+
session.bulk_save_objects(tis)
327+
session.commit()
328+
329+
return dag, dr
330+
331+
332+
def test_dagrun_failed(session, admin_client, dag_run_with_all_done_task):
333+
"""Test marking a dag run as failed with a task having trigger_rule='all_done'"""
334+
dag, dr = dag_run_with_all_done_task
335+
336+
# Verify task instances were created
337+
task_instances = (
338+
session.query(TaskInstance)
339+
.filter(TaskInstance.dag_id == dr.dag_id, TaskInstance.run_id == dr.run_id)
340+
.all()
341+
)
342+
assert len(task_instances) > 0
343+
344+
resp = admin_client.post(
345+
"/dagrun_failed",
346+
data={"dag_id": dr.dag_id, "dag_run_id": dr.run_id, "confirmed": "true"},
347+
follow_redirects=True,
348+
)
349+
350+
assert resp.status_code == 200
351+
352+
with create_session() as session:
353+
updated_dr = (
354+
session.query(DagRun).filter(DagRun.dag_id == dr.dag_id, DagRun.run_id == dr.run_id).first()
355+
)
356+
assert updated_dr.state == "failed"
357+
358+
task_instances = (
359+
session.query(TaskInstance)
360+
.filter(TaskInstance.dag_id == dr.dag_id, TaskInstance.run_id == dr.run_id)
361+
.all()
362+
)
363+
364+
done_states = {"success", "failed", "skipped", "upstream_failed"}
365+
for ti in task_instances:
366+
assert ti.state in done_states

0 commit comments

Comments
 (0)