18
18
from __future__ import annotations
19
19
20
20
import datetime
21
- import unittest
22
21
23
22
import pytest
24
23
import time_machine
35
34
INTERVAL = datetime .timedelta (hours = 12 )
36
35
37
36
38
- class TestBranchDateTimeOperator ( unittest . TestCase ) :
37
+ class TestBranchDateTimeOperator :
39
38
@classmethod
40
- def setUpClass (cls ):
41
- super ().setUpClass ()
39
+ def setup_class (cls ):
42
40
43
41
with create_session () as session :
44
42
session .query (DagRun ).delete ()
45
43
session .query (TI ).delete ()
46
44
47
- cls .targets = [
48
- (datetime .datetime (2020 , 7 , 7 , 10 , 0 , 0 ), datetime .datetime (2020 , 7 , 7 , 11 , 0 , 0 )),
49
- (datetime .time (10 , 0 , 0 ), datetime .time (11 , 0 , 0 )),
50
- (datetime .datetime (2020 , 7 , 7 , 10 , 0 , 0 ), datetime .time (11 , 0 , 0 )),
51
- (datetime .time (10 , 0 , 0 ), datetime .datetime (2020 , 7 , 7 , 11 , 0 , 0 )),
52
- (datetime .time (11 , 0 , 0 ), datetime .time (10 , 0 , 0 )),
53
- ]
45
+ targets = [
46
+ (datetime .datetime (2020 , 7 , 7 , 10 , 0 , 0 ), datetime .datetime (2020 , 7 , 7 , 11 , 0 , 0 )),
47
+ (datetime .time (10 , 0 , 0 ), datetime .time (11 , 0 , 0 )),
48
+ (datetime .datetime (2020 , 7 , 7 , 10 , 0 , 0 ), datetime .time (11 , 0 , 0 )),
49
+ (datetime .time (10 , 0 , 0 ), datetime .datetime (2020 , 7 , 7 , 11 , 0 , 0 )),
50
+ ]
54
51
55
- def setUp (self ):
52
+ def setup_method (self ):
56
53
self .dag = DAG (
57
54
"branch_datetime_operator_test" ,
58
55
default_args = {"owner" : "airflow" , "start_date" : DEFAULT_DATE },
@@ -79,8 +76,7 @@ def setUp(self):
79
76
run_id = "manual__" , start_date = DEFAULT_DATE , execution_date = DEFAULT_DATE , state = State .RUNNING
80
77
)
81
78
82
- def tearDown (self ):
83
- super ().tearDown ()
79
+ def teardown_method (self ):
84
80
85
81
with create_session () as session :
86
82
session .query (DagRun ).delete ()
@@ -95,15 +91,13 @@ def _assert_task_ids_match_states(self, task_ids_to_states):
95
91
except KeyError :
96
92
raise ValueError (f"Invalid task id { ti .task_id } found!" )
97
93
else :
98
- self .assertEqual (
99
- ti .state ,
100
- expected_state ,
101
- f"Task { ti .task_id } has state { ti .state } instead of expected { expected_state } " ,
102
- )
94
+ assert (ti .state ) == (
95
+ expected_state
96
+ ), f"Task { ti .task_id } has state { ti .state } instead of expected { expected_state } "
103
97
104
98
def test_no_target_time (self ):
105
99
"""Check if BranchDateTimeOperator raises exception on missing target"""
106
- with self . assertRaises (AirflowException ):
100
+ with pytest . raises (AirflowException ):
107
101
BranchDateTimeOperator (
108
102
task_id = "datetime_branch" ,
109
103
follow_task_ids_if_true = "branch_1" ,
@@ -113,121 +107,125 @@ def test_no_target_time(self):
113
107
dag = self .dag ,
114
108
)
115
109
110
+ @pytest .mark .parametrize (
111
+ "target_lower,target_upper" ,
112
+ targets ,
113
+ )
116
114
@time_machine .travel ("2020-07-07 10:54:05" )
117
- def test_branch_datetime_operator_falls_within_range (self ):
115
+ def test_branch_datetime_operator_falls_within_range (self , target_lower , target_upper ):
118
116
"""Check BranchDateTimeOperator branch operation"""
119
- for target_lower , target_upper in self .targets :
120
- with self .subTest (target_lower = target_lower , target_upper = target_upper ):
121
- self .branch_op .target_lower = target_lower
122
- self .branch_op .target_upper = target_upper
123
- self .branch_op .run (start_date = DEFAULT_DATE , end_date = DEFAULT_DATE )
124
-
125
- self ._assert_task_ids_match_states (
126
- {
127
- "datetime_branch" : State .SUCCESS ,
128
- "branch_1" : State .NONE ,
129
- "branch_2" : State .SKIPPED ,
130
- }
131
- )
117
+ self .branch_op .target_lower = target_lower
118
+ self .branch_op .target_upper = target_upper
119
+ self .branch_op .run (start_date = DEFAULT_DATE , end_date = DEFAULT_DATE )
120
+
121
+ self ._assert_task_ids_match_states (
122
+ {
123
+ "datetime_branch" : State .SUCCESS ,
124
+ "branch_1" : State .NONE ,
125
+ "branch_2" : State .SKIPPED ,
126
+ }
127
+ )
132
128
133
- def test_branch_datetime_operator_falls_outside_range (self ):
129
+ @pytest .mark .parametrize (
130
+ "target_lower,target_upper" ,
131
+ targets ,
132
+ )
133
+ def test_branch_datetime_operator_falls_outside_range (self , target_lower , target_upper ):
134
134
"""Check BranchDateTimeOperator branch operation"""
135
135
dates = [
136
136
datetime .datetime (2020 , 7 , 7 , 12 , 0 , 0 , tzinfo = datetime .timezone .utc ),
137
137
datetime .datetime (2020 , 6 , 7 , 12 , 0 , 0 , tzinfo = datetime .timezone .utc ),
138
138
]
139
139
140
- for target_lower , target_upper in self .targets :
141
- with self .subTest (target_lower = target_lower , target_upper = target_upper ):
142
- self .branch_op .target_lower = target_lower
143
- self .branch_op .target_upper = target_upper
144
-
145
- for date in dates :
146
- with time_machine .travel (date ):
147
- self .branch_op .run (start_date = DEFAULT_DATE , end_date = DEFAULT_DATE )
148
-
149
- self ._assert_task_ids_match_states (
150
- {
151
- "datetime_branch" : State .SUCCESS ,
152
- "branch_1" : State .SKIPPED ,
153
- "branch_2" : State .NONE ,
154
- }
155
- )
156
-
157
- @time_machine .travel ("2020-07-07 10:54:05" )
158
- def test_branch_datetime_operator_upper_comparison_within_range (self ):
159
- """Check BranchDateTimeOperator branch operation"""
160
- for _ , target_upper in self .targets :
161
- with self .subTest (target_upper = target_upper ):
162
- self .branch_op .target_upper = target_upper
163
- self .branch_op .target_lower = None
140
+ self .branch_op .target_lower = target_lower
141
+ self .branch_op .target_upper = target_upper
164
142
143
+ for date in dates :
144
+ with time_machine .travel (date ):
165
145
self .branch_op .run (start_date = DEFAULT_DATE , end_date = DEFAULT_DATE )
166
146
167
147
self ._assert_task_ids_match_states (
168
148
{
169
149
"datetime_branch" : State .SUCCESS ,
170
- "branch_1" : State .NONE ,
171
- "branch_2" : State .SKIPPED ,
150
+ "branch_1" : State .SKIPPED ,
151
+ "branch_2" : State .NONE ,
172
152
}
173
153
)
174
154
155
+ @pytest .mark .parametrize ("target_upper" , [target_upper for (_ , target_upper ) in targets ])
175
156
@time_machine .travel ("2020-07-07 10:54:05" )
176
- def test_branch_datetime_operator_lower_comparison_within_range (self ):
157
+ def test_branch_datetime_operator_upper_comparison_within_range (self , target_upper ):
177
158
"""Check BranchDateTimeOperator branch operation"""
178
- for target_lower , _ in self .targets :
179
- with self .subTest (target_lower = target_lower ):
180
- self .branch_op .target_lower = target_lower
181
- self .branch_op .target_upper = None
159
+ self .branch_op .target_upper = target_upper
160
+ self .branch_op .target_lower = None
182
161
183
- self .branch_op .run (start_date = DEFAULT_DATE , end_date = DEFAULT_DATE )
162
+ self .branch_op .run (start_date = DEFAULT_DATE , end_date = DEFAULT_DATE )
184
163
185
- self ._assert_task_ids_match_states (
186
- {
187
- "datetime_branch" : State .SUCCESS ,
188
- "branch_1" : State .NONE ,
189
- "branch_2" : State .SKIPPED ,
190
- }
191
- )
164
+ self ._assert_task_ids_match_states (
165
+ {
166
+ "datetime_branch" : State .SUCCESS ,
167
+ "branch_1" : State .NONE ,
168
+ "branch_2" : State .SKIPPED ,
169
+ }
170
+ )
171
+
172
+ @pytest .mark .parametrize ("target_lower" , [target_lower for (target_lower , _ ) in targets ])
173
+ @time_machine .travel ("2020-07-07 10:54:05" )
174
+ def test_branch_datetime_operator_lower_comparison_within_range (self , target_lower ):
175
+ """Check BranchDateTimeOperator branch operation"""
176
+ self .branch_op .target_lower = target_lower
177
+ self .branch_op .target_upper = None
192
178
179
+ self .branch_op .run (start_date = DEFAULT_DATE , end_date = DEFAULT_DATE )
180
+
181
+ self ._assert_task_ids_match_states (
182
+ {
183
+ "datetime_branch" : State .SUCCESS ,
184
+ "branch_1" : State .NONE ,
185
+ "branch_2" : State .SKIPPED ,
186
+ }
187
+ )
188
+
189
+ @pytest .mark .parametrize ("target_upper" , [target_upper for (_ , target_upper ) in targets ])
193
190
@time_machine .travel ("2020-07-07 12:00:00" )
194
- def test_branch_datetime_operator_upper_comparison_outside_range (self ):
191
+ def test_branch_datetime_operator_upper_comparison_outside_range (self , target_upper ):
195
192
"""Check BranchDateTimeOperator branch operation"""
196
- for _ , target_upper in self .targets :
197
- with self .subTest (target_upper = target_upper ):
198
- self .branch_op .target_upper = target_upper
199
- self .branch_op .target_lower = None
193
+ self .branch_op .target_upper = target_upper
194
+ self .branch_op .target_lower = None
200
195
201
- self .branch_op .run (start_date = DEFAULT_DATE , end_date = DEFAULT_DATE )
196
+ self .branch_op .run (start_date = DEFAULT_DATE , end_date = DEFAULT_DATE )
202
197
203
- self ._assert_task_ids_match_states (
204
- {
205
- "datetime_branch" : State .SUCCESS ,
206
- "branch_1" : State .SKIPPED ,
207
- "branch_2" : State .NONE ,
208
- }
209
- )
198
+ self ._assert_task_ids_match_states (
199
+ {
200
+ "datetime_branch" : State .SUCCESS ,
201
+ "branch_1" : State .SKIPPED ,
202
+ "branch_2" : State .NONE ,
203
+ }
204
+ )
210
205
206
+ @pytest .mark .parametrize ("target_lower" , [target_lower for (target_lower , _ ) in targets ])
211
207
@time_machine .travel ("2020-07-07 09:00:00" )
212
- def test_branch_datetime_operator_lower_comparison_outside_range (self ):
208
+ def test_branch_datetime_operator_lower_comparison_outside_range (self , target_lower ):
213
209
"""Check BranchDateTimeOperator branch operation"""
214
- for target_lower , _ in self .targets :
215
- with self .subTest (target_lower = target_lower ):
216
- self .branch_op .target_lower = target_lower
217
- self .branch_op .target_upper = None
210
+ self .branch_op .target_lower = target_lower
211
+ self .branch_op .target_upper = None
218
212
219
- self .branch_op .run (start_date = DEFAULT_DATE , end_date = DEFAULT_DATE )
213
+ self .branch_op .run (start_date = DEFAULT_DATE , end_date = DEFAULT_DATE )
220
214
221
- self ._assert_task_ids_match_states (
222
- {
223
- "datetime_branch" : State .SUCCESS ,
224
- "branch_1" : State .SKIPPED ,
225
- "branch_2" : State .NONE ,
226
- }
227
- )
215
+ self ._assert_task_ids_match_states (
216
+ {
217
+ "datetime_branch" : State .SUCCESS ,
218
+ "branch_1" : State .SKIPPED ,
219
+ "branch_2" : State .NONE ,
220
+ }
221
+ )
228
222
223
+ @pytest .mark .parametrize (
224
+ "target_lower,target_upper" ,
225
+ targets ,
226
+ )
229
227
@time_machine .travel ("2020-12-01 09:00:00" )
230
- def test_branch_datetime_operator_use_task_logical_date (self ):
228
+ def test_branch_datetime_operator_use_task_logical_date (self , target_lower , target_upper ):
231
229
"""Check if BranchDateTimeOperator uses task execution date"""
232
230
in_between_date = timezone .datetime (2020 , 7 , 7 , 10 , 30 , 0 )
233
231
self .branch_op .use_task_logical_date = True
@@ -238,19 +236,17 @@ def test_branch_datetime_operator_use_task_logical_date(self):
238
236
state = State .RUNNING ,
239
237
)
240
238
241
- for target_lower , target_upper in self .targets :
242
- with self .subTest (target_lower = target_lower , target_upper = target_upper ):
243
- self .branch_op .target_lower = target_lower
244
- self .branch_op .target_upper = target_upper
245
- self .branch_op .run (start_date = in_between_date , end_date = in_between_date )
239
+ self .branch_op .target_lower = target_lower
240
+ self .branch_op .target_upper = target_upper
241
+ self .branch_op .run (start_date = in_between_date , end_date = in_between_date )
246
242
247
- self ._assert_task_ids_match_states (
248
- {
249
- "datetime_branch" : State .SUCCESS ,
250
- "branch_1" : State .NONE ,
251
- "branch_2" : State .SKIPPED ,
252
- }
253
- )
243
+ self ._assert_task_ids_match_states (
244
+ {
245
+ "datetime_branch" : State .SUCCESS ,
246
+ "branch_1" : State .NONE ,
247
+ "branch_2" : State .SKIPPED ,
248
+ }
249
+ )
254
250
255
251
def test_deprecation_warning (self ):
256
252
warning_message = (
0 commit comments