Skip to content

Commit 73c8e7d

Browse files
Migrated operators tests to pytest (#29377)
* Migrated operators tests to pytest * migrated remaining operators tests to pytest * resolved static checks * refactored * refactored
1 parent 7bd87e7 commit 73c8e7d

File tree

5 files changed

+135
-140
lines changed

5 files changed

+135
-140
lines changed

tests/operators/test_bash.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from unittest import mock
2626

2727
import pytest
28-
from parameterized import parameterized
2928

3029
from airflow.exceptions import AirflowException, AirflowSkipException, AirflowTaskTimeout
3130
from airflow.models.dag import DAG
@@ -40,11 +39,12 @@
4039

4140

4241
class TestBashOperator:
43-
@parameterized.expand(
42+
@pytest.mark.parametrize(
43+
"append_env,user_defined_env,expected_airflow_home",
4444
[
4545
(False, None, "MY_PATH_TO_AIRFLOW_HOME"),
4646
(True, {"AIRFLOW_HOME": "OVERRIDDEN_AIRFLOW_HOME"}, "OVERRIDDEN_AIRFLOW_HOME"),
47-
]
47+
],
4848
)
4949
def test_echo_env_variables(self, append_env, user_defined_env, expected_airflow_home):
5050
"""
@@ -98,13 +98,14 @@ def test_echo_env_variables(self, append_env, user_defined_env, expected_airflow
9898
output = "".join(file.readlines())
9999
assert expected == output
100100

101-
@parameterized.expand(
101+
@pytest.mark.parametrize(
102+
"val,expected",
102103
[
103104
("test-val", "test-val"),
104105
("test-val\ntest-val\n", ""),
105106
("test-val\ntest-val", "test-val"),
106107
("", ""),
107-
]
108+
],
108109
)
109110
def test_return_value(self, val, expected):
110111
op = BashOperator(task_id="abc", bash_command=f'set -e; echo "{val}";')
@@ -168,13 +169,14 @@ def test_valid_cwd(self):
168169
with open(f"{test_cwd_folder}/outputs.txt") as tmp_file:
169170
assert tmp_file.read().splitlines()[0] == "xxxx"
170171

171-
@parameterized.expand(
172+
@pytest.mark.parametrize(
173+
"extra_kwargs,actual_exit_code,expected_exc",
172174
[
173175
(None, 99, AirflowSkipException),
174176
({"skip_exit_code": 100}, 100, AirflowSkipException),
175177
({"skip_exit_code": 100}, 101, AirflowException),
176178
({"skip_exit_code": None}, 99, AirflowException),
177-
]
179+
],
178180
)
179181
def test_skip(self, extra_kwargs, actual_exit_code, expected_exc):
180182
kwargs = dict(task_id="abc", bash_command=f'set -e; echo "hello world"; exit {actual_exit_code};')

tests/operators/test_datetime.py

Lines changed: 109 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from __future__ import annotations
1919

2020
import datetime
21-
import unittest
2221

2322
import pytest
2423
import time_machine
@@ -35,24 +34,22 @@
3534
INTERVAL = datetime.timedelta(hours=12)
3635

3736

38-
class TestBranchDateTimeOperator(unittest.TestCase):
37+
class TestBranchDateTimeOperator:
3938
@classmethod
40-
def setUpClass(cls):
41-
super().setUpClass()
39+
def setup_class(cls):
4240

4341
with create_session() as session:
4442
session.query(DagRun).delete()
4543
session.query(TI).delete()
4644

47-
cls.targets = [
48-
(datetime.datetime(2020, 7, 7, 10, 0, 0), datetime.datetime(2020, 7, 7, 11, 0, 0)),
49-
(datetime.time(10, 0, 0), datetime.time(11, 0, 0)),
50-
(datetime.datetime(2020, 7, 7, 10, 0, 0), datetime.time(11, 0, 0)),
51-
(datetime.time(10, 0, 0), datetime.datetime(2020, 7, 7, 11, 0, 0)),
52-
(datetime.time(11, 0, 0), datetime.time(10, 0, 0)),
53-
]
45+
targets = [
46+
(datetime.datetime(2020, 7, 7, 10, 0, 0), datetime.datetime(2020, 7, 7, 11, 0, 0)),
47+
(datetime.time(10, 0, 0), datetime.time(11, 0, 0)),
48+
(datetime.datetime(2020, 7, 7, 10, 0, 0), datetime.time(11, 0, 0)),
49+
(datetime.time(10, 0, 0), datetime.datetime(2020, 7, 7, 11, 0, 0)),
50+
]
5451

55-
def setUp(self):
52+
def setup_method(self):
5653
self.dag = DAG(
5754
"branch_datetime_operator_test",
5855
default_args={"owner": "airflow", "start_date": DEFAULT_DATE},
@@ -79,8 +76,7 @@ def setUp(self):
7976
run_id="manual__", start_date=DEFAULT_DATE, execution_date=DEFAULT_DATE, state=State.RUNNING
8077
)
8178

82-
def tearDown(self):
83-
super().tearDown()
79+
def teardown_method(self):
8480

8581
with create_session() as session:
8682
session.query(DagRun).delete()
@@ -95,15 +91,13 @@ def _assert_task_ids_match_states(self, task_ids_to_states):
9591
except KeyError:
9692
raise ValueError(f"Invalid task id {ti.task_id} found!")
9793
else:
98-
self.assertEqual(
99-
ti.state,
100-
expected_state,
101-
f"Task {ti.task_id} has state {ti.state} instead of expected {expected_state}",
102-
)
94+
assert (ti.state) == (
95+
expected_state
96+
), f"Task {ti.task_id} has state {ti.state} instead of expected {expected_state}"
10397

10498
def test_no_target_time(self):
10599
"""Check if BranchDateTimeOperator raises exception on missing target"""
106-
with self.assertRaises(AirflowException):
100+
with pytest.raises(AirflowException):
107101
BranchDateTimeOperator(
108102
task_id="datetime_branch",
109103
follow_task_ids_if_true="branch_1",
@@ -113,121 +107,125 @@ def test_no_target_time(self):
113107
dag=self.dag,
114108
)
115109

110+
@pytest.mark.parametrize(
111+
"target_lower,target_upper",
112+
targets,
113+
)
116114
@time_machine.travel("2020-07-07 10:54:05")
117-
def test_branch_datetime_operator_falls_within_range(self):
115+
def test_branch_datetime_operator_falls_within_range(self, target_lower, target_upper):
118116
"""Check BranchDateTimeOperator branch operation"""
119-
for target_lower, target_upper in self.targets:
120-
with self.subTest(target_lower=target_lower, target_upper=target_upper):
121-
self.branch_op.target_lower = target_lower
122-
self.branch_op.target_upper = target_upper
123-
self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
124-
125-
self._assert_task_ids_match_states(
126-
{
127-
"datetime_branch": State.SUCCESS,
128-
"branch_1": State.NONE,
129-
"branch_2": State.SKIPPED,
130-
}
131-
)
117+
self.branch_op.target_lower = target_lower
118+
self.branch_op.target_upper = target_upper
119+
self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
120+
121+
self._assert_task_ids_match_states(
122+
{
123+
"datetime_branch": State.SUCCESS,
124+
"branch_1": State.NONE,
125+
"branch_2": State.SKIPPED,
126+
}
127+
)
132128

133-
def test_branch_datetime_operator_falls_outside_range(self):
129+
@pytest.mark.parametrize(
130+
"target_lower,target_upper",
131+
targets,
132+
)
133+
def test_branch_datetime_operator_falls_outside_range(self, target_lower, target_upper):
134134
"""Check BranchDateTimeOperator branch operation"""
135135
dates = [
136136
datetime.datetime(2020, 7, 7, 12, 0, 0, tzinfo=datetime.timezone.utc),
137137
datetime.datetime(2020, 6, 7, 12, 0, 0, tzinfo=datetime.timezone.utc),
138138
]
139139

140-
for target_lower, target_upper in self.targets:
141-
with self.subTest(target_lower=target_lower, target_upper=target_upper):
142-
self.branch_op.target_lower = target_lower
143-
self.branch_op.target_upper = target_upper
144-
145-
for date in dates:
146-
with time_machine.travel(date):
147-
self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
148-
149-
self._assert_task_ids_match_states(
150-
{
151-
"datetime_branch": State.SUCCESS,
152-
"branch_1": State.SKIPPED,
153-
"branch_2": State.NONE,
154-
}
155-
)
156-
157-
@time_machine.travel("2020-07-07 10:54:05")
158-
def test_branch_datetime_operator_upper_comparison_within_range(self):
159-
"""Check BranchDateTimeOperator branch operation"""
160-
for _, target_upper in self.targets:
161-
with self.subTest(target_upper=target_upper):
162-
self.branch_op.target_upper = target_upper
163-
self.branch_op.target_lower = None
140+
self.branch_op.target_lower = target_lower
141+
self.branch_op.target_upper = target_upper
164142

143+
for date in dates:
144+
with time_machine.travel(date):
165145
self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
166146

167147
self._assert_task_ids_match_states(
168148
{
169149
"datetime_branch": State.SUCCESS,
170-
"branch_1": State.NONE,
171-
"branch_2": State.SKIPPED,
150+
"branch_1": State.SKIPPED,
151+
"branch_2": State.NONE,
172152
}
173153
)
174154

155+
@pytest.mark.parametrize("target_upper", [target_upper for (_, target_upper) in targets])
175156
@time_machine.travel("2020-07-07 10:54:05")
176-
def test_branch_datetime_operator_lower_comparison_within_range(self):
157+
def test_branch_datetime_operator_upper_comparison_within_range(self, target_upper):
177158
"""Check BranchDateTimeOperator branch operation"""
178-
for target_lower, _ in self.targets:
179-
with self.subTest(target_lower=target_lower):
180-
self.branch_op.target_lower = target_lower
181-
self.branch_op.target_upper = None
159+
self.branch_op.target_upper = target_upper
160+
self.branch_op.target_lower = None
182161

183-
self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
162+
self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
184163

185-
self._assert_task_ids_match_states(
186-
{
187-
"datetime_branch": State.SUCCESS,
188-
"branch_1": State.NONE,
189-
"branch_2": State.SKIPPED,
190-
}
191-
)
164+
self._assert_task_ids_match_states(
165+
{
166+
"datetime_branch": State.SUCCESS,
167+
"branch_1": State.NONE,
168+
"branch_2": State.SKIPPED,
169+
}
170+
)
171+
172+
@pytest.mark.parametrize("target_lower", [target_lower for (target_lower, _) in targets])
173+
@time_machine.travel("2020-07-07 10:54:05")
174+
def test_branch_datetime_operator_lower_comparison_within_range(self, target_lower):
175+
"""Check BranchDateTimeOperator branch operation"""
176+
self.branch_op.target_lower = target_lower
177+
self.branch_op.target_upper = None
192178

179+
self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
180+
181+
self._assert_task_ids_match_states(
182+
{
183+
"datetime_branch": State.SUCCESS,
184+
"branch_1": State.NONE,
185+
"branch_2": State.SKIPPED,
186+
}
187+
)
188+
189+
@pytest.mark.parametrize("target_upper", [target_upper for (_, target_upper) in targets])
193190
@time_machine.travel("2020-07-07 12:00:00")
194-
def test_branch_datetime_operator_upper_comparison_outside_range(self):
191+
def test_branch_datetime_operator_upper_comparison_outside_range(self, target_upper):
195192
"""Check BranchDateTimeOperator branch operation"""
196-
for _, target_upper in self.targets:
197-
with self.subTest(target_upper=target_upper):
198-
self.branch_op.target_upper = target_upper
199-
self.branch_op.target_lower = None
193+
self.branch_op.target_upper = target_upper
194+
self.branch_op.target_lower = None
200195

201-
self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
196+
self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
202197

203-
self._assert_task_ids_match_states(
204-
{
205-
"datetime_branch": State.SUCCESS,
206-
"branch_1": State.SKIPPED,
207-
"branch_2": State.NONE,
208-
}
209-
)
198+
self._assert_task_ids_match_states(
199+
{
200+
"datetime_branch": State.SUCCESS,
201+
"branch_1": State.SKIPPED,
202+
"branch_2": State.NONE,
203+
}
204+
)
210205

206+
@pytest.mark.parametrize("target_lower", [target_lower for (target_lower, _) in targets])
211207
@time_machine.travel("2020-07-07 09:00:00")
212-
def test_branch_datetime_operator_lower_comparison_outside_range(self):
208+
def test_branch_datetime_operator_lower_comparison_outside_range(self, target_lower):
213209
"""Check BranchDateTimeOperator branch operation"""
214-
for target_lower, _ in self.targets:
215-
with self.subTest(target_lower=target_lower):
216-
self.branch_op.target_lower = target_lower
217-
self.branch_op.target_upper = None
210+
self.branch_op.target_lower = target_lower
211+
self.branch_op.target_upper = None
218212

219-
self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
213+
self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
220214

221-
self._assert_task_ids_match_states(
222-
{
223-
"datetime_branch": State.SUCCESS,
224-
"branch_1": State.SKIPPED,
225-
"branch_2": State.NONE,
226-
}
227-
)
215+
self._assert_task_ids_match_states(
216+
{
217+
"datetime_branch": State.SUCCESS,
218+
"branch_1": State.SKIPPED,
219+
"branch_2": State.NONE,
220+
}
221+
)
228222

223+
@pytest.mark.parametrize(
224+
"target_lower,target_upper",
225+
targets,
226+
)
229227
@time_machine.travel("2020-12-01 09:00:00")
230-
def test_branch_datetime_operator_use_task_logical_date(self):
228+
def test_branch_datetime_operator_use_task_logical_date(self, target_lower, target_upper):
231229
"""Check if BranchDateTimeOperator uses task execution date"""
232230
in_between_date = timezone.datetime(2020, 7, 7, 10, 30, 0)
233231
self.branch_op.use_task_logical_date = True
@@ -238,19 +236,17 @@ def test_branch_datetime_operator_use_task_logical_date(self):
238236
state=State.RUNNING,
239237
)
240238

241-
for target_lower, target_upper in self.targets:
242-
with self.subTest(target_lower=target_lower, target_upper=target_upper):
243-
self.branch_op.target_lower = target_lower
244-
self.branch_op.target_upper = target_upper
245-
self.branch_op.run(start_date=in_between_date, end_date=in_between_date)
239+
self.branch_op.target_lower = target_lower
240+
self.branch_op.target_upper = target_upper
241+
self.branch_op.run(start_date=in_between_date, end_date=in_between_date)
246242

247-
self._assert_task_ids_match_states(
248-
{
249-
"datetime_branch": State.SUCCESS,
250-
"branch_1": State.NONE,
251-
"branch_2": State.SKIPPED,
252-
}
253-
)
243+
self._assert_task_ids_match_states(
244+
{
245+
"datetime_branch": State.SUCCESS,
246+
"branch_1": State.NONE,
247+
"branch_2": State.SKIPPED,
248+
}
249+
)
254250

255251
def test_deprecation_warning(self):
256252
warning_message = (

tests/operators/test_email.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from __future__ import annotations
1919

2020
import datetime
21-
import unittest
2221
from unittest import mock
2322

2423
from airflow.models.dag import DAG
@@ -34,15 +33,13 @@
3433
send_email_test = mock.Mock()
3534

3635

37-
class TestEmailOperator(unittest.TestCase):
38-
def setUp(self):
39-
super().setUp()
36+
class TestEmailOperator:
37+
def setup_class(self):
4038
self.dag = DAG(
4139
"test_dag",
4240
default_args={"owner": "airflow", "start_date": DEFAULT_DATE},
4341
schedule=INTERVAL,
4442
)
45-
self.addCleanup(self.dag.clear)
4643

4744
def _run_as_operator(self, **kwargs):
4845
task = EmailOperator(
@@ -56,6 +53,7 @@ def _run_as_operator(self, **kwargs):
5653
**kwargs,
5754
)
5855
task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
56+
self.dag.clear()
5957

6058
def test_execute(self):
6159
with conf_vars({("email", "email_backend"): "tests.operators.test_email.send_email_test"}):

0 commit comments

Comments
 (0)