Skip to content

Commit eda6a8f

Browse files
Mark all tasks as skipped when failing a dag_run manually including tasks with None state (#43482)
1 parent 626f984 commit eda6a8f

File tree

2 files changed

+84
-3
lines changed

2 files changed

+84
-3
lines changed

airflow/api/common/mark_tasks.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from typing import TYPE_CHECKING, Collection, Iterable, Iterator, NamedTuple
2323

24-
from sqlalchemy import or_, select
24+
from sqlalchemy import and_, or_, select
2525
from sqlalchemy.orm import lazyload
2626

2727
from airflow.models.dagrun import DagRun
@@ -402,8 +402,13 @@ def set_dag_run_state_to_failed(
402402
select(TaskInstance).filter(
403403
TaskInstance.dag_id == dag.dag_id,
404404
TaskInstance.run_id == run_id,
405-
TaskInstance.state.not_in(State.finished),
406-
TaskInstance.state.not_in(running_states),
405+
or_(
406+
TaskInstance.state.is_(None),
407+
and_(
408+
TaskInstance.state.not_in(State.finished),
409+
TaskInstance.state.not_in(running_states),
410+
),
411+
),
407412
)
408413
).all()
409414

tests/www/views/test_views_dagrun.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,3 +307,79 @@ def test_dag_runs_queue_new_tasks_action(session, admin_client, completed_dag_ru
307307
check_content_in_response("runme_2", resp)
308308
check_content_not_in_response("runme_1", resp)
309309
assert resp.status_code == 200
310+
311+
312+
@pytest.fixture
313+
def dag_run_with_all_done_task(session):
314+
"""Creates a DAG run for example_bash_decorator with tasks in various states and an ALL_DONE task not yet run."""
315+
dag = DagBag().get_dag("example_bash_decorator")
316+
317+
# Re-sync the DAG to the DB
318+
dag.sync_to_db()
319+
320+
execution_date = timezone.datetime(2016, 1, 9)
321+
triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {}
322+
dr = dag.create_dagrun(
323+
state="running",
324+
execution_date=execution_date,
325+
data_interval=(execution_date, execution_date),
326+
run_id="test_dagrun_failed",
327+
session=session,
328+
**triggered_by_kwargs,
329+
)
330+
331+
# Create task instances in various states to test the ALL_DONE trigger rule
332+
tis = [
333+
# runme_loop tasks
334+
TaskInstance(dag.get_task("runme_0"), run_id=dr.run_id, state="success"),
335+
TaskInstance(dag.get_task("runme_1"), run_id=dr.run_id, state="failed"),
336+
TaskInstance(dag.get_task("runme_2"), run_id=dr.run_id, state="running"),
337+
# Other tasks before run_this_last
338+
TaskInstance(dag.get_task("run_after_loop"), run_id=dr.run_id, state="success"),
339+
TaskInstance(dag.get_task("also_run_this"), run_id=dr.run_id, state="success"),
340+
TaskInstance(dag.get_task("also_run_this_again"), run_id=dr.run_id, state="skipped"),
341+
TaskInstance(dag.get_task("this_will_skip"), run_id=dr.run_id, state="running"),
342+
# The task with trigger_rule=ALL_DONE
343+
TaskInstance(dag.get_task("run_this_last"), run_id=dr.run_id, state=None),
344+
]
345+
session.bulk_save_objects(tis)
346+
session.commit()
347+
348+
return dag, dr
349+
350+
351+
def test_dagrun_failed(session, admin_client, dag_run_with_all_done_task):
352+
"""Test marking a dag run as failed with a task having trigger_rule='all_done'"""
353+
dag, dr = dag_run_with_all_done_task
354+
355+
# Verify task instances were created
356+
task_instances = (
357+
session.query(TaskInstance)
358+
.filter(TaskInstance.dag_id == dr.dag_id, TaskInstance.run_id == dr.run_id)
359+
.all()
360+
)
361+
assert len(task_instances) > 0
362+
363+
resp = admin_client.post(
364+
"/dagrun_failed",
365+
data={"dag_id": dr.dag_id, "dag_run_id": dr.run_id, "confirmed": "true"},
366+
follow_redirects=True,
367+
)
368+
369+
assert resp.status_code == 200
370+
371+
with create_session() as session:
372+
updated_dr = (
373+
session.query(DagRun).filter(DagRun.dag_id == dr.dag_id, DagRun.run_id == dr.run_id).first()
374+
)
375+
assert updated_dr.state == "failed"
376+
377+
task_instances = (
378+
session.query(TaskInstance)
379+
.filter(TaskInstance.dag_id == dr.dag_id, TaskInstance.run_id == dr.run_id)
380+
.all()
381+
)
382+
383+
done_states = {"success", "failed", "skipped", "upstream_failed"}
384+
for ti in task_instances:
385+
assert ti.state in done_states

0 commit comments

Comments
 (0)